From 07b8b1002a44837a2beedc190a8a506cc13e01da Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 2 Sep 2020 15:03:57 +0200 Subject: [PATCH] Serialization registration (class, serializer, id, etc) is now only necessary on the server. The client receives the serialization information during the handshake. --- src/dorkbox/network/Client.kt | 35 +- src/dorkbox/network/Server.kt | 3 + .../network/connection/CryptoManagement.kt | 18 +- src/dorkbox/network/connection/EndPoint.kt | 10 +- .../network/handshake/ClientConnectionInfo.kt | 2 +- .../network/handshake/ClientHandshake.kt | 18 +- .../network/handshake/HandshakeMessage.kt | 4 +- .../network/handshake/ServerHandshake.kt | 135 ++++---- .../serialization/ClassRegistration.kt | 8 +- .../serialization/ClassRegistration0.kt | 2 +- .../serialization/ClassRegistration1.kt | 2 +- .../serialization/ClassRegistration2.kt | 2 +- .../serialization/ClassRegistration3.kt | 2 +- .../serialization/ClassRegistrationForRmi.kt | 15 +- .../network/serialization/Serialization.kt | 304 +++++++----------- 15 files changed, 224 insertions(+), 336 deletions(-) diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index e778d645..f50a8612 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -281,6 +281,23 @@ open class Client(config: Configuration = Configuration // throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports logger.info(reliableClientConnection.clientInfo()) + + /////////////// + //// RMI + /////////////// + + // we setup our kryo information once we connect to a server (using the server's kryo registration details) + if (!serialization.finishInit(type, settingsStore, connectionInfo.kryoRegistrationDetails)) { + handshakeConnection.close() + + // because we are getting the class registration details from the SERVER, this should never be the case. + // It is still and edge case where the reconstruction of the registration details fails (maybe because of custom serializers) + val exception = ClientRejectedException("Connection to $remoteAddress has incorrect class registration details!!") + listenerManager.notifyError(exception) + throw exception + } + + val newConnection = if (isIpcConnection) { newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID)) } else { @@ -297,16 +314,6 @@ open class Client(config: Configuration = Configuration throw exception } - /////////////// - //// RMI - /////////////// - - // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information - serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage -> - listenerManager.notifyError(newConnection, - ClientRejectedException(errorMessage)) - } - ////////////// /// Extra Close action ////////////// @@ -515,7 +522,7 @@ open class Client(config: Configuration = Configuration * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveObject(`object`: Any): Int { + fun saveObject(`object`: Any): Int { val rmiId = rmiConnectionSupport.saveImplObject(`object`) if (rmiId == RemoteObjectStorage.INVALID_RMI) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") @@ -545,7 +552,7 @@ open class Client(config: Configuration = Configuration * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveObject(`object`: Any, objectId: Int): Boolean { + fun saveObject(`object`: Any, objectId: Int): Boolean { val success = rmiConnectionSupport.saveImplObject(`object`, objectId) if (!success) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") @@ -663,7 +670,7 @@ open class Client(config: Configuration = Configuration * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveGlobalObject(`object`: Any): Int { + fun saveGlobalObject(`object`: Any): Int { val rmiId = rmiGlobalSupport.saveImplObject(`object`) if (rmiId == RemoteObjectStorage.INVALID_RMI) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") @@ -691,7 +698,7 @@ open class Client(config: Configuration = Configuration * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveGlobalObject(`object`: Any, objectId: Int): Boolean { + fun saveGlobalObject(`object`: Any, objectId: Int): Boolean { val success = rmiGlobalSupport.saveImplObject(`object`, objectId) if (!success) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index 6487d301..5f5a4bc3 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -125,6 +125,9 @@ open class Server(config: ServerConfiguration = ServerC if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") } if (config.maxConnectionsPerIpAddress == 0) { config.maxConnectionsPerIpAddress = config.maxClientCount} + + // we are done with initial configuration, now finish serialization + serialization.finishInit(type, settingsStore) } override fun newException(message: String, cause: Throwable?): Throwable { diff --git a/src/dorkbox/network/connection/CryptoManagement.kt b/src/dorkbox/network/connection/CryptoManagement.kt index 1484e1b6..05dc0a48 100644 --- a/src/dorkbox/network/connection/CryptoManagement.kt +++ b/src/dorkbox/network/connection/CryptoManagement.kt @@ -188,7 +188,7 @@ internal class CryptoManagement(val logger: KLogger, subscriptionPort: Int, connectionSessionId: Int, connectionStreamId: Int, - kryoRmiIds: IntArray): ByteArray { + kryoRegDetails: ByteArray): ByteArray { val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) secureRandom.nextBytes(iv) @@ -200,10 +200,8 @@ internal class CryptoManagement(val logger: KLogger, cryptOutput.writeInt(connectionStreamId) cryptOutput.writeInt(publicationPort) cryptOutput.writeInt(subscriptionPort) - cryptOutput.writeInt(kryoRmiIds.size) - kryoRmiIds.forEach { - cryptOutput.writeInt(it) - } + cryptOutput.writeInt(kryoRegDetails.size) + cryptOutput.writeBytes(kryoRegDetails) return iv + aesCipher.doFinal(cryptOutput.toBytes()) } @@ -234,12 +232,8 @@ internal class CryptoManagement(val logger: KLogger, val streamId = cryptInput.readInt() val publicationPort = cryptInput.readInt() val subscriptionPort = cryptInput.readInt() - - val rmiIds = mutableListOf() - val rmiIdSize = cryptInput.readInt() - for (i in 0 until rmiIdSize) { - rmiIds.add(cryptInput.readInt()) - } + val regDetailsSize = cryptInput.readInt() + val regDetails = cryptInput.readBytes(regDetailsSize) // now read data off return ClientConnectionInfo(sessionId = sessionId, @@ -247,7 +241,7 @@ internal class CryptoManagement(val logger: KLogger, publicationPort = publicationPort, subscriptionPort = subscriptionPort, publicKey = serverPublicKeyBytes, - kryoIdsForRmi = rmiIds.toIntArray()) + kryoRegistrationDetails = regDetails) } override fun hashCode(): Int { diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 96148aca..62eed87b 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -21,6 +21,7 @@ import dorkbox.network.Server import dorkbox.network.ServerConfiguration import dorkbox.network.aeron.CoroutineIdleStrategy import dorkbox.network.connection.ping.PingMessage +import dorkbox.network.handshake.HandshakeMessage import dorkbox.network.ipFilter.IpFilterRule import dorkbox.network.other.coroutines.SuspendWaiter import dorkbox.network.rmi.RmiManagerConnections @@ -286,14 +287,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A settingsStore = config.settingsStore settingsStore.init(serialization, config.settingsStorageSystem.build()) - settingsStore.getSerializationTypes().forEach { - serialization.register(it) - } - crypto = CryptoManagement(logger, settingsStore, type, config) - - // we are done with initial configuration, now finish serialization - serialization.finishInit(type) } internal fun initEndpointState(): Aeron { @@ -464,7 +458,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } } - internal suspend fun writeHandshakeMessage(publication: Publication, message: Any) { + internal suspend fun writeHandshakeMessage(publication: Publication, message: HandshakeMessage) { // The sessionId is globally unique, and is assigned by the server. logger.trace { "[${publication.sessionId()}] send: $message" diff --git a/src/dorkbox/network/handshake/ClientConnectionInfo.kt b/src/dorkbox/network/handshake/ClientConnectionInfo.kt index 10312aa7..ce8a1cd8 100644 --- a/src/dorkbox/network/handshake/ClientConnectionInfo.kt +++ b/src/dorkbox/network/handshake/ClientConnectionInfo.kt @@ -20,5 +20,5 @@ internal class ClientConnectionInfo(val subscriptionPort: Int = 0, val sessionId: Int, val streamId: Int = 0, val publicKey: ByteArray = ByteArray(0), - val kryoIdsForRmi: IntArray) { + val kryoRegistrationDetails: ByteArray) { } diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index ea56b622..1c05fe1a 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -91,18 +91,14 @@ internal class ClientHandshake(private val logger: KLogg val sessionId = cryptInput.readInt() val streamSubId = cryptInput.readInt() val streamPubId = cryptInput.readInt() - - val rmiIds = mutableListOf() - val rmiIdSize = cryptInput.readInt() - for (i in 0 until rmiIdSize) { - rmiIds.add(cryptInput.readInt()) - } + val regDetailsSize = cryptInput.readInt() + val regDetails = cryptInput.readBytes(regDetailsSize) // now read data off connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId, subscriptionPort = streamSubId, publicationPort = streamPubId, - kryoIdsForRmi = rmiIds.toIntArray()) + kryoRegistrationDetails = regDetails) } HandshakeMessage.DONE_ACK -> { connectionDone = true @@ -120,16 +116,12 @@ internal class ClientHandshake(private val logger: KLogg } suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { - val registrationMessage = HandshakeMessage.helloFromClient( - oneTimePad = oneTimePad, - publicKey = config.settingsStore.getPublicKey()!!, - registrationData = config.serialization.getKryoRegistrationDetails(), - registrationRmiIdData = config.serialization.getKryoRmiIds() - ) + val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, config.settingsStore.getPublicKey()!!) // Send the one-time pad to the server. endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) + endPoint.serialization.takeKryo() // TAKE THE KRYO BACK OFF! We don't want it on the pool yet, since this kryo hasn't had all of the classes registered yet! sessionId = handshakeConnection.publication.sessionId() diff --git a/src/dorkbox/network/handshake/HandshakeMessage.kt b/src/dorkbox/network/handshake/HandshakeMessage.kt index b8b79453..319612cc 100644 --- a/src/dorkbox/network/handshake/HandshakeMessage.kt +++ b/src/dorkbox/network/handshake/HandshakeMessage.kt @@ -57,13 +57,11 @@ internal class HandshakeMessage private constructor() { const val DONE = 3 const val DONE_ACK = 4 - fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray, registrationRmiIdData: IntArray): HandshakeMessage { + fun helloFromClient(oneTimePad: Int, publicKey: ByteArray): HandshakeMessage { val hello = HandshakeMessage() hello.state = HELLO hello.oneTimePad = oneTimePad hello.publicKey = publicKey - hello.registrationData = registrationData - hello.registrationRmiIdData = registrationRmiIdData return hello } diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index e4a8517b..84255e59 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -123,6 +123,50 @@ internal class ServerHandshake(private val logger: KLog return true } + /** + * @return true if we should continue parsing the incoming message, false if we should abort + */ + private fun validateConnectionInfo(server: Server, + handshakePublication: Publication, + config: ServerConfiguration, + clientAddressString: String, + clientAddress: Int): Boolean { + + try { + // VALIDATE:: Check to see if there are already too many clients connected. + if (server.connections.connectionCount() >= config.maxClientCount) { + listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) + + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) + } + return false + } + + // VALIDATE:: we are now connected to the client and are going to create a new connection. + val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) + if (currentCountForIp >= config.maxConnectionsPerIpAddress) { + // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always) + connectionsPerIpCounts.getAndDecrement(clientAddress) + + listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address")) + } + + return false + } + } catch (e: Exception) { + listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) + } + return false + } + + return true + } + // note: CANNOT be called in action dispatch fun processHandshakeMessageServer(server: Server, @@ -140,12 +184,6 @@ internal class ServerHandshake(private val logger: KLog val serialization = config.serialization - // VALIDATE:: make sure the serialization matches between the client/server! - if (!serialization.verifyKryoRegistration(message.registrationData!!)) { - listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Registration data mismatch.")) - return - } - ///// ///// @@ -235,17 +273,6 @@ internal class ServerHandshake(private val logger: KLog } - /////////////// - //// RMI - /////////////// - - // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information - // NOTE: This modifies the readKryo! This cannot be on a different thread! - serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage -> - listenerManager.notifyError(connection, - ClientRejectedException(errorMessage)) - } - /////////////// @@ -267,11 +294,9 @@ internal class ServerHandshake(private val logger: KLog cryptOutput.writeInt(connectionStreamSubId) cryptOutput.writeInt(connectionStreamPubId) - val kryoRmiIds = serialization.getKryoRmiIds() - cryptOutput.writeInt(kryoRmiIds.size) - kryoRmiIds.forEach { - cryptOutput.writeInt(it) - } + val regDetails = serialization.getKryoRegistrationDetails() + cryptOutput.writeInt(regDetails.size) + cryptOutput.writeBytes(regDetails) successMessage.registrationData = cryptOutput.toBytes() @@ -314,51 +339,16 @@ internal class ServerHandshake(private val logger: KLog val validateRemoteAddress: PublicKeyValidationState val serialization = config.serialization - try { - // VALIDATE:: Check to see if there are already too many clients connected. - if (server.connections.connectionCount() >= config.maxClientCount) { - listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) - - server.actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) - } - return - } - - // VALIDATE:: check to see if the remote connection's public key has changed! - validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes) - if (validateRemoteAddress == PublicKeyValidationState.INVALID) { - listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch.")) - return - } - - // VALIDATE:: make sure the serialization matches between the client/server! - if (!serialization.verifyKryoRegistration(message.registrationData!!)) { - listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Registration data mismatch.")) - return - } - - // VALIDATE:: we are now connected to the client and are going to create a new connection. - val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) - if (currentCountForIp >= config.maxConnectionsPerIpAddress) { - // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always) - connectionsPerIpCounts.getAndDecrement(clientAddress) - - listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) - server.actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address")) - } - - return - } - } catch (e: Exception) { - listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) - server.actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) - } + // VALIDATE:: check to see if the remote connection's public key has changed! + validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes) + if (validateRemoteAddress == PublicKeyValidationState.INVALID) { + listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch.")) return } + if (!validateConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress)) { + return + } ///// @@ -448,19 +438,6 @@ internal class ServerHandshake(private val logger: KLog } - /////////////// - //// RMI - /////////////// - - // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information - // NOTE: This modifies the readKryo! This cannot be on a different thread! - serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage -> - listenerManager.notifyError(connection, - ClientRejectedException(errorMessage)) - } - - - /////////////// /// HANDSHAKE /////////////// @@ -471,7 +448,7 @@ internal class ServerHandshake(private val logger: KLog val successMessage = HandshakeMessage.helloAckToClient(sessionId) - // if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint + // Also send the RMI registration data to the client (so the client doesn't register anything) // now create the encrypted payload, using ECDH successMessage.registrationData = server.crypto.encrypt(clientPublicKeyBytes!!, @@ -479,7 +456,7 @@ internal class ServerHandshake(private val logger: KLog subscriptionPort, connectionSessionId, connectionStreamId, - serialization.getKryoRmiIds()) + serialization.getKryoRegistrationDetails()) successMessage.publicKey = server.crypto.publicKeyBytes diff --git a/src/dorkbox/network/serialization/ClassRegistration.kt b/src/dorkbox/network/serialization/ClassRegistration.kt index dfc0a746..d7c307da 100644 --- a/src/dorkbox/network/serialization/ClassRegistration.kt +++ b/src/dorkbox/network/serialization/ClassRegistration.kt @@ -32,7 +32,9 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S * If this class registration will EVENTUALLY be for RMI, then [ClassRegistrationForRmi] will reassign the serializer */ open fun register(kryo: KryoExtra, rmi: RmiHolder) { - val savedKryoId: Int? = rmi.ifaceToId[clazz] + // ClassRegistrationForRmi overrides this method + + val savedKryoId: Int? = rmi.implToId[clazz] // ALL registrations MUST BE IMPL! var overriddenSerializer: Serializer? = null @@ -53,7 +55,7 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S return } else -> { - // mark that this was overridden! + // We didn't do anything. } } } @@ -61,7 +63,7 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S // otherwise, we are OK to continue to register this register(kryo) - if (overriddenSerializer != null) { + if (serializer != null && overriddenSerializer != serializer) { info = "$info (Replaced $overriddenSerializer)" } diff --git a/src/dorkbox/network/serialization/ClassRegistration0.kt b/src/dorkbox/network/serialization/ClassRegistration0.kt index ec1cb1ac..97bf78d4 100644 --- a/src/dorkbox/network/serialization/ClassRegistration0.kt +++ b/src/dorkbox/network/serialization/ClassRegistration0.kt @@ -24,6 +24,6 @@ internal class ClassRegistration0(clazz: Class<*>, serializer: Serializer<*>) : } override fun getInfoArray(): Array { - return arrayOf(id, clazz.name, serializer!!::class.java.name) + return arrayOf(0, id, clazz.name, serializer!!::class.java.name) } } diff --git a/src/dorkbox/network/serialization/ClassRegistration1.kt b/src/dorkbox/network/serialization/ClassRegistration1.kt index 09072c41..b18b9d90 100644 --- a/src/dorkbox/network/serialization/ClassRegistration1.kt +++ b/src/dorkbox/network/serialization/ClassRegistration1.kt @@ -22,6 +22,6 @@ internal class ClassRegistration1(clazz: Class<*>, id: Int) : ClassRegistration( } override fun getInfoArray(): Array { - return arrayOf(id, clazz.name, "") + return arrayOf(1, id, clazz.name, "") } } diff --git a/src/dorkbox/network/serialization/ClassRegistration2.kt b/src/dorkbox/network/serialization/ClassRegistration2.kt index fc73e50b..dbb7481b 100644 --- a/src/dorkbox/network/serialization/ClassRegistration2.kt +++ b/src/dorkbox/network/serialization/ClassRegistration2.kt @@ -25,6 +25,6 @@ internal class ClassRegistration2(clazz: Class<*>, serializer: Serializer<*>, id } override fun getInfoArray(): Array { - return arrayOf(id, clazz.name, serializer!!::class.java.name) + return arrayOf(2, id, clazz.name, serializer!!::class.java.name) } } diff --git a/src/dorkbox/network/serialization/ClassRegistration3.kt b/src/dorkbox/network/serialization/ClassRegistration3.kt index 75e335a0..e10ec638 100644 --- a/src/dorkbox/network/serialization/ClassRegistration3.kt +++ b/src/dorkbox/network/serialization/ClassRegistration3.kt @@ -23,6 +23,6 @@ internal open class ClassRegistration3(clazz: Class<*>) : ClassRegistration(claz } override fun getInfoArray(): Array { - return arrayOf(id, clazz.name, "") + return arrayOf(3, id, clazz.name, "") } } diff --git a/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt b/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt index 8e2b894d..b6d503ab 100644 --- a/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt +++ b/src/dorkbox/network/serialization/ClassRegistrationForRmi.kt @@ -45,7 +45,7 @@ import dorkbox.network.rmi.messages.RmiServerSerializer * If the impl object 'lives' on the SERVER, then the server must tell the client about the iface ID */ internal class ClassRegistrationForRmi(ifaceClass: Class<*>, - val implClass: Class<*>, + val implClass: Class<*>?, serializer: RmiServerSerializer) : ClassRegistration(ifaceClass, serializer) { /** * In general: @@ -117,8 +117,11 @@ internal class ClassRegistrationForRmi(ifaceClass: Class<*>, // now register the impl class id = kryo.register(implClass, serializer).id } - info = "Registered $id -> (RMI) ${implClass.name}" - + info = if (implClass == null) { + "Registered $id -> (RMI-CLIENT) ${clazz.name}" + } else { + "Registered $id -> (RMI-SERVER) ${clazz.name} -> ${implClass.name}" + } // now, we want to save the relationship between classes and kryoId rmi.ifaceToId[clazz] = id @@ -131,6 +134,10 @@ internal class ClassRegistrationForRmi(ifaceClass: Class<*>, override fun getInfoArray(): Array { // the info array has to match for the INTERFACE (not the impl!) - return arrayOf(id, clazz.name, serializer!!::class.java.name) + return if (implClass == null) { + arrayOf(4, id, clazz.name, serializer!!::class.java.name, "") + } else { + arrayOf(4, id, clazz.name, serializer!!::class.java.name, implClass.name) + } } } diff --git a/src/dorkbox/network/serialization/Serialization.kt b/src/dorkbox/network/serialization/Serialization.kt index 6c85ae12..43fd31f8 100644 --- a/src/dorkbox/network/serialization/Serialization.kt +++ b/src/dorkbox/network/serialization/Serialization.kt @@ -23,6 +23,7 @@ import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy import com.esotericsoftware.minlog.Log +import dorkbox.network.Server import dorkbox.network.connection.Connection import dorkbox.network.handshake.HandshakeMessage import dorkbox.network.rmi.CachedMethod @@ -38,6 +39,7 @@ import dorkbox.network.rmi.messages.MethodResponse import dorkbox.network.rmi.messages.MethodResponseSerializer import dorkbox.network.rmi.messages.RmiClientSerializer import dorkbox.network.rmi.messages.RmiServerSerializer +import dorkbox.network.storage.SettingsStore import dorkbox.os.OS import dorkbox.util.serialization.SerializationDefaults import dorkbox.util.serialization.SerializationManager @@ -53,7 +55,6 @@ import org.objenesis.strategy.StdInstantiatorStrategy import java.io.IOException import java.lang.reflect.Constructor import java.lang.reflect.InvocationHandler -import java.util.concurrent.CopyOnWriteArrayList import kotlin.coroutines.Continuation @@ -97,14 +98,8 @@ open class Serialization(private val references: Boolean = true, private val fac // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // Object checking is performed during actual registration. private val classesToRegister = mutableListOf() - private lateinit var savedKryoIdsForRmi: IntArray private lateinit var savedRegistrationDetails: ByteArray - // This is a GLOBAL, single threaded only kryo instance. - // This is to make sure that we have an instance of class registration done correctly and (if not) we are - // are notified on the initial thread (instead of on the network update thread) - private val readKryo: KryoExtra by lazy { initKryo() } - // BY DEFAULT, DefaultInstantiatorStrategy() will use ReflectASM // StdInstantiatorStrategy will create classes bypasses the constructor (which can be useful in some cases) THIS IS A FALLBACK! private val instantiatorStrategy = DefaultInstantiatorStrategy(StdInstantiatorStrategy()) @@ -118,15 +113,16 @@ open class Serialization(private val references: Boolean = true, private val fac val rmiHolder = RmiHolder() - // list of already seen client RMI ids (which the server might not have registered as RMI types). - private var existingRmiIds = CopyOnWriteArrayList() - // the purpose of the method cache, is to accelerate looking up methods for specific class private val methodCache : Int2ObjectHashMap> = Int2ObjectHashMap() // reflectASM doesn't work on android private val useAsm = !OS.isAndroid() + // This is a GLOBAL, single threaded only kryo instance. + // This kryo WILL RE-CONFIGURED during the client handshake! (it is all the same thread, so object visibility is not a problem) + private var readKryo = initKryo() + /** * Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer]. * If the class is already registered, the existing entry is updated with the new serializer. @@ -144,9 +140,9 @@ open class Serialization(private val references: Boolean = true, private val fac override fun register(clazz: Class): Serialization { require(!initialized.value) { "Serialization 'register(class)' cannot happen after client/server initialization!" } -// // The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather -// // with object types... EVEN IF THERE IS A SERIALIZER -// require(!clazz.isInterface) { "Cannot register '${clazz}' with specified ID for serialization. It must be an implementation." } + // The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather + // with object types... EVEN IF THERE IS A SERIALIZER + require(!clazz.isInterface) { "Cannot register '${clazz}' with specified ID for serialization. It must be an implementation." } classesToRegister.add(ClassRegistration3(clazz)) return this @@ -257,14 +253,7 @@ open class Serialization(private val references: Boolean = true, private val fac } /** - * @return the details of all registration IDs for RMI iface serializer rewrites - */ - fun getKryoRmiIds(): IntArray { - return savedKryoIdsForRmi - } - - /** - * called as the first think inside [finishInit] + * called as the first thing inside when initializing the classesToRegister */ private fun initKryo(): KryoExtra { val kryo = KryoExtra(methodCache) @@ -315,19 +304,40 @@ open class Serialization(private val references: Boolean = true, private val fac return kryo } - /** - * Called when initialization is complete. This is to prevent (and recognize) out-of-order class/serializer registration. If an ID - * is already in use by a different type, an exception is thrown. - */ - fun finishInit(endPointClass: Class<*>) { - logger = KotlinLogging.logger(endPointClass.simpleName) - if (!initialized.compareAndSet(expect = false, update = true)) { - logger.error("Unable to initialize serialization more than once!") - return - } + /** + * Called when server initialization is complete. + * Called when client connection receives kryo registration details + * + * This is to prevent (and recognize) out-of-order class/serializer registration. If an ID is already in use by a different type, an exception is thrown. + */ + internal fun finishInit(type: Class<*>, settingsStore: SettingsStore, kryoRegistrationDetailsFromServer: ByteArray = ByteArray(0)): Boolean { + logger = KotlinLogging.logger(type.simpleName) // this will set up the class registration information + return if (type == Server::class.java) { + if (!initialized.compareAndSet(expect = false, update = true)) { + logger.error("Unable to initialize serialization more than once!") + return false + } + + settingsStore.getSerializationTypes().forEach { + classesToRegister.add(ClassRegistration3(it)) + } + + initializeClassRegistrations() + } else { + if (!initialized.compareAndSet(expect = false, update = true)) { + // the client CAN initialize more than once, since initialization happens in the handshake now + return true + } + + require(classesToRegister.isEmpty()) { "Unable to initialize a non-empty class registration state! Make sure there are no serialization registrations for the client!" } + initializeClient(kryoRegistrationDetailsFromServer) + } + } + + private fun initializeClassRegistrations(): Boolean { val kryo = initKryo() // now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID @@ -381,8 +391,6 @@ open class Serialization(private val references: Boolean = true, private val fac } } - val kryoIdsForRmi = mutableListOf() - classesToRegister.forEach { classRegistration -> // now save all of the registration IDs for quick verification/access registrationDetails.add(classRegistration.getInfoArray()) @@ -396,23 +404,29 @@ open class Serialization(private val references: Boolean = true, private val fac val implClass = classRegistration.implClass - // RMI method caching - methodCache[kryoId] = - RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, implClass, kryoId) + // TWO ways to do this. On RMI-SERVER, impl class will actually be an IMPL. On RMI-CLIENT, implClass will be IFACE!! + if (implClass != null && !implClass.isInterface) { + // server - // we ALSO have to cache the instantiator for these, since these are used to create remote objects - @Suppress("UNCHECKED_CAST") - rmiHolder.idToInstantiator[kryoId] = - kryo.instantiatorStrategy.newInstantiatorOf(implClass) as ObjectInstantiator + // RMI-server method caching + methodCache[kryoId] = + RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, implClass, kryoId) + // we ALSO have to cache the instantiator for these, since these are used to create remote objects + @Suppress("UNCHECKED_CAST") + rmiHolder.idToInstantiator[kryoId] = + kryo.instantiatorStrategy.newInstantiatorOf(implClass) as ObjectInstantiator + } else { + // client - // finally, we must save this ID, to tell the remote connection that their interface serializer must change to support - // receiving an RMI impl object as a proxy object - kryoIdsForRmi.add(kryoId) + // RMI-client method caching + methodCache[kryoId] = + RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId) + } } else if (classRegistration.clazz.isInterface) { // non-RMI method caching methodCache[kryoId] = - RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId) + RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId) } if (kryoId > 65000) { @@ -420,189 +434,89 @@ open class Serialization(private val references: Boolean = true, private val fac } } - // save as an array to make it faster to send this info to the remote connection - savedKryoIdsForRmi = kryoIdsForRmi.toIntArray() - - // have to add all of our EXISTING RMI id's, so we don't try to duplicate them (in case RMI registration is duplicated) - existingRmiIds.addAllAbsent(kryoIdsForRmi) - - // save this as a byte array (so class registration validation during connection handshake is faster) val output = AeronOutput() try { kryo.writeCompressed(logger, output, registrationDetails.toTypedArray()) } catch (e: Exception) { logger.error("Unable to write compressed data for registration details", e) + return false } val length = output.position() savedRegistrationDetails = ByteArray(length) output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length) output.close() - } - /** - * NOTE: When this fails, the CLIENT will just time out. We DO NOT want to send an error message to the client - * (it should check for updates or something else). We do not want to give "rogue" clients knowledge of the - * server, thus preventing them from trying to probe the server data structures. - * - * @return true if kryo registration is required for all classes sent over the wire - */ - @Suppress("DuplicatedCode") - fun verifyKryoRegistration(clientBytes: ByteArray): Boolean { - // verify the registration IDs if necessary with our own. The CLIENT does not verify anything, only the server! - val kryoRegistrationDetails = savedRegistrationDetails - val equals = kryoRegistrationDetails.contentEquals(clientBytes) - if (equals) { - return true + // note, we have to check to make sure all classes are registered! Because our classes are registered LAST, this will always be correct. + classesToRegister.forEach { registration -> + registration.register(readKryo, rmiHolder) } - // RMI details might be one reason the arrays are different + return true + } - // now we need to figure out WHAT was screwed up so we know what to fix - // NOTE: it could just be that the byte arrays are different, because java has a non-deterministic iteration of hash maps. - val kryo = takeKryo() - val input = AeronInput(clientBytes) + @Suppress("UNCHECKED_CAST") + private fun initializeClient(kryoRegistrationDetailsFromServer: ByteArray): Boolean { + val kryo = initKryo() + val input = AeronInput(kryoRegistrationDetailsFromServer) + val clientClassRegistrations = kryo.readCompressed(logger, input, kryoRegistrationDetailsFromServer.size) as Array> + + val maker = kryo.instantiatorStrategy try { - var success = true - @Suppress("UNCHECKED_CAST") - val clientClassRegistrations = kryo.readCompressed(logger, input, clientBytes.size) as Array> - val lengthServer = classesToRegister.size - val lengthClient = clientClassRegistrations.size - var index = 0 + // note: this list will be in order by ID! + // We want our "classesToRegister" to be identical (save for RMI stuff) to the server side, so we construct it in the same way + clientClassRegistrations.forEach { bytes -> + val typeId = bytes[0] as Int + val id = bytes[1] as Int + val clazzName = bytes[2] as String + val serializerName = bytes[3] as String - // list all of the registrations that are mis-matched between the server/client - for (i in 0 until lengthServer) { - index = i - val classServer = classesToRegister[index] + val clazz = Class.forName(clazzName) - if (index >= lengthClient) { - success = false - logger.error("Missing client registration for {} -> {}", classServer.id, classServer.clazz.name) - continue - } + when (typeId) { + 0 -> classesToRegister.add(ClassRegistration0(clazz, maker.newInstantiatorOf(Class.forName(serializerName)).newInstance() as Serializer)) + 1 -> classesToRegister.add(ClassRegistration1(clazz, id)) + 2 -> classesToRegister.add(ClassRegistration2(clazz, maker.newInstantiatorOf(Class.forName(serializerName)).newInstance() as Serializer, id)) + 3 -> classesToRegister.add(ClassRegistration3(clazz)) + 4 -> { + // NOTE: when reconstructing, if we have access to the IMPL, we use it. WE MIGHT NOT HAVE ACCESS TO IT ON THE CLIENT! + // we literally want everything to be 100% the same. + // the only WEIRD case is when the client == rmi-server (in which case, the IMPL object is on the client) + // for this, the server (rmi-client) WILL ALSO have the same registration info. (bi-directional RMI, but not really) + val implClazzName = bytes[4] as String + var implClass: Class<*>? = null - - val classClient = clientClassRegistrations[index] - - val idClient = classClient[0] as Int - val nameClient = classClient[1] as String - val serializerClient = classClient[2] as String - - val idServer = classServer.id - val nameServer = classServer.clazz.name - val serializerServer = classServer.serializer?.javaClass?.name ?: "" - - // JUST MAYBE this is a serializer for RMI. The client doesn't have to register for RMI stuff - // this logic is unwrapped, and seemingly complex in order to specifically check for this in a performant way - val idMatches = idClient == idServer - if (!idMatches) { - success = false - logger.error("MISMATCH: Registration $idClient Client -> $nameClient ($serializerClient)") - logger.error("MISMATCH: Registration $idServer Server -> $nameServer ($serializerServer)") - continue - } - - - val nameMatches = nameServer == nameClient - if (!nameMatches) { - success = false - logger.error("MISMATCH: Registration $idClient Client -> $nameClient ($serializerClient)") - logger.error("MISMATCH: Registration $idServer Server -> $nameServer ($serializerServer)") - continue - } - - - val serializerMatches = serializerServer == serializerClient - if (!serializerMatches) { - // JUST MAYBE this is a serializer for RMI. The client doesn't have to register for RMI stuff explicitly - when { - serializerServer == rmiServerSerializer::class.java.name -> { - // this is for when the rmi-server is on the server, and the rmi-client is on client - - // after this check, we tell the client that this ID is for RMI - // This necessary because only 1 side registers RMI iface/impl info + if (implClazzName.isNotEmpty()) { + try { + implClass = Class.forName(implClazzName) + } catch (ignored: Exception) { + } } - serializerClient == rmiServerSerializer::class.java.name -> { - // this is for when the rmi-server is on client, and the rmi-client is on server - // after this check, we tell MYSELF (the server) that this id is for RMI - // This necessary because only 1 side registers RMI iface/impl info - } - else -> { - success = false - logger.error("MISMATCH: Registration $idClient Client -> $nameClient ($serializerClient)") - logger.error("MISMATCH: Registration $idServer Server -> $nameServer ($serializerServer)") - } + classesToRegister.add(ClassRegistrationForRmi(clazz, implClass, rmiServerSerializer)) + } + else -> throw IllegalStateException("Unable to manage class registrations for unkown registration type $typeId") } + + // now all of our classes to register will be the same (except for RMI class registrations } - - // +1 because we are going from index -> length - index++ - - // list all of the registrations that are missing on the server - if (index < lengthClient) { - success = false - for (i in index - 1 until lengthClient) { - val holderClass = clientClassRegistrations[i] - val id = holderClass[0] as Int - val name = holderClass[1] as String - val serializer = holderClass[2] as String - logger.error("Missing server registration : {} -> {} ({})", id, name, serializer) - } - } - - // maybe everything was actually correct, and the byte arrays were different because hashmaps use non-deterministic ordering. - return success } catch (e: Exception) { - logger.error("Error [{}] during registration validation", e.message) - } finally { - returnKryo(kryo) - input.close() + logger.error("Error creating client class registrations using server data!", e) + return false } - return false - } + // we have to re-init so the registrations are set! + initKryo() + // now do a round-trip through the server serialization to make sure our byte arrays are THE SAME. + initializeClassRegistrations() - /** - * Called when the kryo IDs are updated to be the RMI reverse serializer. - * - * NOTE: the IFACE must already be registered!! - */ - fun updateKryoIdsForRmi(connection: CONNECTION, rmiModificationIds: IntArray, onError: (String) -> Unit) { - val typeName = connection.endPoint.type.simpleName - - // store all of the classes + kryo registration IDs - - rmiModificationIds.forEach { - if (!existingRmiIds.contains(it)) { - existingRmiIds.add(it) - - // have to modify the network read kryo with the correct registration id -> serializer info. This is a GLOBAL change made on - // a single thread. - // NOTE: This change will ONLY modify the network-read kryo. This is all we need to modify. The write kryo's will already be correct - // because they are set on initialization - - val registration = readKryo.getRegistration(it) - val regMessage = "$typeName-side RMI serializer for registration $it -> ${registration.type}" - - if (registration.type.isInterface) { - logger.debug { "Modifying $regMessage" } - - // RMI must be with an interface. If it's not an interface then something is wrong - registration.serializer = rmiServerSerializer - } else { - // note: one way that this can be called is when BOTH the client + server register the same way for RMI IDs. When - // the endpoint serialization is initialized, we also add the RMI IDs to this list, so we don't have to worry about this specific - // scenario - onError("Ignoring unsafe modification of $regMessage") - } - } - } + // verify the registration ID data is THE SAME! + return savedRegistrationDetails.contentEquals(kryoRegistrationDetailsFromServer) } /**