Fixed issues with restarting a client, notifyConnect/Disconnect now ALWAYS happen on a new coroutine

This commit is contained in:
nathan 2020-08-25 17:45:08 +02:00
parent b24e4ae710
commit 42664bfdfe
20 changed files with 468 additions and 299 deletions

View File

@ -24,6 +24,7 @@ import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
@ -34,7 +35,6 @@ import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/**
* The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's
@ -71,7 +71,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
private val previousClosedConnectionActivity: Long = 0
private val handshake = ClientHandshake(logger, config, crypto, listenerManager)
private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization)
init {
@ -119,12 +118,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("DuplicatedCode")
suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
if (isConnected) {
logger.error("Unable to connect when already connected!")
return
}
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState()
this.connectionTimeoutMS = connectionTimeoutMS
// localhost/loopback IP might not always be 127.0.0.1 or ::1
when (remoteAddress) {
@ -155,7 +158,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
handshake.init(this)
val handshake = ClientHandshake(logger, config, crypto, this)
if (this.remoteAddress.isEmpty()) {
// this is an IPC address
@ -171,7 +174,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
sessionId = RESERVED_SESSION_ID_INVALID
)
autoClosableObjects.add(handshakeConnection)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
@ -186,6 +188,9 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// @Throws(ConnectTimedOutException::class, ClientRejectedException::class)
val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
println("CO23232232323NASD")
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
}
else {
// THIS IS A NETWORK ADDRESS
@ -210,16 +215,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
// we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address,
// NOTE: pub/sub must be switched!
publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
// VALIDATE:: check to see if the remote connection's public key has changed!
val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -234,6 +229,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// is rogue, we do not want to carelessly provide info.
// we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address,
// NOTE: pub/sub must be switched!
publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
// only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object)
// does not need to do anything
//
@ -268,17 +274,38 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
val pollCount = newConnection.pollSubscriptions()
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.sessionId}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.sessionId}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Otherwise, poll the connection for messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
}
// tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold this connection open
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
if (canFinishConnecting) {
@ -290,6 +317,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} else {
close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
throw exception
}
@ -389,19 +417,9 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
override fun close() {
val con = connection
connection = null
if (con != null) {
connections.remove(con)
runBlocking {
con.close()
listenerManager.notifyDisconnect(con)
}
}
super.close()
isConnected = false
super.close()
}

View File

@ -20,7 +20,6 @@ import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.connection.connectionType.ConnectionProperties
import dorkbox.network.connection.connectionType.ConnectionRule
@ -29,7 +28,7 @@ import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.TimeoutException
import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler
import io.aeron.Image
import io.aeron.logbuffer.Header
import kotlinx.coroutines.launch
import org.agrona.DirectBuffer
@ -134,13 +133,16 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param blockUntilTerminate if true, will BLOCK until the server [close] method is called, and if you want to continue running code
* after this pass in false
*/
@Suppress("DuplicatedCode")
@JvmOverloads
fun bind(blockUntilTerminate: Boolean = true) {
suspend fun bind(blockUntilTerminate: Boolean = true) {
if (bindAlreadyCalled) {
logger.error("Unable to bind when the server is already running!")
return
}
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState()
bindAlreadyCalled = true
config as ServerConfiguration
@ -162,15 +164,15 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
logger.info(handshakeDriver.serverInfo())
val ipcHandshakeDriver = IpcMediaDriverConnection(
streamId = IPC_HANDSHAKE_STREAM_ID_PUB,
streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB,
sessionId = RESERVED_SESSION_ID_INVALID
)
ipcHandshakeDriver.buildServer(aeron)
val ipcHandshakePublication = ipcHandshakeDriver.publication
val ipcHandshakeSubscription = ipcHandshakeDriver.subscription
// val ipcHandshakeDriver = IpcMediaDriverConnection(
// streamId = IPC_HANDSHAKE_STREAM_ID_PUB,
// streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB,
// sessionId = RESERVED_SESSION_ID_INVALID
// )
// ipcHandshakeDriver.buildServer(aeron)
//
// val ipcHandshakePublication = ipcHandshakeDriver.publication
// val ipcHandshakeSubscription = ipcHandshakeDriver.subscription
/**
@ -182,16 +184,41 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val initialConnectionHandler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
val initialConnectionHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString)
val message = readHandshakeMessage(buffer, offset, length, header)
actionDispatch.launch {
handshake.receiveHandshakeMessageServer(handshakePublication, buffer, offset, length, header, this@Server)
handshake.processHandshakeMessageServer(handshakePublication,
sessionId,
clientAddressString,
clientAddress,
message,
this@Server,
aeron)
}
})
val ipcInitialConnectionHandler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
}
val ipcInitialConnectionHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
actionDispatch.launch {
println("GOT MESSAGE!")
}
})
}
actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy
@ -241,7 +268,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
handshake.cleanup(connectionToClean)
connectionToClean.close()
listenerManager.notifyDisconnect(connectionToClean)
})
@ -252,8 +278,8 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
handshakePublication.close()
handshakeSubscription.close()
ipcHandshakePublication.close()
ipcHandshakeSubscription.close()
// ipcHandshakePublication.close()
// ipcHandshakeSubscription.close()
}
}
@ -268,8 +294,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
var pollCount = 0
return pollCount
}
@ -382,16 +406,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
connections.remove(connection)
}
/**
* Checks to see if a server (using the specified configuration) is running.
*
* @return true if the server is active and running
*/
fun isRunning(): Boolean {
return mediaDriver.context().isDriverActive(10_000, logger::debug)
}
override fun close() {
super.close()
bindAlreadyCalled = false

View File

@ -32,7 +32,6 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
import javax.crypto.SecretKey
@ -120,9 +119,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small)
private val aes_gcm_iv = AtomicLong(0)
// when closing this connection, HOW MANY endpoints need to be closed?
private var closeLatch: CountDownLatch? = null
// RMI support for this connection
internal val rmiConnectionSupport = endPoint.getRmiConnectionSupport()
@ -146,16 +142,13 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
sessionId = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server!
messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// small problem... If we expect IN ORDER messages (ie: setting a value, then later reading the value), multiple threads don't work.
// this is worked around by having RMI always return (unless async), even with a null value, so the CALLING side of RMI will always
// go in "lock step"
endPoint.actionDispatch.launch {
endPoint.readMessage(buffer, offset, length, header, this@Connection)
}
}
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// when closing this connection, HOW MANY endpoints need to be closed?
closeLatch = CountDownLatch(1)
// NOTE: subscriptions (ie: reading from buffers, etc) are not thread safe! Because it is ambiguous HOW EXACTLY they are unsafe,
// we exclusively read from the DirectBuffer on a single thread.
endPoint.processMessage(buffer, offset, length, header, this@Connection)
}
}
@ -291,7 +284,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
return isClosed.value
}
/**
* Closes the connection, and removes all connection specific listeners
*/
@ -309,8 +301,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
publication.close()
// a connection might have also registered for disconnect events
notifyDisconnect()
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
endPoint.actionDispatch.launch {
// a connection might have also registered for disconnect events
listenerManager.value?.notifyDisconnect(this@Connection)
}
}
}
@ -345,15 +340,10 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
listenerManager.value!!.onMessage(function)
}
/**
* Invoked when a connection is disconnected from the remote endpoint
*/
internal suspend fun notifyDisconnect() {
listenerManager.value?.notifyDisconnect(this)
}
/**
* Invoked when a message object was received from a remote peer.
*
* This is ALWAYS called on a new dispatch
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false

View File

@ -92,13 +92,6 @@ internal open class ConnectionManager<CONNECTION: Connection>() {
return connections.size()
}
/**
* Closes all associated resources/threads/connections
*/
fun close() {
connections.clear()
}
/**
* Safely sends objects to a destination (such as a custom object or a standard ping). This will automatically choose which protocol
* is available to use.

View File

@ -37,14 +37,12 @@ import io.aeron.logbuffer.Header
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger
import mu.KotlinLogging
import org.agrona.DirectBuffer
import java.io.File
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.CountDownLatch
@ -114,16 +112,14 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val logger: KLogger = KotlinLogging.logger(type.simpleName)
internal val autoClosableObjects = CopyOnWriteArrayList<AutoCloseable>()
internal val actionDispatch = CoroutineScope(Dispatchers.Default)
internal val listenerManager = ListenerManager<CONNECTION>()
internal val connections = ConnectionManager<CONNECTION>()
internal val mediaDriverContext: MediaDriver.Context
internal val mediaDriver: MediaDriver
internal val aeron: Aeron
private var mediaDriverContext: MediaDriver.Context? = null
private var mediaDriver: MediaDriver? = null
private var aeron: Aeron? = null
/**
* Returns the serialization wrapper if there is an object type that needs to be added outside of the basics.
@ -138,11 +134,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal val crypto: CryptoManagement
private val shutdown = atomic(false)
private val shutdownLatch = CountDownLatch(1)
@Volatile
private var shutdownLatch: CountDownLatch = CountDownLatch(1)
// we only want one instance of these created. These will be called appropriately
val settingsStore: SettingsStore
internal val globalThreadUnsafeKryo: KryoExtra = config.serialization.takeKryo()
internal val rmiGlobalSupport = RmiManagerGlobal<CONNECTION>(logger, actionDispatch, config.serialization)
init {
@ -231,6 +231,26 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
logger.info("Aeron log directory already exists! This might not be what you want!")
}
// serialization stuff
serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy
// we have to be able to specify WHAT property store we want to use, since it can change!
settingsStore = config.settingsStore
settingsStore.init(serialization, config.settingsStorageSystem.build())
crypto = CryptoManagement(logger, settingsStore, type, config)
// we are done with initial configuration, now finish serialization
runBlocking {
serialization.finishInit(type)
}
}
internal suspend fun initEndpointState(): Aeron {
val aeronDirectory = config.aeronLogDirectory!!.absolutePath
val threadFactory = NamedThreadFactory("Aeron", false)
// LOW-LATENCY SETTINGS
@ -241,26 +261,27 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE)
// .senderIdleStrategy(NoOpIdleStrategy.INSTANCE);
mediaDriverContext = MediaDriver.Context()
.publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH)
.dirDeleteOnStart(true)
.dirDeleteOnShutdown(true)
.conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory)
.senderThreadFactory(threadFactory)
.sharedNetworkThreadFactory(threadFactory)
.sharedThreadFactory(threadFactory)
.threadingMode(config.threadingMode)
.mtuLength(config.networkMtuSize)
.socketSndbufLength(config.sendBufferSize)
.socketRcvbufLength(config.receiveBufferSize)
.aeronDirectoryName(config.aeronLogDirectory!!.absolutePath)
.publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH)
.dirDeleteOnStart(true)
.dirDeleteOnShutdown(true)
.conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory)
.senderThreadFactory(threadFactory)
.sharedNetworkThreadFactory(threadFactory)
.sharedThreadFactory(threadFactory)
.threadingMode(config.threadingMode)
.mtuLength(config.networkMtuSize)
.socketSndbufLength(config.sendBufferSize)
.socketRcvbufLength(config.receiveBufferSize)
.aeronDirectoryName(aeronDirectory)
val aeronContext = Aeron.Context().aeronDirectoryName(mediaDriverContext.aeronDirectoryName())
val aeronContext = Aeron.Context().aeronDirectoryName(aeronDirectory)
try {
mediaDriver = MediaDriver.launch(mediaDriverContext)
mediaDriver = try {
MediaDriver.launch(mediaDriverContext)
} catch (e: Exception) {
listenerManager.notifyError(e)
throw e
}
@ -268,36 +289,24 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
aeron = Aeron.connect(aeronContext)
} catch (e: Exception) {
try {
mediaDriver.close()
mediaDriver!!.close()
} catch (secondaryException: Exception) {
e.addSuppressed(secondaryException)
}
listenerManager.notifyError(e)
throw e
}
autoClosableObjects.add(aeron)
autoClosableObjects.add(mediaDriver)
shutdown.getAndSet(false)
shutdownLatch.countDown()
shutdownLatch = CountDownLatch(1)
// serialization stuff
serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy
// we have to be able to specify WHAT property store we want to use, since it can change!
settingsStore = config.settingsStore
settingsStore.init(serialization, config.settingsStorageSystem.build())
// the storage is closed via this as well
autoClosableObjects.add(settingsStore)
crypto = CryptoManagement(logger, settingsStore, type, config)
// we are done with initial configuration, now finish serialization
runBlocking {
serialization.finishInit(type)
}
return aeron!!
}
abstract fun newException(message: String, cause: Throwable? = null): Throwable
/**
@ -419,7 +428,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
}
/**
* Runs an action for each connection inside of a read-lock
* Runs an action for each connection
*/
suspend fun forEachConnection(function: suspend (connection: CONNECTION) -> Unit) {
connections.forEach {
@ -433,11 +442,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
"[${publication.sessionId()}] send: $message"
}
val kryo: KryoExtra = serialization.takeKryo()
try {
kryo.write(message)
globalThreadUnsafeKryo.write(message)
val buffer = kryo.writerBuffer
val buffer = globalThreadUnsafeKryo.writerBuffer
val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer
@ -463,7 +471,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
listenerManager.notifyError(newException("Error serializing message $message", e))
} finally {
sendIdleStrategy.reset()
serialization.returnKryo(kryo)
}
}
@ -471,58 +478,68 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* @param buffer The buffer
* @param offset The offset from the start of the buffer
* @param length The number of bytes to extract
* @param header The aeron header information
*
* @return A string
* @return the message
*/
fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
val kryo: KryoExtra = serialization.takeKryo()
try {
val message = kryo.read(buffer, offset, length)
val message = globalThreadUnsafeKryo.read(buffer, offset, length)
logger.trace {
"[${header.sessionId()}] received handshake: $message"
"[${header.sessionId()}] received: $message"
}
return message
} catch (e: Exception) {
logger.error("Error de-serializing message on connection ${header.sessionId()}!", e)
} finally {
serialization.returnKryo(kryo)
}
// The sessionId is globally unique, and is assigned by the server.
val sessionId = header.sessionId()
return null
val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception)
actionDispatch.launch {
listenerManager.notifyError(exception)
}
logger.error("Error de-serializing message on connection ${header.sessionId()}!", e)
return null
}
}
// This is on the action dispatch!
suspend fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) {
// The sessionId is globally unique, and is assigned by the server.
val sessionId = header.sessionId()
/**
* read the message from the aeron buffer
*
* @param buffer The buffer
* @param offset The offset from the start of the buffer
* @param length The number of bytes to extract
* @param header The aeron header information
* @param connection The connection this message happened on
*/
fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) {
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
@Suppress("UNCHECKED_CAST")
connection as CONNECTION
// note: this address will ALWAYS be an IP:PORT combo
// val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
// val splitPoint = remoteIpAndPort.lastIndexOf(':')
// val ip = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// val ipAsInt = NetworkUtil.IP.toInt(ip)
var message: Any? = null
val kryo: KryoExtra = serialization.takeKryo()
val message: Any?
try {
message = kryo.read(buffer, offset, length, connection)
message = globalThreadUnsafeKryo.read(buffer, offset, length, connection)
logger.trace {
// The sessionId is globally unique, and is assigned by the server.
val sessionId = header.sessionId()
"[${sessionId}] received: $message"
}
} catch (e: Exception) {
listenerManager.notifyError(newException("[${sessionId}] Error de-serializing message", e))
} finally {
serialization.returnKryo(kryo)
// The sessionId is globally unique, and is assigned by the server.
val sessionId = header.sessionId()
val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception)
actionDispatch.launch {
listenerManager.notifyError(connection, exception)
}
return // don't do anything!
}
connection as CONNECTION
when (message) {
is PingMessage -> {
@ -537,34 +554,44 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// ping0(ping)
// }
}
// small problem... If we expect IN ORDER messages (ie: setting a value, then later reading the value), multiple threads don't work.
// this is worked around by having RMI always return (unless async), even with a null value, so the CALLING side of RMI will always
// go in "lock step"
is RmiMessage -> {
// if we are an RMI message/registration, we have very specific, defined behavior.
// We do not use the "normal" listener callback pattern because this require special functionality
rmiGlobalSupport.manage(this, connection, message, logger)
actionDispatch.launch {
// if we are an RMI message/registration, we have very specific, defined behavior.
// We do not use the "normal" listener callback pattern because this require special functionality
rmiGlobalSupport.manage(this@EndPoint, connection, message, logger)
}
}
is Any -> {
@Suppress("UNCHECKED_CAST")
var hasListeners = listenerManager.notifyOnMessage(connection, message)
actionDispatch.launch {
@Suppress("UNCHECKED_CAST")
var hasListeners = listenerManager.notifyOnMessage(connection, message)
// each connection registers, and is polled INDEPENDENTLY for messages.
hasListeners = hasListeners or connection.notifyOnMessage(message)
// each connection registers, and is polled INDEPENDENTLY for messages.
hasListeners = hasListeners or connection.notifyOnMessage(message)
if (!hasListeners) {
val exception = MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception)
if (!hasListeners) {
val exception = MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception)
}
}
}
else -> {
// do nothing, there were problems with the message
val exception = if (message != null) {
MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
} else {
MessageNotRegisteredException("Unknown message received!!")
}
actionDispatch.launch {
// do nothing, there were problems with the message
val exception = if (message != null) {
MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
} else {
MessageNotRegisteredException("Unknown message received!!")
}
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
}
}
}
}
@ -576,6 +603,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
"[${publication.sessionId()}] send: $message"
}
// since ANY thread can call 'send', we have to take kryo instances in a safe way
val kryo: KryoExtra = serialization.takeKryo()
try {
kryo.write(connection, message)
@ -652,12 +680,22 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
shutdownLatch.await()
}
/**
* Checks to see if an endpoint (using the specified configuration) is running.
*
* @return true if the client/server is active and running
*/
fun isRunning(): Boolean {
return mediaDriverContext?.isDriverActive(10_000, logger::debug) ?: false
}
override fun close() {
if (shutdown.compareAndSet(expect = false, update = true)) {
autoClosableObjects.forEach {
it.close()
}
autoClosableObjects.clear()
aeron?.close()
mediaDriver?.close()
// the storage is closed via this as well
settingsStore.close()
rmiGlobalSupport.close()
@ -665,12 +703,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// don't need anything fast or fancy here, because this method will only be called once
connections.forEach {
it.close()
listenerManager.notifyDisconnect(it)
listenerManager.notifyDisconnect(it) // if disconnect has a "connect" in it, this will case SO MANY PROBLEMS
}
}
connections.close()
actionDispatch.cancel()
shutdownLatch.countDown()
}
}

View File

@ -21,13 +21,11 @@ import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.MediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header
import kotlinx.coroutines.launch
import mu.KLogger
import org.agrona.DirectBuffer
import java.security.SecureRandom
@ -35,78 +33,71 @@ import java.security.SecureRandom
internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogger,
private val config: Configuration,
private val crypto: CryptoManagement,
private val listenerManager: ListenerManager<CONNECTION>) {
private val endPoint: EndPoint<CONNECTION>) {
// a one-time key for connecting
private val oneTimePad = SecureRandom().nextInt()
@Volatile
var connectionHelloInfo: ClientConnectionInfo? = null
private var connectionHelloInfo: ClientConnectionInfo? = null
@Volatile
var connectionDone = false
private var connectionDone = false
@Volatile
private var failed: Exception? = null
lateinit var handler: FragmentHandler
lateinit var endPoint: EndPoint<CONNECTION>
var sessionId: Int = 0
fun init(endPoint: EndPoint<CONNECTION>) {
this.endPoint = endPoint
private var handler: FragmentHandler
private var sessionId: Int = 0
init {
// now we have a bi-directional connection with the server on the handshake "socket".
handler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
endPoint.actionDispatch.launch {
val sessionId = header.sessionId()
handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
val sessionId = header.sessionId()
// it must be a registration message
if (message !is HandshakeMessage) {
failed = ClientException("[$sessionId] server returned unrecognized message: $message")
return@FragmentAssembler
}
// this is an error message
if (message.sessionId == 0) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@FragmentAssembler
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}")
return@FragmentAssembler
}
// it must be the correct state
when (message.state) {
HandshakeMessage.HELLO_ACK -> {
// The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
connectionHelloInfo!!.log(sessionId, logger)
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
logger.trace {
"[$sessionId] handshake response: $message"
}
// it must be a registration message
if (message !is HandshakeMessage) {
failed = ClientException("[$sessionId] server returned unrecognized message: $message")
return@launch
HandshakeMessage.DONE_ACK -> {
connectionDone = true
}
// this is an error message
if (message.sessionId == 0) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@launch
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}")
return@launch
}
// it must be the correct state
when (message.state) {
HandshakeMessage.HELLO_ACK -> {
// The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
connectionHelloInfo!!.log(sessionId, logger)
else -> {
if (message.state != HandshakeMessage.HELLO_ACK) {
failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
}
HandshakeMessage.DONE_ACK -> {
connectionDone = true
}
else -> {
if (message.state != HandshakeMessage.HELLO_ACK) {
failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
} else if (message.state != HandshakeMessage.DONE_ACK) {
failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
return@launch
else if (message.state != HandshakeMessage.DONE_ACK) {
failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
}
}
})
}
}
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {

View File

@ -15,7 +15,6 @@
*/
package dorkbox.network.handshake
import dorkbox.netUtil.IPv4
import dorkbox.network.Server
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.client.ClientRejectedException
@ -28,12 +27,10 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.Image
import io.aeron.Aeron
import io.aeron.Publication
import io.aeron.logbuffer.Header
import kotlinx.coroutines.launch
import mu.KLogger
import org.agrona.DirectBuffer
import org.agrona.collections.Int2IntCounterMap
import org.agrona.collections.Int2ObjectHashMap
import java.util.concurrent.locks.ReentrantReadWriteLock
@ -58,23 +55,13 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
// note: this is called in action dispatch
suspend fun receiveHandshakeMessageServer(handshakePublication: Publication,
buffer: DirectBuffer, offset: Int, length: Int, header: Header,
server: Server<CONNECTION>) {
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// ONLY for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString)
val message = server.readHandshakeMessage(buffer, offset, length, header)
suspend fun processHandshakeMessageServer(handshakePublication: Publication,
sessionId: Int,
clientAddressString: String,
clientAddress: Int,
message: Any?,
server: Server<CONNECTION>,
aeron: Aeron) {
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
@ -86,6 +73,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState
// check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) {
pendingConnectionsLock.write {
@ -109,6 +97,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
}
try {
// VALIDATE:: Check to see if there are already too many clients connected.
if (server.connections.connectionCount() >= config.maxClientCount) {
@ -148,7 +137,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return
}
// VALIDATE:: TODO: ?? check to see if this session is ALREADY connected??. It should not be!
/////
@ -208,7 +196,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
message.isReliable)
// we have to construct how the connection will communicate!
clientConnection.buildServer(server.aeron)
clientConnection.buildServer(aeron)
logger.trace {
"Creating new connection $clientConnection"

View File

@ -112,7 +112,9 @@ object AeronServer {
}
}
server.bind()
runBlocking {
server.bind()
}
}
init {

View File

@ -83,9 +83,9 @@ abstract class BaseTest {
// rootLogger.setLevel(Level.OFF);
rootLogger.level = Level.INFO;
// rootLogger.level = Level.INFO;
// rootLogger.level = Level.DEBUG
// rootLogger.level = Level.TRACE;
rootLogger.level = Level.TRACE;
// rootLogger.level = Level.ALL;

View File

@ -0,0 +1,73 @@
package dorkboxTest.network
import dorkbox.network.Client
import dorkbox.network.Server
import dorkbox.network.connection.Connection
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.junit.Assert
import org.junit.Test
import java.io.IOException
import java.util.*
class DisconnectReconnectTest : BaseTest() {
private val timer = Timer()
private val reconnectCount = atomic(0)
@Test
fun reconnectClient() {
run {
val configuration = serverConfig()
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
runBlocking {
server.bind(false)
}
server.onConnect { connection ->
println("Disconnecting after 2 seconds.")
delay(2000)
println("Disconnecting....")
connection.close()
}
}
run {
val config = clientConfig()
val client: Client<Connection> = Client(config)
addEndPoint(client)
client.onDisconnect { connection ->
println("Disconnected!")
val count = reconnectCount.getAndIncrement()
if (count == 3) {
println("Shutting down")
stopEndPoints()
}
else {
println("Reconnecting: $count")
try {
client.connect(LOOPBACK)
} catch (e: IOException) {
e.printStackTrace()
}
}
}
runBlocking {
client.connect(LOOPBACK)
}
}
waitForThreads()
System.err.println("Connection count (after reconnecting) is: " + reconnectCount.value)
Assert.assertEquals(4, reconnectCount.value)
}
}

View File

@ -112,7 +112,9 @@ class ListenerTest : BaseTest() {
}
server.bind(false)
runBlocking {
server.bind(false)
}

View File

@ -83,7 +83,9 @@ class MultipleServerTest : BaseTest() {
}
}
server.bind(false)
runBlocking {
server.bind(false)
}
serverAeronDir = File(configuration.aeronLogDirectory.toString() + count)
}

View File

@ -64,12 +64,20 @@ class PingPongTest : BaseTest() {
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.bind(false)
runBlocking {
server.bind(false)
}
server.onError { _, throwable ->
fail = "Error during processing. $throwable"
}
server.onConnect { connection ->
server.forEachConnection { connection ->
println("server connection: $connection")
}
}
server.onMessage<Data> { connection, message ->
connection.send(message)
}
@ -85,6 +93,10 @@ class PingPongTest : BaseTest() {
client.onConnect { connection ->
client.forEachConnection { connection ->
println("client connection: $connection")
}
fail = null
connection.send(data)
}

View File

@ -72,7 +72,9 @@ class RmiDelayedInvocationSpamTest : BaseTest() {
server.saveGlobalObject(TestObjectImpl(counter), RMI_ID)
server.bind(false)
runBlocking {
server.bind(false)
}
}

View File

@ -66,7 +66,9 @@ class RmiDelayedInvocationTest : BaseTest() {
server.saveGlobalObject(TestObjectImpl(iterateLock), OBJ_ID)
server.bind(false)
runBlocking {
server.bind(false)
}
}
run {

View File

@ -263,7 +263,9 @@ class RmiInitValidationTest : BaseTest() {
stopEndPoints()
}
server.bind(false)
runBlocking {
server.bind(false)
}
}

View File

@ -100,8 +100,9 @@ class RmiOverrideAndProxyTest : BaseTest() {
}
}
server.bind(false)
runBlocking {
server.bind(false)
}
}

View File

@ -40,7 +40,6 @@ import dorkbox.network.Server
import dorkbox.network.connection.Connection
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.serialization.NetworkSerializationManager
import dorkbox.util.exceptions.SecurityException
import dorkboxTest.network.BaseTest
import dorkboxTest.network.rmi.classes.MessageWithTestCow
import dorkboxTest.network.rmi.classes.TestCow
@ -48,7 +47,6 @@ import dorkboxTest.network.rmi.classes.TestCowImpl
import kotlinx.coroutines.runBlocking
import org.junit.Assert
import org.junit.Test
import java.io.IOException
class RmiTest : BaseTest() {
@ -94,8 +92,7 @@ class RmiTest : BaseTest() {
try {
test.throwSuspendException()
} catch (e: UnsupportedOperationException) {
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).")
e.printStackTrace()
connection.logger.error("\tExpected exception (exception log should also be on the object impl side).", e)
caught = true
}
@ -191,7 +188,6 @@ class RmiTest : BaseTest() {
// }
// }
@Throws(SecurityException::class, IOException::class)
fun rmi(config: (Configuration) -> Unit = {}) {
run {
val configuration = serverConfig()
@ -204,7 +200,9 @@ class RmiTest : BaseTest() {
val server = Server<Connection>(configuration)
addEndPoint(server)
server.bind(false)
runBlocking {
server.bind(false)
}
server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server")
@ -259,7 +257,6 @@ class RmiTest : BaseTest() {
waitForThreads(99999999)
}
@Throws(SecurityException::class, IOException::class)
fun rmiGlobal(config: (Configuration) -> Unit = {}) {
run {
val configuration = serverConfig()
@ -272,7 +269,9 @@ class RmiTest : BaseTest() {
val server = Server<Connection>(configuration)
addEndPoint(server)
server.bind(false)
runBlocking {
server.bind(false)
}
server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server")

View File

@ -78,7 +78,7 @@ object TestClient {
RmiTest.runTests(connection, remoteObject, 124123)
System.err.println("DONE")
// now send this remote object ACROSS the wire to the server.
// now send this remote object ACROSS the wire to the server (on the server, this is where the IMPLEMENTATION lives)
connection.send(remoteObject)
client.close()

View File

@ -19,9 +19,12 @@ import dorkbox.network.Server
import dorkbox.network.connection.Connection
import dorkboxTest.network.BaseTest
import dorkboxTest.network.rmi.RmiTest
import dorkboxTest.network.rmi.classes.MessageWithTestCow
import dorkboxTest.network.rmi.classes.TestCow
import dorkboxTest.network.rmi.classes.TestCowImpl
import dorkboxTest.network.rmi.multiJVM.TestClient.setup
import kotlinx.coroutines.runBlocking
import org.junit.Assert
/**
*
@ -39,6 +42,45 @@ object TestServer {
val server = Server<Connection>(configuration)
server.bind(false)
server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(124123, id.toLong())
System.err.println("Finished test for: Client -> Server")
//
// System.err.println("Starting test for: Server -> Client")
// connection.createObject<TestCow>(123) { rmiId, remoteObject ->
// System.err.println("Running test for: Server -> Client")
// RmiTest.runTests(connection, remoteObject, 123)
// System.err.println("Done with test for: Server -> Client")
// }
}
server.onMessage<TestCow> { connection, test ->
System.err.println("Received test cow from client")
// this object LIVES on the server.
test.moo()
test.moo("Cow")
Assert.assertEquals(123123, test.id())
// Test that RMI correctly waits for the remotely invoked method to exit
test.moo("You should see this two seconds before...", 2000)
connection.logger.error("...This")
//
// System.err.println("Starting test for: Server -> Client")
// connection.createObject<TestCow>(123) { rmiId, remoteObject ->
// System.err.println("Running test for: Server -> Client")
// RmiTest.runTests(connection, remoteObject, 123)
// System.err.println("Done with test for: Server -> Client")
// }
}
runBlocking {
server.bind(false)
}
}
}