Fixed issues with restarting a client, notifyConnect/Disconnect now ALWAYS happen on a new coroutine

This commit is contained in:
nathan 2020-08-25 17:45:08 +02:00
parent b24e4ae710
commit 42664bfdfe
20 changed files with 468 additions and 299 deletions

View File

@ -24,6 +24,7 @@ import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.Ping import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
@ -34,7 +35,6 @@ import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.TimeoutException import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException import dorkbox.util.exceptions.SecurityException
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/** /**
* The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's * The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's
@ -71,7 +71,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
private val previousClosedConnectionActivity: Long = 0 private val previousClosedConnectionActivity: Long = 0
private val handshake = ClientHandshake(logger, config, crypto, listenerManager)
private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization) private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization)
init { init {
@ -119,12 +118,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected * @throws ClientRejectedException if the client connection is rejected
*/ */
@Suppress("DuplicatedCode")
suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
if (isConnected) { if (isConnected) {
logger.error("Unable to connect when already connected!") logger.error("Unable to connect when already connected!")
return return
} }
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState()
this.connectionTimeoutMS = connectionTimeoutMS this.connectionTimeoutMS = connectionTimeoutMS
// localhost/loopback IP might not always be 127.0.0.1 or ::1 // localhost/loopback IP might not always be 127.0.0.1 or ::1
when (remoteAddress) { when (remoteAddress) {
@ -155,7 +158,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
handshake.init(this) val handshake = ClientHandshake(logger, config, crypto, this)
if (this.remoteAddress.isEmpty()) { if (this.remoteAddress.isEmpty()) {
// this is an IPC address // this is an IPC address
@ -171,7 +174,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
sessionId = RESERVED_SESSION_ID_INVALID sessionId = RESERVED_SESSION_ID_INVALID
) )
autoClosableObjects.add(handshakeConnection)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
@ -186,6 +188,9 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// @Throws(ConnectTimedOutException::class, ClientRejectedException::class) // @Throws(ConnectTimedOutException::class, ClientRejectedException::class)
val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS) val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
println("CO23232232323NASD") println("CO23232232323NASD")
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
} }
else { else {
// THIS IS A NETWORK ADDRESS // THIS IS A NETWORK ADDRESS
@ -210,16 +215,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS) val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
// we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address,
// NOTE: pub/sub must be switched!
publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
// VALIDATE:: check to see if the remote connection's public key has changed! // VALIDATE:: check to see if the remote connection's public key has changed!
val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey) val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) { if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -234,6 +229,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// is rogue, we do not want to carelessly provide info. // is rogue, we do not want to carelessly provide info.
// we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address,
// NOTE: pub/sub must be switched!
publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
// only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object) // only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object)
// does not need to do anything // does not need to do anything
// //
@ -268,17 +274,38 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val pollIdleStrategy = config.pollIdleStrategy val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) { while (!isShutdown()) {
val pollCount = newConnection.pollSubscriptions() // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) if (newConnection.isExpired()) {
pollIdleStrategy.idle(pollCount) logger.debug {"[${newConnection.sessionId}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.sessionId}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Otherwise, poll the connection for messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
} }
} }
// tell the server our connection handshake is done, and the connection can now listen for data. // tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold this connection open // no longer necessary to hold the handshake connection open
handshakeConnection.close() handshakeConnection.close()
if (canFinishConnecting) { if (canFinishConnecting) {
@ -290,6 +317,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} else { } else {
close() close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
throw exception throw exception
} }
@ -389,19 +417,9 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
override fun close() { override fun close() {
val con = connection
connection = null connection = null
if (con != null) {
connections.remove(con)
runBlocking {
con.close()
listenerManager.notifyDisconnect(con)
}
}
super.close()
isConnected = false isConnected = false
super.close()
} }

View File

@ -20,7 +20,6 @@ import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.server.ServerException import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.connection.connectionType.ConnectionProperties import dorkbox.network.connection.connectionType.ConnectionProperties
import dorkbox.network.connection.connectionType.ConnectionRule import dorkbox.network.connection.connectionType.ConnectionRule
@ -29,7 +28,7 @@ import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.TimeoutException import dorkbox.network.rmi.TimeoutException
import io.aeron.FragmentAssembler import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler import io.aeron.Image
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
@ -134,13 +133,16 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param blockUntilTerminate if true, will BLOCK until the server [close] method is called, and if you want to continue running code * @param blockUntilTerminate if true, will BLOCK until the server [close] method is called, and if you want to continue running code
* after this pass in false * after this pass in false
*/ */
@Suppress("DuplicatedCode")
@JvmOverloads @JvmOverloads
fun bind(blockUntilTerminate: Boolean = true) { suspend fun bind(blockUntilTerminate: Boolean = true) {
if (bindAlreadyCalled) { if (bindAlreadyCalled) {
logger.error("Unable to bind when the server is already running!") logger.error("Unable to bind when the server is already running!")
return return
} }
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState()
bindAlreadyCalled = true bindAlreadyCalled = true
config as ServerConfiguration config as ServerConfiguration
@ -162,15 +164,15 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
logger.info(handshakeDriver.serverInfo()) logger.info(handshakeDriver.serverInfo())
val ipcHandshakeDriver = IpcMediaDriverConnection( // val ipcHandshakeDriver = IpcMediaDriverConnection(
streamId = IPC_HANDSHAKE_STREAM_ID_PUB, // streamId = IPC_HANDSHAKE_STREAM_ID_PUB,
streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB, // streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB,
sessionId = RESERVED_SESSION_ID_INVALID // sessionId = RESERVED_SESSION_ID_INVALID
) // )
ipcHandshakeDriver.buildServer(aeron) // ipcHandshakeDriver.buildServer(aeron)
//
val ipcHandshakePublication = ipcHandshakeDriver.publication // val ipcHandshakePublication = ipcHandshakeDriver.publication
val ipcHandshakeSubscription = ipcHandshakeDriver.subscription // val ipcHandshakeSubscription = ipcHandshakeDriver.subscription
/** /**
@ -182,16 +184,41 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy. * properties from failure and streams with mechanical sympathy.
*/ */
val initialConnectionHandler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> val initialConnectionHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString)
val message = readHandshakeMessage(buffer, offset, length, header)
actionDispatch.launch { actionDispatch.launch {
handshake.receiveHandshakeMessageServer(handshakePublication, buffer, offset, length, header, this@Server) handshake.processHandshakeMessageServer(handshakePublication,
sessionId,
clientAddressString,
clientAddress,
message,
this@Server,
aeron)
} }
}) }
val ipcInitialConnectionHandler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> val ipcInitialConnectionHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
actionDispatch.launch { actionDispatch.launch {
println("GOT MESSAGE!") println("GOT MESSAGE!")
} }
}) }
actionDispatch.launch { actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy val pollIdleStrategy = config.pollIdleStrategy
@ -241,7 +268,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
handshake.cleanup(connectionToClean) handshake.cleanup(connectionToClean)
connectionToClean.close() connectionToClean.close()
listenerManager.notifyDisconnect(connectionToClean)
}) })
@ -252,8 +278,8 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
handshakePublication.close() handshakePublication.close()
handshakeSubscription.close() handshakeSubscription.close()
ipcHandshakePublication.close() // ipcHandshakePublication.close()
ipcHandshakeSubscription.close() // ipcHandshakeSubscription.close()
} }
} }
@ -268,8 +294,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
var pollCount = 0 var pollCount = 0
return pollCount return pollCount
} }
@ -382,16 +406,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
connections.remove(connection) connections.remove(connection)
} }
/**
* Checks to see if a server (using the specified configuration) is running.
*
* @return true if the server is active and running
*/
fun isRunning(): Boolean {
return mediaDriver.context().isDriverActive(10_000, logger::debug)
}
override fun close() { override fun close() {
super.close() super.close()
bindAlreadyCalled = false bindAlreadyCalled = false

View File

@ -32,7 +32,6 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import javax.crypto.SecretKey import javax.crypto.SecretKey
@ -120,9 +119,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small) // counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small)
private val aes_gcm_iv = AtomicLong(0) private val aes_gcm_iv = AtomicLong(0)
// when closing this connection, HOW MANY endpoints need to be closed?
private var closeLatch: CountDownLatch? = null
// RMI support for this connection // RMI support for this connection
internal val rmiConnectionSupport = endPoint.getRmiConnectionSupport() internal val rmiConnectionSupport = endPoint.getRmiConnectionSupport()
@ -146,16 +142,13 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
sessionId = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server! sessionId = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server!
messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// small problem... If we expect IN ORDER messages (ie: setting a value, then later reading the value), multiple threads don't work. // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// this is worked around by having RMI always return (unless async), even with a null value, so the CALLING side of RMI will always
// go in "lock step"
endPoint.actionDispatch.launch {
endPoint.readMessage(buffer, offset, length, header, this@Connection)
}
}
// when closing this connection, HOW MANY endpoints need to be closed? // NOTE: subscriptions (ie: reading from buffers, etc) are not thread safe! Because it is ambiguous HOW EXACTLY they are unsafe,
closeLatch = CountDownLatch(1) // we exclusively read from the DirectBuffer on a single thread.
endPoint.processMessage(buffer, offset, length, header, this@Connection)
}
} }
@ -291,7 +284,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
return isClosed.value return isClosed.value
} }
/** /**
* Closes the connection, and removes all connection specific listeners * Closes the connection, and removes all connection specific listeners
*/ */
@ -309,8 +301,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
publication.close() publication.close()
// a connection might have also registered for disconnect events // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
notifyDisconnect() endPoint.actionDispatch.launch {
// a connection might have also registered for disconnect events
listenerManager.value?.notifyDisconnect(this@Connection)
}
} }
} }
@ -345,15 +340,10 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
listenerManager.value!!.onMessage(function) listenerManager.value!!.onMessage(function)
} }
/**
* Invoked when a connection is disconnected from the remote endpoint
*/
internal suspend fun notifyDisconnect() {
listenerManager.value?.notifyDisconnect(this)
}
/** /**
* Invoked when a message object was received from a remote peer. * Invoked when a message object was received from a remote peer.
*
* This is ALWAYS called on a new dispatch
*/ */
internal suspend fun notifyOnMessage(message: Any): Boolean { internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false return listenerManager.value?.notifyOnMessage(this, message) ?: false

View File

@ -92,13 +92,6 @@ internal open class ConnectionManager<CONNECTION: Connection>() {
return connections.size() return connections.size()
} }
/**
* Closes all associated resources/threads/connections
*/
fun close() {
connections.clear()
}
/** /**
* Safely sends objects to a destination (such as a custom object or a standard ping). This will automatically choose which protocol * Safely sends objects to a destination (such as a custom object or a standard ping). This will automatically choose which protocol
* is available to use. * is available to use.

View File

@ -37,14 +37,12 @@ import io.aeron.logbuffer.Header
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
import mu.KotlinLogging import mu.KotlinLogging
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.io.File import java.io.File
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
@ -114,16 +112,14 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val logger: KLogger = KotlinLogging.logger(type.simpleName) val logger: KLogger = KotlinLogging.logger(type.simpleName)
internal val autoClosableObjects = CopyOnWriteArrayList<AutoCloseable>()
internal val actionDispatch = CoroutineScope(Dispatchers.Default) internal val actionDispatch = CoroutineScope(Dispatchers.Default)
internal val listenerManager = ListenerManager<CONNECTION>() internal val listenerManager = ListenerManager<CONNECTION>()
internal val connections = ConnectionManager<CONNECTION>() internal val connections = ConnectionManager<CONNECTION>()
internal val mediaDriverContext: MediaDriver.Context private var mediaDriverContext: MediaDriver.Context? = null
internal val mediaDriver: MediaDriver private var mediaDriver: MediaDriver? = null
internal val aeron: Aeron private var aeron: Aeron? = null
/** /**
* Returns the serialization wrapper if there is an object type that needs to be added outside of the basics. * Returns the serialization wrapper if there is an object type that needs to be added outside of the basics.
@ -138,11 +134,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal val crypto: CryptoManagement internal val crypto: CryptoManagement
private val shutdown = atomic(false) private val shutdown = atomic(false)
private val shutdownLatch = CountDownLatch(1)
@Volatile
private var shutdownLatch: CountDownLatch = CountDownLatch(1)
// we only want one instance of these created. These will be called appropriately // we only want one instance of these created. These will be called appropriately
val settingsStore: SettingsStore val settingsStore: SettingsStore
internal val globalThreadUnsafeKryo: KryoExtra = config.serialization.takeKryo()
internal val rmiGlobalSupport = RmiManagerGlobal<CONNECTION>(logger, actionDispatch, config.serialization) internal val rmiGlobalSupport = RmiManagerGlobal<CONNECTION>(logger, actionDispatch, config.serialization)
init { init {
@ -231,6 +231,26 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
logger.info("Aeron log directory already exists! This might not be what you want!") logger.info("Aeron log directory already exists! This might not be what you want!")
} }
// serialization stuff
serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy
// we have to be able to specify WHAT property store we want to use, since it can change!
settingsStore = config.settingsStore
settingsStore.init(serialization, config.settingsStorageSystem.build())
crypto = CryptoManagement(logger, settingsStore, type, config)
// we are done with initial configuration, now finish serialization
runBlocking {
serialization.finishInit(type)
}
}
internal suspend fun initEndpointState(): Aeron {
val aeronDirectory = config.aeronLogDirectory!!.absolutePath
val threadFactory = NamedThreadFactory("Aeron", false) val threadFactory = NamedThreadFactory("Aeron", false)
// LOW-LATENCY SETTINGS // LOW-LATENCY SETTINGS
@ -241,26 +261,27 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE) // .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE)
// .senderIdleStrategy(NoOpIdleStrategy.INSTANCE); // .senderIdleStrategy(NoOpIdleStrategy.INSTANCE);
mediaDriverContext = MediaDriver.Context() mediaDriverContext = MediaDriver.Context()
.publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW) .publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH) .publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH)
.dirDeleteOnStart(true) .dirDeleteOnStart(true)
.dirDeleteOnShutdown(true) .dirDeleteOnShutdown(true)
.conductorThreadFactory(threadFactory) .conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory) .receiverThreadFactory(threadFactory)
.senderThreadFactory(threadFactory) .senderThreadFactory(threadFactory)
.sharedNetworkThreadFactory(threadFactory) .sharedNetworkThreadFactory(threadFactory)
.sharedThreadFactory(threadFactory) .sharedThreadFactory(threadFactory)
.threadingMode(config.threadingMode) .threadingMode(config.threadingMode)
.mtuLength(config.networkMtuSize) .mtuLength(config.networkMtuSize)
.socketSndbufLength(config.sendBufferSize) .socketSndbufLength(config.sendBufferSize)
.socketRcvbufLength(config.receiveBufferSize) .socketRcvbufLength(config.receiveBufferSize)
.aeronDirectoryName(config.aeronLogDirectory!!.absolutePath) .aeronDirectoryName(aeronDirectory)
val aeronContext = Aeron.Context().aeronDirectoryName(mediaDriverContext.aeronDirectoryName()) val aeronContext = Aeron.Context().aeronDirectoryName(aeronDirectory)
try { mediaDriver = try {
mediaDriver = MediaDriver.launch(mediaDriverContext) MediaDriver.launch(mediaDriverContext)
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(e)
throw e throw e
} }
@ -268,36 +289,24 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
aeron = Aeron.connect(aeronContext) aeron = Aeron.connect(aeronContext)
} catch (e: Exception) { } catch (e: Exception) {
try { try {
mediaDriver.close() mediaDriver!!.close()
} catch (secondaryException: Exception) { } catch (secondaryException: Exception) {
e.addSuppressed(secondaryException) e.addSuppressed(secondaryException)
} }
listenerManager.notifyError(e)
throw e throw e
} }
autoClosableObjects.add(aeron) shutdown.getAndSet(false)
autoClosableObjects.add(mediaDriver)
shutdownLatch.countDown()
shutdownLatch = CountDownLatch(1)
// serialization stuff return aeron!!
serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy
// we have to be able to specify WHAT property store we want to use, since it can change!
settingsStore = config.settingsStore
settingsStore.init(serialization, config.settingsStorageSystem.build())
// the storage is closed via this as well
autoClosableObjects.add(settingsStore)
crypto = CryptoManagement(logger, settingsStore, type, config)
// we are done with initial configuration, now finish serialization
runBlocking {
serialization.finishInit(type)
}
} }
abstract fun newException(message: String, cause: Throwable? = null): Throwable abstract fun newException(message: String, cause: Throwable? = null): Throwable
/** /**
@ -419,7 +428,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
/** /**
* Runs an action for each connection inside of a read-lock * Runs an action for each connection
*/ */
suspend fun forEachConnection(function: suspend (connection: CONNECTION) -> Unit) { suspend fun forEachConnection(function: suspend (connection: CONNECTION) -> Unit) {
connections.forEach { connections.forEach {
@ -433,11 +442,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
"[${publication.sessionId()}] send: $message" "[${publication.sessionId()}] send: $message"
} }
val kryo: KryoExtra = serialization.takeKryo()
try { try {
kryo.write(message) globalThreadUnsafeKryo.write(message)
val buffer = kryo.writerBuffer val buffer = globalThreadUnsafeKryo.writerBuffer
val objectSize = buffer.position() val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer val internalBuffer = buffer.internalBuffer
@ -463,7 +471,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
listenerManager.notifyError(newException("Error serializing message $message", e)) listenerManager.notifyError(newException("Error serializing message $message", e))
} finally { } finally {
sendIdleStrategy.reset() sendIdleStrategy.reset()
serialization.returnKryo(kryo)
} }
} }
@ -471,58 +478,68 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* @param buffer The buffer * @param buffer The buffer
* @param offset The offset from the start of the buffer * @param offset The offset from the start of the buffer
* @param length The number of bytes to extract * @param length The number of bytes to extract
* @param header The aeron header information
* *
* @return A string * @return the message
*/ */
fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
val kryo: KryoExtra = serialization.takeKryo()
try { try {
val message = kryo.read(buffer, offset, length) val message = globalThreadUnsafeKryo.read(buffer, offset, length)
logger.trace { logger.trace {
"[${header.sessionId()}] received handshake: $message" "[${header.sessionId()}] received: $message"
} }
return message return message
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error de-serializing message on connection ${header.sessionId()}!", e) // The sessionId is globally unique, and is assigned by the server.
} finally { val sessionId = header.sessionId()
serialization.returnKryo(kryo)
}
return null val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception)
actionDispatch.launch {
listenerManager.notifyError(exception)
}
logger.error("Error de-serializing message on connection ${header.sessionId()}!", e)
return null
}
} }
// This is on the action dispatch! /**
suspend fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) { * read the message from the aeron buffer
// The sessionId is globally unique, and is assigned by the server. *
val sessionId = header.sessionId() * @param buffer The buffer
* @param offset The offset from the start of the buffer
* @param length The number of bytes to extract
* @param header The aeron header information
* @param connection The connection this message happened on
*/
fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) {
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
@Suppress("UNCHECKED_CAST")
connection as CONNECTION
// note: this address will ALWAYS be an IP:PORT combo val message: Any?
// val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
// val splitPoint = remoteIpAndPort.lastIndexOf(':')
// val ip = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// val ipAsInt = NetworkUtil.IP.toInt(ip)
var message: Any? = null
val kryo: KryoExtra = serialization.takeKryo()
try { try {
message = kryo.read(buffer, offset, length, connection) message = globalThreadUnsafeKryo.read(buffer, offset, length, connection)
logger.trace { logger.trace {
// The sessionId is globally unique, and is assigned by the server.
val sessionId = header.sessionId()
"[${sessionId}] received: $message" "[${sessionId}] received: $message"
} }
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(newException("[${sessionId}] Error de-serializing message", e)) // The sessionId is globally unique, and is assigned by the server.
} finally { val sessionId = header.sessionId()
serialization.returnKryo(kryo)
val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception)
actionDispatch.launch {
listenerManager.notifyError(connection, exception)
}
return // don't do anything!
} }
connection as CONNECTION
when (message) { when (message) {
is PingMessage -> { is PingMessage -> {
@ -537,34 +554,44 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// ping0(ping) // ping0(ping)
// } // }
} }
// small problem... If we expect IN ORDER messages (ie: setting a value, then later reading the value), multiple threads don't work.
// this is worked around by having RMI always return (unless async), even with a null value, so the CALLING side of RMI will always
// go in "lock step"
is RmiMessage -> { is RmiMessage -> {
// if we are an RMI message/registration, we have very specific, defined behavior. actionDispatch.launch {
// We do not use the "normal" listener callback pattern because this require special functionality // if we are an RMI message/registration, we have very specific, defined behavior.
rmiGlobalSupport.manage(this, connection, message, logger) // We do not use the "normal" listener callback pattern because this require special functionality
rmiGlobalSupport.manage(this@EndPoint, connection, message, logger)
}
} }
is Any -> { is Any -> {
@Suppress("UNCHECKED_CAST") actionDispatch.launch {
var hasListeners = listenerManager.notifyOnMessage(connection, message) @Suppress("UNCHECKED_CAST")
var hasListeners = listenerManager.notifyOnMessage(connection, message)
// each connection registers, and is polled INDEPENDENTLY for messages. // each connection registers, and is polled INDEPENDENTLY for messages.
hasListeners = hasListeners or connection.notifyOnMessage(message) hasListeners = hasListeners or connection.notifyOnMessage(message)
if (!hasListeners) { if (!hasListeners) {
val exception = MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}") val exception = MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception) listenerManager.notifyError(connection, exception)
}
} }
} }
else -> { else -> {
// do nothing, there were problems with the message actionDispatch.launch {
val exception = if (message != null) { // do nothing, there were problems with the message
MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}") val exception = if (message != null) {
} else { MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
MessageNotRegisteredException("Unknown message received!!") } else {
} MessageNotRegisteredException("Unknown message received!!")
}
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
}
} }
} }
} }
@ -576,6 +603,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
"[${publication.sessionId()}] send: $message" "[${publication.sessionId()}] send: $message"
} }
// since ANY thread can call 'send', we have to take kryo instances in a safe way
val kryo: KryoExtra = serialization.takeKryo() val kryo: KryoExtra = serialization.takeKryo()
try { try {
kryo.write(connection, message) kryo.write(connection, message)
@ -652,12 +680,22 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
shutdownLatch.await() shutdownLatch.await()
} }
/**
* Checks to see if an endpoint (using the specified configuration) is running.
*
* @return true if the client/server is active and running
*/
fun isRunning(): Boolean {
return mediaDriverContext?.isDriverActive(10_000, logger::debug) ?: false
}
override fun close() { override fun close() {
if (shutdown.compareAndSet(expect = false, update = true)) { if (shutdown.compareAndSet(expect = false, update = true)) {
autoClosableObjects.forEach { aeron?.close()
it.close() mediaDriver?.close()
}
autoClosableObjects.clear() // the storage is closed via this as well
settingsStore.close()
rmiGlobalSupport.close() rmiGlobalSupport.close()
@ -665,12 +703,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// don't need anything fast or fancy here, because this method will only be called once // don't need anything fast or fancy here, because this method will only be called once
connections.forEach { connections.forEach {
it.close() it.close()
listenerManager.notifyDisconnect(it) listenerManager.notifyDisconnect(it) // if disconnect has a "connect" in it, this will case SO MANY PROBLEMS
} }
} }
connections.close()
actionDispatch.cancel()
shutdownLatch.countDown() shutdownLatch.countDown()
} }
} }

View File

@ -21,13 +21,11 @@ import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.MediaDriverConnection import dorkbox.network.connection.MediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.FragmentAssembler import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import kotlinx.coroutines.launch
import mu.KLogger import mu.KLogger
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.security.SecureRandom import java.security.SecureRandom
@ -35,78 +33,71 @@ import java.security.SecureRandom
internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogger, internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogger,
private val config: Configuration, private val config: Configuration,
private val crypto: CryptoManagement, private val crypto: CryptoManagement,
private val listenerManager: ListenerManager<CONNECTION>) { private val endPoint: EndPoint<CONNECTION>) {
// a one-time key for connecting // a one-time key for connecting
private val oneTimePad = SecureRandom().nextInt() private val oneTimePad = SecureRandom().nextInt()
@Volatile @Volatile
var connectionHelloInfo: ClientConnectionInfo? = null private var connectionHelloInfo: ClientConnectionInfo? = null
@Volatile @Volatile
var connectionDone = false private var connectionDone = false
@Volatile @Volatile
private var failed: Exception? = null private var failed: Exception? = null
lateinit var handler: FragmentHandler private var handler: FragmentHandler
lateinit var endPoint: EndPoint<CONNECTION> private var sessionId: Int = 0
var sessionId: Int = 0
fun init(endPoint: EndPoint<CONNECTION>) {
this.endPoint = endPoint
init {
// now we have a bi-directional connection with the server on the handshake "socket". // now we have a bi-directional connection with the server on the handshake "socket".
handler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
endPoint.actionDispatch.launch { // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
val sessionId = header.sessionId()
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
val sessionId = header.sessionId()
// it must be a registration message
if (message !is HandshakeMessage) {
failed = ClientException("[$sessionId] server returned unrecognized message: $message")
return@FragmentAssembler
}
// this is an error message
if (message.sessionId == 0) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@FragmentAssembler
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}")
return@FragmentAssembler
}
// it must be the correct state
when (message.state) {
HandshakeMessage.HELLO_ACK -> {
// The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
connectionHelloInfo!!.log(sessionId, logger)
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
logger.trace {
"[$sessionId] handshake response: $message"
} }
HandshakeMessage.DONE_ACK -> {
// it must be a registration message connectionDone = true
if (message !is HandshakeMessage) {
failed = ClientException("[$sessionId] server returned unrecognized message: $message")
return@launch
} }
else -> {
// this is an error message if (message.state != HandshakeMessage.HELLO_ACK) {
if (message.sessionId == 0) { failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@launch
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}")
return@launch
}
// it must be the correct state
when (message.state) {
HandshakeMessage.HELLO_ACK -> {
// The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
connectionHelloInfo!!.log(sessionId, logger)
} }
HandshakeMessage.DONE_ACK -> { else if (message.state != HandshakeMessage.DONE_ACK) {
connectionDone = true failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
else -> {
if (message.state != HandshakeMessage.HELLO_ACK) {
failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
} else if (message.state != HandshakeMessage.DONE_ACK) {
failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
return@launch
} }
} }
} }
}) }
} }
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {

View File

@ -15,7 +15,6 @@
*/ */
package dorkbox.network.handshake package dorkbox.network.handshake
import dorkbox.netUtil.IPv4
import dorkbox.network.Server import dorkbox.network.Server
import dorkbox.network.ServerConfiguration import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.client.ClientRejectedException import dorkbox.network.aeron.client.ClientRejectedException
@ -28,12 +27,10 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.Image import io.aeron.Aeron
import io.aeron.Publication import io.aeron.Publication
import io.aeron.logbuffer.Header
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import mu.KLogger import mu.KLogger
import org.agrona.DirectBuffer
import org.agrona.collections.Int2IntCounterMap import org.agrona.collections.Int2IntCounterMap
import org.agrona.collections.Int2ObjectHashMap import org.agrona.collections.Int2ObjectHashMap
import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.concurrent.locks.ReentrantReadWriteLock
@ -58,23 +55,13 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
// note: this is called in action dispatch // note: this is called in action dispatch
suspend fun receiveHandshakeMessageServer(handshakePublication: Publication, suspend fun processHandshakeMessageServer(handshakePublication: Publication,
buffer: DirectBuffer, offset: Int, length: Int, header: Header, sessionId: Int,
server: Server<CONNECTION>) { clientAddressString: String,
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. clientAddress: Int,
// ONLY for the handshake, the sessionId IS NOT GLOBALLY UNIQUE message: Any?,
val sessionId = header.sessionId() server: Server<CONNECTION>,
aeron: Aeron) {
// note: this address will ALWAYS be an IP:PORT combo
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString)
val message = server.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
@ -86,6 +73,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val clientPublicKeyBytes = message.publicKey val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState val validateRemoteAddress: PublicKeyValidationState
// check to see if this is a pending connection // check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) { if (message.state == HandshakeMessage.DONE) {
pendingConnectionsLock.write { pendingConnectionsLock.write {
@ -109,6 +97,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
} }
try { try {
// VALIDATE:: Check to see if there are already too many clients connected. // VALIDATE:: Check to see if there are already too many clients connected.
if (server.connections.connectionCount() >= config.maxClientCount) { if (server.connections.connectionCount() >= config.maxClientCount) {
@ -148,7 +137,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return return
} }
// VALIDATE:: TODO: ?? check to see if this session is ALREADY connected??. It should not be!
///// /////
@ -208,7 +196,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
message.isReliable) message.isReliable)
// we have to construct how the connection will communicate! // we have to construct how the connection will communicate!
clientConnection.buildServer(server.aeron) clientConnection.buildServer(aeron)
logger.trace { logger.trace {
"Creating new connection $clientConnection" "Creating new connection $clientConnection"

View File

@ -112,7 +112,9 @@ object AeronServer {
} }
} }
server.bind() runBlocking {
server.bind()
}
} }
init { init {

View File

@ -83,9 +83,9 @@ abstract class BaseTest {
// rootLogger.setLevel(Level.OFF); // rootLogger.setLevel(Level.OFF);
rootLogger.level = Level.INFO; // rootLogger.level = Level.INFO;
// rootLogger.level = Level.DEBUG // rootLogger.level = Level.DEBUG
// rootLogger.level = Level.TRACE; rootLogger.level = Level.TRACE;
// rootLogger.level = Level.ALL; // rootLogger.level = Level.ALL;

View File

@ -0,0 +1,73 @@
package dorkboxTest.network
import dorkbox.network.Client
import dorkbox.network.Server
import dorkbox.network.connection.Connection
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.junit.Assert
import org.junit.Test
import java.io.IOException
import java.util.*
class DisconnectReconnectTest : BaseTest() {
private val timer = Timer()
private val reconnectCount = atomic(0)
@Test
fun reconnectClient() {
run {
val configuration = serverConfig()
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
runBlocking {
server.bind(false)
}
server.onConnect { connection ->
println("Disconnecting after 2 seconds.")
delay(2000)
println("Disconnecting....")
connection.close()
}
}
run {
val config = clientConfig()
val client: Client<Connection> = Client(config)
addEndPoint(client)
client.onDisconnect { connection ->
println("Disconnected!")
val count = reconnectCount.getAndIncrement()
if (count == 3) {
println("Shutting down")
stopEndPoints()
}
else {
println("Reconnecting: $count")
try {
client.connect(LOOPBACK)
} catch (e: IOException) {
e.printStackTrace()
}
}
}
runBlocking {
client.connect(LOOPBACK)
}
}
waitForThreads()
System.err.println("Connection count (after reconnecting) is: " + reconnectCount.value)
Assert.assertEquals(4, reconnectCount.value)
}
}

View File

@ -112,7 +112,9 @@ class ListenerTest : BaseTest() {
} }
server.bind(false) runBlocking {
server.bind(false)
}

View File

@ -83,7 +83,9 @@ class MultipleServerTest : BaseTest() {
} }
} }
server.bind(false) runBlocking {
server.bind(false)
}
serverAeronDir = File(configuration.aeronLogDirectory.toString() + count) serverAeronDir = File(configuration.aeronLogDirectory.toString() + count)
} }

View File

@ -64,12 +64,20 @@ class PingPongTest : BaseTest() {
val server: Server<Connection> = Server(configuration) val server: Server<Connection> = Server(configuration)
addEndPoint(server) addEndPoint(server)
server.bind(false) runBlocking {
server.bind(false)
}
server.onError { _, throwable -> server.onError { _, throwable ->
fail = "Error during processing. $throwable" fail = "Error during processing. $throwable"
} }
server.onConnect { connection ->
server.forEachConnection { connection ->
println("server connection: $connection")
}
}
server.onMessage<Data> { connection, message -> server.onMessage<Data> { connection, message ->
connection.send(message) connection.send(message)
} }
@ -85,6 +93,10 @@ class PingPongTest : BaseTest() {
client.onConnect { connection -> client.onConnect { connection ->
client.forEachConnection { connection ->
println("client connection: $connection")
}
fail = null fail = null
connection.send(data) connection.send(data)
} }

View File

@ -72,7 +72,9 @@ class RmiDelayedInvocationSpamTest : BaseTest() {
server.saveGlobalObject(TestObjectImpl(counter), RMI_ID) server.saveGlobalObject(TestObjectImpl(counter), RMI_ID)
server.bind(false) runBlocking {
server.bind(false)
}
} }

View File

@ -66,7 +66,9 @@ class RmiDelayedInvocationTest : BaseTest() {
server.saveGlobalObject(TestObjectImpl(iterateLock), OBJ_ID) server.saveGlobalObject(TestObjectImpl(iterateLock), OBJ_ID)
server.bind(false) runBlocking {
server.bind(false)
}
} }
run { run {

View File

@ -263,7 +263,9 @@ class RmiInitValidationTest : BaseTest() {
stopEndPoints() stopEndPoints()
} }
server.bind(false) runBlocking {
server.bind(false)
}
} }

View File

@ -100,8 +100,9 @@ class RmiOverrideAndProxyTest : BaseTest() {
} }
} }
runBlocking {
server.bind(false) server.bind(false)
}
} }

View File

@ -40,7 +40,6 @@ import dorkbox.network.Server
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
import dorkbox.network.serialization.NetworkSerializationManager import dorkbox.network.serialization.NetworkSerializationManager
import dorkbox.util.exceptions.SecurityException
import dorkboxTest.network.BaseTest import dorkboxTest.network.BaseTest
import dorkboxTest.network.rmi.classes.MessageWithTestCow import dorkboxTest.network.rmi.classes.MessageWithTestCow
import dorkboxTest.network.rmi.classes.TestCow import dorkboxTest.network.rmi.classes.TestCow
@ -48,7 +47,6 @@ import dorkboxTest.network.rmi.classes.TestCowImpl
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.Assert import org.junit.Assert
import org.junit.Test import org.junit.Test
import java.io.IOException
class RmiTest : BaseTest() { class RmiTest : BaseTest() {
@ -94,8 +92,7 @@ class RmiTest : BaseTest() {
try { try {
test.throwSuspendException() test.throwSuspendException()
} catch (e: UnsupportedOperationException) { } catch (e: UnsupportedOperationException) {
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).") connection.logger.error("\tExpected exception (exception log should also be on the object impl side).", e)
e.printStackTrace()
caught = true caught = true
} }
@ -191,7 +188,6 @@ class RmiTest : BaseTest() {
// } // }
// } // }
@Throws(SecurityException::class, IOException::class)
fun rmi(config: (Configuration) -> Unit = {}) { fun rmi(config: (Configuration) -> Unit = {}) {
run { run {
val configuration = serverConfig() val configuration = serverConfig()
@ -204,7 +200,9 @@ class RmiTest : BaseTest() {
val server = Server<Connection>(configuration) val server = Server<Connection>(configuration)
addEndPoint(server) addEndPoint(server)
server.bind(false) runBlocking {
server.bind(false)
}
server.onMessage<MessageWithTestCow> { connection, m -> server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server") System.err.println("Received finish signal for test for: Client -> Server")
@ -259,7 +257,6 @@ class RmiTest : BaseTest() {
waitForThreads(99999999) waitForThreads(99999999)
} }
@Throws(SecurityException::class, IOException::class)
fun rmiGlobal(config: (Configuration) -> Unit = {}) { fun rmiGlobal(config: (Configuration) -> Unit = {}) {
run { run {
val configuration = serverConfig() val configuration = serverConfig()
@ -272,7 +269,9 @@ class RmiTest : BaseTest() {
val server = Server<Connection>(configuration) val server = Server<Connection>(configuration)
addEndPoint(server) addEndPoint(server)
server.bind(false) runBlocking {
server.bind(false)
}
server.onMessage<MessageWithTestCow> { connection, m -> server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server") System.err.println("Received finish signal for test for: Client -> Server")

View File

@ -78,7 +78,7 @@ object TestClient {
RmiTest.runTests(connection, remoteObject, 124123) RmiTest.runTests(connection, remoteObject, 124123)
System.err.println("DONE") System.err.println("DONE")
// now send this remote object ACROSS the wire to the server. // now send this remote object ACROSS the wire to the server (on the server, this is where the IMPLEMENTATION lives)
connection.send(remoteObject) connection.send(remoteObject)
client.close() client.close()

View File

@ -19,9 +19,12 @@ import dorkbox.network.Server
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkboxTest.network.BaseTest import dorkboxTest.network.BaseTest
import dorkboxTest.network.rmi.RmiTest import dorkboxTest.network.rmi.RmiTest
import dorkboxTest.network.rmi.classes.MessageWithTestCow
import dorkboxTest.network.rmi.classes.TestCow import dorkboxTest.network.rmi.classes.TestCow
import dorkboxTest.network.rmi.classes.TestCowImpl import dorkboxTest.network.rmi.classes.TestCowImpl
import dorkboxTest.network.rmi.multiJVM.TestClient.setup import dorkboxTest.network.rmi.multiJVM.TestClient.setup
import kotlinx.coroutines.runBlocking
import org.junit.Assert
/** /**
* *
@ -39,6 +42,45 @@ object TestServer {
val server = Server<Connection>(configuration) val server = Server<Connection>(configuration)
server.bind(false) server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(124123, id.toLong())
System.err.println("Finished test for: Client -> Server")
//
// System.err.println("Starting test for: Server -> Client")
// connection.createObject<TestCow>(123) { rmiId, remoteObject ->
// System.err.println("Running test for: Server -> Client")
// RmiTest.runTests(connection, remoteObject, 123)
// System.err.println("Done with test for: Server -> Client")
// }
}
server.onMessage<TestCow> { connection, test ->
System.err.println("Received test cow from client")
// this object LIVES on the server.
test.moo()
test.moo("Cow")
Assert.assertEquals(123123, test.id())
// Test that RMI correctly waits for the remotely invoked method to exit
test.moo("You should see this two seconds before...", 2000)
connection.logger.error("...This")
//
// System.err.println("Starting test for: Server -> Client")
// connection.createObject<TestCow>(123) { rmiId, remoteObject ->
// System.err.println("Running test for: Server -> Client")
// RmiTest.runTests(connection, remoteObject, 123)
// System.err.println("Done with test for: Server -> Client")
// }
}
runBlocking {
server.bind(false)
}
} }
} }