Serialization registration (class, serializer, id, etc) is now only necessary on the server. The client receives the serialization information during the handshake.

This commit is contained in:
nathan 2020-09-02 15:03:57 +02:00
parent 4504f7167e
commit 07b8b1002a
15 changed files with 224 additions and 336 deletions

View File

@ -281,6 +281,23 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports // throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
logger.info(reliableClientConnection.clientInfo()) logger.info(reliableClientConnection.clientInfo())
///////////////
//// RMI
///////////////
// we setup our kryo information once we connect to a server (using the server's kryo registration details)
if (!serialization.finishInit(type, settingsStore, connectionInfo.kryoRegistrationDetails)) {
handshakeConnection.close()
// because we are getting the class registration details from the SERVER, this should never be the case.
// It is still and edge case where the reconstruction of the registration details fails (maybe because of custom serializers)
val exception = ClientRejectedException("Connection to $remoteAddress has incorrect class registration details!!")
listenerManager.notifyError(exception)
throw exception
}
val newConnection = if (isIpcConnection) { val newConnection = if (isIpcConnection) {
newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID)) newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID))
} else { } else {
@ -297,16 +314,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
throw exception throw exception
} }
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage ->
listenerManager.notifyError(newConnection,
ClientRejectedException(errorMessage))
}
////////////// //////////////
/// Extra Close action /// Extra Close action
////////////// //////////////
@ -515,7 +522,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveObject(`object`: Any): Int { fun saveObject(`object`: Any): Int {
val rmiId = rmiConnectionSupport.saveImplObject(`object`) val rmiId = rmiConnectionSupport.saveImplObject(`object`)
if (rmiId == RemoteObjectStorage.INVALID_RMI) { if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")
@ -545,7 +552,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveObject(`object`: Any, objectId: Int): Boolean { fun saveObject(`object`: Any, objectId: Int): Boolean {
val success = rmiConnectionSupport.saveImplObject(`object`, objectId) val success = rmiConnectionSupport.saveImplObject(`object`, objectId)
if (!success) { if (!success) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")
@ -663,7 +670,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveGlobalObject(`object`: Any): Int { fun saveGlobalObject(`object`: Any): Int {
val rmiId = rmiGlobalSupport.saveImplObject(`object`) val rmiId = rmiGlobalSupport.saveImplObject(`object`)
if (rmiId == RemoteObjectStorage.INVALID_RMI) { if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")
@ -691,7 +698,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveGlobalObject(`object`: Any, objectId: Int): Boolean { fun saveGlobalObject(`object`: Any, objectId: Int): Boolean {
val success = rmiGlobalSupport.saveImplObject(`object`, objectId) val success = rmiGlobalSupport.saveImplObject(`object`, objectId)
if (!success) { if (!success) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")

View File

@ -125,6 +125,9 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") } if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") }
if (config.maxConnectionsPerIpAddress == 0) { config.maxConnectionsPerIpAddress = config.maxClientCount} if (config.maxConnectionsPerIpAddress == 0) { config.maxConnectionsPerIpAddress = config.maxClientCount}
// we are done with initial configuration, now finish serialization
serialization.finishInit(type, settingsStore)
} }
override fun newException(message: String, cause: Throwable?): Throwable { override fun newException(message: String, cause: Throwable?): Throwable {

View File

@ -188,7 +188,7 @@ internal class CryptoManagement(val logger: KLogger,
subscriptionPort: Int, subscriptionPort: Int,
connectionSessionId: Int, connectionSessionId: Int,
connectionStreamId: Int, connectionStreamId: Int,
kryoRmiIds: IntArray): ByteArray { kryoRegDetails: ByteArray): ByteArray {
val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes)
secureRandom.nextBytes(iv) secureRandom.nextBytes(iv)
@ -200,10 +200,8 @@ internal class CryptoManagement(val logger: KLogger,
cryptOutput.writeInt(connectionStreamId) cryptOutput.writeInt(connectionStreamId)
cryptOutput.writeInt(publicationPort) cryptOutput.writeInt(publicationPort)
cryptOutput.writeInt(subscriptionPort) cryptOutput.writeInt(subscriptionPort)
cryptOutput.writeInt(kryoRmiIds.size) cryptOutput.writeInt(kryoRegDetails.size)
kryoRmiIds.forEach { cryptOutput.writeBytes(kryoRegDetails)
cryptOutput.writeInt(it)
}
return iv + aesCipher.doFinal(cryptOutput.toBytes()) return iv + aesCipher.doFinal(cryptOutput.toBytes())
} }
@ -234,12 +232,8 @@ internal class CryptoManagement(val logger: KLogger,
val streamId = cryptInput.readInt() val streamId = cryptInput.readInt()
val publicationPort = cryptInput.readInt() val publicationPort = cryptInput.readInt()
val subscriptionPort = cryptInput.readInt() val subscriptionPort = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val rmiIds = mutableListOf<Int>() val regDetails = cryptInput.readBytes(regDetailsSize)
val rmiIdSize = cryptInput.readInt()
for (i in 0 until rmiIdSize) {
rmiIds.add(cryptInput.readInt())
}
// now read data off // now read data off
return ClientConnectionInfo(sessionId = sessionId, return ClientConnectionInfo(sessionId = sessionId,
@ -247,7 +241,7 @@ internal class CryptoManagement(val logger: KLogger,
publicationPort = publicationPort, publicationPort = publicationPort,
subscriptionPort = subscriptionPort, subscriptionPort = subscriptionPort,
publicKey = serverPublicKeyBytes, publicKey = serverPublicKeyBytes,
kryoIdsForRmi = rmiIds.toIntArray()) kryoRegistrationDetails = regDetails)
} }
override fun hashCode(): Int { override fun hashCode(): Int {

View File

@ -21,6 +21,7 @@ import dorkbox.network.Server
import dorkbox.network.ServerConfiguration import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.CoroutineIdleStrategy import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.connection.ping.PingMessage import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.ipFilter.IpFilterRule import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.other.coroutines.SuspendWaiter import dorkbox.network.other.coroutines.SuspendWaiter
import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.RmiManagerConnections
@ -286,14 +287,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
settingsStore = config.settingsStore settingsStore = config.settingsStore
settingsStore.init(serialization, config.settingsStorageSystem.build()) settingsStore.init(serialization, config.settingsStorageSystem.build())
settingsStore.getSerializationTypes().forEach {
serialization.register(it)
}
crypto = CryptoManagement(logger, settingsStore, type, config) crypto = CryptoManagement(logger, settingsStore, type, config)
// we are done with initial configuration, now finish serialization
serialization.finishInit(type)
} }
internal fun initEndpointState(): Aeron { internal fun initEndpointState(): Aeron {
@ -464,7 +458,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
} }
internal suspend fun writeHandshakeMessage(publication: Publication, message: Any) { internal suspend fun writeHandshakeMessage(publication: Publication, message: HandshakeMessage) {
// The sessionId is globally unique, and is assigned by the server. // The sessionId is globally unique, and is assigned by the server.
logger.trace { logger.trace {
"[${publication.sessionId()}] send: $message" "[${publication.sessionId()}] send: $message"

View File

@ -20,5 +20,5 @@ internal class ClientConnectionInfo(val subscriptionPort: Int = 0,
val sessionId: Int, val sessionId: Int,
val streamId: Int = 0, val streamId: Int = 0,
val publicKey: ByteArray = ByteArray(0), val publicKey: ByteArray = ByteArray(0),
val kryoIdsForRmi: IntArray) { val kryoRegistrationDetails: ByteArray) {
} }

View File

@ -91,18 +91,14 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
val sessionId = cryptInput.readInt() val sessionId = cryptInput.readInt()
val streamSubId = cryptInput.readInt() val streamSubId = cryptInput.readInt()
val streamPubId = cryptInput.readInt() val streamPubId = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val rmiIds = mutableListOf<Int>() val regDetails = cryptInput.readBytes(regDetailsSize)
val rmiIdSize = cryptInput.readInt()
for (i in 0 until rmiIdSize) {
rmiIds.add(cryptInput.readInt())
}
// now read data off // now read data off
connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId, connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId,
subscriptionPort = streamSubId, subscriptionPort = streamSubId,
publicationPort = streamPubId, publicationPort = streamPubId,
kryoIdsForRmi = rmiIds.toIntArray()) kryoRegistrationDetails = regDetails)
} }
HandshakeMessage.DONE_ACK -> { HandshakeMessage.DONE_ACK -> {
connectionDone = true connectionDone = true
@ -120,16 +116,12 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
} }
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
val registrationMessage = HandshakeMessage.helloFromClient( val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, config.settingsStore.getPublicKey()!!)
oneTimePad = oneTimePad,
publicKey = config.settingsStore.getPublicKey()!!,
registrationData = config.serialization.getKryoRegistrationDetails(),
registrationRmiIdData = config.serialization.getKryoRmiIds()
)
// Send the one-time pad to the server. // Send the one-time pad to the server.
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
endPoint.serialization.takeKryo() // TAKE THE KRYO BACK OFF! We don't want it on the pool yet, since this kryo hasn't had all of the classes registered yet!
sessionId = handshakeConnection.publication.sessionId() sessionId = handshakeConnection.publication.sessionId()

View File

@ -57,13 +57,11 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3 const val DONE = 3
const val DONE_ACK = 4 const val DONE_ACK = 4
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray, registrationRmiIdData: IntArray): HandshakeMessage { fun helloFromClient(oneTimePad: Int, publicKey: ByteArray): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = HELLO hello.state = HELLO
hello.oneTimePad = oneTimePad hello.oneTimePad = oneTimePad
hello.publicKey = publicKey hello.publicKey = publicKey
hello.registrationData = registrationData
hello.registrationRmiIdData = registrationRmiIdData
return hello return hello
} }

View File

@ -123,6 +123,50 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return true return true
} }
/**
* @return true if we should continue parsing the incoming message, false if we should abort
*/
private fun validateConnectionInfo(server: Server<CONNECTION>,
handshakePublication: Publication,
config: ServerConfiguration,
clientAddressString: String,
clientAddress: Int): Boolean {
try {
// VALIDATE:: Check to see if there are already too many clients connected.
if (server.connections.connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
}
return false
}
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress)
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
}
return false
}
} catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
}
return false
}
return true
}
// note: CANNOT be called in action dispatch // note: CANNOT be called in action dispatch
fun processHandshakeMessageServer(server: Server<CONNECTION>, fun processHandshakeMessageServer(server: Server<CONNECTION>,
@ -140,12 +184,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val serialization = config.serialization val serialization = config.serialization
// VALIDATE:: make sure the serialization matches between the client/server!
if (!serialization.verifyKryoRegistration(message.registrationData!!)) {
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Registration data mismatch."))
return
}
///// /////
///// /////
@ -235,17 +273,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
} }
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
// NOTE: This modifies the readKryo! This cannot be on a different thread!
serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage ->
listenerManager.notifyError(connection,
ClientRejectedException(errorMessage))
}
/////////////// ///////////////
@ -267,11 +294,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
cryptOutput.writeInt(connectionStreamSubId) cryptOutput.writeInt(connectionStreamSubId)
cryptOutput.writeInt(connectionStreamPubId) cryptOutput.writeInt(connectionStreamPubId)
val kryoRmiIds = serialization.getKryoRmiIds() val regDetails = serialization.getKryoRegistrationDetails()
cryptOutput.writeInt(kryoRmiIds.size) cryptOutput.writeInt(regDetails.size)
kryoRmiIds.forEach { cryptOutput.writeBytes(regDetails)
cryptOutput.writeInt(it)
}
successMessage.registrationData = cryptOutput.toBytes() successMessage.registrationData = cryptOutput.toBytes()
@ -314,51 +339,16 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val validateRemoteAddress: PublicKeyValidationState val validateRemoteAddress: PublicKeyValidationState
val serialization = config.serialization val serialization = config.serialization
try { // VALIDATE:: check to see if the remote connection's public key has changed!
// VALIDATE:: Check to see if there are already too many clients connected. validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)
if (server.connections.connectionCount() >= config.maxClientCount) { if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch."))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
}
return
}
// VALIDATE:: check to see if the remote connection's public key has changed!
validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch."))
return
}
// VALIDATE:: make sure the serialization matches between the client/server!
if (!serialization.verifyKryoRegistration(message.registrationData!!)) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Registration data mismatch."))
return
}
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress)
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
}
return
}
} catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
}
return return
} }
if (!validateConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress)) {
return
}
///// /////
@ -448,19 +438,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
} }
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
// NOTE: This modifies the readKryo! This cannot be on a different thread!
serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage ->
listenerManager.notifyError(connection,
ClientRejectedException(errorMessage))
}
/////////////// ///////////////
/// HANDSHAKE /// HANDSHAKE
/////////////// ///////////////
@ -471,7 +448,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val successMessage = HandshakeMessage.helloAckToClient(sessionId) val successMessage = HandshakeMessage.helloAckToClient(sessionId)
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint // Also send the RMI registration data to the client (so the client doesn't register anything)
// now create the encrypted payload, using ECDH // now create the encrypted payload, using ECDH
successMessage.registrationData = server.crypto.encrypt(clientPublicKeyBytes!!, successMessage.registrationData = server.crypto.encrypt(clientPublicKeyBytes!!,
@ -479,7 +456,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
subscriptionPort, subscriptionPort,
connectionSessionId, connectionSessionId,
connectionStreamId, connectionStreamId,
serialization.getKryoRmiIds()) serialization.getKryoRegistrationDetails())
successMessage.publicKey = server.crypto.publicKeyBytes successMessage.publicKey = server.crypto.publicKeyBytes

View File

@ -32,7 +32,9 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S
* If this class registration will EVENTUALLY be for RMI, then [ClassRegistrationForRmi] will reassign the serializer * If this class registration will EVENTUALLY be for RMI, then [ClassRegistrationForRmi] will reassign the serializer
*/ */
open fun register(kryo: KryoExtra, rmi: RmiHolder) { open fun register(kryo: KryoExtra, rmi: RmiHolder) {
val savedKryoId: Int? = rmi.ifaceToId[clazz] // ClassRegistrationForRmi overrides this method
val savedKryoId: Int? = rmi.implToId[clazz] // ALL registrations MUST BE IMPL!
var overriddenSerializer: Serializer<Any>? = null var overriddenSerializer: Serializer<Any>? = null
@ -53,7 +55,7 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S
return return
} }
else -> { else -> {
// mark that this was overridden! // We didn't do anything.
} }
} }
} }
@ -61,7 +63,7 @@ internal abstract class ClassRegistration(val clazz: Class<*>, val serializer: S
// otherwise, we are OK to continue to register this // otherwise, we are OK to continue to register this
register(kryo) register(kryo)
if (overriddenSerializer != null) { if (serializer != null && overriddenSerializer != serializer) {
info = "$info (Replaced $overriddenSerializer)" info = "$info (Replaced $overriddenSerializer)"
} }

View File

@ -24,6 +24,6 @@ internal class ClassRegistration0(clazz: Class<*>, serializer: Serializer<*>) :
} }
override fun getInfoArray(): Array<Any> { override fun getInfoArray(): Array<Any> {
return arrayOf(id, clazz.name, serializer!!::class.java.name) return arrayOf(0, id, clazz.name, serializer!!::class.java.name)
} }
} }

View File

@ -22,6 +22,6 @@ internal class ClassRegistration1(clazz: Class<*>, id: Int) : ClassRegistration(
} }
override fun getInfoArray(): Array<Any> { override fun getInfoArray(): Array<Any> {
return arrayOf(id, clazz.name, "") return arrayOf(1, id, clazz.name, "")
} }
} }

View File

@ -25,6 +25,6 @@ internal class ClassRegistration2(clazz: Class<*>, serializer: Serializer<*>, id
} }
override fun getInfoArray(): Array<Any> { override fun getInfoArray(): Array<Any> {
return arrayOf(id, clazz.name, serializer!!::class.java.name) return arrayOf(2, id, clazz.name, serializer!!::class.java.name)
} }
} }

View File

@ -23,6 +23,6 @@ internal open class ClassRegistration3(clazz: Class<*>) : ClassRegistration(claz
} }
override fun getInfoArray(): Array<Any> { override fun getInfoArray(): Array<Any> {
return arrayOf(id, clazz.name, "") return arrayOf(3, id, clazz.name, "")
} }
} }

View File

@ -45,7 +45,7 @@ import dorkbox.network.rmi.messages.RmiServerSerializer
* If the impl object 'lives' on the SERVER, then the server must tell the client about the iface ID * If the impl object 'lives' on the SERVER, then the server must tell the client about the iface ID
*/ */
internal class ClassRegistrationForRmi(ifaceClass: Class<*>, internal class ClassRegistrationForRmi(ifaceClass: Class<*>,
val implClass: Class<*>, val implClass: Class<*>?,
serializer: RmiServerSerializer) : ClassRegistration(ifaceClass, serializer) { serializer: RmiServerSerializer) : ClassRegistration(ifaceClass, serializer) {
/** /**
* In general: * In general:
@ -117,8 +117,11 @@ internal class ClassRegistrationForRmi(ifaceClass: Class<*>,
// now register the impl class // now register the impl class
id = kryo.register(implClass, serializer).id id = kryo.register(implClass, serializer).id
} }
info = "Registered $id -> (RMI) ${implClass.name}" info = if (implClass == null) {
"Registered $id -> (RMI-CLIENT) ${clazz.name}"
} else {
"Registered $id -> (RMI-SERVER) ${clazz.name} -> ${implClass.name}"
}
// now, we want to save the relationship between classes and kryoId // now, we want to save the relationship between classes and kryoId
rmi.ifaceToId[clazz] = id rmi.ifaceToId[clazz] = id
@ -131,6 +134,10 @@ internal class ClassRegistrationForRmi(ifaceClass: Class<*>,
override fun getInfoArray(): Array<Any> { override fun getInfoArray(): Array<Any> {
// the info array has to match for the INTERFACE (not the impl!) // the info array has to match for the INTERFACE (not the impl!)
return arrayOf(id, clazz.name, serializer!!::class.java.name) return if (implClass == null) {
arrayOf(4, id, clazz.name, serializer!!::class.java.name, "")
} else {
arrayOf(4, id, clazz.name, serializer!!::class.java.name, implClass.name)
}
} }
} }

View File

@ -23,6 +23,7 @@ import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy
import com.esotericsoftware.minlog.Log import com.esotericsoftware.minlog.Log
import dorkbox.network.Server
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.handshake.HandshakeMessage import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.rmi.CachedMethod import dorkbox.network.rmi.CachedMethod
@ -38,6 +39,7 @@ import dorkbox.network.rmi.messages.MethodResponse
import dorkbox.network.rmi.messages.MethodResponseSerializer import dorkbox.network.rmi.messages.MethodResponseSerializer
import dorkbox.network.rmi.messages.RmiClientSerializer import dorkbox.network.rmi.messages.RmiClientSerializer
import dorkbox.network.rmi.messages.RmiServerSerializer import dorkbox.network.rmi.messages.RmiServerSerializer
import dorkbox.network.storage.SettingsStore
import dorkbox.os.OS import dorkbox.os.OS
import dorkbox.util.serialization.SerializationDefaults import dorkbox.util.serialization.SerializationDefaults
import dorkbox.util.serialization.SerializationManager import dorkbox.util.serialization.SerializationManager
@ -53,7 +55,6 @@ import org.objenesis.strategy.StdInstantiatorStrategy
import java.io.IOException import java.io.IOException
import java.lang.reflect.Constructor import java.lang.reflect.Constructor
import java.lang.reflect.InvocationHandler import java.lang.reflect.InvocationHandler
import java.util.concurrent.CopyOnWriteArrayList
import kotlin.coroutines.Continuation import kotlin.coroutines.Continuation
@ -97,14 +98,8 @@ open class Serialization(private val references: Boolean = true, private val fac
// All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems.
// Object checking is performed during actual registration. // Object checking is performed during actual registration.
private val classesToRegister = mutableListOf<ClassRegistration>() private val classesToRegister = mutableListOf<ClassRegistration>()
private lateinit var savedKryoIdsForRmi: IntArray
private lateinit var savedRegistrationDetails: ByteArray private lateinit var savedRegistrationDetails: ByteArray
// This is a GLOBAL, single threaded only kryo instance.
// This is to make sure that we have an instance of class registration done correctly and (if not) we are
// are notified on the initial thread (instead of on the network update thread)
private val readKryo: KryoExtra by lazy { initKryo() }
// BY DEFAULT, DefaultInstantiatorStrategy() will use ReflectASM // BY DEFAULT, DefaultInstantiatorStrategy() will use ReflectASM
// StdInstantiatorStrategy will create classes bypasses the constructor (which can be useful in some cases) THIS IS A FALLBACK! // StdInstantiatorStrategy will create classes bypasses the constructor (which can be useful in some cases) THIS IS A FALLBACK!
private val instantiatorStrategy = DefaultInstantiatorStrategy(StdInstantiatorStrategy()) private val instantiatorStrategy = DefaultInstantiatorStrategy(StdInstantiatorStrategy())
@ -118,15 +113,16 @@ open class Serialization(private val references: Boolean = true, private val fac
val rmiHolder = RmiHolder() val rmiHolder = RmiHolder()
// list of already seen client RMI ids (which the server might not have registered as RMI types).
private var existingRmiIds = CopyOnWriteArrayList<Int>()
// the purpose of the method cache, is to accelerate looking up methods for specific class // the purpose of the method cache, is to accelerate looking up methods for specific class
private val methodCache : Int2ObjectHashMap<Array<CachedMethod>> = Int2ObjectHashMap() private val methodCache : Int2ObjectHashMap<Array<CachedMethod>> = Int2ObjectHashMap()
// reflectASM doesn't work on android // reflectASM doesn't work on android
private val useAsm = !OS.isAndroid() private val useAsm = !OS.isAndroid()
// This is a GLOBAL, single threaded only kryo instance.
// This kryo WILL RE-CONFIGURED during the client handshake! (it is all the same thread, so object visibility is not a problem)
private var readKryo = initKryo()
/** /**
* Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer]. * Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer].
* If the class is already registered, the existing entry is updated with the new serializer. * If the class is already registered, the existing entry is updated with the new serializer.
@ -144,9 +140,9 @@ open class Serialization(private val references: Boolean = true, private val fac
override fun <T> register(clazz: Class<T>): Serialization { override fun <T> register(clazz: Class<T>): Serialization {
require(!initialized.value) { "Serialization 'register(class)' cannot happen after client/server initialization!" } require(!initialized.value) { "Serialization 'register(class)' cannot happen after client/server initialization!" }
// // The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather // The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather
// // with object types... EVEN IF THERE IS A SERIALIZER // with object types... EVEN IF THERE IS A SERIALIZER
// require(!clazz.isInterface) { "Cannot register '${clazz}' with specified ID for serialization. It must be an implementation." } require(!clazz.isInterface) { "Cannot register '${clazz}' with specified ID for serialization. It must be an implementation." }
classesToRegister.add(ClassRegistration3(clazz)) classesToRegister.add(ClassRegistration3(clazz))
return this return this
@ -257,14 +253,7 @@ open class Serialization(private val references: Boolean = true, private val fac
} }
/** /**
* @return the details of all registration IDs for RMI iface serializer rewrites * called as the first thing inside when initializing the classesToRegister
*/
fun getKryoRmiIds(): IntArray {
return savedKryoIdsForRmi
}
/**
* called as the first think inside [finishInit]
*/ */
private fun initKryo(): KryoExtra { private fun initKryo(): KryoExtra {
val kryo = KryoExtra(methodCache) val kryo = KryoExtra(methodCache)
@ -315,19 +304,40 @@ open class Serialization(private val references: Boolean = true, private val fac
return kryo return kryo
} }
/**
* Called when initialization is complete. This is to prevent (and recognize) out-of-order class/serializer registration. If an ID
* is already in use by a different type, an exception is thrown.
*/
fun finishInit(endPointClass: Class<*>) {
logger = KotlinLogging.logger(endPointClass.simpleName)
if (!initialized.compareAndSet(expect = false, update = true)) { /**
logger.error("Unable to initialize serialization more than once!") * Called when server initialization is complete.
return * Called when client connection receives kryo registration details
} *
* This is to prevent (and recognize) out-of-order class/serializer registration. If an ID is already in use by a different type, an exception is thrown.
*/
internal fun finishInit(type: Class<*>, settingsStore: SettingsStore, kryoRegistrationDetailsFromServer: ByteArray = ByteArray(0)): Boolean {
logger = KotlinLogging.logger(type.simpleName)
// this will set up the class registration information // this will set up the class registration information
return if (type == Server::class.java) {
if (!initialized.compareAndSet(expect = false, update = true)) {
logger.error("Unable to initialize serialization more than once!")
return false
}
settingsStore.getSerializationTypes().forEach {
classesToRegister.add(ClassRegistration3(it))
}
initializeClassRegistrations()
} else {
if (!initialized.compareAndSet(expect = false, update = true)) {
// the client CAN initialize more than once, since initialization happens in the handshake now
return true
}
require(classesToRegister.isEmpty()) { "Unable to initialize a non-empty class registration state! Make sure there are no serialization registrations for the client!" }
initializeClient(kryoRegistrationDetailsFromServer)
}
}
private fun initializeClassRegistrations(): Boolean {
val kryo = initKryo() val kryo = initKryo()
// now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID // now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID
@ -381,8 +391,6 @@ open class Serialization(private val references: Boolean = true, private val fac
} }
} }
val kryoIdsForRmi = mutableListOf<Int>()
classesToRegister.forEach { classRegistration -> classesToRegister.forEach { classRegistration ->
// now save all of the registration IDs for quick verification/access // now save all of the registration IDs for quick verification/access
registrationDetails.add(classRegistration.getInfoArray()) registrationDetails.add(classRegistration.getInfoArray())
@ -396,23 +404,29 @@ open class Serialization(private val references: Boolean = true, private val fac
val implClass = classRegistration.implClass val implClass = classRegistration.implClass
// RMI method caching // TWO ways to do this. On RMI-SERVER, impl class will actually be an IMPL. On RMI-CLIENT, implClass will be IFACE!!
methodCache[kryoId] = if (implClass != null && !implClass.isInterface) {
RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, implClass, kryoId) // server
// we ALSO have to cache the instantiator for these, since these are used to create remote objects // RMI-server method caching
@Suppress("UNCHECKED_CAST") methodCache[kryoId] =
rmiHolder.idToInstantiator[kryoId] = RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, implClass, kryoId)
kryo.instantiatorStrategy.newInstantiatorOf(implClass) as ObjectInstantiator<Any>
// we ALSO have to cache the instantiator for these, since these are used to create remote objects
@Suppress("UNCHECKED_CAST")
rmiHolder.idToInstantiator[kryoId] =
kryo.instantiatorStrategy.newInstantiatorOf(implClass) as ObjectInstantiator<Any>
} else {
// client
// finally, we must save this ID, to tell the remote connection that their interface serializer must change to support // RMI-client method caching
// receiving an RMI impl object as a proxy object methodCache[kryoId] =
kryoIdsForRmi.add(kryoId) RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId)
}
} else if (classRegistration.clazz.isInterface) { } else if (classRegistration.clazz.isInterface) {
// non-RMI method caching // non-RMI method caching
methodCache[kryoId] = methodCache[kryoId] =
RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId) RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId)
} }
if (kryoId > 65000) { if (kryoId > 65000) {
@ -420,189 +434,89 @@ open class Serialization(private val references: Boolean = true, private val fac
} }
} }
// save as an array to make it faster to send this info to the remote connection
savedKryoIdsForRmi = kryoIdsForRmi.toIntArray()
// have to add all of our EXISTING RMI id's, so we don't try to duplicate them (in case RMI registration is duplicated)
existingRmiIds.addAllAbsent(kryoIdsForRmi)
// save this as a byte array (so class registration validation during connection handshake is faster) // save this as a byte array (so class registration validation during connection handshake is faster)
val output = AeronOutput() val output = AeronOutput()
try { try {
kryo.writeCompressed(logger, output, registrationDetails.toTypedArray()) kryo.writeCompressed(logger, output, registrationDetails.toTypedArray())
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Unable to write compressed data for registration details", e) logger.error("Unable to write compressed data for registration details", e)
return false
} }
val length = output.position() val length = output.position()
savedRegistrationDetails = ByteArray(length) savedRegistrationDetails = ByteArray(length)
output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length) output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length)
output.close() output.close()
}
/** // note, we have to check to make sure all classes are registered! Because our classes are registered LAST, this will always be correct.
* NOTE: When this fails, the CLIENT will just time out. We DO NOT want to send an error message to the client classesToRegister.forEach { registration ->
* (it should check for updates or something else). We do not want to give "rogue" clients knowledge of the registration.register(readKryo, rmiHolder)
* server, thus preventing them from trying to probe the server data structures.
*
* @return true if kryo registration is required for all classes sent over the wire
*/
@Suppress("DuplicatedCode")
fun verifyKryoRegistration(clientBytes: ByteArray): Boolean {
// verify the registration IDs if necessary with our own. The CLIENT does not verify anything, only the server!
val kryoRegistrationDetails = savedRegistrationDetails
val equals = kryoRegistrationDetails.contentEquals(clientBytes)
if (equals) {
return true
} }
// RMI details might be one reason the arrays are different return true
}
// now we need to figure out WHAT was screwed up so we know what to fix @Suppress("UNCHECKED_CAST")
// NOTE: it could just be that the byte arrays are different, because java has a non-deterministic iteration of hash maps. private fun initializeClient(kryoRegistrationDetailsFromServer: ByteArray): Boolean {
val kryo = takeKryo() val kryo = initKryo()
val input = AeronInput(clientBytes) val input = AeronInput(kryoRegistrationDetailsFromServer)
val clientClassRegistrations = kryo.readCompressed(logger, input, kryoRegistrationDetailsFromServer.size) as Array<Array<Any>>
val maker = kryo.instantiatorStrategy
try { try {
var success = true // note: this list will be in order by ID!
@Suppress("UNCHECKED_CAST") // We want our "classesToRegister" to be identical (save for RMI stuff) to the server side, so we construct it in the same way
val clientClassRegistrations = kryo.readCompressed(logger, input, clientBytes.size) as Array<Array<Any>> clientClassRegistrations.forEach { bytes ->
val lengthServer = classesToRegister.size val typeId = bytes[0] as Int
val lengthClient = clientClassRegistrations.size val id = bytes[1] as Int
var index = 0 val clazzName = bytes[2] as String
val serializerName = bytes[3] as String
// list all of the registrations that are mis-matched between the server/client val clazz = Class.forName(clazzName)
for (i in 0 until lengthServer) {
index = i
val classServer = classesToRegister[index]
if (index >= lengthClient) { when (typeId) {
success = false 0 -> classesToRegister.add(ClassRegistration0(clazz, maker.newInstantiatorOf(Class.forName(serializerName)).newInstance() as Serializer<Any>))
logger.error("Missing client registration for {} -> {}", classServer.id, classServer.clazz.name) 1 -> classesToRegister.add(ClassRegistration1(clazz, id))
continue 2 -> classesToRegister.add(ClassRegistration2(clazz, maker.newInstantiatorOf(Class.forName(serializerName)).newInstance() as Serializer<Any>, id))
} 3 -> classesToRegister.add(ClassRegistration3(clazz))
4 -> {
// NOTE: when reconstructing, if we have access to the IMPL, we use it. WE MIGHT NOT HAVE ACCESS TO IT ON THE CLIENT!
// we literally want everything to be 100% the same.
// the only WEIRD case is when the client == rmi-server (in which case, the IMPL object is on the client)
// for this, the server (rmi-client) WILL ALSO have the same registration info. (bi-directional RMI, but not really)
val implClazzName = bytes[4] as String
var implClass: Class<*>? = null
if (implClazzName.isNotEmpty()) {
val classClient = clientClassRegistrations[index] try {
implClass = Class.forName(implClazzName)
val idClient = classClient[0] as Int } catch (ignored: Exception) {
val nameClient = classClient[1] as String }
val serializerClient = classClient[2] as String
val idServer = classServer.id
val nameServer = classServer.clazz.name
val serializerServer = classServer.serializer?.javaClass?.name ?: ""
// JUST MAYBE this is a serializer for RMI. The client doesn't have to register for RMI stuff
// this logic is unwrapped, and seemingly complex in order to specifically check for this in a performant way
val idMatches = idClient == idServer
if (!idMatches) {
success = false
logger.error("MISMATCH: Registration $idClient Client -> $nameClient ($serializerClient)")
logger.error("MISMATCH: Registration $idServer Server -> $nameServer ($serializerServer)")
continue
}
val nameMatches = nameServer == nameClient
if (!nameMatches) {
success = false
logger.error("MISMATCH: Registration $idClient Client -> $nameClient ($serializerClient)")
logger.error("MISMATCH: Registration $idServer Server -> $nameServer ($serializerServer)")
continue
}
val serializerMatches = serializerServer == serializerClient
if (!serializerMatches) {
// JUST MAYBE this is a serializer for RMI. The client doesn't have to register for RMI stuff explicitly
when {
serializerServer == rmiServerSerializer::class.java.name -> {
// this is for when the rmi-server is on the server, and the rmi-client is on client
// after this check, we tell the client that this ID is for RMI
// This necessary because only 1 side registers RMI iface/impl info
} }
serializerClient == rmiServerSerializer::class.java.name -> {
// this is for when the rmi-server is on client, and the rmi-client is on server
// after this check, we tell MYSELF (the server) that this id is for RMI classesToRegister.add(ClassRegistrationForRmi(clazz, implClass, rmiServerSerializer))
// This necessary because only 1 side registers RMI iface/impl info
}
else -> {
success = false
logger.error("MISMATCH: Registration $idClient Client -> $nameClient ($serializerClient)")
logger.error("MISMATCH: Registration $idServer Server -> $nameServer ($serializerServer)")
}
} }
else -> throw IllegalStateException("Unable to manage class registrations for unkown registration type $typeId")
} }
// now all of our classes to register will be the same (except for RMI class registrations
} }
// +1 because we are going from index -> length
index++
// list all of the registrations that are missing on the server
if (index < lengthClient) {
success = false
for (i in index - 1 until lengthClient) {
val holderClass = clientClassRegistrations[i]
val id = holderClass[0] as Int
val name = holderClass[1] as String
val serializer = holderClass[2] as String
logger.error("Missing server registration : {} -> {} ({})", id, name, serializer)
}
}
// maybe everything was actually correct, and the byte arrays were different because hashmaps use non-deterministic ordering.
return success
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error [{}] during registration validation", e.message) logger.error("Error creating client class registrations using server data!", e)
} finally { return false
returnKryo(kryo)
input.close()
} }
return false // we have to re-init so the registrations are set!
} initKryo()
// now do a round-trip through the server serialization to make sure our byte arrays are THE SAME.
initializeClassRegistrations()
/** // verify the registration ID data is THE SAME!
* Called when the kryo IDs are updated to be the RMI reverse serializer. return savedRegistrationDetails.contentEquals(kryoRegistrationDetailsFromServer)
*
* NOTE: the IFACE must already be registered!!
*/
fun <CONNECTION : Connection> updateKryoIdsForRmi(connection: CONNECTION, rmiModificationIds: IntArray, onError: (String) -> Unit) {
val typeName = connection.endPoint.type.simpleName
// store all of the classes + kryo registration IDs
rmiModificationIds.forEach {
if (!existingRmiIds.contains(it)) {
existingRmiIds.add(it)
// have to modify the network read kryo with the correct registration id -> serializer info. This is a GLOBAL change made on
// a single thread.
// NOTE: This change will ONLY modify the network-read kryo. This is all we need to modify. The write kryo's will already be correct
// because they are set on initialization
val registration = readKryo.getRegistration(it)
val regMessage = "$typeName-side RMI serializer for registration $it -> ${registration.type}"
if (registration.type.isInterface) {
logger.debug { "Modifying $regMessage" }
// RMI must be with an interface. If it's not an interface then something is wrong
registration.serializer = rmiServerSerializer
} else {
// note: one way that this can be called is when BOTH the client + server register the same way for RMI IDs. When
// the endpoint serialization is initialized, we also add the RMI IDs to this list, so we don't have to worry about this specific
// scenario
onError("Ignoring unsafe modification of $regMessage")
}
}
}
} }
/** /**