Fixed issues with getting RMI interface IDS and late id registration

This commit is contained in:
nathan 2020-08-27 00:47:37 +02:00
parent 006e597867
commit 661c978b07
4 changed files with 187 additions and 166 deletions

View File

@ -249,8 +249,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// we have to construct how the connection will communicate!
reliableClientConnection.buildClient(aeron)
logger.trace {
"Creating new connection $reliableClientConnection"
logger.info {
"Creating new connection to $reliableClientConnection"
}
val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress))
@ -265,9 +265,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
throw exception
}
// before we do anything else, we have to correct the RMI serializers, as necessary.
val rmiModificationIds = connectionInfo.kryoIdsForRmi
updateKryoIdsForRmi(newConnection, rmiModificationIds)
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage ->
listenerManager.notifyError(newConnection,
ClientRejectedException(errorMessage))
}
connection = newConnection
connections.add(newConnection)
@ -542,13 +550,13 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
suspend inline fun <reified Iface> createObject(vararg objectParameters: Any?, noinline callback: suspend (Int, Iface) -> Unit) {
// NOTE: It's not possible to have reified inside a virtual function
// https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function
val classId = serialization.getClassId(Iface::class.java)
val kryoId = serialization.getKryoIdForRmi(Iface::class.java)
@Suppress("UNCHECKED_CAST")
objectParameters as Array<Any?>
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
rmiConnectionSupport.createRemoteObject(getConnection(), classId, objectParameters, callback)
rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, objectParameters, callback)
}
/**
@ -572,7 +580,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
suspend inline fun <reified Iface> createObject(noinline callback: suspend (Int, Iface) -> Unit) {
// NOTE: It's not possible to have reified inside a virtual function
// https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function
val classId = serialization.getClassId(Iface::class.java)
val classId = serialization.getKryoIdForRmi(Iface::class.java)
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
rmiConnectionSupport.createRemoteObject(getConnection(), classId, null, callback)

View File

@ -537,12 +537,12 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
*/
suspend fun <Iface> createObject(vararg objectParameters: Any?, callback: suspend (Int, Iface) -> Unit) {
val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function2::class.java, callback.javaClass, 1)
val interfaceClassId = endPoint.serialization.getClassId(iFaceClass)
val kryoId = endPoint.serialization.getKryoIdForRmi(iFaceClass)
@Suppress("UNCHECKED_CAST")
objectParameters as Array<Any?>
rmiConnectionSupport.createRemoteObject(this, interfaceClassId, objectParameters, callback)
rmiConnectionSupport.createRemoteObject(this, kryoId, objectParameters, callback)
}
/**
@ -564,7 +564,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
*/
suspend fun <Iface> createObject(callback: suspend (Int, Iface) -> Unit) {
val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function2::class.java, callback.javaClass, 1)
val interfaceClassId = endPoint.serialization.getClassId(iFaceClass)
val interfaceClassId = endPoint.serialization.getKryoIdForRmi(iFaceClass)
rmiConnectionSupport.createRemoteObject(this, interfaceClassId, null, callback)
}

View File

@ -20,7 +20,6 @@ import dorkbox.network.Configuration
import dorkbox.network.Server
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.rmi.RmiManagerConnections
@ -44,7 +43,6 @@ import mu.KLogger
import mu.KotlinLogging
import org.agrona.DirectBuffer
import java.io.File
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.CountDownLatch
@ -143,11 +141,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// we only want one instance of these created. These will be called appropriately
val settingsStore: SettingsStore
// list of already seen client RMI ids (which the server might not have registered as RMI types).
private var alreadySeenClientRmiIds = CopyOnWriteArrayList<Int>()
private val networkReadKryo: KryoExtra = config.serialization.takeKryo()
internal val rmiGlobalSupport = RmiManagerGlobal<CONNECTION>(logger, actionDispatch, config.serialization)
init {
@ -233,7 +226,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
logger.info("Aeron log directory: ${config.aeronLogDirectory}")
if (aeronDirAlreadyExists) {
logger.info("Aeron log directory already exists! This might not be what you want!")
logger.warn("Aeron log directory already exists! This might not be what you want!")
}
// serialization stuff
@ -448,9 +441,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
}
try {
networkReadKryo.write(message)
val buffer = networkReadKryo.writerBuffer
// NOTE: it is safe to use the global, single-threaded kryo instance!
val buffer = serialization.writeMessage(message)
val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer
@ -489,7 +481,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*/
fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
try {
val message = networkReadKryo.read(buffer, offset, length)
val message = serialization.readMessage(buffer, offset, length)
logger.trace {
"[${header.sessionId()}] received: $message"
}
@ -526,11 +518,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val message: Any?
try {
message = networkReadKryo.read(buffer, offset, length, connection)
message = serialization.readMessage(buffer, offset, length, connection)
logger.trace {
// The sessionId is globally unique, and is assigned by the server.
val sessionId = header.sessionId()
"[${sessionId}] received: $message"
"[${header.sessionId()}] received: $message"
}
} catch (e: Exception) {
// The sessionId is globally unique, and is assigned by the server.
@ -715,29 +705,4 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
shutdownLatch.countDown()
}
}
suspend fun updateKryoIdsForRmi(connection: CONNECTION, rmiModificationIds: IntArray) {
rmiModificationIds.forEach {
if (!alreadySeenClientRmiIds.contains(it)) {
alreadySeenClientRmiIds.add(it)
// have to modify the network read kryo with the correct registration id -> serializer info. This is a GLOBAL change made on
// a single thread.
// NOTE: This change will ONLY modify the network-read kryo. This is all we need to modify. The write kryo's will already be correct
val registration = networkReadKryo.getRegistration(it)
val regMessage = "${type.simpleName}-side RMI serializer for registration $it -> ${registration.type}"
if (registration.type.isInterface) {
logger.debug {
"Modifying $regMessage"
}
// RMI must be with an interface. If it's not an interface then something is wrong
registration.serializer = serialization.rmiClientReverseSerializer
} else {
listenerManager.notifyError(connection,
ClientRejectedException("Attempting an unsafe modification of $regMessage"))
}
}
}
}
}

View File

@ -17,12 +17,14 @@ package dorkbox.network.serialization
import com.esotericsoftware.kryo.ClassResolver
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.SerializerFactory
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy
import com.esotericsoftware.kryo.util.IdentityMap
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.rmi.CachedMethod
@ -53,6 +55,7 @@ import org.objenesis.strategy.StdInstantiatorStrategy
import java.io.IOException
import java.lang.reflect.Constructor
import java.lang.reflect.InvocationHandler
import java.util.concurrent.CopyOnWriteArrayList
import kotlin.coroutines.Continuation
/**
@ -119,7 +122,7 @@ class Serialization(private val references: Boolean,
// All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems.
// Object checking is performed during actual registration.
private val classesToRegister = mutableListOf<ClassRegistration>()
private lateinit var savedKryoIdsForRmi: IntArray
internal lateinit var savedKryoIdsForRmi: IntArray
private lateinit var savedRegistrationDetails: ByteArray
/// RMI things
@ -127,6 +130,8 @@ class Serialization(private val references: Boolean,
private val rmiIfaceToImpl = IdentityMap<Class<*>, Class<*>>()
private val rmiImplToIface = IdentityMap<Class<*>, Class<*>>()
// This is a GLOBAL, single threaded only kryo instance.
private val globalKryo: KryoExtra by lazy { initKryo() }
// BY DEFAULT, DefaultInstantiatorStrategy() will use ReflectASM
// StdInstantiatorStrategy will create classes bypasses the constructor (which can be useful in some cases) THIS IS A FALLBACK!
@ -138,7 +143,10 @@ class Serialization(private val references: Boolean,
private val continuationSerializer = ContinuationSerializer()
private val rmiClientSerializer = RmiClientSerializer()
internal val rmiClientReverseSerializer = RmiClientReverseSerializer(rmiImplToIface)
private val rmiClientReverseSerializer = RmiClientReverseSerializer(rmiImplToIface)
// list of already seen client RMI ids (which the server might not have registered as RMI types).
private var existingRmiIds = CopyOnWriteArrayList<Int>()
@ -154,6 +162,7 @@ class Serialization(private val references: Boolean,
val KRYO_COUNT = 64
kryoPool = Channel(KRYO_COUNT)
}
@Synchronized
@ -344,113 +353,110 @@ class Serialization(private val references: Boolean,
// initialize the kryo pool with at least 1 kryo instance. This ALSO makes sure that all of our class registration is done
// correctly and (if not) we are are notified on the initial thread (instead of on the network update thread)
val kryo = initKryo()
// save off the class-resolver, so we can lookup the class <-> id relationships
classResolver = kryo.classResolver
classResolver = globalKryo.classResolver
// now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID
// in order to get the ID's, these have to be registered with a kryo instance!
val mergedRegistrations = mutableListOf<ClassRegistration>()
classesToRegister.forEach { registration ->
val id = registration.id
// if we ALREADY contain this registration (based ONLY on ID), then overwrite the existing one and REMOVE the current one
var found = false
mergedRegistrations.forEachIndexed { index, classRegistration ->
if (classRegistration.id == id) {
mergedRegistrations[index] = registration
found = true
return@forEachIndexed
}
}
if (!found) {
mergedRegistrations.add(registration)
}
}
// sort these by ID, because that is what they should be registered as...
mergedRegistrations.sortBy { it.id }
// now all of the registrations are IN ORDER and MERGED (save back to original array)
// set 'classesToRegister' to our mergedRegistrations, because this is now the correct order
classesToRegister.clear()
classesToRegister.addAll(mergedRegistrations)
// now create the registration details, used to validate that the client/server have the EXACT same class registration setup
val registrationDetails = arrayListOf<Array<Any>>()
if (logger.isDebugEnabled) {
// log the in-order output first
classesToRegister.forEach { classRegistration ->
logger.debug(classRegistration.info())
}
}
val kryoIdsForRmi = mutableListOf<Int>()
classesToRegister.forEach { classRegistration ->
// now save all of the registration IDs for quick verification/access
registrationDetails.add(classRegistration.getInfoArray())
// we should cache RMI methods! We don't always know if something is RMI or not (from just how things are registered...)
// so it is super trivial to map out all possible, relevant types
val kryoId = classRegistration.id
if (classRegistration is ClassRegistrationIfaceAndImpl) {
// on the "RMI server" (aka, where the object lives) side, there will be an interface + implementation!
// RMI method caching
methodCache[kryoId] =
RmiUtils.getCachedMethods(logger, globalKryo, useAsm, classRegistration.ifaceClass, classRegistration.implClass, kryoId)
// we ALSO have to cache the instantiator for these, since these are used to create remote objects
val instantiator = globalKryo.instantiatorStrategy.newInstantiatorOf(classRegistration.implClass)
@Suppress("UNCHECKED_CAST")
rmiIfaceToInstantiator[kryoId] = instantiator as ObjectInstantiator<Any>
// finally, we must save this ID, to tell the remote connection that their interface serializer must change to support
// receiving an RMI impl object as a proxy object
kryoIdsForRmi.add(kryoId)
} else if (classRegistration.clazz.isInterface) {
// non-RMI method caching
methodCache[kryoId] =
RmiUtils.getCachedMethods(logger, globalKryo, useAsm, classRegistration.clazz, null, kryoId)
}
if (kryoId > 65000) {
throw RuntimeException("There are too many kryo class registrations!!")
}
}
// save as an array to make it faster to send this info to the remote connection
savedKryoIdsForRmi = kryoIdsForRmi.toIntArray()
// have to add all of our EXISTING RMI id's, so we don't try to duplicate them (in case RMI registration is duplicated)
existingRmiIds.addAllAbsent(kryoIdsForRmi)
// save this as a byte array (so class registration validation during connection handshake is faster)
val output = AeronOutput()
try {
// now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations based on ID
// in order to get the ID's, these have to be registered with a kryo instance!
val mergedRegistrations = mutableListOf<ClassRegistration>()
classesToRegister.forEach { registration ->
val id = registration.id
// if we ALREADY contain this registration (based ONLY on ID), then overwrite the existing one and REMOVE the current one
var found = false
mergedRegistrations.forEachIndexed { index, classRegistration ->
if (classRegistration.id == id) {
mergedRegistrations[index] = registration
found = true
return@forEachIndexed
}
}
if (!found) {
mergedRegistrations.add(registration)
}
}
// sort these by ID, because that is what they should be registered as...
mergedRegistrations.sortBy { it.id }
// now all of the registrations are IN ORDER and MERGED (save back to original array)
// set 'classesToRegister' to our mergedRegistrations, because this is now the correct order
classesToRegister.clear()
classesToRegister.addAll(mergedRegistrations)
// now create the registration details, used to validate that the client/server have the EXACT same class registration setup
val registrationDetails = arrayListOf<Array<Any>>()
if (logger.isDebugEnabled) {
// log the in-order output first
classesToRegister.forEach { classRegistration ->
logger.debug(classRegistration.info())
}
}
val kryoIdsForRmi = mutableListOf<Int>()
classesToRegister.forEach { classRegistration ->
// now save all of the registration IDs for quick verification/access
registrationDetails.add(classRegistration.getInfoArray())
// we should cache RMI methods! We don't always know if something is RMI or not (from just how things are registered...)
// so it is super trivial to map out all possible, relevant types
val kryoId = classRegistration.id
if (classRegistration is ClassRegistrationIfaceAndImpl) {
// on the "RMI server" (aka, where the object lives) side, there will be an interface + implementation!
// RMI method caching
methodCache[kryoId] =
RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.ifaceClass, classRegistration.implClass, kryoId)
// we ALSO have to cache the instantiator for these, since these are used to create remote objects
val instantiator = kryo.instantiatorStrategy.newInstantiatorOf(classRegistration.implClass)
@Suppress("UNCHECKED_CAST")
rmiIfaceToInstantiator[kryoId] = instantiator as ObjectInstantiator<Any>
// finally, we must save this ID, to tell the remote connection that their interface serializer must change to support
// receiving an RMI impl object as a proxy object
kryoIdsForRmi.add(kryoId)
} else if (classRegistration.clazz.isInterface) {
// non-RMI method caching
methodCache[kryoId] =
RmiUtils.getCachedMethods(logger, kryo, useAsm, classRegistration.clazz, null, kryoId)
}
if (kryoId > 65000) {
throw RuntimeException("There are too many kryo class registrations!!")
}
}
savedKryoIdsForRmi = kryoIdsForRmi.toIntArray()
// save as an array to make it faster to send this info to the remote connection
// save this as a byte array (so class registration validation during connection handshake is faster)
val output = AeronOutput()
try {
kryo.writeCompressed(logger, output, registrationDetails.toTypedArray())
} catch (e: Exception) {
logger.error("Unable to write compressed data for registration details", e)
}
val length = output.position()
savedRegistrationDetails = ByteArray(length)
output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length)
output.close()
} finally {
returnKryo(kryo)
globalKryo.writeCompressed(logger, output, registrationDetails.toTypedArray())
} catch (e: Exception) {
logger.error("Unable to write compressed data for registration details", e)
}
val length = output.position()
savedRegistrationDetails = ByteArray(length)
output.toBytes().copyInto(savedRegistrationDetails, 0, 0, length)
output.close()
}
/**
@ -617,17 +623,16 @@ class Serialization(private val references: Boolean,
/**
* Returns the Kryo class registration ID
*/
fun getClassId(iFace: Class<*>): Int {
return classResolver.getRegistration(iFace).id
}
fun getKryoIdForRmi(interfaceClass: Class<*>): Int {
if (!interfaceClass.isInterface) {
throw KryoException("Can only get the kryo IDs for RMI on an interface!")
}
/**
* Returns the Kryo class from a registration ID
*/
fun getClassFromId(kryoId: Int): Class<*> {
return classResolver.getRegistration(kryoId).type
}
val implClass = rmiIfaceToImpl[interfaceClass]
// for RMI, we store the IMPL class in the class registration -- not the iface!
return classResolver.getRegistration(implClass).id
}
/**
* Creates a NEW object implementation based on the KRYO interface ID.
@ -643,7 +648,7 @@ class Serialization(private val references: Boolean,
val size = objectParameters.size
// we have to get the constructor for this object.
val clazz = getClassFromId(interfaceClassId)
val clazz = classResolver.getRegistration(interfaceClassId).type
val constructors = clazz.declaredConstructors
// now have to find the closest match.
@ -781,6 +786,49 @@ class Serialization(private val references: Boolean,
}
}
suspend fun <CONNECTION: Connection> updateKryoIdsForRmi(connection: CONNECTION, rmiModificationIds: IntArray, onError: suspend (String) -> Unit) {
val endPoint = connection.endPoint()
val typeName = endPoint.type.simpleName
rmiModificationIds.forEach {
if (!existingRmiIds.contains(it)) {
existingRmiIds.add(it)
// have to modify the network read kryo with the correct registration id -> serializer info. This is a GLOBAL change made on
// a single thread.
// NOTE: This change will ONLY modify the network-read kryo. This is all we need to modify. The write kryo's will already be correct
val registration = globalKryo.getRegistration(it)
val regMessage = "$typeName-side RMI serializer for registration $it -> ${registration.type}"
if (registration.type.isInterface) {
logger.debug {
"Modifying $regMessage"
}
// RMI must be with an interface. If it's not an interface then something is wrong
registration.serializer = rmiClientReverseSerializer
} else {
// note: one way that this can be called is when BOTH the client + server register the same way for RMI IDs. When
// the endpoint serialization is initialized, we also add the RMI IDs to this list, so we don't have to worry about this specific
// scenario
onError("Attempting an unsafe modification of $regMessage")
}
}
}
}
// NOTE: These following functions are ONLY called on a single thread!
fun readMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? {
return globalKryo.read(buffer, offset, length)
}
fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, connection: Connection): Any? {
return globalKryo.read(buffer, offset, length, connection)
}
fun writeMessage(message: Any): AeronOutput {
globalKryo.write(message)
return globalKryo.writerBuffer
}
// /**
// * Waits until a kryo is available to write, using CAS operations to prevent having to synchronize.
// *