From c9f725ea2d87a891b2133d0c0652412a408e2d4f Mon Sep 17 00:00:00 2001 From: nathan Date: Thu, 3 Sep 2020 14:37:09 +0200 Subject: [PATCH] Fixed race condition when initialize kryo instances during startup --- src/dorkbox/network/connection/EndPoint.kt | 6 +- .../network/handshake/ClientHandshake.kt | 1 - .../rmi/messages/MethodRequestSerializer.kt | 6 +- .../serialization/ClassRegistration.kt | 16 +- .../serialization/ClassRegistrationForRmi.kt | 17 +- .../network/serialization/KryoExtra.kt | 8 +- .../network/serialization/Serialization.kt | 188 +++++++++++++----- 7 files changed, 181 insertions(+), 61 deletions(-) diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index b6aa63cf..150ff26f 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -466,7 +466,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } // we are not thread-safe! - val kryo = serialization.takeKryo() + val kryo = serialization.takeHandshakeKryo() try { kryo.write(message) @@ -499,7 +499,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A listenerManager.notifyError(newException("[${publication.sessionId()}] Error serializing handshake message $message", e)) } finally { sendIdleStrategy.reset() - serialization.returnKryo(kryo) + serialization.returnHandshakeKryo(kryo) } } @@ -513,7 +513,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A */ internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { try { - val message = serialization.readMessage(buffer, offset, length) + val message = serialization.readHandshakeMessage(buffer, offset, length) logger.trace { "[${header.sessionId()}] received: $message" } diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index c3c27b8e..71372b18 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -116,7 +116,6 @@ internal class ClientHandshake(private val logger: KLogg // Send the one-time pad to the server. endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) - endPoint.serialization.takeKryo() // TAKE THE KRYO BACK OFF! We don't want it on the pool yet, since this kryo hasn't had all of the classes registered yet! sessionId = handshakeConnection.publication.sessionId() diff --git a/src/dorkbox/network/rmi/messages/MethodRequestSerializer.kt b/src/dorkbox/network/rmi/messages/MethodRequestSerializer.kt index b2b30d05..a7ba2a9e 100644 --- a/src/dorkbox/network/rmi/messages/MethodRequestSerializer.kt +++ b/src/dorkbox/network/rmi/messages/MethodRequestSerializer.kt @@ -39,15 +39,17 @@ import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import dorkbox.network.rmi.CachedMethod import dorkbox.network.rmi.RmiUtils import dorkbox.network.serialization.KryoExtra +import org.agrona.collections.Int2ObjectHashMap import java.lang.reflect.Method /** * Internal message to invoke methods remotely. */ @Suppress("ConstantConditionIf") -class MethodRequestSerializer : Serializer() { +class MethodRequestSerializer(private val methodCache: Int2ObjectHashMap>) : Serializer() { override fun write(kryo: Kryo, output: Output, methodRequest: MethodRequest) { val method = methodRequest.cachedMethod @@ -83,7 +85,7 @@ class MethodRequestSerializer : Serializer() { (kryo as KryoExtra) val cachedMethod = try { - kryo.getMethods(methodClassId)[methodIndex] + methodCache[methodClassId][methodIndex] } catch (ex: Exception) { val methodClass = kryo.getRegistration(methodClassId).type throw KryoException("Invalid method index " + methodIndex + " for class: " + methodClass.name) diff --git a/src/dorkbox/network/serialization/ClassRegistration.kt b/src/dorkbox/network/serialization/ClassRegistration.kt index 8e83a4a9..889caa0e 100644 --- a/src/dorkbox/network/serialization/ClassRegistration.kt +++ b/src/dorkbox/network/serialization/ClassRegistration.kt @@ -34,8 +34,22 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S open fun register(kryo: KryoExtra, rmi: RmiHolder) { // ClassRegistrationForRmi overrides this method - val savedKryoId: Int? = rmi.implToId[clazz] // ALL registrations MUST BE IMPL! + if (id != 0) { + // our ID will always be > 0 + // this means that this registration was PREVIOUSLY registered on a different kryo. Shortcut the logic. + if (serializer != null) { + kryo.register(clazz, serializer, id) + } else { + kryo.register(clazz, id) + } + + return + } + + + + val savedKryoId: Int? = rmi.implToId[clazz] // ALL registrations MUST BE IMPL! var overriddenSerializer: Serializer? = null // did we already process this class? We permit overwriting serializers, etc! diff --git a/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt b/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt index 5233badd..c68b4984 100644 --- a/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt +++ b/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt @@ -105,6 +105,21 @@ internal class ClassRegistrationForRmi(ifaceClass: Class<*>, */ override fun register(kryo: KryoExtra, rmi: RmiHolder) { // we override this, because we ALWAYS will call our RMI registration! + if (id != 0) { + // our ID will always be > 0 + // this means that this registration was PREVIOUSLY registered on a different kryo. Shortcut the logic. + + if (implClass != null) { + // RMI-SERVER + kryo.register(implClass, serializer, id) + } else { + // RMI-CLIENT + kryo.register(clazz, serializer, id) + } + return + } + + // EVERY time initKryo() is called, this will happen. We have to ensure that every call produces the same results @@ -114,7 +129,7 @@ internal class ClassRegistrationForRmi(ifaceClass: Class<*>, // Both IFACE+IMPL must be checked when deciding if something needs to be overloaded, but ONLY for registerRmi // check to see if we have already registered this as RMI-CLIENT - val alreadyRegistered = rmi.ifaceToId[clazz] != null + val alreadyRegistered = rmi.ifaceToId[clazz] != null if (alreadyRegistered) { // if we are ALREADY registered, then we have to make sure that RMI-CLIENT doesn't override RMI-SERVER... diff --git a/src/dorkbox/network/serialization/KryoExtra.kt b/src/dorkbox/network/serialization/KryoExtra.kt index f92cccfa..923feae7 100644 --- a/src/dorkbox/network/serialization/KryoExtra.kt +++ b/src/dorkbox/network/serialization/KryoExtra.kt @@ -19,20 +19,18 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import dorkbox.network.connection.Connection -import dorkbox.network.rmi.CachedMethod import dorkbox.os.OS import dorkbox.util.Sys import dorkbox.util.bytes.OptimizeUtilsByteArray import net.jpountz.lz4.LZ4Factory import org.agrona.DirectBuffer -import org.agrona.collections.Int2ObjectHashMap import org.slf4j.Logger import java.io.IOException /** * Nothing in this class is thread safe */ -class KryoExtra(private val methodCache: Int2ObjectHashMap>) : Kryo() { +class KryoExtra() : Kryo() { // for kryo serialization private val readerBuffer = AeronInput() val writerBuffer = AeronOutput() @@ -75,10 +73,6 @@ class KryoExtra(private val methodCache: Int2ObjectHashMap>) // } // } - fun getMethods(classId: Int): Array { - return methodCache[classId] - } - /** * NOTE: THIS CANNOT BE USED FOR ANYTHING RELATED TO RMI! * diff --git a/src/dorkbox/network/serialization/Serialization.kt b/src/dorkbox/network/serialization/Serialization.kt index 519ad7f9..3e60ceab 100644 --- a/src/dorkbox/network/serialization/Serialization.kt +++ b/src/dorkbox/network/serialization/Serialization.kt @@ -93,6 +93,7 @@ open class Serialization(private val references: Boolean = true, private val fac private var initialized = atomic(false) private val kryoPool = MultithreadConcurrentQueue(1024) // reasonable size of available kryo's + private val kryoHandshakePool = MultithreadConcurrentQueue(1024) // reasonable size of available kryo's // used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class) // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. @@ -100,11 +101,15 @@ open class Serialization(private val references: Boolean = true, private val fac private val classesToRegister = mutableListOf() private lateinit var savedRegistrationDetails: ByteArray + // the purpose of the method cache, is to accelerate looking up methods for specific class + private val methodCache : Int2ObjectHashMap> = Int2ObjectHashMap() + + // BY DEFAULT, DefaultInstantiatorStrategy() will use ReflectASM // StdInstantiatorStrategy will create classes bypasses the constructor (which can be useful in some cases) THIS IS A FALLBACK! private val instantiatorStrategy = DefaultInstantiatorStrategy(StdInstantiatorStrategy()) - private val methodRequestSerializer = MethodRequestSerializer() + private val methodRequestSerializer = MethodRequestSerializer(methodCache) // note: the methodCache is configured BEFORE anything reads from it! private val methodResponseSerializer = MethodResponseSerializer() private val continuationSerializer = ContinuationSerializer() @@ -113,15 +118,14 @@ open class Serialization(private val references: Boolean = true, private val fac val rmiHolder = RmiHolder() - // the purpose of the method cache, is to accelerate looking up methods for specific class - private val methodCache : Int2ObjectHashMap> = Int2ObjectHashMap() // reflectASM doesn't work on android private val useAsm = !OS.isAndroid() - // This is a GLOBAL, single threaded only kryo instance. - // This kryo WILL RE-CONFIGURED during the client handshake! (it is all the same thread, so object visibility is not a problem) - private var readKryo = initKryo() + // These are GLOBAL, single threaded only kryo instances. + // The readKryo WILL RE-CONFIGURED during the client handshake! (it is all the same thread, so object visibility is not a problem) + private var readKryo = initGlobalKryo() + private var readHandshakeKryo = initHandshakeKryo() /** * Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer]. @@ -261,14 +265,41 @@ open class Serialization(private val references: Boolean = true, private val fac } /** - * called as the first thing inside when initializing the classesToRegister + * Kryo specifically for handshakes */ - private fun initKryo(): KryoExtra { - val kryo = KryoExtra(methodCache) + private fun initHandshakeKryo(): KryoExtra { + val kryo = KryoExtra() kryo.instantiatorStrategy = instantiatorStrategy kryo.references = references + if (factory != null) { + kryo.setDefaultSerializer(factory) + } + + // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. + SerializationDefaults.register(kryo) + + kryo.register(HandshakeMessage::class.java) + + return kryo + } + + /** + * called as the first thing inside when initializing the classesToRegister + */ + private fun initGlobalKryo(): KryoExtra { + // NOTE: classesToRegister.forEach will be called after serialization init! + + val kryo = KryoExtra() + + kryo.instantiatorStrategy = instantiatorStrategy + kryo.references = references + + if (factory != null) { + kryo.setDefaultSerializer(factory) + } + // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. SerializationDefaults.register(kryo) @@ -282,9 +313,51 @@ open class Serialization(private val references: Boolean = true, private val fac // serialization.register(XECPrivateKey::class.java, XECPrivateKeySerializer()) // serialization.register(Message::class.java) // must use full package name! + // RMI stuff! + kryo.register(GlobalObjectCreateRequest::class.java) + kryo.register(GlobalObjectCreateResponse::class.java) + + kryo.register(ConnectionObjectCreateRequest::class.java) + kryo.register(ConnectionObjectCreateResponse::class.java) + + kryo.register(MethodRequest::class.java, methodRequestSerializer) + kryo.register(MethodResponse::class.java, methodResponseSerializer) + + @Suppress("UNCHECKED_CAST") + kryo.register(InvocationHandler::class.java as Class, rmiClientSerializer) + + kryo.register(Continuation::class.java, continuationSerializer) + + return kryo + } + + /** + * called as the first thing inside when initializing the classesToRegister + */ + private fun initKryo(): KryoExtra { + val kryo = KryoExtra() + + kryo.instantiatorStrategy = instantiatorStrategy + kryo.references = references + + if (factory != null) { + kryo.setDefaultSerializer(factory) + } + + // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. + SerializationDefaults.register(kryo) + +// serialization.register(PingMessage::class.java) // TODO this is built into aeron!??!?!?! + + // TODO: this is for diffie hellmen handshake stuff! +// serialization.register(IESParameters::class.java, IesParametersSerializer()) +// serialization.register(IESWithCipherParameters::class.java, IesWithCipherParametersSerializer()) + // TODO: fix kryo to work the way we want, so we can register interfaces + serializers with kryo +// serialization.register(XECPublicKey::class.java, XECPublicKeySerializer()) +// serialization.register(XECPrivateKey::class.java, XECPrivateKeySerializer()) +// serialization.register(Message::class.java) // must use full package name! // RMI stuff! - kryo.register(HandshakeMessage::class.java) kryo.register(GlobalObjectCreateRequest::class.java) kryo.register(GlobalObjectCreateResponse::class.java) @@ -305,10 +378,6 @@ open class Serialization(private val references: Boolean = true, private val fac registration.register(kryo, rmiHolder) } - if (factory != null) { - kryo.setDefaultSerializer(factory) - } - return kryo } @@ -321,7 +390,10 @@ open class Serialization(private val references: Boolean = true, private val fac */ internal fun finishInit(type: Class<*>, settingsStore: SettingsStore, kryoRegistrationDetailsFromServer: ByteArray = ByteArray(0)): Boolean { logger = KotlinLogging.logger(type.simpleName) - +logger.error("*********************ININT") +logger.error("*********************ININT") +logger.error("*********************ININT") +logger.error("*********************ININT") // this will set up the class registration information return if (type == Server::class.java) { if (!initialized.compareAndSet(expect = false, update = true)) { @@ -333,7 +405,7 @@ open class Serialization(private val references: Boolean = true, private val fac classesToRegister.add(ClassRegistration3(it)) } - val kryo = initKryo() // this will initialize the class registrations + val kryo = initKryo() initializeClassRegistrations(kryo) } else { if (!initialized.compareAndSet(expect = false, update = true)) { @@ -341,7 +413,18 @@ open class Serialization(private val references: Boolean = true, private val fac return true } - initializeClient(kryoRegistrationDetailsFromServer) + // we have to allow CUSTOM classes to register (where the order does not matter), so that if the CLIENT is the RMI-SERVER, it can + // specify IMPL classes for RMI. + classesToRegister.forEach { registration -> + require(registration is ClassRegistrationForRmi) { "Unable to initialize a class registrations for anything OTHER than RMI!! To fix this, remove ${registration.clazz}" } + } + val classesToRegisterForRmi = listOf(*classesToRegister.toTypedArray()) as List + classesToRegister.clear() + + // NOTE: to be clear, the "client" can ONLY registerRmi(IFACE, IMPL), to have extra info as the RMI-SERVER!! + + val kryo = initKryo() // this will initialize the class registrations + initializeClient(kryoRegistrationDetailsFromServer, classesToRegisterForRmi, kryo) } } @@ -435,11 +518,18 @@ open class Serialization(private val references: Boolean = true, private val fac RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId) } - if (kryoId > 65000) { + if (kryoId >= 65535) { throw RuntimeException("There are too many kryo class registrations!!") } } + + // we have to check to make sure all classes are registered on the GLOBAL READ KRYO !!! + // Because our classes are registered LAST, this will always be correct. + classesToRegister.forEach { registration -> + registration.register(readKryo, rmiHolder) + } + // save this as a byte array (so class registration validation during connection handshake is faster) val output = AeronOutput() try { @@ -454,30 +544,13 @@ open class Serialization(private val references: Boolean = true, private val fac output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length) output.close() - - // note, we have to check to make sure all classes are registered! Because our classes are registered LAST, this will always be correct. - classesToRegister.forEach { registration -> - registration.register(readKryo, rmiHolder) - } - return true } @Suppress("UNCHECKED_CAST") - private fun initializeClient(kryoRegistrationDetailsFromServer: ByteArray): Boolean { - // we have to allow CUSTOM classes to register (where the order does not matter), so that if the CLIENT is the RMI-SERVER, it can - // specify IMPL classes for RMI. - classesToRegister.forEach { registration -> - require(registration is ClassRegistrationForRmi) { "Unable to initialize a class registrations for anything OTHER than RMI!! To fix this, remove ${registration.clazz}" } - } - val classesToRegisterForRmi = listOf(*classesToRegister.toTypedArray()) as List - classesToRegister.clear() - - // NOTE: to be clear, the "client" can ONLY registerRmi(IFACE, IMPL), to have extra info as the RMI-SERVER!! - - - - var kryo = initKryo() + private fun initializeClient(kryoRegistrationDetailsFromServer: ByteArray, + classesToRegisterForRmi: List, + kryo: KryoExtra): Boolean { val input = AeronInput(kryoRegistrationDetailsFromServer) val clientClassRegistrations = kryo.readCompressed(logger, input, kryoRegistrationDetailsFromServer.size) as Array> @@ -515,6 +588,8 @@ open class Serialization(private val references: Boolean = true, private val fac } } + logger.trace("CLIENT RMI REG $clazz $implClass") + // implClass MIGHT BE NULL! classesToRegister.add(ClassRegistrationForRmi(clazz, implClass, rmiServerSerializer)) @@ -529,13 +604,32 @@ open class Serialization(private val references: Boolean = true, private val fac return false } - // we have to re-init so the registrations are set! - kryo = initKryo() + // so far, our CURRENT kryo instance was 'registered' with everything, EXCEPT our classesToRegister. + // fortunately for us, this always happens LAST, so we can "do it" here instead of having to reInit kryo all over + classesToRegister.forEach { registration -> + registration.register(kryo, rmiHolder) + } // now do a round-trip through the class registrations return initializeClassRegistrations(kryo) } + /** + * @return takes a kryo instance from the pool, or creates one if the pool was empty + */ + fun takeHandshakeKryo(): KryoExtra { + // ALWAYS get as many as needed. Recycle them to prevent too many getting created + return kryoHandshakePool.poll() ?: initHandshakeKryo() + } + + /** + * Returns a kryo instance to the pool for re-use later on + */ + fun returnHandshakeKryo(kryo: KryoExtra) { + // return as much as we can. don't suspend if the pool is full, we just throw it away. + kryoHandshakePool.offer(kryo) + } + /** * @return takes a kryo instance from the pool, or creates one if the pool was empty */ @@ -563,8 +657,8 @@ open class Serialization(private val references: Boolean = true, private val fac // the rmi-server will have iface+impl id's // the rmi-client will have iface id's - val id = rmiHolder.ifaceToId[interfaceClass]!! - require(id != INVALID_KRYO_ID) { "Registration for $interfaceClass is invalid!!" } + val id = rmiHolder.ifaceToId[interfaceClass] + require(id != null) { "Registration for $interfaceClass is invalid!!" } return id } @@ -577,12 +671,14 @@ open class Serialization(private val references: Boolean = true, private val fac try { if (objectParameters.isNullOrEmpty()) { // simple, easy, fast. - return rmiHolder.idToInstantiator[interfaceClassId].newInstance() + val objectInstantiator = rmiHolder.idToInstantiator[interfaceClassId] ?: + throw NullPointerException("Object instantiator for ID $interfaceClassId is null") + return objectInstantiator.newInstance() } // we have to get the constructor for this object. val clazz: Class<*> = rmiHolder.idToImpl[interfaceClassId] ?: - return IllegalArgumentException("Cannot create RMI object for kryo interfaceClassId: $interfaceClassId (no class exists)") + return NullPointerException("Cannot create RMI object for kryo interfaceClassId: $interfaceClassId (no class exists)") // now have to find the closest match. @@ -714,8 +810,8 @@ open class Serialization(private val references: Boolean = true, private val fac } // NOTE: These following functions are ONLY called on a single thread! - fun readMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? { - return readKryo.read(buffer, offset, length) + fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? { + return readHandshakeKryo.read(buffer, offset, length) } fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, connection: Connection): Any? { return readKryo.read(buffer, offset, length, connection)