Fixed issue with with RMI sync/async.
This commit is contained in:
parent
3abbdf8825
commit
8e9e0441ed
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
package dorkbox.network.rmi
|
||||
|
||||
import com.conversantmedia.util.collection.FixedStack
|
||||
import dorkbox.network.connection.Connection
|
||||
import dorkbox.network.connection.EndPoint
|
||||
import dorkbox.network.rmi.messages.MethodRequest
|
||||
|
@ -22,7 +23,6 @@ import kotlinx.coroutines.asContextElement
|
|||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import mu.KLogger
|
||||
import java.lang.reflect.InvocationHandler
|
||||
import java.lang.reflect.Method
|
||||
import java.util.*
|
||||
|
@ -75,8 +75,8 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||
@Suppress("UNCHECKED_CAST")
|
||||
private val EMPTY_ARRAY: Array<Any> = Collections.EMPTY_LIST.toTypedArray() as Array<Any>
|
||||
|
||||
private val safeAsyncState: ThreadLocal<Boolean?> = ThreadLocal.withInitial {
|
||||
null
|
||||
private val safeAsyncStack: ThreadLocal<FixedStack<Boolean?>> = ThreadLocal.withInitial {
|
||||
FixedStack(64)
|
||||
}
|
||||
|
||||
private const val charPrim = 0.toChar()
|
||||
|
@ -86,20 +86,17 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||
@Suppress("UNCHECKED_CAST")
|
||||
private fun syncMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>) {
|
||||
val action = args[0] as Any.() -> Unit
|
||||
val prev = safeAsyncState.get()
|
||||
|
||||
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
|
||||
// the docs cover that (and say, `don't do this`)
|
||||
safeAsyncState.set(false)
|
||||
safeAsyncStack.get().push(isAsync)
|
||||
|
||||
// the `sync` method is always a unit function - we want to execute that unit function directly - this way we can control
|
||||
// exactly how sync state is preserved.
|
||||
try {
|
||||
action(proxy)
|
||||
} finally {
|
||||
if (prev != isAsync) {
|
||||
safeAsyncState.remove()
|
||||
}
|
||||
safeAsyncStack.get().pop()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,8 +104,6 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||
private fun syncSuspendMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>): Any? {
|
||||
val action = args[0] as suspend Any.() -> Unit
|
||||
|
||||
val prev = safeAsyncState.get()
|
||||
|
||||
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
|
||||
// We will use this for our coroutine context instead of running on a new coroutine
|
||||
val suspendCoroutineArg = args.last()
|
||||
|
@ -117,20 +112,20 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||
val suspendFunction: suspend () -> Any? = {
|
||||
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
|
||||
// the docs cover that (and say, `don't do this`)
|
||||
withContext(safeAsyncState.asContextElement(isAsync)) {
|
||||
withContext(safeAsyncStack.asContextElement()) {
|
||||
yield() // must have an actually suspending call here!
|
||||
safeAsyncStack.get().push(isAsync)
|
||||
action(proxy)
|
||||
}
|
||||
}
|
||||
|
||||
// function suspension works differently !!
|
||||
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
|
||||
val result = (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
|
||||
Continuation(continuation.context) {
|
||||
val any = try {
|
||||
it.getOrNull()
|
||||
} finally {
|
||||
if (prev != isAsync) {
|
||||
safeAsyncState.remove()
|
||||
}
|
||||
safeAsyncStack.get().pop()
|
||||
}
|
||||
when (any) {
|
||||
is Exception -> {
|
||||
|
@ -144,6 +139,10 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||
}
|
||||
}
|
||||
})
|
||||
|
||||
runBlocking(safeAsyncStack.asContextElement()) {}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,7 +158,7 @@ internal class RmiClient(val isGlobal: Boolean,
|
|||
*/
|
||||
override fun invoke(proxy: Any, method: Method, args: Array<Any>?): Any? {
|
||||
val localAsync =
|
||||
safeAsyncState.get() // value set via obj.sync {}
|
||||
safeAsyncStack.get().peek() // value set via obj.sync {}
|
||||
?:
|
||||
isAsync // the value was set via obj.sync = xyz
|
||||
|
||||
|
|
|
@ -19,20 +19,171 @@ import dorkbox.network.connection.Connection
|
|||
import dorkbox.network.rmi.RemoteObject
|
||||
import dorkboxTest.network.rmi.cows.MessageWithTestCow
|
||||
import dorkboxTest.network.rmi.cows.TestCow
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Assert
|
||||
|
||||
object RmiCommonTest {
|
||||
suspend fun runTests(connection: Connection, test: TestCow, remoteObjectID: Int) {
|
||||
val remoteObject = RemoteObject.cast<TestCow>(test)
|
||||
fun runTimeoutTest(connection: Connection, test: TestCow) {
|
||||
val remoteObject = RemoteObject.cast(test)
|
||||
|
||||
remoteObject.responseTimeout = 1000
|
||||
try {
|
||||
test.moo("You should see this two seconds before...", 2000)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
|
||||
try {
|
||||
test.moo("You should see this two seconds before...", 200)
|
||||
} catch (ignored: Exception) {
|
||||
Assert.fail("We should NOT be throwing a timeout exception!")
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
try {
|
||||
test.mooSuspend("You should see this two seconds before...", 2000)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
|
||||
try {
|
||||
test.mooSuspend("You should see this two seconds before...", 200)
|
||||
} catch (ignored: Exception) {
|
||||
Assert.fail("We should NOT be throwing a timeout exception!")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Test sending a reference to a remote object.
|
||||
val m = MessageWithTestCow(test)
|
||||
m.number = 678
|
||||
m.text = "sometext"
|
||||
connection.send(m)
|
||||
|
||||
remoteObject.enableHashCode(true)
|
||||
remoteObject.enableEquals(true)
|
||||
|
||||
connection.logger.error("Finished tests")
|
||||
}
|
||||
|
||||
fun runSyncTest(connection: Connection, test: TestCow) {
|
||||
val remoteObject = RemoteObject.cast(test)
|
||||
|
||||
remoteObject.responseTimeout = 1000
|
||||
|
||||
remoteObject.sync {
|
||||
try {
|
||||
test.moo("You should see this two seconds before...", 2000)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
remoteObject.syncSuspend {
|
||||
try {
|
||||
test.mooSuspend("You should see this two seconds before...", 2000)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Test sending a reference to a remote object.
|
||||
val m = MessageWithTestCow(test)
|
||||
m.number = 678
|
||||
m.text = "sometext"
|
||||
connection.send(m)
|
||||
|
||||
remoteObject.enableHashCode(true)
|
||||
remoteObject.enableEquals(true)
|
||||
|
||||
connection.logger.error("Finished tests")
|
||||
}
|
||||
|
||||
fun runASyncTest(connection: Connection, test: TestCow) {
|
||||
val remoteObject = RemoteObject.cast(test)
|
||||
|
||||
remoteObject.responseTimeout = 100
|
||||
remoteObject.async {
|
||||
try {
|
||||
test.moo("You should see this 400 m-seconds before...", 400)
|
||||
} catch (ignored: Exception) {
|
||||
Assert.fail("We should NOT be throwing a timeout exception!")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
remoteObject.sync {
|
||||
try {
|
||||
test.moo("You should see this 400 m-seconds before...", 400)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
|
||||
remoteObject.async {
|
||||
remoteObject.sync {
|
||||
try {
|
||||
test.moo("You should see this 400 m-seconds before...", 400)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
test.moo("You should see this 400 m-seconds before...", 400)
|
||||
} catch (ignored: Exception) {
|
||||
Assert.fail("We should NOT be throwing a timeout exception!")
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
test.moo("You should see this 400 m-seconds before...", 400)
|
||||
Assert.fail("We should be throwing a timeout exception!")
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
runBlocking {
|
||||
remoteObject.asyncSuspend {
|
||||
try {
|
||||
test.mooSuspend("You should see this 400 m-seconds before...", 400)
|
||||
} catch (ignored: Exception) {
|
||||
Assert.fail("We should NOT be throwing a timeout exception!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Test sending a reference to a remote object.
|
||||
val m = MessageWithTestCow(test)
|
||||
m.number = 678
|
||||
m.text = "sometext"
|
||||
connection.send(m)
|
||||
|
||||
remoteObject.enableHashCode(true)
|
||||
remoteObject.enableEquals(true)
|
||||
|
||||
connection.logger.error("Finished tests")
|
||||
}
|
||||
|
||||
fun runTests(connection: Connection, test: TestCow, remoteObjectID: Int) {
|
||||
val remoteObject = RemoteObject.cast(test)
|
||||
|
||||
// Default behavior. RMI is transparent, method calls behave like normal
|
||||
// (return values and exceptions are returned, call is synchronous)
|
||||
connection.logger.error("hashCode: " + test.hashCode())
|
||||
connection.logger.error("toString: $test")
|
||||
|
||||
test.withSuspend("test", 32)
|
||||
val s1 = test.withSuspendAndReturn("test", 32)
|
||||
Assert.assertEquals(s1, 32)
|
||||
runBlocking {
|
||||
test.withSuspend("test", 32)
|
||||
val s1 = test.withSuspendAndReturn("test", 32)
|
||||
Assert.assertEquals(s1, 32)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// see what the "remote" toString() method is
|
||||
|
@ -50,6 +201,14 @@ object RmiCommonTest {
|
|||
connection.logger.error("...This")
|
||||
remoteObject.responseTimeout = 3000
|
||||
|
||||
runBlocking {
|
||||
remoteObject.responseTimeout = 5000
|
||||
test.mooSuspend("You should see this two seconds before...", 2000)
|
||||
connection.logger.error("...This")
|
||||
remoteObject.responseTimeout = 3000
|
||||
}
|
||||
|
||||
|
||||
// Try exception handling
|
||||
try {
|
||||
test.throwException()
|
||||
|
@ -58,11 +217,14 @@ object RmiCommonTest {
|
|||
connection.logger.error("Expected exception (exception log should also be on the object impl side).", e)
|
||||
}
|
||||
|
||||
try {
|
||||
test.throwSuspendException()
|
||||
Assert.fail("sync should be throwing an exception!")
|
||||
} catch (e: UnsupportedOperationException) {
|
||||
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).", e)
|
||||
runBlocking {
|
||||
try {
|
||||
test.throwSuspendException()
|
||||
Assert.fail("sync should be throwing an exception!")
|
||||
}
|
||||
catch (e: UnsupportedOperationException) {
|
||||
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).", e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -70,8 +232,10 @@ object RmiCommonTest {
|
|||
moo("Bzzzzzz")
|
||||
}
|
||||
|
||||
remoteObject.syncSuspend {
|
||||
moo("Bzzzzzz----MOOO", 22)
|
||||
runBlocking {
|
||||
remoteObject.syncSuspend {
|
||||
moo("Bzzzzzz----MOOO", 22)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -81,13 +245,14 @@ object RmiCommonTest {
|
|||
connection.logger.error("I'm currently async: ${remoteObject.async}. Now testing ASYNC")
|
||||
|
||||
|
||||
runBlocking {
|
||||
remoteObject.asyncSuspend {
|
||||
// calls that ignore the return value
|
||||
mooSuspend("Bark. should wait 4 seconds", 4000) // this should not timeout (because it's async!)
|
||||
|
||||
remoteObject.asyncSuspend {
|
||||
// calls that ignore the return value
|
||||
moo("Bark. should wait 4 seconds", 4000) // this should not timeout (because it's async!)
|
||||
|
||||
// Non-blocking call that ignores the return value
|
||||
Assert.assertEquals(0, test.id().toLong())
|
||||
// Non-blocking call that ignores the return value
|
||||
Assert.assertEquals(0, test.id().toLong())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -110,11 +275,14 @@ object RmiCommonTest {
|
|||
Assert.fail("Async should not be throwing an exception!")
|
||||
}
|
||||
|
||||
try {
|
||||
test.throwSuspendException()
|
||||
} catch (e: IllegalStateException) {
|
||||
// exceptions are not caught when async = true!
|
||||
Assert.fail("Async should not be throwing an exception!")
|
||||
runBlocking {
|
||||
try {
|
||||
test.throwSuspendException()
|
||||
}
|
||||
catch (e: IllegalStateException) {
|
||||
// exceptions are not caught when async = true!
|
||||
Assert.fail("Async should not be throwing an exception!")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -129,9 +297,11 @@ object RmiCommonTest {
|
|||
remoteObject.responseTimeout = 6000
|
||||
connection.logger.error("You should see this 2 seconds before")
|
||||
|
||||
val slow = test.slow()
|
||||
connection.logger.error("...This")
|
||||
Assert.assertEquals(slow.toDouble(), 123.0, 0.0001)
|
||||
runBlocking {
|
||||
val slow = test.slow()
|
||||
connection.logger.error("...This")
|
||||
Assert.assertEquals(slow.toDouble(), 123.0, 0.0001)
|
||||
}
|
||||
|
||||
|
||||
// Test sending a reference to a remote object.
|
||||
|
|
|
@ -338,4 +338,99 @@ class RmiSimpleTest : BaseTest() {
|
|||
|
||||
waitForThreads()
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun rmiTimeoutIpc() {
|
||||
rmiBasicIpc { connection, testCow ->
|
||||
RmiCommonTest.runTimeoutTest(connection, testCow)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun rmiSyncIpc() {
|
||||
rmiBasicIpc { connection, testCow ->
|
||||
RmiCommonTest.runSyncTest(connection, testCow)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun rmiASyncIpc() {
|
||||
rmiBasicIpc { connection, testCow ->
|
||||
RmiCommonTest.runASyncTest(connection, testCow)
|
||||
}
|
||||
}
|
||||
|
||||
fun rmiBasicIpc(runFun: (Connection, TestCow) -> Unit) {
|
||||
val server = run {
|
||||
val configuration = serverConfig()
|
||||
configuration.enableIPv4 = false
|
||||
configuration.enableIPv6 = false
|
||||
configuration.enableIpc = true
|
||||
|
||||
configuration.serialization.rmi.register(TestCow::class.java, TestCowImpl::class.java)
|
||||
configuration.serialization.register(MessageWithTestCow::class.java)
|
||||
configuration.serialization.register(UnsupportedOperationException::class.java)
|
||||
|
||||
|
||||
val server = Server<Connection>(configuration)
|
||||
addEndPoint(server)
|
||||
|
||||
|
||||
server.onMessage<MessageWithTestCow> { m ->
|
||||
server.logger.error("Received finish signal for test for: Client -> Server")
|
||||
val `object` = m.testCow
|
||||
val id = `object`.id()
|
||||
Assert.assertEquals(23, id)
|
||||
server.logger.error("Finished test for: Client -> Server")
|
||||
|
||||
|
||||
server.logger.error("Starting test for: Server -> Client")
|
||||
// NOTE: THIS IS BI-DIRECTIONAL!
|
||||
rmi.create<TestCow>(123) {
|
||||
server.logger.error("Running test for: Server -> Client")
|
||||
runFun(this@onMessage, this@create)
|
||||
server.logger.error("Done with test for: Server -> Client")
|
||||
}
|
||||
}
|
||||
server
|
||||
}
|
||||
|
||||
val client = run {
|
||||
val configuration = clientConfig()
|
||||
configuration.enableIPv4 = false
|
||||
configuration.enableIPv6 = false
|
||||
configuration.enableIpc = true
|
||||
// configuration.serialization.rmi.register(TestCow::class.java, TestCowImpl::class.java)
|
||||
|
||||
val client = Client<Connection>(configuration)
|
||||
addEndPoint(client)
|
||||
|
||||
client.onConnect {
|
||||
rmi.create<TestCow>(23) {
|
||||
runBlocking {
|
||||
client.logger.error("Running test for: Client -> Server")
|
||||
runFun(this@onConnect, this@create)
|
||||
client.logger.error("Done with test for: Client -> Server")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client.onMessage<MessageWithTestCow> { m ->
|
||||
client.logger.error("Received finish signal for test for: Client -> Server")
|
||||
val `object` = m.testCow
|
||||
val id = `object`.id()
|
||||
Assert.assertEquals(123, id)
|
||||
client.logger.error("Finished test for: Client -> Server")
|
||||
stopEndPoints()
|
||||
}
|
||||
|
||||
client
|
||||
}
|
||||
|
||||
server.bindIpc()
|
||||
client.connectIpc()
|
||||
|
||||
waitForThreads()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,19 @@
|
|||
/*
|
||||
* Copyright 2023 dorkbox, llc
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package dorkboxTest.network.rmi
|
||||
import junit.framework.TestCase
|
||||
import kotlinx.coroutines.delay
|
||||
|
@ -17,7 +33,7 @@ class SuspendProxyTest : TestCase() {
|
|||
fun addSync(a:Int, b:Int):Int
|
||||
}
|
||||
|
||||
class SuspendHandler(private val delegate:Adder):InvocationHandler {
|
||||
class SuspendHandler(private val delegate:Adder): InvocationHandler {
|
||||
override fun invoke(proxy: Any, method: Method, arguments: Array<Any>): Any {
|
||||
val suspendCoroutineObject = arguments.lastOrNull()
|
||||
return if (suspendCoroutineObject is Continuation<*>) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2020 dorkbox, llc
|
||||
* Copyright 2023 dorkbox, llc
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,7 +22,8 @@ interface TestCow : TestCowBase {
|
|||
fun moo()
|
||||
fun moo(value: String)
|
||||
fun mooTwo(value: String): String
|
||||
suspend fun moo(value: String, delay: Long)
|
||||
fun moo(value: String, delay: Long)
|
||||
suspend fun mooSuspend(value: String, delay: Long)
|
||||
fun id(): Int
|
||||
suspend fun slow(): Float
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2020 dorkbox, llc
|
||||
* Copyright 2023 dorkbox, llc
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -45,11 +45,21 @@ open class TestCowImpl(val id: Int) : TestCowBaseImpl(), TestCow {
|
|||
connection.logger.error("Moo! $moos: $value")
|
||||
}
|
||||
|
||||
override suspend fun moo(value: String, delay: Long) {
|
||||
override fun moo(value: String, delay: Long) {
|
||||
throw RuntimeException("Should never be executed!")
|
||||
}
|
||||
|
||||
suspend fun moo(connection: Connection, value: String, delay: Long) {
|
||||
fun moo(connection: Connection, value: String, delay: Long) {
|
||||
moos += 4
|
||||
connection.logger.error("Moo! $moos: $value ($delay)")
|
||||
Thread.sleep(delay)
|
||||
}
|
||||
|
||||
|
||||
override suspend fun mooSuspend(value: String, delay: Long) {
|
||||
throw RuntimeException("Should never be executed!")
|
||||
}
|
||||
suspend fun mooSuspend(connection: Connection, value: String, delay: Long) {
|
||||
moos += 4
|
||||
connection.logger.error("Moo! $moos: $value ($delay)")
|
||||
delay(delay)
|
||||
|
|
Loading…
Reference in New Issue