Network/src/dorkbox/network/rmi/RmiClient.kt

464 lines
20 KiB
Kotlin

/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi
import com.conversantmedia.util.collection.FixedStack
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.rmi.ResponseManager.Companion.TIMEOUT_EXCEPTION
import dorkbox.network.rmi.messages.MethodRequest
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.util.*
import java.util.concurrent.*
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
/**
* Handles network communication when methods are invoked on a proxy.
*
* For NON-BLOCKING performance, the RMI interface must have the 'suspend' keyword added. If this keyword is not
* present, then all method invocation will be BLOCKING.
*
* @param isGlobal true if this is a global object, or false if it is connection specific
* @param rmiObjectId this is the remote object ID (assigned by RMI). This is NOT the kryo registration ID
* @param connection this is really the network client -- there is ONLY ever 1 connection
* @param proxyString this is the name assigned to the proxy [toString] method
* @param responseManager is used to provide RMI request/response support
* @param cachedMethods this is the methods available for the specified class
*/
internal class RmiClient(val isGlobal: Boolean,
val rmiObjectId: Int,
private val connection: Connection,
private val proxyString: String,
private val responseManager: ResponseManager,
private val cachedMethods: Array<CachedMethod>) : InvocationHandler {
companion object {
private val methods = RmiUtils.getMethods(RemoteObject::class.java)
private val toStringMethod = methods.find { it.name == "toString" }
private val hashCodeMethod = methods.find { it.name == "hashCode" }
private val equalsMethod = methods.find { it.name == "equals" }
private val enableToStringMethod = methods.find { it.name == "enableToString" }
private val enableHashCodeMethod = methods.find { it.name == "enableHashCode" }
private val enableEqualsMethod = methods.find { it.name == "enableEquals" }
private val asyncMethod = methods.find { it.name == "async" }
private val syncMethod = methods.find { it.name == "sync" }
private val asyncSuspendMethod = methods.find { it.name == "asyncSuspend" }
private val syncSuspendMethod = methods.find { it.name == "syncSuspend" }
private val setResponseTimeoutMethod = methods.find { it.name == "setResponseTimeout" }
private val getResponseTimeoutMethod = methods.find { it.name == "getResponseTimeout" }
private val setAsyncMethod = methods.find { it.name == "setAsync" }
private val getAsyncMethod = methods.find { it.name == "getAsync" }
@Suppress("UNCHECKED_CAST")
private val EMPTY_ARRAY: Array<Any> = Collections.EMPTY_LIST.toTypedArray() as Array<Any>
private val safeAsyncStack: ThreadLocal<FixedStack<Boolean?>> = ThreadLocal.withInitial {
FixedStack(64)
}
private const val charPrim = 0.toChar()
private const val shortPrim = 0.toShort()
private const val bytePrim = 0.toByte()
@Suppress("UNCHECKED_CAST")
private fun syncMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>) {
val action = args[0] as Any.() -> Unit
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
// the docs cover that (and say, `don't do this`)
safeAsyncStack.get().push(isAsync)
// the `sync` method is always a unit function - we want to execute that unit function directly - this way we can control
// exactly how sync state is preserved.
try {
action(proxy)
} finally {
safeAsyncStack.get().pop()
}
}
@Suppress("UNCHECKED_CAST")
private fun syncSuspendMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>): Any? {
val action = args[0] as suspend Any.() -> Unit
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
// We will use this for our coroutine context instead of running on a new coroutine
val suspendCoroutineArg = args.last()
val continuation = suspendCoroutineArg as Continuation<Any?>
val suspendFunction: suspend () -> Any? = {
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
// the docs cover that (and say, `don't do this`)
withContext(safeAsyncStack.asContextElement()) {
yield() // must have an actually suspending call here!
safeAsyncStack.get().push(isAsync)
action(proxy)
}
}
// function suspension works differently !!
val result = (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
Continuation(continuation.context) {
val any = try {
it.getOrNull()
} finally {
safeAsyncStack.get().pop()
}
when (any) {
is Exception -> {
// for co-routines, it's impossible to get a legit stacktrace without impacting general performance,
// so we just don't do it.
// RmiUtils.cleanStackTraceForProxy(Exception(), any)
continuation.resumeWithException(any)
}
else -> {
continuation.resume(null)
}
}
})
runBlocking(safeAsyncStack.asContextElement()) {}
return result
}
}
@Volatile private var isAsync = false
@Volatile private var timeoutMillis: Long = if (EndPoint.DEBUG_CONNECTIONS) TimeUnit.HOURS.toMillis(2) else 3_000L
@Volatile private var enableToString = false
@Volatile private var enableHashCode = false
@Volatile private var enableEquals = false
@Suppress("DuplicatedCode", "UNCHECKED_CAST")
/**
* @throws Exception
*/
override fun invoke(proxy: Any, method: Method, args: Array<Any>?): Any? {
val localAsync =
safeAsyncStack.get().peek() // value set via obj.sync {}
?:
isAsync // the value was set via obj.sync = xyz
if (method.declaringClass == RemoteObject::class.java) {
// manage all the RemoteObject proxy methods
when (method) {
setResponseTimeoutMethod -> {
timeoutMillis = (args!![0] as Int).toLong()
require(timeoutMillis >= 0) { "ResponseTimeout must be >= 0" }
return null
}
getResponseTimeoutMethod -> {
return timeoutMillis.toInt()
}
getAsyncMethod -> {
return localAsync
}
setAsyncMethod -> {
isAsync = args!![0] as Boolean
return null
}
asyncMethod -> {
syncMethodAction(true, proxy as RemoteObject<*>, args!!)
return null
}
syncMethod -> {
syncMethodAction(false, proxy as RemoteObject<*>, args!!)
return null
}
asyncSuspendMethod -> {
return syncSuspendMethodAction(true, proxy as RemoteObject<*>, args!!)
}
syncSuspendMethod -> {
return syncSuspendMethodAction(false, proxy as RemoteObject<*>, args!!)
}
enableToStringMethod -> {
enableToString = args!![0] as Boolean
return null
}
enableHashCodeMethod -> {
enableHashCode = args!![0] as Boolean
return null
}
enableEqualsMethod -> {
enableEquals = args!![0] as Boolean
return null
}
else -> throw RmiException("Invocation handler could not find RemoteObject method for ${method.name}")
}
} else {
when (method) {
toStringMethod -> if (!enableToString) return proxyString // otherwise, the RMI round trip logic is done for toString()
hashCodeMethod -> if (!enableHashCode) return rmiObjectId // otherwise, the RMI round trip logic is done for hashCode()
equalsMethod -> {
val other = args!![0]
if (other !is RmiClient) {
return false
}
if (!enableEquals) {
return rmiObjectId == other.rmiObjectId
}
// otherwise, the RMI round trip logic is done for equals()
}
}
}
val connection = connection
// setup the RMI request
val invokeMethod = MethodRequest()
// if this is a kotlin suspend function, the continuation arg will NOT be here (it's replaced at runtime)!
invokeMethod.args = args ?: EMPTY_ARRAY
// which method do we access? We always want to access the IMPLEMENTATION (if available!). we know that this will always succeed
// this should be accessed via the KRYO class ID + method index (both are SHORT, and can be packed)
invokeMethod.cachedMethod = cachedMethods.first { it.method == method }
// there is a STRANGE problem, where if we DO NOT respond/reply to method invocation, and immediate invoke multiple methods --
// the "server" side can have out-of-order method invocation. There are 2 ways to solve this
// 1) make the "server" side single threaded
// 2) make the "client" side wait for execution response (from the "server"). <--- this is what we are using.
//
// Because we have to ALWAYS make the client wait (unless 'isAsync' is true), we will always be returning, and will always have a
// response (even if it is a void response). This simplifies our response mask, and lets us use more bits for storing the
// response ID
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
invokeMethod.isGlobal = isGlobal
if (localAsync) {
// If we are async, we ignore the response (don't invoke the response manager at all)....
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, RemoteObjectStorage.ASYNC_RMI)
val success = connection.send(invokeMethod)
if (!success) {
throw RmiException("Unable to send async message, an error occurred during the send process")
}
// if we are async then we return immediately (but must return the correct type!)
// If you want the response value, disable async!
val returnType = method.returnType
if (returnType.isPrimitive) {
return when (returnType) {
Boolean::class.javaPrimitiveType -> java.lang.Boolean.FALSE
Int::class.javaPrimitiveType -> 0
Float::class.javaPrimitiveType -> 0.0f
Char::class.javaPrimitiveType -> charPrim
Long::class.javaPrimitiveType -> 0L
Short::class.javaPrimitiveType -> shortPrim
Byte::class.javaPrimitiveType -> bytePrim
Double::class.javaPrimitiveType -> 0.0
else -> null // void type
}
}
return null
}
val logger = connection.logger
//
// this is all SYNC code
//
// The response, even if there is NOT one (ie: not void) will always return a thing (so our code execution is in lockstep -- unless it is ASYNC)
val responseWaiter = responseManager.prep(logger)
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, responseWaiter.id)
val success = connection.send(invokeMethod)
if (!success) {
responseManager.abort(responseWaiter, logger)
throw RmiException("Unable to send message, an error occurred during the send process")
}
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
// We will use this for our coroutine context instead of running on a new coroutine
val suspendCoroutineArg = args?.lastOrNull()
// async will return immediately
if (suspendCoroutineArg is Continuation<*>) {
val continuation = suspendCoroutineArg as Continuation<Any?>
val suspendFunction: suspend () -> Any? = {
// NOTE: once something ELSE is suspending, we can remove the `yield`
yield() // if this is not here, it will not work (something must actually suspend!)
// NOTE: this is blocking!
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
if (timeoutMillis > 0) {
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for:
// A) get response
// B) timeout
if (!responseWaiter.doWait(timeoutMillis)) {
// if we timeout, it doesn't matter since we'll be removing the waiter from the array anyways,
// so no signal can occur, or a signal won't matter
responseManager.abort(responseWaiter, logger)
TIMEOUT_EXCEPTION
} else {
responseManager.getReply(responseWaiter, timeoutMillis, logger)
}
} else {
// wait for the response --- THIS WAITS FOREVER (there is no timeout)!
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for one
// A) get response
responseWaiter.doWait()
responseManager.getReply(responseWaiter, timeoutMillis, logger)
}
}
// function suspension works differently. THIS IS A TRAMPOLINE TO CALL SUSPEND !!
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(Continuation(continuation.context) {
val any = it.getOrNull()
when (any) {
TIMEOUT_EXCEPTION -> {
val fancyName = RmiUtils.makeFancyMethodName(method)
val exception = TimeoutException("Response timed out: $fancyName")
// from top down, clean up the coroutine stack
RmiUtils.cleanStackTraceForProxy(exception)
continuation.resumeWithException(exception)
}
is Throwable -> {
// for co-routines, it's impossible to get a legit stacktrace without impacting general performance,
// so we just don't do it.
// RmiUtils.cleanStackTraceForProxy(Exception(), any)
continuation.resumeWithException(any)
}
else -> {
continuation.resume(any)
}
}
})
} else {
// NOTE: this is blocking!
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
if (timeoutMillis > 0) {
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for:
// A) get response
// B) timeout
if (!responseWaiter.doWait(timeoutMillis)) {
// if we timeout, it doesn't matter since we'll be removing the waiter from the array anyways,
// so no signal can occur, or a signal won't matter
responseManager.abort(responseWaiter, logger)
throw TIMEOUT_EXCEPTION
}
} else {
// wait for the response --- THIS WAITS FOREVER (there is no timeout)!
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for one
// A) get response
responseWaiter.doWait()
}
val any = responseManager.getReply(responseWaiter, timeoutMillis, logger)
when (any) {
TIMEOUT_EXCEPTION -> {
val fancyName = RmiUtils.makeFancyMethodName(method)
val exception = TimeoutException("Response timed out: $fancyName")
// from top down, clean up the coroutine stack
RmiUtils.cleanStackTraceForProxy(exception)
throw exception
}
is Throwable -> {
// reconstruct the stack trace, so the calling method knows where the method invocation happened, and can trace the call
// this stack will ALWAYS run up to this method (so we remove from the top->down, to get to the call site)
RmiUtils.cleanStackTraceForProxy(Exception(), any)
throw any
}
else -> {
return any
}
}
}
}
override fun hashCode(): Int {
val prime = 31
var result = 1
result = prime * result + rmiObjectId
return result
}
override fun equals(other: Any?): Boolean {
if (this === other) {
return true
}
if (other == null) {
return false
}
if (javaClass != other.javaClass) {
return false
}
if (other !is RmiClient) {
return false
}
return rmiObjectId == other.rmiObjectId
}
}