Fixed issue with with RMI sync/async.

This commit is contained in:
Robinson 2023-09-13 13:49:13 +02:00
parent 3abbdf8825
commit 8e9e0441ed
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
6 changed files with 339 additions and 48 deletions

View File

@ -15,6 +15,7 @@
*/
package dorkbox.network.rmi
import com.conversantmedia.util.collection.FixedStack
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.rmi.messages.MethodRequest
@ -22,7 +23,6 @@ import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import mu.KLogger
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.util.*
@ -75,8 +75,8 @@ internal class RmiClient(val isGlobal: Boolean,
@Suppress("UNCHECKED_CAST")
private val EMPTY_ARRAY: Array<Any> = Collections.EMPTY_LIST.toTypedArray() as Array<Any>
private val safeAsyncState: ThreadLocal<Boolean?> = ThreadLocal.withInitial {
null
private val safeAsyncStack: ThreadLocal<FixedStack<Boolean?>> = ThreadLocal.withInitial {
FixedStack(64)
}
private const val charPrim = 0.toChar()
@ -86,20 +86,17 @@ internal class RmiClient(val isGlobal: Boolean,
@Suppress("UNCHECKED_CAST")
private fun syncMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>) {
val action = args[0] as Any.() -> Unit
val prev = safeAsyncState.get()
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
// the docs cover that (and say, `don't do this`)
safeAsyncState.set(false)
safeAsyncStack.get().push(isAsync)
// the `sync` method is always a unit function - we want to execute that unit function directly - this way we can control
// exactly how sync state is preserved.
try {
action(proxy)
} finally {
if (prev != isAsync) {
safeAsyncState.remove()
}
safeAsyncStack.get().pop()
}
}
@ -107,8 +104,6 @@ internal class RmiClient(val isGlobal: Boolean,
private fun syncSuspendMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>): Any? {
val action = args[0] as suspend Any.() -> Unit
val prev = safeAsyncState.get()
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
// We will use this for our coroutine context instead of running on a new coroutine
val suspendCoroutineArg = args.last()
@ -117,20 +112,20 @@ internal class RmiClient(val isGlobal: Boolean,
val suspendFunction: suspend () -> Any? = {
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
// the docs cover that (and say, `don't do this`)
withContext(safeAsyncState.asContextElement(isAsync)) {
withContext(safeAsyncStack.asContextElement()) {
yield() // must have an actually suspending call here!
safeAsyncStack.get().push(isAsync)
action(proxy)
}
}
// function suspension works differently !!
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
val result = (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
Continuation(continuation.context) {
val any = try {
it.getOrNull()
} finally {
if (prev != isAsync) {
safeAsyncState.remove()
}
safeAsyncStack.get().pop()
}
when (any) {
is Exception -> {
@ -144,6 +139,10 @@ internal class RmiClient(val isGlobal: Boolean,
}
}
})
runBlocking(safeAsyncStack.asContextElement()) {}
return result
}
}
@ -159,7 +158,7 @@ internal class RmiClient(val isGlobal: Boolean,
*/
override fun invoke(proxy: Any, method: Method, args: Array<Any>?): Any? {
val localAsync =
safeAsyncState.get() // value set via obj.sync {}
safeAsyncStack.get().peek() // value set via obj.sync {}
?:
isAsync // the value was set via obj.sync = xyz

View File

@ -19,20 +19,171 @@ import dorkbox.network.connection.Connection
import dorkbox.network.rmi.RemoteObject
import dorkboxTest.network.rmi.cows.MessageWithTestCow
import dorkboxTest.network.rmi.cows.TestCow
import kotlinx.coroutines.runBlocking
import org.junit.Assert
object RmiCommonTest {
suspend fun runTests(connection: Connection, test: TestCow, remoteObjectID: Int) {
val remoteObject = RemoteObject.cast<TestCow>(test)
fun runTimeoutTest(connection: Connection, test: TestCow) {
val remoteObject = RemoteObject.cast(test)
remoteObject.responseTimeout = 1000
try {
test.moo("You should see this two seconds before...", 2000)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
try {
test.moo("You should see this two seconds before...", 200)
} catch (ignored: Exception) {
Assert.fail("We should NOT be throwing a timeout exception!")
}
runBlocking {
try {
test.mooSuspend("You should see this two seconds before...", 2000)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
try {
test.mooSuspend("You should see this two seconds before...", 200)
} catch (ignored: Exception) {
Assert.fail("We should NOT be throwing a timeout exception!")
}
}
// Test sending a reference to a remote object.
val m = MessageWithTestCow(test)
m.number = 678
m.text = "sometext"
connection.send(m)
remoteObject.enableHashCode(true)
remoteObject.enableEquals(true)
connection.logger.error("Finished tests")
}
fun runSyncTest(connection: Connection, test: TestCow) {
val remoteObject = RemoteObject.cast(test)
remoteObject.responseTimeout = 1000
remoteObject.sync {
try {
test.moo("You should see this two seconds before...", 2000)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
}
runBlocking {
remoteObject.syncSuspend {
try {
test.mooSuspend("You should see this two seconds before...", 2000)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
}
}
// Test sending a reference to a remote object.
val m = MessageWithTestCow(test)
m.number = 678
m.text = "sometext"
connection.send(m)
remoteObject.enableHashCode(true)
remoteObject.enableEquals(true)
connection.logger.error("Finished tests")
}
fun runASyncTest(connection: Connection, test: TestCow) {
val remoteObject = RemoteObject.cast(test)
remoteObject.responseTimeout = 100
remoteObject.async {
try {
test.moo("You should see this 400 m-seconds before...", 400)
} catch (ignored: Exception) {
Assert.fail("We should NOT be throwing a timeout exception!")
}
}
remoteObject.sync {
try {
test.moo("You should see this 400 m-seconds before...", 400)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
remoteObject.async {
remoteObject.sync {
try {
test.moo("You should see this 400 m-seconds before...", 400)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
}
try {
test.moo("You should see this 400 m-seconds before...", 400)
} catch (ignored: Exception) {
Assert.fail("We should NOT be throwing a timeout exception!")
}
}
try {
test.moo("You should see this 400 m-seconds before...", 400)
Assert.fail("We should be throwing a timeout exception!")
} catch (ignored: Exception) {
}
}
runBlocking {
remoteObject.asyncSuspend {
try {
test.mooSuspend("You should see this 400 m-seconds before...", 400)
} catch (ignored: Exception) {
Assert.fail("We should NOT be throwing a timeout exception!")
}
}
}
// Test sending a reference to a remote object.
val m = MessageWithTestCow(test)
m.number = 678
m.text = "sometext"
connection.send(m)
remoteObject.enableHashCode(true)
remoteObject.enableEquals(true)
connection.logger.error("Finished tests")
}
fun runTests(connection: Connection, test: TestCow, remoteObjectID: Int) {
val remoteObject = RemoteObject.cast(test)
// Default behavior. RMI is transparent, method calls behave like normal
// (return values and exceptions are returned, call is synchronous)
connection.logger.error("hashCode: " + test.hashCode())
connection.logger.error("toString: $test")
test.withSuspend("test", 32)
val s1 = test.withSuspendAndReturn("test", 32)
Assert.assertEquals(s1, 32)
runBlocking {
test.withSuspend("test", 32)
val s1 = test.withSuspendAndReturn("test", 32)
Assert.assertEquals(s1, 32)
}
// see what the "remote" toString() method is
@ -50,6 +201,14 @@ object RmiCommonTest {
connection.logger.error("...This")
remoteObject.responseTimeout = 3000
runBlocking {
remoteObject.responseTimeout = 5000
test.mooSuspend("You should see this two seconds before...", 2000)
connection.logger.error("...This")
remoteObject.responseTimeout = 3000
}
// Try exception handling
try {
test.throwException()
@ -58,11 +217,14 @@ object RmiCommonTest {
connection.logger.error("Expected exception (exception log should also be on the object impl side).", e)
}
try {
test.throwSuspendException()
Assert.fail("sync should be throwing an exception!")
} catch (e: UnsupportedOperationException) {
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).", e)
runBlocking {
try {
test.throwSuspendException()
Assert.fail("sync should be throwing an exception!")
}
catch (e: UnsupportedOperationException) {
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).", e)
}
}
@ -70,8 +232,10 @@ object RmiCommonTest {
moo("Bzzzzzz")
}
remoteObject.syncSuspend {
moo("Bzzzzzz----MOOO", 22)
runBlocking {
remoteObject.syncSuspend {
moo("Bzzzzzz----MOOO", 22)
}
}
@ -81,13 +245,14 @@ object RmiCommonTest {
connection.logger.error("I'm currently async: ${remoteObject.async}. Now testing ASYNC")
runBlocking {
remoteObject.asyncSuspend {
// calls that ignore the return value
mooSuspend("Bark. should wait 4 seconds", 4000) // this should not timeout (because it's async!)
remoteObject.asyncSuspend {
// calls that ignore the return value
moo("Bark. should wait 4 seconds", 4000) // this should not timeout (because it's async!)
// Non-blocking call that ignores the return value
Assert.assertEquals(0, test.id().toLong())
// Non-blocking call that ignores the return value
Assert.assertEquals(0, test.id().toLong())
}
}
@ -110,11 +275,14 @@ object RmiCommonTest {
Assert.fail("Async should not be throwing an exception!")
}
try {
test.throwSuspendException()
} catch (e: IllegalStateException) {
// exceptions are not caught when async = true!
Assert.fail("Async should not be throwing an exception!")
runBlocking {
try {
test.throwSuspendException()
}
catch (e: IllegalStateException) {
// exceptions are not caught when async = true!
Assert.fail("Async should not be throwing an exception!")
}
}
@ -129,9 +297,11 @@ object RmiCommonTest {
remoteObject.responseTimeout = 6000
connection.logger.error("You should see this 2 seconds before")
val slow = test.slow()
connection.logger.error("...This")
Assert.assertEquals(slow.toDouble(), 123.0, 0.0001)
runBlocking {
val slow = test.slow()
connection.logger.error("...This")
Assert.assertEquals(slow.toDouble(), 123.0, 0.0001)
}
// Test sending a reference to a remote object.

View File

@ -338,4 +338,99 @@ class RmiSimpleTest : BaseTest() {
waitForThreads()
}
@Test
fun rmiTimeoutIpc() {
rmiBasicIpc { connection, testCow ->
RmiCommonTest.runTimeoutTest(connection, testCow)
}
}
@Test
fun rmiSyncIpc() {
rmiBasicIpc { connection, testCow ->
RmiCommonTest.runSyncTest(connection, testCow)
}
}
@Test
fun rmiASyncIpc() {
rmiBasicIpc { connection, testCow ->
RmiCommonTest.runASyncTest(connection, testCow)
}
}
fun rmiBasicIpc(runFun: (Connection, TestCow) -> Unit) {
val server = run {
val configuration = serverConfig()
configuration.enableIPv4 = false
configuration.enableIPv6 = false
configuration.enableIpc = true
configuration.serialization.rmi.register(TestCow::class.java, TestCowImpl::class.java)
configuration.serialization.register(MessageWithTestCow::class.java)
configuration.serialization.register(UnsupportedOperationException::class.java)
val server = Server<Connection>(configuration)
addEndPoint(server)
server.onMessage<MessageWithTestCow> { m ->
server.logger.error("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(23, id)
server.logger.error("Finished test for: Client -> Server")
server.logger.error("Starting test for: Server -> Client")
// NOTE: THIS IS BI-DIRECTIONAL!
rmi.create<TestCow>(123) {
server.logger.error("Running test for: Server -> Client")
runFun(this@onMessage, this@create)
server.logger.error("Done with test for: Server -> Client")
}
}
server
}
val client = run {
val configuration = clientConfig()
configuration.enableIPv4 = false
configuration.enableIPv6 = false
configuration.enableIpc = true
// configuration.serialization.rmi.register(TestCow::class.java, TestCowImpl::class.java)
val client = Client<Connection>(configuration)
addEndPoint(client)
client.onConnect {
rmi.create<TestCow>(23) {
runBlocking {
client.logger.error("Running test for: Client -> Server")
runFun(this@onConnect, this@create)
client.logger.error("Done with test for: Client -> Server")
}
}
}
client.onMessage<MessageWithTestCow> { m ->
client.logger.error("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(123, id)
client.logger.error("Finished test for: Client -> Server")
stopEndPoints()
}
client
}
server.bindIpc()
client.connectIpc()
waitForThreads()
}
}

View File

@ -1,3 +1,19 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkboxTest.network.rmi
import junit.framework.TestCase
import kotlinx.coroutines.delay
@ -17,7 +33,7 @@ class SuspendProxyTest : TestCase() {
fun addSync(a:Int, b:Int):Int
}
class SuspendHandler(private val delegate:Adder):InvocationHandler {
class SuspendHandler(private val delegate:Adder): InvocationHandler {
override fun invoke(proxy: Any, method: Method, arguments: Array<Any>): Any {
val suspendCoroutineObject = arguments.lastOrNull()
return if (suspendCoroutineObject is Continuation<*>) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,7 +22,8 @@ interface TestCow : TestCowBase {
fun moo()
fun moo(value: String)
fun mooTwo(value: String): String
suspend fun moo(value: String, delay: Long)
fun moo(value: String, delay: Long)
suspend fun mooSuspend(value: String, delay: Long)
fun id(): Int
suspend fun slow(): Float

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,11 +45,21 @@ open class TestCowImpl(val id: Int) : TestCowBaseImpl(), TestCow {
connection.logger.error("Moo! $moos: $value")
}
override suspend fun moo(value: String, delay: Long) {
override fun moo(value: String, delay: Long) {
throw RuntimeException("Should never be executed!")
}
suspend fun moo(connection: Connection, value: String, delay: Long) {
fun moo(connection: Connection, value: String, delay: Long) {
moos += 4
connection.logger.error("Moo! $moos: $value ($delay)")
Thread.sleep(delay)
}
override suspend fun mooSuspend(value: String, delay: Long) {
throw RuntimeException("Should never be executed!")
}
suspend fun mooSuspend(connection: Connection, value: String, delay: Long) {
moos += 4
connection.logger.error("Moo! $moos: $value ($delay)")
delay(delay)