Cleaned up RMI coroutine suspension
This commit is contained in:
parent
0a9ae32595
commit
4a033651cc
@ -319,7 +319,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
|
|||||||
* from a "global" context
|
* from a "global" context
|
||||||
*/
|
*/
|
||||||
internal open fun getRmiConnectionSupport() : RmiManagerConnections {
|
internal open fun getRmiConnectionSupport() : RmiManagerConnections {
|
||||||
return RmiManagerConnections(logger, rmiGlobalSupport, serialization, actionDispatch)
|
return RmiManagerConnections(logger, rmiGlobalSupport, serialization)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
package dorkbox.network.other.coroutines
|
|
||||||
|
|
||||||
import kotlinx.coroutines.channels.Channel
|
|
||||||
|
|
||||||
// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting
|
|
||||||
// https://kotlinlang.org/docs/reference/coroutines/channels.html
|
|
||||||
inline class SuspendWaiter(private val channel: Channel<Unit> = Channel()) {
|
|
||||||
// "receive' suspends until another coroutine invokes "send"
|
|
||||||
// and
|
|
||||||
// "send" suspends until another coroutine invokes "receive".
|
|
||||||
suspend fun doWait() {
|
|
||||||
try {
|
|
||||||
channel.receive()
|
|
||||||
} catch (ignored: Exception) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
suspend fun doNotify() {
|
|
||||||
try {
|
|
||||||
channel.send(Unit)
|
|
||||||
} catch (ignored: Exception) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fun cancel() {
|
|
||||||
try {
|
|
||||||
channel.cancel()
|
|
||||||
} catch (ignored: Exception) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fun isCancelled(): Boolean {
|
|
||||||
// once the channel is cancelled, it can never work again
|
|
||||||
@Suppress("EXPERIMENTAL_API_USAGE")
|
|
||||||
return channel.isClosedForReceive && channel.isClosedForSend
|
|
||||||
}
|
|
||||||
}
|
|
@ -17,10 +17,8 @@ package dorkbox.network.rmi
|
|||||||
|
|
||||||
import dorkbox.network.connection.Connection
|
import dorkbox.network.connection.Connection
|
||||||
import dorkbox.network.connection.ListenerManager
|
import dorkbox.network.connection.ListenerManager
|
||||||
import dorkbox.network.other.coroutines.SuspendFunctionTrampoline
|
|
||||||
import dorkbox.network.rmi.messages.MethodRequest
|
import dorkbox.network.rmi.messages.MethodRequest
|
||||||
import kotlinx.coroutines.CoroutineDispatcher
|
import kotlinx.coroutines.CoroutineDispatcher
|
||||||
import kotlinx.coroutines.Dispatchers
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import java.lang.reflect.InvocationHandler
|
import java.lang.reflect.InvocationHandler
|
||||||
import java.lang.reflect.Method
|
import java.lang.reflect.Method
|
||||||
@ -114,11 +112,53 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||||||
|
|
||||||
connection.send(invokeMethod)
|
connection.send(invokeMethod)
|
||||||
|
|
||||||
|
|
||||||
// if we are async, then this will immediately return
|
// if we are async, then this will immediately return
|
||||||
return responseManager.waitForReply(isAsync, rmiWaiter, timeoutMillis)
|
return responseManager.waitForReply(isAsync, rmiWaiter, timeoutMillis)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun returnAsyncOrSync(method: Method, returnValue: Any?): Any? {
|
||||||
|
if (isAsync) {
|
||||||
|
// if we are async then we return immediately.
|
||||||
|
// If you want the response value, disable async!
|
||||||
|
val returnType = method.returnType
|
||||||
|
if (returnType.isPrimitive) {
|
||||||
|
return when (returnType) {
|
||||||
|
Int::class.javaPrimitiveType -> {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
Boolean::class.javaPrimitiveType -> {
|
||||||
|
java.lang.Boolean.FALSE
|
||||||
|
}
|
||||||
|
Float::class.javaPrimitiveType -> {
|
||||||
|
0.0f
|
||||||
|
}
|
||||||
|
Char::class.javaPrimitiveType -> {
|
||||||
|
0.toChar()
|
||||||
|
}
|
||||||
|
Long::class.javaPrimitiveType -> {
|
||||||
|
0L
|
||||||
|
}
|
||||||
|
Short::class.javaPrimitiveType -> {
|
||||||
|
0.toShort()
|
||||||
|
}
|
||||||
|
Byte::class.javaPrimitiveType -> {
|
||||||
|
0.toByte()
|
||||||
|
}
|
||||||
|
Double::class.javaPrimitiveType -> {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return returnValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Suppress("DuplicatedCode")
|
@Suppress("DuplicatedCode")
|
||||||
/**
|
/**
|
||||||
* @throws Exception
|
* @throws Exception
|
||||||
@ -195,152 +235,47 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
|
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
|
||||||
// We will use this for our coroutine context instead of running on a new coroutine
|
// We will use this for our coroutine context instead of running on a new coroutine
|
||||||
val suspendCoroutineObject = args?.lastOrNull()
|
val suspendCoroutineArg = args?.lastOrNull()
|
||||||
|
|
||||||
// async will return immediately
|
// async will return immediately
|
||||||
var returnValue: Any? = null
|
if (suspendCoroutineArg is Continuation<*>) {
|
||||||
if (suspendCoroutineObject is Continuation<*>) {
|
@Suppress("UNCHECKED_CAST")
|
||||||
// val continuation = suspendCoroutineObject as Continuation<Any?>
|
val continuation = suspendCoroutineArg as Continuation<Any?>
|
||||||
//
|
|
||||||
//
|
|
||||||
// val suspendFunction: suspend () -> Any? = {
|
|
||||||
// val rmiResult = sendRequest(invokeMethod)
|
|
||||||
// println("RMI: ${rmiResult?.javaClass}")
|
|
||||||
// println(1)
|
|
||||||
// delay(3000)
|
|
||||||
// println(2)
|
|
||||||
// }
|
|
||||||
// val suspendFunction1: Function1<Continuation<Any?>, *> = suspendFunction as Function1<Continuation<Any?>?, *>
|
|
||||||
// returnValue = suspendFunction1.invoke(Continuation(EmptyCoroutineContext) {
|
|
||||||
// it.getOrNull().apply {
|
|
||||||
// continuation.resume(this)
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// System.err.println("have suspend ret value ${returnValue?.javaClass}")
|
|
||||||
//
|
|
||||||
//// returnValue = invokeSuspendFunction(continuation, suspendFunction)
|
|
||||||
//
|
|
||||||
// // https://stackoverflow.com/questions/57230869/how-to-propagate-kotlin-coroutine-context-through-reflective-invocation-of-suspe
|
|
||||||
// // https://stackoverflow.com/questions/52869672/call-kotlin-suspend-function-in-java-class
|
|
||||||
// // https://discuss.kotlinlang.org/t/how-to-continue-a-suspend-function-in-a-dynamic-proxy-in-the-same-coroutine/11391
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// // NOTE:
|
|
||||||
// // Calls to OkHttp Call.enqueue() like those inside await and awaitNullable can sometimes
|
|
||||||
// // invoke the supplied callback with an exception before the invoking stack frame can return.
|
|
||||||
// // Coroutines will intercept the subsequent invocation of the Continuation and throw the
|
|
||||||
// // exception synchronously. A Java Proxy cannot throw checked exceptions without them being
|
|
||||||
// // declared on the interface method. To avoid the synchronous checked exception being wrapped
|
|
||||||
// // in an UndeclaredThrowableException, it is intercepted and supplied to a helper which will
|
|
||||||
// // force suspension to occur so that it can be instead delivered to the continuation to
|
|
||||||
// // bypass this restriction.
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * https://jakewharton.com/exceptions-and-proxies-and-coroutines-oh-my/
|
|
||||||
// * https://github.com/Kotlin/kotlinx.coroutines/pull/1667
|
|
||||||
// * https://github.com/square/retrofit/blob/master/retrofit/src/main/java/retrofit2/KotlinExtensions.kt
|
|
||||||
// * https://github.com/square/retrofit/blob/master/retrofit/src/main/java/retrofit2/HttpServiceMethod.java
|
|
||||||
// * https://github.com/square/retrofit/blob/master/retrofit/src/main/java/retrofit2/Utils.java
|
|
||||||
// * https://github.com/square/retrofit/blob/master/retrofit/src/main/java/retrofit2
|
|
||||||
// */
|
|
||||||
//// returnValue = try {
|
|
||||||
//// val actualContinuation = suspendCoroutineObject.intercepted() as Continuation<Any?>
|
|
||||||
////// suspend {
|
|
||||||
////// try {
|
|
||||||
////// delay(100)
|
|
||||||
////// sendRequest(invokeMethod)
|
|
||||||
////// } catch (e: Exception) {
|
|
||||||
////// yield()
|
|
||||||
////// throw e
|
|
||||||
////// }
|
|
||||||
////// }.startCoroutineUninterceptedOrReturn(actualContinuation)
|
|
||||||
////
|
|
||||||
//// invokeSuspendFunction(actualContinuation) {
|
|
||||||
////
|
|
||||||
////
|
|
||||||
//////// kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn<Any?> {
|
|
||||||
//// delay(100)
|
|
||||||
//// sendRequest(invokeMethod)
|
|
||||||
//////// }
|
|
||||||
////////
|
|
||||||
////// kotlinx.coroutines.suspendCancellableCoroutine<Any?> { continuation: Any? ->
|
|
||||||
////// resume(body)
|
|
||||||
////// }
|
|
||||||
////// withContext(MyUnconfined) {
|
|
||||||
//////
|
|
||||||
////// }
|
|
||||||
//// }
|
|
||||||
////
|
|
||||||
////// MyUnconfined.dispatch(suspendCoroutineObject.context, Runnable {
|
|
||||||
////// invokeSuspendFunction(suspendCoroutineObject) {
|
|
||||||
//////
|
|
||||||
////// }
|
|
||||||
////// })
|
|
||||||
////
|
|
||||||
//// } catch (e: Exception) {
|
|
||||||
//// e.printStackTrace()
|
|
||||||
//// }
|
|
||||||
////
|
|
||||||
//// if (returnValue == kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED) {
|
|
||||||
//// // we were suspend, and when we unsuspend, we will pick up where we left off
|
|
||||||
//// return returnValue
|
|
||||||
//// }
|
|
||||||
|
|
||||||
// if this was an exception, we want to get it out!
|
val suspendFunction: suspend () -> Any? = {
|
||||||
returnValue = runBlocking {
|
|
||||||
sendRequest(invokeMethod)
|
sendRequest(invokeMethod)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
returnValue = runBlocking {
|
|
||||||
sendRequest(invokeMethod)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isAsync) {
|
// function suspension works differently !!
|
||||||
// if we are async then we return immediately.
|
@Suppress("UNCHECKED_CAST")
|
||||||
// If you want the response value, disable async!
|
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(Continuation(EmptyCoroutineContext) {
|
||||||
val returnType = method.returnType
|
val any = it.getOrNull()
|
||||||
if (returnType.isPrimitive) {
|
when (any) {
|
||||||
return when (returnType) {
|
RmiResponseManager.TIMEOUT_EXCEPTION -> {
|
||||||
Int::class.javaPrimitiveType -> {
|
val fancyName = RmiUtils.makeFancyMethodName(method)
|
||||||
0
|
val exception = TimeoutException("Response timed out: $fancyName")
|
||||||
|
// from top down, clean up the coroutine stack
|
||||||
|
ListenerManager.cleanStackTrace(exception, RmiClient::class.java)
|
||||||
|
continuation.resumeWithException(exception)
|
||||||
}
|
}
|
||||||
Boolean::class.javaPrimitiveType -> {
|
is Exception -> {
|
||||||
java.lang.Boolean.FALSE
|
// reconstruct the stack trace, so the calling method knows where the method invocation happened, and can trace the call
|
||||||
}
|
// this stack will ALWAYS run up to this method (so we remove from the top->down, to get to the call site)
|
||||||
Float::class.javaPrimitiveType -> {
|
ListenerManager.cleanStackTrace(Exception(), RmiClient::class.java, any)
|
||||||
0.0f
|
continuation.resumeWithException(any)
|
||||||
}
|
|
||||||
Char::class.javaPrimitiveType -> {
|
|
||||||
0.toChar()
|
|
||||||
}
|
|
||||||
Long::class.javaPrimitiveType -> {
|
|
||||||
0L
|
|
||||||
}
|
|
||||||
Short::class.javaPrimitiveType -> {
|
|
||||||
0.toShort()
|
|
||||||
}
|
|
||||||
Byte::class.javaPrimitiveType -> {
|
|
||||||
0.toByte()
|
|
||||||
}
|
|
||||||
Double::class.javaPrimitiveType -> {
|
|
||||||
0.0
|
|
||||||
}
|
}
|
||||||
else -> {
|
else -> {
|
||||||
null
|
continuation.resume(returnAsyncOrSync(method, any))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
val any = runBlocking {
|
||||||
|
sendRequest(invokeMethod)
|
||||||
}
|
}
|
||||||
return null
|
when (any) {
|
||||||
}
|
|
||||||
else {
|
|
||||||
// this will not return immediately. This will be suspended until there is a response
|
|
||||||
when (returnValue) {
|
|
||||||
RmiResponseManager.TIMEOUT_EXCEPTION -> {
|
RmiResponseManager.TIMEOUT_EXCEPTION -> {
|
||||||
val fancyName = RmiUtils.makeFancyMethodName(method)
|
val fancyName = RmiUtils.makeFancyMethodName(method)
|
||||||
val exception = TimeoutException("Response timed out: $fancyName")
|
val exception = TimeoutException("Response timed out: $fancyName")
|
||||||
@ -348,48 +283,19 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||||||
ListenerManager.cleanStackTrace(exception, RmiClient::class.java)
|
ListenerManager.cleanStackTrace(exception, RmiClient::class.java)
|
||||||
throw exception
|
throw exception
|
||||||
}
|
}
|
||||||
// is Exception -> {
|
is Exception -> {
|
||||||
// // reconstruct the stack trace, so the calling method knows where the method invocation happened, and can trace the call
|
// reconstruct the stack trace, so the calling method knows where the method invocation happened, and can trace the call
|
||||||
// // this stack will ALWAYS run up to this method (so we remove from the top->down, to get to the call site)
|
// this stack will ALWAYS run up to this method (so we remove from the top->down, to get to the call site)
|
||||||
// ListenerManager.cleanStackTrace(Exception(), RmiClient::class.java, returnValue)
|
ListenerManager.cleanStackTrace(Exception(), RmiClient::class.java, any)
|
||||||
// throw returnValue
|
throw any
|
||||||
// }
|
}
|
||||||
else -> {
|
else -> {
|
||||||
return returnValue
|
return returnAsyncOrSync(method, any)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Force the calling coroutine to suspend before throwing [this].
|
|
||||||
*
|
|
||||||
* This is needed when a checked exception is synchronously caught in a [java.lang.reflect.Proxy]
|
|
||||||
* invocation to avoid being wrapped in [java.lang.reflect.UndeclaredThrowableException].
|
|
||||||
*
|
|
||||||
* The implementation is derived from:
|
|
||||||
* https://github.com/Kotlin/kotlinx.coroutines/pull/1667#issuecomment-556106349
|
|
||||||
*/
|
|
||||||
suspend fun suspendAndThrow(e: Throwable): Nothing {
|
|
||||||
kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn<Nothing> { continuation ->
|
|
||||||
Dispatchers.Default.dispatch(continuation.context, Runnable {
|
|
||||||
continuation.resumeWithException(e)
|
|
||||||
})
|
|
||||||
kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// trampoline so we can access suspend functions correctly and (if suspend) get the coroutine connection parameter)
|
|
||||||
private fun invokeSuspendFunction(continuation: Continuation<Any?>, suspendFunction: suspend () -> Any?): Any {
|
|
||||||
return SuspendFunctionTrampoline.invoke(Continuation<Any?>(EmptyCoroutineContext) {
|
|
||||||
it.getOrNull().apply {
|
|
||||||
continuation.resume(this)
|
|
||||||
}
|
|
||||||
}, suspendFunction) as Any
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
override fun hashCode(): Int {
|
||||||
val prime = 31
|
val prime = 31
|
||||||
var result = 1
|
var result = 1
|
||||||
|
@ -99,10 +99,10 @@ class RmiTest : BaseTest() {
|
|||||||
e.printStackTrace()
|
e.printStackTrace()
|
||||||
caught = true
|
caught = true
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert.assertTrue(caught)
|
Assert.assertTrue(caught)
|
||||||
caught = false
|
caught = false
|
||||||
|
|
||||||
|
|
||||||
// Non-blocking call tests
|
// Non-blocking call tests
|
||||||
// Non-blocking call tests
|
// Non-blocking call tests
|
||||||
// Non-blocking call tests
|
// Non-blocking call tests
|
||||||
@ -140,6 +140,7 @@ class RmiTest : BaseTest() {
|
|||||||
}
|
}
|
||||||
// exceptions are not caught when async = true!
|
// exceptions are not caught when async = true!
|
||||||
Assert.assertFalse(caught)
|
Assert.assertFalse(caught)
|
||||||
|
caught = false
|
||||||
|
|
||||||
|
|
||||||
// Call will time out if non-blocking isn't working properly
|
// Call will time out if non-blocking isn't working properly
|
||||||
|
Loading…
Reference in New Issue
Block a user