diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index 16639faf..a5f99c2e 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -249,8 +249,8 @@ open class Client(config: Configuration = Configuration // we have to construct how the connection will communicate! reliableClientConnection.buildClient(aeron) - logger.trace { - "Creating new connection $reliableClientConnection" + logger.info { + "Creating new connection to $reliableClientConnection" } val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress)) @@ -265,9 +265,17 @@ open class Client(config: Configuration = Configuration throw exception } - // before we do anything else, we have to correct the RMI serializers, as necessary. - val rmiModificationIds = connectionInfo.kryoIdsForRmi - updateKryoIdsForRmi(newConnection, rmiModificationIds) + /////////////// + //// RMI + /////////////// + + // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information + serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage -> + listenerManager.notifyError(newConnection, + ClientRejectedException(errorMessage)) + } + + connection = newConnection connections.add(newConnection) @@ -542,13 +550,13 @@ open class Client(config: Configuration = Configuration suspend inline fun createObject(vararg objectParameters: Any?, noinline callback: suspend (Int, Iface) -> Unit) { // NOTE: It's not possible to have reified inside a virtual function // https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function - val classId = serialization.getClassId(Iface::class.java) + val kryoId = serialization.getKryoIdForRmi(Iface::class.java) @Suppress("UNCHECKED_CAST") objectParameters as Array @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") - rmiConnectionSupport.createRemoteObject(getConnection(), classId, objectParameters, callback) + rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, objectParameters, callback) } /** @@ -572,7 +580,7 @@ open class Client(config: Configuration = Configuration suspend inline fun createObject(noinline callback: suspend (Int, Iface) -> Unit) { // NOTE: It's not possible to have reified inside a virtual function // https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function - val classId = serialization.getClassId(Iface::class.java) + val classId = serialization.getKryoIdForRmi(Iface::class.java) @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") rmiConnectionSupport.createRemoteObject(getConnection(), classId, null, callback) diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index 52a1de1f..720eaeeb 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -537,12 +537,12 @@ open class Connection(connectionParameters: ConnectionParams<*>) { */ suspend fun createObject(vararg objectParameters: Any?, callback: suspend (Int, Iface) -> Unit) { val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function2::class.java, callback.javaClass, 1) - val interfaceClassId = endPoint.serialization.getClassId(iFaceClass) + val kryoId = endPoint.serialization.getKryoIdForRmi(iFaceClass) @Suppress("UNCHECKED_CAST") objectParameters as Array - rmiConnectionSupport.createRemoteObject(this, interfaceClassId, objectParameters, callback) + rmiConnectionSupport.createRemoteObject(this, kryoId, objectParameters, callback) } /** @@ -564,7 +564,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { */ suspend fun createObject(callback: suspend (Int, Iface) -> Unit) { val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function2::class.java, callback.javaClass, 1) - val interfaceClassId = endPoint.serialization.getClassId(iFaceClass) + val interfaceClassId = endPoint.serialization.getKryoIdForRmi(iFaceClass) rmiConnectionSupport.createRemoteObject(this, interfaceClassId, null, callback) } diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 093d65e6..9bb16360 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -20,7 +20,6 @@ import dorkbox.network.Configuration import dorkbox.network.Server import dorkbox.network.ServerConfiguration import dorkbox.network.aeron.CoroutineIdleStrategy -import dorkbox.network.aeron.client.ClientRejectedException import dorkbox.network.connection.ping.PingMessage import dorkbox.network.ipFilter.IpFilterRule import dorkbox.network.rmi.RmiManagerConnections @@ -44,7 +43,6 @@ import mu.KLogger import mu.KotlinLogging import org.agrona.DirectBuffer import java.io.File -import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.CountDownLatch @@ -143,11 +141,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A // we only want one instance of these created. These will be called appropriately val settingsStore: SettingsStore - // list of already seen client RMI ids (which the server might not have registered as RMI types). - private var alreadySeenClientRmiIds = CopyOnWriteArrayList() - - private val networkReadKryo: KryoExtra = config.serialization.takeKryo() - internal val rmiGlobalSupport = RmiManagerGlobal(logger, actionDispatch, config.serialization) init { @@ -233,7 +226,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A logger.info("Aeron log directory: ${config.aeronLogDirectory}") if (aeronDirAlreadyExists) { - logger.info("Aeron log directory already exists! This might not be what you want!") + logger.warn("Aeron log directory already exists! This might not be what you want!") } // serialization stuff @@ -448,9 +441,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } try { - networkReadKryo.write(message) - - val buffer = networkReadKryo.writerBuffer + // NOTE: it is safe to use the global, single-threaded kryo instance! + val buffer = serialization.writeMessage(message) val objectSize = buffer.position() val internalBuffer = buffer.internalBuffer @@ -489,7 +481,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A */ fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { try { - val message = networkReadKryo.read(buffer, offset, length) + val message = serialization.readMessage(buffer, offset, length) logger.trace { "[${header.sessionId()}] received: $message" } @@ -526,11 +518,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A val message: Any? try { - message = networkReadKryo.read(buffer, offset, length, connection) + message = serialization.readMessage(buffer, offset, length, connection) logger.trace { - // The sessionId is globally unique, and is assigned by the server. - val sessionId = header.sessionId() - "[${sessionId}] received: $message" + "[${header.sessionId()}] received: $message" } } catch (e: Exception) { // The sessionId is globally unique, and is assigned by the server. @@ -715,29 +705,4 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A shutdownLatch.countDown() } } - - suspend fun updateKryoIdsForRmi(connection: CONNECTION, rmiModificationIds: IntArray) { - rmiModificationIds.forEach { - if (!alreadySeenClientRmiIds.contains(it)) { - alreadySeenClientRmiIds.add(it) - - // have to modify the network read kryo with the correct registration id -> serializer info. This is a GLOBAL change made on - // a single thread. - // NOTE: This change will ONLY modify the network-read kryo. This is all we need to modify. The write kryo's will already be correct - - val registration = networkReadKryo.getRegistration(it) - val regMessage = "${type.simpleName}-side RMI serializer for registration $it -> ${registration.type}" - if (registration.type.isInterface) { - logger.debug { - "Modifying $regMessage" - } - // RMI must be with an interface. If it's not an interface then something is wrong - registration.serializer = serialization.rmiClientReverseSerializer - } else { - listenerManager.notifyError(connection, - ClientRejectedException("Attempting an unsafe modification of $regMessage")) - } - } - } - } } diff --git a/src/dorkbox/network/serialization/Serialization.kt b/src/dorkbox/network/serialization/Serialization.kt index 571f079a..48e67f87 100644 --- a/src/dorkbox/network/serialization/Serialization.kt +++ b/src/dorkbox/network/serialization/Serialization.kt @@ -17,12 +17,14 @@ package dorkbox.network.serialization import com.esotericsoftware.kryo.ClassResolver import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.SerializerFactory import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy import com.esotericsoftware.kryo.util.IdentityMap +import dorkbox.network.connection.Connection import dorkbox.network.connection.ping.PingMessage import dorkbox.network.handshake.HandshakeMessage import dorkbox.network.rmi.CachedMethod @@ -53,6 +55,7 @@ import org.objenesis.strategy.StdInstantiatorStrategy import java.io.IOException import java.lang.reflect.Constructor import java.lang.reflect.InvocationHandler +import java.util.concurrent.CopyOnWriteArrayList import kotlin.coroutines.Continuation /** @@ -119,7 +122,7 @@ class Serialization(private val references: Boolean, // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // Object checking is performed during actual registration. private val classesToRegister = mutableListOf() - private lateinit var savedKryoIdsForRmi: IntArray + internal lateinit var savedKryoIdsForRmi: IntArray private lateinit var savedRegistrationDetails: ByteArray /// RMI things @@ -127,6 +130,8 @@ class Serialization(private val references: Boolean, private val rmiIfaceToImpl = IdentityMap, Class<*>>() private val rmiImplToIface = IdentityMap, Class<*>>() + // This is a GLOBAL, single threaded only kryo instance. + private val globalKryo: KryoExtra by lazy { initKryo() } // BY DEFAULT, DefaultInstantiatorStrategy() will use ReflectASM // StdInstantiatorStrategy will create classes bypasses the constructor (which can be useful in some cases) THIS IS A FALLBACK! @@ -138,7 +143,10 @@ class Serialization(private val references: Boolean, private val continuationSerializer = ContinuationSerializer() private val rmiClientSerializer = RmiClientSerializer() - internal val rmiClientReverseSerializer = RmiClientReverseSerializer(rmiImplToIface) + private val rmiClientReverseSerializer = RmiClientReverseSerializer(rmiImplToIface) + + // list of already seen client RMI ids (which the server might not have registered as RMI types). + private var existingRmiIds = CopyOnWriteArrayList() @@ -154,6 +162,7 @@ class Serialization(private val references: Boolean, val KRYO_COUNT = 64 kryoPool = Channel(KRYO_COUNT) + } @Synchronized @@ -344,113 +353,110 @@ class Serialization(private val references: Boolean, // initialize the kryo pool with at least 1 kryo instance. This ALSO makes sure that all of our class registration is done // correctly and (if not) we are are notified on the initial thread (instead of on the network update thread) - val kryo = initKryo() - // save off the class-resolver, so we can lookup the class <-> id relationships - classResolver = kryo.classResolver + classResolver = globalKryo.classResolver + // now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID + // in order to get the ID's, these have to be registered with a kryo instance! + val mergedRegistrations = mutableListOf() + classesToRegister.forEach { registration -> + val id = registration.id + + // if we ALREADY contain this registration (based ONLY on ID), then overwrite the existing one and REMOVE the current one + var found = false + mergedRegistrations.forEachIndexed { index, classRegistration -> + if (classRegistration.id == id) { + mergedRegistrations[index] = registration + found = true + return@forEachIndexed + } + } + + if (!found) { + mergedRegistrations.add(registration) + } + } + + // sort these by ID, because that is what they should be registered as... + mergedRegistrations.sortBy { it.id } + + + // now all of the registrations are IN ORDER and MERGED (save back to original array) + + + // set 'classesToRegister' to our mergedRegistrations, because this is now the correct order + classesToRegister.clear() + classesToRegister.addAll(mergedRegistrations) + + + // now create the registration details, used to validate that the client/server have the EXACT same class registration setup + val registrationDetails = arrayListOf>() + + if (logger.isDebugEnabled) { + // log the in-order output first + classesToRegister.forEach { classRegistration -> + logger.debug(classRegistration.info()) + } + } + + val kryoIdsForRmi = mutableListOf() + + classesToRegister.forEach { classRegistration -> + // now save all of the registration IDs for quick verification/access + registrationDetails.add(classRegistration.getInfoArray()) + + // we should cache RMI methods! We don't always know if something is RMI or not (from just how things are registered...) + // so it is super trivial to map out all possible, relevant types + val kryoId = classRegistration.id + + if (classRegistration is ClassRegistrationIfaceAndImpl) { + // on the "RMI server" (aka, where the object lives) side, there will be an interface + implementation! + + // RMI method caching + methodCache[kryoId] = + RmiUtils.getCachedMethods(logger, globalKryo, useAsm, classRegistration.ifaceClass, classRegistration.implClass, kryoId) + + // we ALSO have to cache the instantiator for these, since these are used to create remote objects + val instantiator = globalKryo.instantiatorStrategy.newInstantiatorOf(classRegistration.implClass) + + @Suppress("UNCHECKED_CAST") + rmiIfaceToInstantiator[kryoId] = instantiator as ObjectInstantiator + + // finally, we must save this ID, to tell the remote connection that their interface serializer must change to support + // receiving an RMI impl object as a proxy object + kryoIdsForRmi.add(kryoId) + } else if (classRegistration.clazz.isInterface) { + // non-RMI method caching + methodCache[kryoId] = + RmiUtils.getCachedMethods(logger, globalKryo, useAsm, classRegistration.clazz, null, kryoId) + } + + if (kryoId > 65000) { + throw RuntimeException("There are too many kryo class registrations!!") + } + } + + // save as an array to make it faster to send this info to the remote connection + savedKryoIdsForRmi = kryoIdsForRmi.toIntArray() + + // have to add all of our EXISTING RMI id's, so we don't try to duplicate them (in case RMI registration is duplicated) + existingRmiIds.addAllAbsent(kryoIdsForRmi) + + + // save this as a byte array (so class registration validation during connection handshake is faster) + val output = AeronOutput() + try { - // now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID - // in order to get the ID's, these have to be registered with a kryo instance! - val mergedRegistrations = mutableListOf() - classesToRegister.forEach { registration -> - val id = registration.id - - // if we ALREADY contain this registration (based ONLY on ID), then overwrite the existing one and REMOVE the current one - var found = false - mergedRegistrations.forEachIndexed { index, classRegistration -> - if (classRegistration.id == id) { - mergedRegistrations[index] = registration - found = true - return@forEachIndexed - } - } - - if (!found) { - mergedRegistrations.add(registration) - } - } - - // sort these by ID, because that is what they should be registered as... - mergedRegistrations.sortBy { it.id } - - - // now all of the registrations are IN ORDER and MERGED (save back to original array) - - - // set 'classesToRegister' to our mergedRegistrations, because this is now the correct order - classesToRegister.clear() - classesToRegister.addAll(mergedRegistrations) - - - // now create the registration details, used to validate that the client/server have the EXACT same class registration setup - val registrationDetails = arrayListOf>() - - if (logger.isDebugEnabled) { - // log the in-order output first - classesToRegister.forEach { classRegistration -> - logger.debug(classRegistration.info()) - } - } - - val kryoIdsForRmi = mutableListOf() - - classesToRegister.forEach { classRegistration -> - // now save all of the registration IDs for quick verification/access - registrationDetails.add(classRegistration.getInfoArray()) - - // we should cache RMI methods! We don't always know if something is RMI or not (from just how things are registered...) - // so it is super trivial to map out all possible, relevant types - val kryoId = classRegistration.id - - if (classRegistration is ClassRegistrationIfaceAndImpl) { - // on the "RMI server" (aka, where the object lives) side, there will be an interface + implementation! - - // RMI method caching - methodCache[kryoId] = - RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.ifaceClass, classRegistration.implClass, kryoId) - - // we ALSO have to cache the instantiator for these, since these are used to create remote objects - val instantiator = kryo.instantiatorStrategy.newInstantiatorOf(classRegistration.implClass) - - @Suppress("UNCHECKED_CAST") - rmiIfaceToInstantiator[kryoId] = instantiator as ObjectInstantiator - - // finally, we must save this ID, to tell the remote connection that their interface serializer must change to support - // receiving an RMI impl object as a proxy object - kryoIdsForRmi.add(kryoId) - - } else if (classRegistration.clazz.isInterface) { - // non-RMI method caching - methodCache[kryoId] = - RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId) - } - - if (kryoId > 65000) { - throw RuntimeException("There are too many kryo class registrations!!") - } - } - - savedKryoIdsForRmi = kryoIdsForRmi.toIntArray() - // save as an array to make it faster to send this info to the remote connection - - // save this as a byte array (so class registration validation during connection handshake is faster) - val output = AeronOutput() - - try { - kryo.writeCompressed(logger, output, registrationDetails.toTypedArray()) - } catch (e: Exception) { - logger.error("Unable to write compressed data for registration details", e) - } - - val length = output.position() - savedRegistrationDetails = ByteArray(length) - output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length) - output.close() - } finally { - returnKryo(kryo) + globalKryo.writeCompressed(logger, output, registrationDetails.toTypedArray()) + } catch (e: Exception) { + logger.error("Unable to write compressed data for registration details", e) } + + val length = output.position() + savedRegistrationDetails = ByteArray(length) + output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length) + output.close() } /** @@ -617,17 +623,16 @@ class Serialization(private val references: Boolean, /** * Returns the Kryo class registration ID */ - fun getClassId(iFace: Class<*>): Int { - return classResolver.getRegistration(iFace).id - } + fun getKryoIdForRmi(interfaceClass: Class<*>): Int { + if (!interfaceClass.isInterface) { + throw KryoException("Can only get the kryo IDs for RMI on an interface!") + } - /** - * Returns the Kryo class from a registration ID - */ - fun getClassFromId(kryoId: Int): Class<*> { - return classResolver.getRegistration(kryoId).type - } + val implClass = rmiIfaceToImpl[interfaceClass] + // for RMI, we store the IMPL class in the class registration -- not the iface! + return classResolver.getRegistration(implClass).id + } /** * Creates a NEW object implementation based on the KRYO interface ID. @@ -643,7 +648,7 @@ class Serialization(private val references: Boolean, val size = objectParameters.size // we have to get the constructor for this object. - val clazz = getClassFromId(interfaceClassId) + val clazz = classResolver.getRegistration(interfaceClassId).type val constructors = clazz.declaredConstructors // now have to find the closest match. @@ -781,6 +786,49 @@ class Serialization(private val references: Boolean, } } + suspend fun updateKryoIdsForRmi(connection: CONNECTION, rmiModificationIds: IntArray, onError: suspend (String) -> Unit) { + val endPoint = connection.endPoint() + val typeName = endPoint.type.simpleName + + rmiModificationIds.forEach { + if (!existingRmiIds.contains(it)) { + existingRmiIds.add(it) + + // have to modify the network read kryo with the correct registration id -> serializer info. This is a GLOBAL change made on + // a single thread. + // NOTE: This change will ONLY modify the network-read kryo. This is all we need to modify. The write kryo's will already be correct + + val registration = globalKryo.getRegistration(it) + val regMessage = "$typeName-side RMI serializer for registration $it -> ${registration.type}" + + if (registration.type.isInterface) { + logger.debug { + "Modifying $regMessage" + } + // RMI must be with an interface. If it's not an interface then something is wrong + registration.serializer = rmiClientReverseSerializer + } else { + // note: one way that this can be called is when BOTH the client + server register the same way for RMI IDs. When + // the endpoint serialization is initialized, we also add the RMI IDs to this list, so we don't have to worry about this specific + // scenario + onError("Attempting an unsafe modification of $regMessage") + } + } + } + } + + // NOTE: These following functions are ONLY called on a single thread! + fun readMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? { + return globalKryo.read(buffer, offset, length) + } + fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, connection: Connection): Any? { + return globalKryo.read(buffer, offset, length, connection) + } + fun writeMessage(message: Any): AeronOutput { + globalKryo.write(message) + return globalKryo.writerBuffer + } + // /** // * Waits until a kryo is available to write, using CAS operations to prevent having to synchronize. // *