Fixed issues with RMI, fixed how rmi-async is handled

This commit is contained in:
nathan 2020-08-19 01:34:31 +02:00
parent c6ad8e9829
commit 0c9949225f
15 changed files with 512 additions and 387 deletions

View File

@ -30,7 +30,7 @@ import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.rmi.RmiManagerForConnections
import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException
import kotlinx.coroutines.launch
@ -72,7 +72,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
private val previousClosedConnectionActivity: Long = 0
private val handshake = ClientHandshake(logger, config, crypto, listenerManager)
private val rmiConnectionSupport = RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch)
private val rmiConnectionSupport = RmiManagerForConnections(logger, rmiGlobalSupport, serialization, actionDispatch)
init {
// have to do some basic validation of our configuration
@ -93,7 +93,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
/**
* So the client class can get remote objects that are THE SAME OBJECT as if called from a connection
*/
override fun getRmiConnectionSupport(): RmiSupportConnection {
override fun getRmiConnectionSupport(): RmiManagerForConnections {
return rmiConnectionSupport
}

View File

@ -22,8 +22,8 @@ import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.rmi.RmiSupport
import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.rmi.RmiManagerForConnections
import dorkbox.network.rmi.RmiMessageManager
import dorkbox.network.rmi.messages.RmiMessage
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.NetworkSerializationManager
@ -133,7 +133,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// we only want one instance of these created. These will be called appropriately
val settingsStore: SettingsStore
internal val rmiGlobalSupport = RmiSupport(logger, actionDispatch, config.serialization)
internal val rmiGlobalSupport = RmiMessageManager(logger, actionDispatch, config.serialization)
init {
logger.error("NETWORK STACK IS ONLY IPV4 AT THE MOMENT. IPV6 is in progress!")
@ -318,8 +318,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* Used for the client, because the client only has ONE ever support connection, and it allows us to create connection specific objects
* from a "global" context
*/
internal open fun getRmiConnectionSupport() : RmiSupportConnection {
return RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch)
internal open fun getRmiConnectionSupport() : RmiManagerForConnections {
return RmiManagerForConnections(logger, rmiGlobalSupport, serialization, actionDispatch)
}
/**
@ -595,9 +595,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
}
autoClosableObjects.clear()
runBlocking {
rmiGlobalSupport.close()
rmiGlobalSupport.close()
runBlocking {
// don't need anything fast or fancy here, because this method will only be called once
connections.forEach {
it.close()

View File

@ -1,4 +1,4 @@
package dorkbox.network.other;
package dorkbox.network.other.coroutines;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@ -13,11 +13,15 @@ import kotlin.jvm.functions.Function1;
* discarded at compile time.
*/
public
class SuspendFunctionAccess {
class SuspendFunctionTrampoline {
/**
* trampoline so we can access suspend functions correctly using reflection
*/
@SuppressWarnings("unchecked")
@Nullable
public static
Object invokeSuspendFunction(@NotNull final Object suspendFunction, @NotNull final Continuation<?> continuation) {
Object invoke(@NotNull final Continuation<?> continuation, @NotNull final Object suspendFunction) throws Throwable {
Function1<? super Continuation<? super Object>, ?> suspendFunction1 = (Function1<? super Continuation<? super Object>, ?>) suspendFunction;
return suspendFunction1.invoke((Continuation<? super Object>) continuation);
}

View File

@ -0,0 +1,34 @@
package dorkbox.network.other.coroutines
import kotlinx.coroutines.channels.Channel
// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting
// https://kotlinlang.org/docs/reference/coroutines/channels.html
inline class SuspendWaiter(private val channel: Channel<Unit> = Channel()) {
// "receive' suspends until another coroutine invokes "send"
// and
// "send" suspends until another coroutine invokes "receive".
suspend fun doWait() {
try {
channel.receive()
} catch (ignored: Exception) {
}
}
suspend fun doNotify() {
try {
channel.send(Unit)
} catch (ignored: Exception) {
}
}
fun cancel() {
try {
channel.cancel()
} catch (ignored: Exception) {
}
}
fun isCancelled(): Boolean {
// once the channel is cancelled, it can never work again
@Suppress("EXPERIMENTAL_API_USAGE")
return channel.isClosedForReceive && channel.isClosedForSend
}
}

View File

@ -16,11 +16,10 @@
package dorkbox.network.rmi
import dorkbox.network.connection.Connection
import dorkbox.network.other.SuspendFunctionAccess
import dorkbox.network.other.coroutines.SuspendFunctionTrampoline
import dorkbox.network.rmi.messages.MethodRequest
import kotlinx.coroutines.runBlocking
import java.lang.reflect.InvocationHandler
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
import java.util.*
import kotlin.coroutines.Continuation
@ -36,14 +35,14 @@ import kotlin.coroutines.Continuation
* @param rmiObjectId this is the remote object ID (assigned by RMI). This is NOT the kryo registration ID
* @param connection this is really the network client -- there is ONLY ever 1 connection
* @param proxyString this is the name assigned to the proxy [toString] method
* @param rmiSupportCache is used to provide RMI support
* @param rmiObjectCache is used to provide RMI support
* @param cachedMethods this is the methods available for the specified class
*/
internal class RmiClient(val isGlobal: Boolean,
val rmiObjectId: Int,
private val connection: Connection,
private val proxyString: String,
private val rmiSupportCache: RmiSupportCache,
private val rmiObjectCache: RmiObjectCache,
private val cachedMethods: Array<CachedMethod>) : InvocationHandler {
companion object {
@ -77,7 +76,6 @@ internal class RmiClient(val isGlobal: Boolean,
// if we are ASYNC, then this method immediately returns
private suspend fun sendRequest(method: Method, args: Array<Any>): Any? {
// there is a STRANGE problem, where if we DO NOT respond/reply to method invocation, and immediate invoke multiple methods --
// the "server" side can have out-of-order method invocation. There are 2 ways to solve this
// 1) make the "server" side single threaded
@ -87,17 +85,24 @@ internal class RmiClient(val isGlobal: Boolean,
// response (even if it is a void response). This simplifies our response mask, and lets us use more bits for storing the
// response ID
val responseStorage = rmiSupportCache.getResponseStorage()
val responseStorage = rmiObjectCache.getResponseStorage()
// NOTE: we ALWAYS send a response from the remote end.
//
// 'async' -> DO NOT WAIT
// 'timeout > 0' -> WAIT
// 'timeout == 0' -> same as async (DO NOT WAIT)
val isAsync = isAsync || timeoutMillis <= 0L
// If we are async, we ignore the response....
// The response, even if there is NOT one (ie: not void) will always return a thing (so we will know when to stop blocking).
val rmiWaiter = responseStorage.prep(rmiObjectId)
// The response, even if there is NOT one (ie: not void) will always return a thing (so our code excution is in lockstep
val rmiWaiter = responseStorage.prep(isAsync)
val invokeMethod = MethodRequest()
invokeMethod.isGlobal = isGlobal
invokeMethod.objectId = rmiObjectId
invokeMethod.responseId = rmiWaiter.id
invokeMethod.args = args
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, rmiWaiter.id)
invokeMethod.args = args // if this is a kotlin suspend function, the suspend arg will NOT be here!
// which method do we access? We always want to access the IMPLEMENTATION (if available!). we know that this will always succeed
// this should be accessed via the KRYO class ID + method index (both are SHORT, and can be packed)
@ -107,10 +112,12 @@ internal class RmiClient(val isGlobal: Boolean,
// if we are async, then this will immediately return
val result = responseStorage.waitForReply(isAsync, rmiObjectId, rmiWaiter, timeoutMillis)
@Suppress("MoveVariableDeclarationIntoWhen")
val result = responseStorage.waitForReply(isAsync, rmiWaiter, timeoutMillis)
when (result) {
RmiResponseStorage.TIMEOUT_EXCEPTION -> {
throw TimeoutException("Response timed out: ${method.declaringClass.name}.${method.name}")
RmiResponseManager.TIMEOUT_EXCEPTION -> {
// TODO: from top down, clean up the coroutine stack
throw TimeoutException("Response timed out: ${method.declaringClass.name}.${method.name}(${method.parameterTypes.map{it.simpleName}})")
}
is Exception -> {
// reconstruct the stack trace, so the calling method knows where the method invocation happened, and can trace the call
@ -152,7 +159,7 @@ internal class RmiClient(val isGlobal: Boolean,
// manage all of the RemoteObject proxy methods
when (method) {
closeMethod -> {
rmiSupportCache.removeProxyObject(rmiObjectId)
rmiObjectCache.removeProxyObject(rmiObjectId)
return null
}
@ -210,76 +217,62 @@ internal class RmiClient(val isGlobal: Boolean,
// We will use this for our coroutine context instead of running on a new coroutine
val maybeContinuation = args?.lastOrNull()
if (isAsync) {
// return immediately, without suspends
if (maybeContinuation is Continuation<*>) {
val argsWithoutContinuation = args.take(args.size - 1)
invokeSuspendFunction(maybeContinuation) {
sendRequest(method, argsWithoutContinuation.toTypedArray())
}
} else {
runBlocking {
sendRequest(method, args ?: EMPTY_ARRAY)
}
// async will return immediately
val returnValue = if (maybeContinuation is Continuation<*>) {
invokeSuspendFunction(maybeContinuation) {
sendRequest(method, args)
}
// if we are async then we return immediately. If you want the response value, you MUST use
// 'waitForLastResponse()' or 'waitForResponse'('getLastResponseID()')
val returnType = method.returnType
if (returnType.isPrimitive) {
return when (returnType) {
Int::class.javaPrimitiveType -> {
0
}
Boolean::class.javaPrimitiveType -> {
java.lang.Boolean.FALSE
}
Float::class.javaPrimitiveType -> {
0.0f
}
Char::class.javaPrimitiveType -> {
0.toChar()
}
Long::class.javaPrimitiveType -> {
0L
}
Short::class.javaPrimitiveType -> {
0.toShort()
}
Byte::class.javaPrimitiveType -> {
0.toByte()
}
Double::class.javaPrimitiveType -> {
0.0
}
else -> {
null
}
}
}
return null
} else {
// non-async code, so we will be blocking/suspending!
return if (maybeContinuation is Continuation<*>) {
val argsWithoutContinuation = args.take(args.size - 1)
invokeSuspendFunction(maybeContinuation) {
sendRequest(method, argsWithoutContinuation.toTypedArray())
runBlocking {
sendRequest(method, args ?: EMPTY_ARRAY)
}
}
if (!isAsync) {
return returnValue
}
// if we are async then we return immediately.
// If you want the response value, disable async!
val returnType = method.returnType
if (returnType.isPrimitive) {
return when (returnType) {
Int::class.javaPrimitiveType -> {
0
}
} else {
runBlocking {
sendRequest(method, args ?: EMPTY_ARRAY)
Boolean::class.javaPrimitiveType -> {
java.lang.Boolean.FALSE
}
Float::class.javaPrimitiveType -> {
0.0f
}
Char::class.javaPrimitiveType -> {
0.toChar()
}
Long::class.javaPrimitiveType -> {
0L
}
Short::class.javaPrimitiveType -> {
0.toShort()
}
Byte::class.javaPrimitiveType -> {
0.toByte()
}
Double::class.javaPrimitiveType -> {
0.0
}
else -> {
null
}
}
}
return null
}
// trampoline so we can access suspend functions correctly and (if suspend) get the coroutine connection parameter)
private fun invokeSuspendFunction(continuation: Continuation<*>, suspendFunction: suspend () -> Any?): Any? {
return try {
SuspendFunctionAccess.invokeSuspendFunction(suspendFunction, continuation)
} catch (e: InvocationTargetException) {
throw e.cause!!
}
private fun invokeSuspendFunction(continuation: Continuation<*>, suspendFunction: suspend () -> Any?): Any {
return SuspendFunctionTrampoline.invoke(continuation, suspendFunction) as Any
}
override fun hashCode(): Int {

View File

@ -8,10 +8,10 @@ import dorkbox.network.serialization.NetworkSerializationManager
import kotlinx.coroutines.CoroutineScope
import mu.KLogger
internal class RmiSupportConnection(logger: KLogger,
val rmiGlobalSupport: RmiSupport,
private val serialization: NetworkSerializationManager,
actionDispatch: CoroutineScope) : RmiSupportCache(logger, actionDispatch) {
internal class RmiManagerForConnections(logger: KLogger,
val rmiGlobalSupport: RmiMessageManager,
private val serialization: NetworkSerializationManager,
actionDispatch: CoroutineScope) : RmiObjectCache(logger, actionDispatch) {
private fun <Iface> createProxyObject(isGlobalObject: Boolean,
connection: Connection,
@ -22,7 +22,7 @@ internal class RmiSupportConnection(logger: KLogger,
// so we can just instantly create the proxy object (or get the cached one)
var proxyObject = getProxyObject(objectId)
if (proxyObject == null) {
proxyObject = RmiSupport.createProxyObject(isGlobalObject, connection, serialization, rmiGlobalSupport, endPoint.type.simpleName, objectId, interfaceClass)
proxyObject = RmiMessageManager.createProxyObject(isGlobalObject, connection, serialization, rmiGlobalSupport, endPoint.type.simpleName, objectId, interfaceClass)
saveProxyObject(objectId, proxyObject)
}
@ -80,7 +80,7 @@ internal class RmiSupportConnection(logger: KLogger,
// this means we could register this object.
// next, scan this object to see if there are any RMI fields
RmiSupport.scanImplForRmiFields(logger, implObject) {
RmiMessageManager.scanImplForRmiFields(logger, implObject) {
saveImplObject(it)
}
} else {

View File

@ -17,7 +17,12 @@ package dorkbox.network.rmi
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.rmi.messages.*
import dorkbox.network.rmi.messages.ConnectionObjectCreateRequest
import dorkbox.network.rmi.messages.ConnectionObjectCreateResponse
import dorkbox.network.rmi.messages.GlobalObjectCreateRequest
import dorkbox.network.rmi.messages.GlobalObjectCreateResponse
import dorkbox.network.rmi.messages.MethodRequest
import dorkbox.network.rmi.messages.MethodResponse
import dorkbox.network.serialization.NetworkSerializationManager
import dorkbox.util.classes.ClassHelper
import kotlinx.coroutines.CoroutineScope
@ -26,10 +31,11 @@ import mu.KLogger
import java.lang.reflect.Proxy
import java.util.*
internal class RmiSupport(logger: KLogger,
actionDispatch: CoroutineScope,
internal val serialization: NetworkSerializationManager) : RmiSupportCache(logger, actionDispatch) {
internal class RmiMessageManager(logger: KLogger,
actionDispatch: CoroutineScope,
internal val serialization: NetworkSerializationManager) : RmiObjectCache(logger, actionDispatch) {
companion object {
/**
* Returns a proxy object that implements the specified interface, and the methods invoked on the proxy object will be invoked
* remotely.
@ -47,7 +53,7 @@ internal class RmiSupport(logger: KLogger,
*/
internal fun createProxyObject(isGlobalObject: Boolean,
connection: Connection, serialization: NetworkSerializationManager,
rmiSupportCache: RmiSupportCache, namePrefix: String,
rmiObjectCache: RmiObjectCache, namePrefix: String,
rmiId: Int, interfaceClass: Class<*>): RemoteObject {
require(interfaceClass.isInterface) { "iface must be an interface." }
@ -62,12 +68,12 @@ internal class RmiSupport(logger: KLogger,
// the ACTUAL proxy is created in the connection impl. Our proxy handler MUST BE suspending because of:
// 1) how we send data on the wire
// 2) how we must (sometimes) wait for a response
val proxyObject = RmiClient(isGlobalObject, rmiId, connection, name, rmiSupportCache, cachedMethods)
val proxyObject = RmiClient(isGlobalObject, rmiId, connection, name, rmiObjectCache, cachedMethods)
// This is the interface inheritance by the proxy object
val interfaces: Array<Class<*>> = arrayOf(RemoteObject::class.java, interfaceClass)
return Proxy.newProxyInstance(RmiSupport::class.java.classLoader, interfaces, proxyObject) as RemoteObject
return Proxy.newProxyInstance(RmiMessageManager::class.java.classLoader, interfaces, proxyObject) as RemoteObject
}
/**
@ -128,7 +134,7 @@ internal class RmiSupport(logger: KLogger,
*/
private fun onGenericObjectResponse(endPoint: EndPoint<*>, connection: Connection, logger: KLogger,
isGlobal: Boolean, rmiId: Int, callback: suspend (Int, Any) -> Unit,
rmiSupportCache: RmiSupportCache, serialization: NetworkSerializationManager) {
rmiObjectCache: RmiObjectCache, serialization: NetworkSerializationManager) {
// we only create the proxy + execute the callback if the RMI id is valid!
if (rmiId == RemoteObjectStorage.INVALID_RMI) {
@ -141,10 +147,10 @@ internal class RmiSupport(logger: KLogger,
val interfaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(RemoteObjectCallback::class.java, callback.javaClass, 1)
// create the client-side proxy object, if possible
var proxyObject = rmiSupportCache.getProxyObject(rmiId)
var proxyObject = rmiObjectCache.getProxyObject(rmiId)
if (proxyObject == null) {
proxyObject = createProxyObject(isGlobal, connection, serialization, rmiSupportCache, endPoint.type.simpleName, rmiId, interfaceClass)
rmiSupportCache.saveProxyObject(rmiId, proxyObject)
proxyObject = createProxyObject(isGlobal, connection, serialization, rmiObjectCache, endPoint.type.simpleName, rmiId, interfaceClass)
rmiObjectCache.saveProxyObject(rmiId, proxyObject)
}
// this should be executed on a NEW coroutine!
@ -162,8 +168,6 @@ internal class RmiSupport(logger: KLogger,
private val remoteObjectCreationCallbacks = RemoteObjectStorage(logger)
internal fun <Iface> registerCallback(callback: suspend (Int, Iface) -> Unit): Int {
return remoteObjectCreationCallbacks.register(callback)
}
@ -176,8 +180,8 @@ internal class RmiSupport(logger: KLogger,
/**
* @return the implementation object based on if it is global, or not global
*/
fun <T> getImplObject(isGlobal: Boolean, rmiObjectId: Int, connection: Connection): T? {
return if (isGlobal) getImplObject(rmiObjectId) else connection.rmiConnectionSupport.getImplObject(rmiObjectId)
fun <T> getImplObject(isGlobal: Boolean, rmiId: Int, connection: Connection): T? {
return if (isGlobal) getImplObject(rmiId) else connection.rmiConnectionSupport.getImplObject(rmiId)
}
/**
@ -233,9 +237,6 @@ internal class RmiSupport(logger: KLogger,
return success as T?
}
override fun close() {
super.close()
remoteObjectCreationCallbacks.close()
@ -261,6 +262,7 @@ internal class RmiSupport(logger: KLogger,
/**
* Manages ALL OF THE RMI stuff!
*/
@Suppress("DuplicatedCode")
@Throws(IllegalArgumentException::class)
suspend fun manage(endPoint: EndPoint<*>, connection: Connection, message: Any, logger: KLogger) {
when (message) {
@ -302,36 +304,40 @@ internal class RmiSupport(logger: KLogger,
*
* The remote side of this connection requested the invocation.
*/
val objectId: Int = message.objectId
val isGlobal: Boolean = message.isGlobal
val isGlobal = message.isGlobal
val isCoroutine = message.isCoroutine
val rmiObjectId = RmiUtils.unpackLeft(message.packedId)
val rmiId = RmiUtils.unpackRight(message.packedId)
val cachedMethod = message.cachedMethod
val args = message.args
val sendResponse = rmiId != 1 // async is always with a '1', and we should NOT send a message back if it is '1'
val implObject = getImplObject<Any>(isGlobal, objectId, connection)
logger.trace { "RMI received: $rmiId" }
val implObject = getImplObject<Any>(isGlobal, rmiObjectId, connection)
if (implObject == null) {
logger.error("Unable to resolve implementation object for [global=$isGlobal, objectID=$objectId, connection=$connection")
logger.error("Unable to resolve implementation object for [global=$isGlobal, objectID=$rmiObjectId, connection=$connection")
val invokeMethodResult = MethodResponse()
invokeMethodResult.objectId = objectId
invokeMethodResult.responseId = message.responseId
invokeMethodResult.result = NullPointerException("Remote object for proxy [global=$isGlobal, objectID=$objectId] does not exist.")
connection.send(invokeMethodResult)
if (sendResponse) {
returnRmiMessage(connection,
message,
NullPointerException("Remote object for proxy [global=$isGlobal, rmiObjectID=$rmiObjectId] does not exist."),
logger)
}
return
}
logger.trace {
var argString = ""
if (message.args != null) {
argString = Arrays.deepToString(message.args)
if (args != null) {
argString = Arrays.deepToString(args)
argString = argString.substring(1, argString.length - 1)
}
val stringBuilder = StringBuilder(128)
stringBuilder.append(connection.toString())
.append(" received: ")
.append(implObject.javaClass.simpleName)
stringBuilder.append(":").append(objectId)
stringBuilder.append(connection.toString()).append(" received: ").append(implObject.javaClass.simpleName)
stringBuilder.append(":").append(rmiObjectId)
stringBuilder.append("#").append(cachedMethod.method.name)
stringBuilder.append("(").append(argString).append(")")
@ -342,32 +348,82 @@ internal class RmiSupport(logger: KLogger,
stringBuilder.toString()
}
var result: Any?
try {
// args!! is safe to do here (even though it doesn't make sense)
result = cachedMethod.invoke(connection, implObject, message.args!!)
} catch (ex: Exception) {
result = ex.cause
// added to prevent a stack overflow when references is false, (because 'cause' == "this").
// See:
// https://groups.google.com/forum/?fromgroups=#!topic/kryo-users/6PDs71M1e9Y
if (result == null) {
result = ex
} else {
result.initCause(null)
if (isCoroutine) {
// https://stackoverflow.com/questions/47654537/how-to-run-suspend-method-via-reflection
// https://discuss.kotlinlang.org/t/calling-coroutines-suspend-functions-via-reflection/4672
var suspendResult = kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn<Any?> { cont ->
// if we are a coroutine, we have to replace the LAST arg with the coroutine object
// we KNOW this is OK, because a continuation arg will always be there!
args!![args.size - 1] = cont
var insideResult: Any?
try {
// args!! is safe to do here (even though it doesn't make sense)
insideResult = cachedMethod.invoke(connection, implObject, args)
} catch (ex: Exception) {
insideResult = ex.cause
// added to prevent a stack overflow when references is false, (because 'cause' == "this").
// See:
// https://groups.google.com/forum/?fromgroups=#!topic/kryo-users/6PDs71M1e9Y
if (insideResult == null) {
insideResult = ex
}
else {
insideResult.initCause(null)
}
ListenerManager.cleanStackTrace(insideResult as Throwable)
logger.error("Error invoking method: ${cachedMethod.method.declaringClass.name}.${cachedMethod.method.name}",
insideResult)
}
insideResult
}
ListenerManager.cleanStackTrace(result as Throwable)
logger.error("Error invoking method: ${cachedMethod.method.declaringClass.name}.${cachedMethod.method.name}", result)
if (suspendResult === kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED) {
// we were suspending, and the stack will resume when possible, then it will call the response below
}
else {
if (suspendResult === Unit) {
// kotlin suspend returns, that DO NOT have a return value, REALLY return kotlin.Unit. This means there is no
// return value!
suspendResult = null
}
if (sendResponse) {
returnRmiMessage(connection, message, suspendResult, logger)
}
}
}
else {
// not a suspend (coroutine) call
try {
// args!! is safe to do here (even though it doesn't make sense)
result = cachedMethod.invoke(connection, implObject, message.args!!)
} catch (ex: Exception) {
result = ex.cause
// added to prevent a stack overflow when references is false, (because 'cause' == "this").
// See:
// https://groups.google.com/forum/?fromgroups=#!topic/kryo-users/6PDs71M1e9Y
if (result == null) {
result = ex
}
else {
result.initCause(null)
}
val invokeMethodResult = MethodResponse()
invokeMethodResult.objectId = objectId
invokeMethodResult.responseId = message.responseId
invokeMethodResult.result = result
ListenerManager.cleanStackTrace(result as Throwable)
logger.error("Error invoking method: ${cachedMethod.method.declaringClass.name}.${cachedMethod.method.name}",
result)
}
connection.send(invokeMethodResult)
if (sendResponse) {
returnRmiMessage(connection, message, result, logger)
}
}
}
is MethodResponse -> {
// notify the pending proxy requests that we have a response!
@ -376,6 +432,16 @@ internal class RmiSupport(logger: KLogger,
}
}
private suspend fun returnRmiMessage(connection: Connection, message: MethodRequest, result: Any?, logger: KLogger) {
logger.trace { "RMI returned: ${RmiUtils.unpackRight(message.packedId)}" }
val rmiMessage = MethodResponse()
rmiMessage.packedId = message.packedId
rmiMessage.result = result
connection.send(rmiMessage)
}
/**
* called on "server"
*/

View File

@ -10,9 +10,9 @@ import mu.KLogger
* The impl/proxy objects CANNOT be stored in the same data structure, because their IDs are not tied to the same ID source (and there
* would be conflicts in the data structure)
*/
internal open class RmiSupportCache(logger: KLogger, actionDispatch: CoroutineScope) {
internal open class RmiObjectCache(logger: KLogger, actionDispatch: CoroutineScope) {
private val responseStorage = RmiResponseStorage(actionDispatch)
private val responseStorage = RmiResponseManager(logger, actionDispatch)
private val implObjects = RemoteObjectStorage(logger)
private val proxyObjects = LockFreeIntMap<RemoteObject>()
@ -48,13 +48,13 @@ internal open class RmiSupportCache(logger: KLogger, actionDispatch: CoroutineSc
proxyObjects.put(rmiId, remoteObject)
}
fun getResponseStorage(): RmiResponseStorage {
fun getResponseStorage(): RmiResponseManager {
return responseStorage
}
open fun close() {
responseStorage.close()
implObjects.close()
proxyObjects.clear()
responseStorage.close()
}
}

View File

@ -0,0 +1,177 @@
package dorkbox.network.rmi
import dorkbox.network.rmi.messages.MethodResponse
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import mu.KLogger
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
/**
* Manages the "pending response" from method invocation.
*
* response ID's and the memory they hold will leak if the response never arrives!
*/
internal class RmiResponseManager(private val logger: KLogger, private val actionDispatch: CoroutineScope) {
companion object {
val TIMEOUT_EXCEPTION = Exception()
val ASYNC_WAITER = RmiWaiter(1) // this is never waited on, we just need this to optimize how we assigned waiters.
const val MAX = Short.MAX_VALUE.toInt()
}
// Response ID's are for ALL in-flight RMI on the network stack. instead of limited to (originally) 64, we are now limited to 65,535
// these are just looped around in a ring buffer.
// These are stored here as int, however these are REALLY shorts and are int-packed when transferring data on the wire
// 64,000 IN FLIGHT RMI method invocations is plenty
private val maxValuesInCache = (MAX * 2) - 1 // -1 because 0 is a reserved number
private val rmiWaiterCache = Channel<RmiWaiter>(maxValuesInCache)
private val pendingLock = ReentrantReadWriteLock()
private val pending = arrayOfNulls<Any>(maxValuesInCache)
init {
// create a shuffled list of ID's. This operation is ONLY performed ONE TIME per endpoint!
val ids = mutableListOf<Int>()
for (id in Short.MIN_VALUE..-1) {
ids.add(id)
}
// ZERO is special, and is never added!
// ONE is special, and is used for ASYNC (the response will never be sent back)
for (id in 1..Short.MAX_VALUE) {
ids.add(id)
}
ids.shuffle()
// populate the array of randomly assigned ID's + waiters.
ids.forEach {
rmiWaiterCache.offer(RmiWaiter(it))
}
}
// resume any pending remote object method invocations (if they are not async, or not manually waiting)
// async RMI will never get here!
suspend fun onMessage(message: MethodResponse) {
val rmiId = RmiUtils.unpackRight(message.packedId)
val adjustedRmiId = rmiId + MAX
val result = message.result
logger.trace { "RMI return: $rmiId" }
val previous = pendingLock.write {
val previous = pending[adjustedRmiId]
pending[adjustedRmiId] = result
previous
}
// if NULL, since either we don't exist (because we were async), or it was cancelled
if (previous is RmiWaiter) {
logger.trace { "RMI valid-cancel: $rmiId" }
// this means we were NOT timed out! (we cannot be timed out here)
previous.doNotify()
// since this was the FIRST one to trigger, return it to the cache.
rmiWaiterCache.send(previous)
}
}
/**
* gets the RmiWaiter (id + waiter).
*
* We ONLY care about the ID to get the correct response info. If there is no response, the ID can be ignored.
*/
internal suspend fun prep(isAsync: Boolean): RmiWaiter {
return if (isAsync) {
ASYNC_WAITER
} else {
val responseRmi = rmiWaiterCache.receive()
// this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
responseRmi.prep()
pendingLock.write {
pending[responseRmi.id + MAX] = responseRmi
}
responseRmi
}
}
/**
* @return the result (can be null) or timeout exception
*/
suspend fun waitForReply(isAsync: Boolean, rmiWaiter: RmiWaiter, timeoutMillis: Long): Any? {
if (isAsync) {
return null
}
val rmiId = rmiWaiter.id
val adjustedRmiId = rmiWaiter.id + MAX
// NOTE: we ALWAYS send a response from the remote end.
//
// 'async' -> DO NOT WAIT (and no response)
// 'timeout > 0' -> WAIT
// 'timeout <= 0' -> same as async (DO NOT WAIT)
val responseTimeoutJob = actionDispatch.launch {
delay(timeoutMillis) // this will always wait. if this job is cancelled, this will immediately stop waiting
// check if we have a result or not
val maybeResult = pendingLock.read { pending[adjustedRmiId] }
if (maybeResult is RmiWaiter) {
logger.trace { "RMI timeout ($timeoutMillis) cancel: $rmiId" }
maybeResult.cancel()
}
}
logger.trace {
"RMI waiting: $rmiId"
}
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will suspend and either
// A) get response
// B) timeout
rmiWaiter.doWait()
val resultOrWaiter = pendingLock.write {
val previous = pending[adjustedRmiId]
pending[adjustedRmiId] = null
previous
}
// always cancel the timeout
responseTimeoutJob.cancel()
if (resultOrWaiter is RmiWaiter) {
logger.trace { "RMI was canceled ($timeoutMillis): $rmiId" }
// since this was the FIRST one to trigger, return it to the cache.
rmiWaiterCache.send(resultOrWaiter)
return TIMEOUT_EXCEPTION
}
return resultOrWaiter
}
fun close() {
pendingLock.write {
pending.forEachIndexed { index, any ->
pending[index] = null
}
}
rmiWaiterCache.close()
}
}

View File

@ -1,191 +0,0 @@
package dorkbox.network.rmi
import dorkbox.network.rmi.messages.MethodResponse
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import org.agrona.collections.Hashing
import org.agrona.collections.Int2NullableObjectHashMap
import org.agrona.collections.IntArrayList
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
internal data class RmiWaiter(val id: Int) {
// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting
// https://stackoverflow.com/questions/55421710/how-to-suspend-kotlin-coroutine-until-notified
// https://kotlinlang.org/docs/reference/coroutines/channels.html
// "receive' suspends until another coroutine invokes "send"
// and
// "send" suspends until another coroutine invokes "receive".
//
// these are wrapped in a try/catch, because cancel will cause exceptions to be thrown (which we DO NOT want)
@Volatile
var channel: Channel<Unit> = Channel(0)
/**
* this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
*/
fun prep() {
@Suppress("EXPERIMENTAL_API_USAGE")
if (channel.isClosedForReceive && channel.isClosedForSend) {
channel = Channel(0)
}
}
suspend fun doNotify() {
try {
channel.send(Unit)
} catch (ignored: Exception) {
}
}
suspend fun doWait() {
try {
channel.receive()
} catch (ignored: Exception) {
}
}
fun cancel() {
try {
channel.cancel()
} catch (ignored: Exception) {
}
}
}
/**
* Manages the "pending response" from method invocation.
*
* response ID's and the memory they hold will leak if the response never arrives!
*/
internal class RmiResponseStorage(private val actionDispatch: CoroutineScope) {
companion object {
val TIMEOUT_EXCEPTION = Exception()
}
// Response ID's are for ALL in-flight RMI on the network stack. instead of limited to (originally) 64, we are now limited to 65,535
// these are just looped around in a ring buffer.
// These are stored here as int, however these are REALLY shorts and are int-packed when transferring data on the wire
// 64,000 IN FLIGHT RMI method invocations is PLENTY
private val maxValuesInCache = (Short.MAX_VALUE.toInt() * 2) - 1 // -1 because 0 is reserved
private val rmiWaiterCache = Channel<RmiWaiter>(maxValuesInCache)
private val pendingLock = ReentrantReadWriteLock()
private val pending = Int2NullableObjectHashMap<Any>(32, Hashing.DEFAULT_LOAD_FACTOR, true)
init {
// create a shuffled list of ID's. This operation is ONLY performed ONE TIME per endpoint!
val ids = IntArrayList(maxValuesInCache, Integer.MIN_VALUE)
for (id in Short.MIN_VALUE..-1) {
ids.addInt(id)
}
// ZERO is special, and is never added!
for (id in 1..Short.MAX_VALUE) {
ids.addInt(id)
}
ids.shuffle()
// populate the array of randomly assigned ID's + waiters.
ids.forEach {
rmiWaiterCache.offer(RmiWaiter(it))
}
}
// resume any pending remote object method invocations (if they are not async, or not manually waiting)
suspend fun onMessage(message: MethodResponse) {
val objectId = message.objectId
val responseId = message.responseId
val result = message.result
val pendingId = RmiUtils.packShorts(objectId, responseId)
val previous = pendingLock.write { pending.put(pendingId, result) }
// if NULL, since either we don't exist, or it was cancelled
if (previous is RmiWaiter) {
// this means we were NOT timed out! If we were cancelled, then this does nothing.
previous.doNotify()
// since this was the FIRST one to trigger, return it to the cache.
rmiWaiterCache.send(previous)
}
}
/**
* gets the RmiWaiter (id + waiter)
*/
internal suspend fun prep(rmiObjectId: Int): RmiWaiter {
val responseRmi = rmiWaiterCache.receive()
// this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
responseRmi.prep()
// we pack them together so we can fully use the range of ints, so we can service ALL rmi requests in a single spot
pendingLock.write { pending[RmiUtils.packShorts(rmiObjectId, responseRmi.id)] = responseRmi }
return responseRmi
}
/**
* @return the result (can be null) or timeout exception
*/
suspend fun waitForReply(isAsync: Boolean, rmiObjectId: Int, rmiWaiter: RmiWaiter, timeoutMillis: Long): Any? {
val pendingId = RmiUtils.packShorts(rmiObjectId, rmiWaiter.id)
// NOTE: we ALWAYS send a response from the remote end.
//
// 'async' -> DO NOT WAIT
// 'timeout > 0' -> WAIT
// 'timeout == 0' -> same as async (DO NOT WAIT)
val returnImmediately = isAsync || timeoutMillis <= 0L
if (returnImmediately) {
return null
}
val responseTimeoutJob = actionDispatch.launch {
delay(timeoutMillis) // this will always wait
// check if we have a result or not
val maybeResult = pendingLock.read { pending[pendingId] }
if (maybeResult is RmiWaiter) {
maybeResult.cancel()
}
}
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will suspend and either
// A) get response
// B) timeout
rmiWaiter.doWait()
// always cancel the timeout
responseTimeoutJob.cancel()
val resultOrWaiter = pendingLock.write { pending.remove(pendingId) }
if (resultOrWaiter is RmiWaiter) {
// since this was the FIRST one to trigger, return it to the cache.
rmiWaiterCache.send(resultOrWaiter)
return TIMEOUT_EXCEPTION
}
return resultOrWaiter
}
fun close() {
rmiWaiterCache.close()
pendingLock.write { pending.clear() }
}
}

View File

@ -0,0 +1,50 @@
package dorkbox.network.rmi
import kotlinx.coroutines.channels.Channel
internal data class RmiWaiter(val id: Int) {
// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting
// https://stackoverflow.com/questions/55421710/how-to-suspend-kotlin-coroutine-until-notified
// https://kotlinlang.org/docs/reference/coroutines/channels.html
// "receive' suspends until another coroutine invokes "send"
// and
// "send" suspends until another coroutine invokes "receive".
//
// these are wrapped in a try/catch, because cancel will cause exceptions to be thrown (which we DO NOT want)
var channel: Channel<Unit> = Channel(0)
var isCancelled = false
/**
* this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
*/
fun prep() {
if (isCancelled) {
isCancelled = false
channel = Channel(0)
}
}
suspend fun doNotify() {
try {
channel.send(Unit)
} catch (ignored: Exception) {
}
}
suspend fun doWait() {
try {
channel.receive()
} catch (ignored: Exception) {
}
}
fun cancel() {
try {
isCancelled = true
channel.cancel()
} catch (ignored: Exception) {
}
}
}

View File

@ -35,25 +35,31 @@
package dorkbox.network.rmi.messages
import dorkbox.network.rmi.CachedMethod
import dorkbox.network.rmi.RmiUtils
/**
* Internal message to invoke methods remotely.
*/
class MethodRequest : RmiMessage {
// if this object was a global or connection specific object
/**
* true if this method is invoked on a global object, false if it is connection scoped
*/
var isGlobal: Boolean = false
// the registered kryo ID for the object
// NOTE: this is REALLY a short, but is represented as an int to make life easier. It is also packed with the responseId for serialization
var objectId: Int = 0
/**
* true if this method is a suspend function (with coroutine) or a normal method
*/
var isCoroutine: Boolean = false
// A value of 0 means to not respond, otherwise it is an ID to match requests <-> responses
// NOTE: this is REALLY a short, but is represented as an int to make life easier. It is also packed with the objectId for serialization
var responseId: Int = 0
// this is packed
// LEFT -> rmiObjectId (the registered rmi ID)
// RIGHT -> rmiId (ID to match requests <-> responses)
var packedId: Int = 0
// This field is NOT sent across the wire (but some of it's contents are).
// We use a custom serializer to manage this because we have to ALSO be able to serialize the invocation arguments.
// NOTE: the info we serialze is REALLY a short, but is represented as an int to make life easier. It is also packed!
// NOTE: the info we serialize is REALLY a short, but is represented as an int to make life easier. It is also packed!
lateinit var cachedMethod: CachedMethod
// these are the arguments for executing the method (they are serialized using the info from the cachedMethod field
@ -61,6 +67,6 @@ class MethodRequest : RmiMessage {
override fun toString(): String {
return "MethodRequest(isGlobal=$isGlobal, objectId=$objectId, responseId=$responseId, cachedMethod=$cachedMethod, args=${args?.contentToString()})"
return "MethodRequest(isGlobal=$isGlobal, rmiObjectId=${RmiUtils.unpackLeft(packedId)}, rmiId=${RmiUtils.unpackRight(packedId)}, cachedMethod=$cachedMethod, args=${args?.contentToString()})"
}
}

View File

@ -48,29 +48,16 @@ import java.lang.reflect.Method
*/
@Suppress("ConstantConditionIf")
class MethodRequestSerializer : Serializer<MethodRequest>() {
companion object {
private const val DEBUG = false
}
override fun write(kryo: Kryo, output: Output, methodRequest: MethodRequest) {
val method = methodRequest.cachedMethod
if (DEBUG) {
System.err.println("WRITING")
System.err.println(":: isGlobal ${methodRequest.isGlobal}")
System.err.println(":: objectID ${methodRequest.objectId}")
System.err.println(":: methodClassID ${method.methodClassId}")
System.err.println(":: methodIndex ${method.methodIndex}")
}
// we pack objectId + responseId into the same "int", since they are both really shorts (but are represented as ints to make
// working with them a lot easier
output.writeInt(RmiUtils.packShorts(methodRequest.objectId, methodRequest.responseId), true)
output.writeInt(methodRequest.packedId)
output.writeInt(RmiUtils.packShorts(method.methodClassId, method.methodIndex), true)
output.writeBoolean(methodRequest.isGlobal)
val serializers = method.serializers
if (serializers.isNotEmpty()) {
val args = methodRequest.args!!
@ -87,25 +74,12 @@ class MethodRequestSerializer : Serializer<MethodRequest>() {
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<out MethodRequest>): MethodRequest {
val objectIdRmiId = input.readInt(true)
val objectId = RmiUtils.unpackLeft(objectIdRmiId)
val responseId = RmiUtils.unpackRight(objectIdRmiId)
val packedId = input.readInt()
val methodInfo = input.readInt(true)
val methodClassId = RmiUtils.unpackLeft(methodInfo)
val methodIndex = RmiUtils.unpackRight(methodInfo)
val isGlobal = input.readBoolean()
if (DEBUG) {
System.err.println("READING")
System.err.println(":: isGlobal $isGlobal")
System.err.println(":: objectID $objectId")
System.err.println(":: methodClassID $methodClassId")
System.err.println(":: methodIndex $methodIndex")
}
(kryo as KryoExtra)
val cachedMethod = try {
@ -138,15 +112,24 @@ class MethodRequestSerializer : Serializer<MethodRequest>() {
val parameterTypes = method.parameterTypes
var isCoroutine = false
// we don't start at 0 for the arguments, in case we have an overwritten method, in which case, the 1st arg is always "Connection.class"
var index = 0
val size = serializers.size
var argStart = argStartIndex
while (index < size) {
val serializer = serializers[index]
if (serializer != null) {
args[argStart] = kryo.readObjectOrNull(input, parameterTypes[index], serializer)
if (serializer is ContinuationSerializer) {
isCoroutine = true
// have to check if it's a coroutine or not!
args[argStart] = ContinuationSerializer::class.java
} else {
args[argStart] = kryo.readObjectOrNull(input, parameterTypes[index], serializer)
}
} else {
args[argStart] = kryo.readClassAndObject(input)
}
@ -157,10 +140,10 @@ class MethodRequestSerializer : Serializer<MethodRequest>() {
val invokeMethod = MethodRequest()
invokeMethod.isGlobal = isGlobal
invokeMethod.objectId = objectId
invokeMethod.isCoroutine = isCoroutine
invokeMethod.packedId = packedId
invokeMethod.cachedMethod = cachedMethod
invokeMethod.args = args
invokeMethod.responseId = responseId
return invokeMethod
}

View File

@ -48,7 +48,6 @@ class MethodResponse : RmiMessage {
// this is the result of the invoked method
var result: Any? = null
override fun toString(): String {
return "MethodResponse(rmiObjectId=${RmiUtils.unpackLeft(packedId)}, rmiId=${RmiUtils.unpackRight(packedId)}, result=$result)"
}

View File

@ -39,9 +39,13 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.IdentityMap
import dorkbox.network.connection.KryoExtra
import dorkbox.network.serialization.KryoExtra
/**
* this is to manage serializing proxy object objects across the wire...
*
* SO the "rmi client" sends an RMI proxy object, and the "rmi server" reads an actual object
*
* Serializes an object registered with the RmiBridge so the receiving side gets a [RemoteObject] proxy rather than the bytes for the
* serialized object.
*
@ -69,4 +73,4 @@ class ObjectResponseSerializer(private val rmiImplToIface: IdentityMap<Class<*>,
}
}
// TODO: FIX THIS CLASS MAYBE!
/// TODO: FIX THIS CLASS MAYBE!