From e20f9b91de5752412c9d9ec58e4097d211ed8a81 Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 2 Sep 2020 02:39:05 +0200 Subject: [PATCH] Added IPC support, filled out more methods. Better support for connect in a nested disconnect callback. Fixed issue with closing connections-with-handshake-errors on the server pending. --- src/dorkbox/network/Client.kt | 495 +++++++++--------- src/dorkbox/network/Server.kt | 165 ++---- src/dorkbox/network/connection/Connection.kt | 118 +++-- .../network/connection/CryptoManagement.kt | 46 +- src/dorkbox/network/connection/EndPoint.kt | 125 +++-- .../connection/MediaDriverConnection.kt | 41 +- .../network/handshake/ClientConnectionInfo.kt | 8 +- .../network/handshake/ClientHandshake.kt | 28 +- .../network/handshake/HandshakeMessage.kt | 13 +- .../network/handshake/ServerHandshake.kt | 276 ++++++++-- .../network/other/coroutines/SuspendWaiter.kt | 29 + 11 files changed, 804 insertions(+), 540 deletions(-) create mode 100644 src/dorkbox/network/other/coroutines/SuspendWaiter.kt diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index 7fcd2811..3773463a 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -29,13 +29,14 @@ import dorkbox.network.connection.Ping import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.handshake.ClientHandshake +import dorkbox.network.other.coroutines.SuspendWaiter import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.TimeoutException import dorkbox.util.exceptions.SecurityException +import kotlinx.atomicfu.atomic import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking /** * The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's @@ -74,6 +75,8 @@ open class Client(config: Configuration = Configuration private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization) + private val lockStepForReconnect = atomic(null) + init { // have to do some basic validation of our configuration if (config.publicationPort <= 0) { throw ClientException("configuration port must be > 0") } @@ -106,12 +109,12 @@ open class Client(config: Configuration = Configuration * - a network name ("localhost", "loopback", "lo", "bob.example.org") * - an IP address ("127.0.0.1", "123.123.123.123", "::1") * - * ### For the IPC (Inter-Process-Communication) address. it must be: - * - the IPC integer ID, "0x1337c0de", "0x12312312", etc. + * ### For the IPC (Inter-Process-Communication) it must be: + * - EMPTY. ie: just call `connect()` * * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * - * @param remoteAddress The network or IPC address for the client to connect to + * @param remoteAddress The network or if localhost, IPC address for the client to connect to * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely * @param reliable true if we want to create a reliable connection. IPC connections are always reliable * @@ -121,29 +124,55 @@ open class Client(config: Configuration = Configuration */ @Suppress("DuplicatedCode") suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + // this will exist ONLY if we are reconnecting via a "disconnect" callback + lockStepForReconnect.value?.doWait() + if (isConnected) { logger.error("Unable to connect when already connected!") return } + lockStepForReconnect.lazySet(null) + connection = null + // we are done with initial configuration, now initialize aeron and the general state of this endpoint val aeron = initEndpointState() + this.connectionTimeoutMS = connectionTimeoutMS + val isIpcConnection: Boolean + + // NETWORK OR IPC ADDRESS + // if we connect to "loopback", then we substitute if for IPC (with log message) + // localhost/loopback IP might not always be 127.0.0.1 or ::1 when (remoteAddress) { - "loopback", "localhost", "lo", "" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress - else -> when { - IPv4.isLoopback(remoteAddress) -> this.remoteAddress = IPv4.LOCALHOST.hostAddress - IPv6.isLoopback(remoteAddress) -> this.remoteAddress = IPv6.LOCALHOST.hostAddress - else -> this.remoteAddress = remoteAddress // might be IPC address! + "0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") + "loopback", "localhost", "lo", "" -> { + isIpcConnection = true + logger.info("Auto-changing network connection from $remoteAddress -> IPC") + this.remoteAddress = "ipc" + } + "0x" -> { + isIpcConnection = true + this.remoteAddress = "ipc" + } + else -> when { + IPv4.isLoopback(remoteAddress) -> { + logger.info("Auto-changing network connection from $remoteAddress -> IPC") + isIpcConnection = true + this.remoteAddress = "ipc" + } + IPv6.isLoopback(remoteAddress) -> { + logger.info("Auto-changing network connection from $remoteAddress -> IPC") + isIpcConnection = true + this.remoteAddress = "ipc" + } + else -> { + isIpcConnection = false + this.remoteAddress = remoteAddress + } } - } - - - // if we are IPv4 wildcard - if (this.remoteAddress == "0.0.0.0") { - throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") } @@ -158,234 +187,232 @@ open class Client(config: Configuration = Configuration } } - val handshake = ClientHandshake(logger, config, crypto, this) - if (this.remoteAddress.isEmpty()) { - // this is an IPC address - // When conducting IPC transfers, we MUST use the same aeron configuration as the server! -// config.aeronLogDirectory - - - // stream IDs are flipped for a client because we operate from the perspective of the server - val handshakeConnection = IpcMediaDriverConnection( - streamId = IPC_HANDSHAKE_STREAM_ID_SUB, - streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_PUB, - sessionId = RESERVED_SESSION_ID_INVALID - ) - - - - // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports - handshakeConnection.buildClient(aeron) -// logger.debug(handshakeConnection.clientInfo()) - - - println("CONASD") - - // this will block until the connection timeout, and throw an exception if we were unable to connect with the server - - // @Throws(ConnectTimedOutException::class, ClientRejectedException::class) - val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS) - println("CO23232232323NASD") - - // no longer necessary to hold the handshake connection open - handshakeConnection.close() + // initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER + val handshakeConnection = if (isIpcConnection) { + IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_PUB, + streamId = IPC_HANDSHAKE_STREAM_ID_SUB, + sessionId = RESERVED_SESSION_ID_INVALID) } else { - // THIS IS A NETWORK ADDRESS - - // initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER - val handshakeConnection = UdpMediaDriverConnection(address = this.remoteAddress, - publicationPort = config.subscriptionPort, - subscriptionPort = config.publicationPort, - streamId = UDP_HANDSHAKE_STREAM_ID, - sessionId = RESERVED_SESSION_ID_INVALID, - connectionTimeoutMS = connectionTimeoutMS, - isReliable = reliable) - - // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports - handshakeConnection.buildClient(aeron) - logger.info(handshakeConnection.clientInfo()) + UdpMediaDriverConnection(address = this.remoteAddress, + publicationPort = config.subscriptionPort, + subscriptionPort = config.publicationPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID, + connectionTimeoutMS = connectionTimeoutMS, + isReliable = reliable) + } - // this will block until the connection timeout, and throw an exception if we were unable to connect with the server - - // @Throws(ConnectTimedOutException::class, ClientRejectedException::class) - val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS) + // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports + handshakeConnection.buildClient(aeron) + logger.info(handshakeConnection.clientInfo()) - // VALIDATE:: check to see if the remote connection's public key has changed! - val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey) - if (validateRemoteAddress == PublicKeyValidationState.INVALID) { - handshakeConnection.close() - val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.") - listenerManager.notifyError(exception) - throw exception - } + // this will block until the connection timeout, and throw an exception if we were unable to connect with the server - // VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the - // client will timeout. SPECIFICALLY.... we do not give class serialization/registration info to the client (in case the client - // is rogue, we do not want to carelessly provide info. + // @Throws(ConnectTimedOutException::class, ClientRejectedException::class) + val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS) - // we are now connected, so we can connect to the NEW client-specific ports - val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address, - // NOTE: pub/sub must be switched! - publicationPort = connectionInfo.subscriptionPort, - subscriptionPort = connectionInfo.publicationPort, - streamId = connectionInfo.streamId, - sessionId = connectionInfo.sessionId, - connectionTimeoutMS = connectionTimeoutMS, - isReliable = handshakeConnection.isReliable) + // VALIDATE:: check to see if the remote connection's public key has changed! + val validateRemoteAddress = if (isIpcConnection) { + PublicKeyValidationState.VALID + } else { + crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey) + } - - // only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object) - // does not need to do anything - // - // throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports - logger.info(reliableClientConnection.clientInfo()) - - // we have to construct how the connection will communicate! - reliableClientConnection.buildClient(aeron) - - logger.info { - "Creating new connection to $reliableClientConnection" - } - - val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress)) - - // VALIDATE are we allowed to connect to this server (now that we have the initial server information) - @Suppress("UNCHECKED_CAST") - val permitConnection = listenerManager.notifyFilter(newConnection) - if (!permitConnection) { - handshakeConnection.close() - val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!") - listenerManager.notifyError(exception) - throw exception - } - - /////////////// - //// RMI - /////////////// - - // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information - serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage -> - listenerManager.notifyError(newConnection, - ClientRejectedException(errorMessage)) - } - - connection = newConnection - connections.add(newConnection) - - // have to make a new thread to listen for incoming data! - // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them - actionDispatch.launch { - val pollIdleStrategy = config.pollIdleStrategy - - while (!isShutdown()) { - // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. - var shouldCleanupConnection = false - - if (newConnection.isExpired()) { - logger.debug {"[${newConnection.sessionId}] connection expired"} - shouldCleanupConnection = true - } - - else if (newConnection.isClosed()) { - logger.debug {"[${newConnection.sessionId}] connection closed"} - shouldCleanupConnection = true - } - - - if (shouldCleanupConnection) { - close() - return@launch - } - else { - // Polls the AERON media driver subscription channel for incoming messages - val pollCount = newConnection.pollSubscriptions() - - // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) - pollIdleStrategy.idle(pollCount) - } - } - } - - // tell the server our connection handshake is done, and the connection can now listen for data. - val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) - - // no longer necessary to hold the handshake connection open + if (validateRemoteAddress == PublicKeyValidationState.INVALID) { handshakeConnection.close() + val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.") + listenerManager.notifyError(exception) + throw exception + } - if (canFinishConnecting) { - isConnected = true - actionDispatch.launch { - listenerManager.notifyConnect(newConnection) - } - } else { - close() - val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") - ListenerManager.cleanStackTrace(exception) - listenerManager.notifyError(exception) - throw exception + // VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the + // client will timeout. SPECIFICALLY.... we do not give class serialization/registration info to the client (in case the client + // is rogue, we do not want to carelessly provide info. + + + // we are now connected, so we can connect to the NEW client-specific ports + val reliableClientConnection = if (isIpcConnection) { + IpcMediaDriverConnection(sessionId = connectionInfo.sessionId, + // NOTE: pub/sub must be switched! + streamIdSubscription = connectionInfo.publicationPort, + streamId = connectionInfo.subscriptionPort, + connectionTimeoutMS = connectionTimeoutMS) + } + else { + UdpMediaDriverConnection(address = handshakeConnection.address, + // NOTE: pub/sub must be switched! + subscriptionPort = connectionInfo.publicationPort, + publicationPort = connectionInfo.subscriptionPort, + streamId = connectionInfo.streamId, + sessionId = connectionInfo.sessionId, + connectionTimeoutMS = connectionTimeoutMS, + isReliable = handshakeConnection.isReliable) + } + + // we have to construct how the connection will communicate! + reliableClientConnection.buildClient(aeron) + + // only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object) + // does not need to do anything + // + // throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports + logger.info(reliableClientConnection.clientInfo()) + + val newConnection = if (isIpcConnection) { + newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID)) + } else { + newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress)) + } + + // VALIDATE are we allowed to connect to this server (now that we have the initial server information) + @Suppress("UNCHECKED_CAST") + val permitConnection = listenerManager.notifyFilter(newConnection) + if (!permitConnection) { + handshakeConnection.close() + val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!") + listenerManager.notifyError(exception) + throw exception + } + + /////////////// + //// RMI + /////////////// + + // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information + serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage -> + listenerManager.notifyError(newConnection, + ClientRejectedException(errorMessage)) + } + + ////////////// + /// Extra Close action + ////////////// + newConnection.preCloseAction = { + // this is called whenever connection.close() is called by the framework or via client.close() + if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) { + listenerManager.notifyError(getConnection(), IllegalStateException("lockStep for reconnect was in the wrong state!")) } } + newConnection.postCloseAction = { + // this is called whenever connection.close() is called by the framework or via client.close() + + // make sure to call our client.notifyDisconnect() callbacks + + // manually call it. + // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback + actionDispatch.launch { + listenerManager.notifyDisconnect(getConnection()) + } + + // in case notifyDisconnect called client.connect().... cancel them waiting + isConnected = false + lockStepForReconnect.value?.cancel() + } + + connection = newConnection + connections.add(newConnection) + + // have to make a new thread to listen for incoming data! + // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them + actionDispatch.launch { + val pollIdleStrategy = config.pollIdleStrategy + + while (!isShutdown()) { + // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. + var shouldCleanupConnection = false + + if (newConnection.isExpired()) { + logger.debug {"[${newConnection.id}] connection expired"} + shouldCleanupConnection = true + } + + else if (newConnection.isClosed()) { + logger.debug {"[${newConnection.id}] connection closed"} + shouldCleanupConnection = true + } + + + if (shouldCleanupConnection) { + close() + return@launch + } + else { + // Polls the AERON media driver subscription channel for incoming messages + val pollCount = newConnection.pollSubscriptions() + + // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) + pollIdleStrategy.idle(pollCount) + } + } + } + + + // tell the server our connection handshake is done, and the connection can now listen for data. + val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) + + // no longer necessary to hold the handshake connection open + handshakeConnection.close() + + if (canFinishConnecting) { + isConnected = true + + actionDispatch.launch { + listenerManager.notifyConnect(newConnection) + } + } else { + close() + val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") + ListenerManager.cleanStackTrace(exception) + listenerManager.notifyError(exception) + throw exception + } } -// override fun hasRemoteKeyChanged(): Boolean { -// return connection!!.hasRemoteKeyChanged() -// } -// -// /** -// * @return the remote address, as a string. -// */ -// override fun getRemoteHost(): String { -// return connection!!.remoteHost -// } -// -// /** -// * @return true if this connection is established on the loopback interface -// */ -// override fun isLoopback(): Boolean { -// return connection!!.isLoopback -// } -// -// override fun isIPC(): Boolean { -// return false -// } - -// /** -// * @return true if this connection is a network connection -// */ -// override fun isNetwork(): Boolean { -// return false -// } -// -// /** -// * @return the connection (TCP or LOCAL) id of this connection. -// */ -// override fun id(): Int { -// return connection!!.id() -// } -// -// /** -// * @return the connection (TCP or LOCAL) id of this connection as a HEX string. -// */ -// override fun idAsHex(): String { -// return connection!!.idAsHex() -// } - - - - - - + /** + * @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed. + */ + fun hasRemoteKeyChanged(): Boolean { + return getConnection().hasRemoteKeyChanged() + } /** - * Fetches the connection used by the client, this is only valid after the client has connected + * @return the remote address, as a string. + */ + fun getRemoteHost(): String { + return this.remoteAddress + } + + /** + * @return true if this connection is an IPC connection + */ + fun isIPC(): Boolean { + return getConnection().isIpc + } + + /** + * @return true if this connection is a network connection + */ + fun isNetwork(): Boolean { + return getConnection().isNetwork + } + + /** + * @return the connection (TCP or IPC) id of this connection. + */ + fun id(): Int { + return getConnection().id + } + + /** + * @return the connection used by the client, this is only valid after the client has connected */ fun getConnection(): CONNECTION { return connection as CONNECTION @@ -427,32 +454,6 @@ open class Client(config: Configuration = Configuration } } - override fun close() { - val con = connection - connection = null - isConnected = false - super.close() - - // in the client, "client-notifyDisconnect" will NEVER be called, because it's only called on a connection! - // (meaning, 'connection-notifiyDisconnect' is what is called) - - // manually call it. - if (con != null) { - // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback - val job = actionDispatch.launch { - listenerManager.notifyDisconnect(con) - } - - // when we close a client or a server, we want to make sure that ALL notifications are finished. - // when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown - // NOTE: this must be the LAST thing happening! - runBlocking { - job.join() - } - } - } - - // RMI notes (in multiple places, copypasta, because this is confusing if not written down) // // only server can create a global object (in itself, via save) diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index 12c6b70b..6487d301 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -20,6 +20,7 @@ import dorkbox.netUtil.IPv6 import dorkbox.network.aeron.server.ServerException import dorkbox.network.connection.Connection import dorkbox.network.connection.EndPoint +import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.connectionType.ConnectionProperties @@ -88,7 +89,7 @@ open class Server(config: ServerConfiguration = ServerC // localhost/loopback IP might not always be 127.0.0.1 or ::1 when (config.listenIpAddress) { - "loopback", "localhost", "lo" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress + "loopback", "localhost", "lo", "" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress else -> when { IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress @@ -132,13 +133,9 @@ open class Server(config: ServerConfiguration = ServerC /** * Binds the server to AERON configuration - * - * @param blockUntilTerminate if true, will BLOCK until the server [close] method is called, and if you want to continue running code - * after this pass in false */ @Suppress("DuplicatedCode") - @JvmOverloads - suspend fun bind(blockUntilTerminate: Boolean = true) { + fun bind() { if (bindAlreadyCalled) { logger.error("Unable to bind when the server is already running!") return @@ -150,32 +147,29 @@ open class Server(config: ServerConfiguration = ServerC config as ServerConfiguration - // setup the "HANDSHAKE" ports, for initial clients to connect. - // The is how clients then get the new ports to connect to + other configuration options - val handshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress, - publicationPort = config.publicationPort, - subscriptionPort = config.subscriptionPort, - streamId = UDP_HANDSHAKE_STREAM_ID, - sessionId = RESERVED_SESSION_ID_INVALID) - - handshakeDriver.buildServer(aeron) - - val handshakePublication = handshakeDriver.publication - val handshakeSubscription = handshakeDriver.subscription - - logger.info(handshakeDriver.serverInfo()) + val ipcHandshakeDriver = IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB, + streamId = IPC_HANDSHAKE_STREAM_ID_PUB, + sessionId = RESERVED_SESSION_ID_INVALID) + ipcHandshakeDriver.buildServer(aeron) + val ipcHandshakePublication = ipcHandshakeDriver.publication + val ipcHandshakeSubscription = ipcHandshakeDriver.subscription -// val ipcHandshakeDriver = IpcMediaDriverConnection( -// streamId = IPC_HANDSHAKE_STREAM_ID_PUB, -// streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB, -// sessionId = RESERVED_SESSION_ID_INVALID -// ) -// ipcHandshakeDriver.buildServer(aeron) -// -// val ipcHandshakePublication = ipcHandshakeDriver.publication -// val ipcHandshakeSubscription = ipcHandshakeDriver.subscription + + val udpHandshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress, + publicationPort = config.publicationPort, + subscriptionPort = config.subscriptionPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID) + + udpHandshakeDriver.buildServer(aeron) + val handshakePublication = udpHandshakeDriver.publication + val handshakeSubscription = udpHandshakeDriver.subscription + + + logger.info(ipcHandshakeDriver.serverInfo()) + logger.info(udpHandshakeDriver.serverInfo()) /** @@ -187,14 +181,14 @@ open class Server(config: ServerConfiguration = ServerC * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery * properties from failure and streams with mechanical sympathy. */ - val handshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> + val udpHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE val sessionId = header.sessionId() - // note: this address will ALWAYS be an IP:PORT combo + // note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!) val remoteIpAndPort = (header.context() as Image).sourceIdentity() // split @@ -204,23 +198,28 @@ open class Server(config: ServerConfiguration = ServerC val clientAddress = IPv4.toInt(clientAddressString) val message = readHandshakeMessage(buffer, offset, length, header) - - actionDispatch.launch { - handshake.processHandshakeMessageServer(handshakePublication, - sessionId, - clientAddressString, - clientAddress, - message, - this@Server, - aeron) - } + handshake.processHandshakeMessageServer(this@Server, + handshakePublication, + sessionId, + clientAddressString, + clientAddress, + message, + aeron) } - val ipcInitialConnectionHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> + + val ipcHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! - actionDispatch.launch { - println("GOT MESSAGE!") - } + // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. + // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE + val sessionId = header.sessionId() + + val message = readHandshakeMessage(buffer, offset, length, header) + handshake.processHandshakeMessageServer(this@Server, + ipcHandshakePublication, + sessionId, + message, + aeron) } actionDispatch.launch { @@ -236,10 +235,10 @@ open class Server(config: ServerConfiguration = ServerC // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` // this checks to see if there are NEW clients on the handshake ports - pollCount += handshakeSubscription.poll(handshakeHandler, 2) + pollCount += handshakeSubscription.poll(udpHandshakeHandler, 1) // this checks to see if there are NEW clients via IPC -// pollCount += ipcHandshakeSubscription.poll(ipcInitialConnectionHandler, 100) + pollCount += ipcHandshakeSubscription.poll(ipcHandshakeHandler, 1) // this manages existing clients (for cleanup + connection polling) @@ -248,12 +247,12 @@ open class Server(config: ServerConfiguration = ServerC var shouldCleanupConnection = false if (connection.isExpired()) { - logger.trace {"[${connection.sessionId}] connection expired"} + logger.trace {"[${connection.id}] connection expired"} shouldCleanupConnection = true } else if (connection.isClosed()) { - logger.trace {"[${connection.sessionId}] connection closed"} + logger.trace {"[${connection.id}] connection closed"} shouldCleanupConnection = true } @@ -268,7 +267,7 @@ open class Server(config: ServerConfiguration = ServerC false } }, { connectionToClean -> - logger.info {"[${connectionToClean.sessionId}] cleaned-up connection"} + logger.info {"[${connectionToClean.id}] cleaned-up connection"} // have to free up resources! handshake.cleanup(connectionToClean) @@ -294,16 +293,10 @@ open class Server(config: ServerConfiguration = ServerC handshakePublication.close() handshakeSubscription.close() -// ipcHandshakePublication.close() -// ipcHandshakeSubscription.close() + ipcHandshakePublication.close() + ipcHandshakeSubscription.close() } } - - - // we now BLOCK until the stop method is called. - if (blockUntilTerminate) { - waitForShutdown(); - } } /** @@ -364,8 +357,7 @@ open class Server(config: ServerConfiguration = ServerC /** * Closes the server and all it's connections. After a close, you may call 'bind' again. */ - override fun close() { - super.close() + override fun close0() { bindAlreadyCalled = false // when we call close, it will shutdown the polling mechanism, so we have to manually cleanup the connections and call server-notifyDisconnect @@ -436,55 +428,6 @@ open class Server(config: ServerConfiguration = ServerC } - - -// enum class STATE { -// ERROR, WAIT, CONTINUE -// } - -// fun verifyClassRegistration(metaChannel: MetaChannel, registration: Registration): STATE { -// if (registration.upgradeType == UpgradeType.FRAGMENTED) { -// val fragment = registration.payload!! -// -// // this means that the registrations are FRAGMENTED! -// // max size of ALL fragments is xxx * 127 -// if (metaChannel.fragmentedRegistrationDetails == null) { -// metaChannel.remainingFragments = fragment[1] -// metaChannel.fragmentedRegistrationDetails = ByteArray(Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * fragment[1]) -// } -// System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE, fragment.size - 2) -// -// metaChannel.remainingFragments-- -// -// if (fragment[0] + 1 == fragment[1].toInt()) { -// // this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive) -// val correctSize = Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * (fragment[1] - 1) + (fragment.size - 2) -// val correctlySized = ByteArray(correctSize) -// System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize) -// metaChannel.fragmentedRegistrationDetails = correctlySized -// } -// if (metaChannel.remainingFragments.toInt() == 0) { -// // there are no more fragments available -// val details = metaChannel.fragmentedRegistrationDetails -// metaChannel.fragmentedRegistrationDetails = null -// if (!serialization.verifyKryoRegistration(details)) { -// // error -// return STATE.ERROR -// } -// } else { -// // wait for more fragments -// return STATE.WAIT -// } -// } else { -// if (!serialization.verifyKryoRegistration(registration.payload!!)) { -// return STATE.ERROR -// } -// } -// return STATE.CONTINUE -// } - - - // RMI notes (in multiple places, copypasta, because this is confusing if not written down) // // only server can create a global object (in itself, via save) @@ -532,7 +475,7 @@ open class Server(config: ServerConfiguration = ServerC * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveGlobalObject(`object`: Any): Int { + fun saveGlobalObject(`object`: Any): Int { val rmiId = rmiGlobalSupport.saveImplObject(`object`) if (rmiId == RemoteObjectStorage.INVALID_RMI) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") @@ -561,7 +504,7 @@ open class Server(config: ServerConfiguration = ServerC * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveGlobalObject(`object`: Any, objectId: Int): Boolean { + fun saveGlobalObject(`object`: Any, objectId: Int): Boolean { val success = rmiGlobalSupport.saveImplObject(`object`, objectId) if (!success) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index 3519e448..d9aec9cd 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -16,6 +16,7 @@ package dorkbox.network.connection import dorkbox.netUtil.IPv4 +import dorkbox.network.aeron.server.RandomIdAllocator import dorkbox.network.connection.ping.PingFuture import dorkbox.network.connection.ping.PingMessage import dorkbox.network.rmi.RemoteObject @@ -32,6 +33,7 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.agrona.DirectBuffer +import org.agrona.collections.Int2IntCounterMap import java.io.IOException import java.util.concurrent.TimeUnit @@ -46,76 +48,68 @@ open class Connection(connectionParameters: ConnectionParams<*>) { /** * The publication port (used by aeron) for this connection. This is from the perspective of the server! */ - internal val subscriptionPort: Int - internal val publicationPort: Int + private val subscriptionPort: Int + private val publicationPort: Int /** - * the stream id of this connection. + * the stream id of this connection. Can be 0 for IPC connections */ - internal val streamId: Int + private val streamId: Int /** * the session id of this connection. This value is UNIQUE */ - internal val sessionId: Int - - /** - * the id of this connection. This value is UNIQUE - */ val id: Int - get() = sessionId /** - * the remote address, as a string. + * the remote address, as a string. Will be "ipc" for IPC connections */ val remoteAddress: String /** - * the remote address, as an integer. + * the remote address, as an integer. Can be 0 for IPC connections */ - val remoteAddressInt: Int - + private val remoteAddressInt: Int /** * @return true if this connection is an IPC connection */ - val isIPC = connectionParameters.mediaDriverConnection is IpcMediaDriverConnection + val isIpc = connectionParameters.mediaDriverConnection is IpcMediaDriverConnection /** * @return true if this connection is a network connection */ val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverConnection - - - - - /** - * Returns the last calculated TCP return trip time, or -1 if or the [PingMessage] response has not yet been received. - */ - val lastRoundTripTime: Int - get() { - val pingFuture2 = pingFuture - return pingFuture2?.response ?: -1 - } - /** * the endpoint associated with this connection */ internal val endPoint = connectionParameters.endPoint - private val listenerManager = atomic?>(null) + + private val listenerManager = atomic?>(null) val logger = endPoint.logger + internal var preCloseAction: suspend () -> Unit = {} + internal var postCloseAction: suspend () -> Unit = {} + private val isClosed = atomic(false) + /** + // * Returns the last calculated TCP return trip time, or -1 if or the [PingMessage] response has not yet been received. + // */ +// val lastRoundTripTime: Int +// get() { +// val pingFuture2 = pingFuture +// return pingFuture2?.response ?: -1 +// } @Volatile private var pingFuture: PingFuture? = null // while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error. - private var remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED + private val remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED // The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 8 (external counter) + 4 (GCM counter) // The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this @@ -128,6 +122,8 @@ open class Connection(connectionParameters: ConnectionParams<*>) { // a record of how many messages are in progress of being sent. When closing the connection, this number must be 0 private val messagesInProgress = atomic(0) + val toString0: () -> String + init { val mediaDriverConnection = connectionParameters.mediaDriverConnection @@ -135,12 +131,25 @@ open class Connection(connectionParameters: ConnectionParams<*>) { subscription = mediaDriverConnection.subscription publication = mediaDriverConnection.publication - subscriptionPort = mediaDriverConnection.subscriptionPort - publicationPort = mediaDriverConnection.publicationPort - remoteAddress = mediaDriverConnection.address - remoteAddressInt = IPv4.toInt(remoteAddress) - streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server! - sessionId = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server! + remoteAddress = mediaDriverConnection.address // this can be the IP address or "ipc" word + id = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server! + + if (mediaDriverConnection is IpcMediaDriverConnection) { + streamId = 0 // this is because with IPC, we have stream sub/pub (which are replaced as port sub/pub) + subscriptionPort = mediaDriverConnection.streamIdSubscription + publicationPort = mediaDriverConnection.streamId + remoteAddressInt = 0 + + toString0 = { "[$id] IPC [$subscriptionPort|$publicationPort]" } + } else { + streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server! + subscriptionPort = mediaDriverConnection.subscriptionPort + publicationPort = mediaDriverConnection.publicationPort + remoteAddressInt = IPv4.toInt(mediaDriverConnection.address) + + toString0 = { "[$id] $remoteAddress [$publicationPort|$subscriptionPort]" } + } + messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! @@ -155,7 +164,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { /** - * Has the remote ECC public key changed. This can be useful if specific actions are necessary when the key has changed. + * @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed. */ fun hasRemoteKeyChanged(): Boolean { return remoteKeyChanged @@ -272,15 +281,14 @@ open class Connection(connectionParameters: ConnectionParams<*>) { return messagesInProgress.value } - /** - * @return `true` if this connection has no subscribers (which means this connection longer has a remote connection) + * @return `true` if this connection has no subscribers (which means this connection does not have a remote connection) */ internal fun isExpired(): Boolean { - return !subscription.isConnected + // cannot use subscription.isConnected !!! images can be in a state of flux. We only care if there are NO images. + return subscription.hasNoImages() } - /** * @return `true` if this connection has been closed */ @@ -300,7 +308,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { // the server 'handshake' connection info is cleaned up with the disconnect via timeout/expire. if (isClosed.compareAndSet(expect = false, update = true)) { - logger.info {"[${sessionId}] closed connection"} + logger.info {"[$id] closed connection"} subscription.close() @@ -332,11 +340,19 @@ open class Connection(connectionParameters: ConnectionParams<*>) { rmiConnectionSupport.clearProxyObjects() + // This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper + // lock-stop ordering for how disconnect and connect work with each-other + preCloseAction() + // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback endPoint.actionDispatch.launch { // a connection might have also registered for disconnect events listenerManager.value?.notifyDisconnect(this@Connection) } + + // This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper + // lock-stop ordering for how disconnect and connect work with each-other + postCloseAction() } } @@ -387,11 +403,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) { // // override fun toString(): String { - return "$remoteAddress $publicationPort/$subscriptionPort ID: $sessionId" + return toString0() } override fun hashCode(): Int { - return sessionId + return id } override fun equals(other: Any?): Boolean { @@ -406,9 +422,21 @@ open class Connection(connectionParameters: ConnectionParams<*>) { } val other1 = other as Connection - return sessionId == other1.sessionId + return id == other1.id } + // cleans up the connection information + fun cleanup(connectionsPerIpCounts: Int2IntCounterMap, sessionIdAllocator: RandomIdAllocator, streamIdAllocator: RandomIdAllocator) { + if (isIpc) { + sessionIdAllocator.free(subscriptionPort) + sessionIdAllocator.free(publicationPort) + streamIdAllocator.free(streamId) + } else { + connectionsPerIpCounts.getAndDecrement(remoteAddressInt) + sessionIdAllocator.free(id) + streamIdAllocator.free(streamId) + } + } // RMI notes (in multiple places, copypasta, because this is confusing if not written down) // diff --git a/src/dorkbox/network/connection/CryptoManagement.kt b/src/dorkbox/network/connection/CryptoManagement.kt index 84455377..1484e1b6 100644 --- a/src/dorkbox/network/connection/CryptoManagement.kt +++ b/src/dorkbox/network/connection/CryptoManagement.kt @@ -69,6 +69,11 @@ internal class CryptoManagement(val logger: KLogger, val secureRandom = SecureRandom(settingsStore.getSalt()) + private val iv = ByteArray(GCM_IV_LENGTH) + private val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) + val cryptOutput = AeronOutput() + val cryptInput = AeronInput() + private val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation init { @@ -177,6 +182,7 @@ internal class CryptoManagement(val logger: KLogger, return SecretKeySpec(hash.digest(), "AES") } + // NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the server, mutually exclusive calls to decrypt) fun encrypt(clientPublicKeyBytes: ByteArray, publicationPort: Int, subscriptionPort: Int, @@ -185,29 +191,24 @@ internal class CryptoManagement(val logger: KLogger, kryoRmiIds: IntArray): ByteArray { val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) - - val iv = ByteArray(GCM_IV_LENGTH) secureRandom.nextBytes(iv) - - val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec) // now create the byte array that holds all our data - val data = AeronOutput() - data.writeInt(connectionSessionId) - data.writeInt(connectionStreamId) - data.writeInt(publicationPort) - data.writeInt(subscriptionPort) - data.writeInt(kryoRmiIds.size) + cryptOutput.reset() + cryptOutput.writeInt(connectionSessionId) + cryptOutput.writeInt(connectionStreamId) + cryptOutput.writeInt(publicationPort) + cryptOutput.writeInt(subscriptionPort) + cryptOutput.writeInt(kryoRmiIds.size) kryoRmiIds.forEach { - data.writeInt(it) + cryptOutput.writeInt(it) } - val bytes = data.toBytes() - - return iv + aesCipher.doFinal(bytes) + return iv + aesCipher.doFinal(cryptOutput.toBytes()) } + // NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt) fun decrypt(registrationData: ByteArray?, serverPublicKeyBytes: ByteArray?): ClientConnectionInfo? { if (registrationData == null || serverPublicKeyBytes == null) { return null @@ -216,7 +217,6 @@ internal class CryptoManagement(val logger: KLogger, val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes) // now read the encrypted data - val iv = ByteArray(GCM_IV_LENGTH) registrationData.copyInto(destination = iv, endIndex = GCM_IV_LENGTH) @@ -226,21 +226,19 @@ internal class CryptoManagement(val logger: KLogger, // now decrypt the data - val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec) - val data = AeronInput(aesCipher.doFinal(secretBytes)) + cryptInput.buffer = aesCipher.doFinal(secretBytes) - - val sessionId = data.readInt() - val streamId = data.readInt() - val publicationPort = data.readInt() - val subscriptionPort = data.readInt() + val sessionId = cryptInput.readInt() + val streamId = cryptInput.readInt() + val publicationPort = cryptInput.readInt() + val subscriptionPort = cryptInput.readInt() val rmiIds = mutableListOf() - val rmiIdSize = data.readInt() + val rmiIdSize = cryptInput.readInt() for (i in 0 until rmiIdSize) { - rmiIds.add(data.readInt()) + rmiIds.add(cryptInput.readInt()) } // now read data off diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 907c2c55..d5c1ae17 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -22,6 +22,7 @@ import dorkbox.network.ServerConfiguration import dorkbox.network.aeron.CoroutineIdleStrategy import dorkbox.network.connection.ping.PingMessage import dorkbox.network.ipFilter.IpFilterRule +import dorkbox.network.other.coroutines.SuspendWaiter import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.RmiManagerGlobal import dorkbox.network.rmi.messages.RmiMessage @@ -43,7 +44,6 @@ import mu.KLogger import mu.KotlinLogging import org.agrona.DirectBuffer import java.io.File -import java.util.concurrent.CountDownLatch // If TCP and UDP both fill the pipe, THERE WILL BE FRAGMENTATION and dropped UDP packets! @@ -117,7 +117,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A internal val listenerManager = ListenerManager() internal val connections = ConnectionManager() - private var mediaDriverContext: MediaDriver.Context? = null + internal var mediaDriverContext: MediaDriver.Context? = null private var mediaDriver: MediaDriver? = null private var aeron: Aeron? = null @@ -136,7 +136,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A private val shutdown = atomic(false) @Volatile - private var shutdownLatch: CountDownLatch = CountDownLatch(1) + private var shutdownLatch: SuspendWaiter = SuspendWaiter() // we only want one instance of these created. These will be called appropriately val settingsStore: SettingsStore @@ -219,7 +219,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A if (config.aeronLogDirectory == null) { val baseFileLocation = config.suggestAeronLogLocation(logger) - val aeronLogDirectory = File(baseFileLocation, "aeron-" + type.simpleName) +// val aeronLogDirectory = File(baseFileLocation, "aeron-" + type.simpleName) + val aeronLogDirectory = File(baseFileLocation, "aeron") aeronDirAlreadyExists = aeronLogDirectory.exists() config.aeronLogDirectory = aeronLogDirectory } @@ -229,6 +230,47 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A logger.warn("Aeron log directory already exists! This might not be what you want!") } + val threadFactory = NamedThreadFactory("Aeron", false) + + // LOW-LATENCY SETTINGS + // .termBufferSparseFile(false) + // .useWindowsHighResTimer(true) + // .threadingMode(ThreadingMode.DEDICATED) + // .conductorIdleStrategy(BusySpinIdleStrategy.INSTANCE) + // .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE) + // .senderIdleStrategy(NoOpIdleStrategy.INSTANCE); + // setProperty(DISABLE_BOUNDS_CHECKS_PROP_NAME, "true"); + // setProperty("aeron.mtu.length", "16384"); + // setProperty("aeron.socket.so_sndbuf", "2097152"); + // setProperty("aeron.socket.so_rcvbuf", "2097152"); + // setProperty("aeron.rcv.initial.window.length", "2097152"); + + // driver context must happen in the initializer, because we have a Server.isRunning() method that uses the mediaDriverContext (without bind) + val mDrivercontext = MediaDriver.Context() + .publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW) + .publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH) + .dirDeleteOnStart(true) + .dirDeleteOnShutdown(true) + .conductorThreadFactory(threadFactory) + .receiverThreadFactory(threadFactory) + .senderThreadFactory(threadFactory) + .sharedNetworkThreadFactory(threadFactory) + .sharedThreadFactory(threadFactory) + .threadingMode(config.threadingMode) + .mtuLength(config.networkMtuSize) + .socketSndbufLength(config.sendBufferSize) + .socketRcvbufLength(config.receiveBufferSize) + + mDrivercontext + .aeronDirectoryName(config.aeronLogDirectory!!.absolutePath) + .concludeAeronDirectory() + + mDrivercontext.ipcTermBufferLength(16 * 1024 * 1024) // default: 64 megs each is HUGE + mDrivercontext.publicationTermBufferLength(4 * 1024 * 1024) // default: 16 megs each is HUGE (we run out of space in production w/ lots of clients) + + mediaDriverContext = mDrivercontext + + // serialization stuff serialization = config.serialization sendIdleStrategy = config.sendIdleStrategy @@ -250,45 +292,26 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A internal fun initEndpointState(): Aeron { val aeronDirectory = config.aeronLogDirectory!!.absolutePath - val threadFactory = NamedThreadFactory("Aeron", false) - - // LOW-LATENCY SETTINGS - // .termBufferSparseFile(false) - // .useWindowsHighResTimer(true) - // .threadingMode(ThreadingMode.DEDICATED) - // .conductorIdleStrategy(BusySpinIdleStrategy.INSTANCE) - // .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE) - // .senderIdleStrategy(NoOpIdleStrategy.INSTANCE); - mediaDriverContext = MediaDriver.Context() - .publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW) - .publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH) - .dirDeleteOnStart(true) - .dirDeleteOnShutdown(true) - .conductorThreadFactory(threadFactory) - .receiverThreadFactory(threadFactory) - .senderThreadFactory(threadFactory) - .sharedNetworkThreadFactory(threadFactory) - .sharedThreadFactory(threadFactory) - .threadingMode(config.threadingMode) - .mtuLength(config.networkMtuSize) - .socketSndbufLength(config.sendBufferSize) - .socketRcvbufLength(config.receiveBufferSize) - .aeronDirectoryName(aeronDirectory) - - val aeronContext = Aeron.Context().aeronDirectoryName(aeronDirectory) - - mediaDriver = try { - MediaDriver.launch(mediaDriverContext) - } catch (e: Exception) { - listenerManager.notifyError(e) - throw e + if (!isRunning()) { + // the server always creates a media driver. + mediaDriver = try { + MediaDriver.launch(mediaDriverContext) + } catch (e: Exception) { + listenerManager.notifyError(e) + throw e + } } + val aeronContext = Aeron.Context() + aeronContext + .aeronDirectoryName(aeronDirectory) + .concludeAeronDirectory() + try { aeron = Aeron.connect(aeronContext) } catch (e: Exception) { try { - mediaDriver!!.close() + mediaDriver?.close() } catch (secondaryException: Exception) { e.addSuppressed(secondaryException) } @@ -299,8 +322,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A shutdown.getAndSet(false) - shutdownLatch.countDown() - shutdownLatch = CountDownLatch(1) + shutdownLatch = SuspendWaiter() return aeron!! } @@ -466,11 +488,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } // more critical error sending the message. we shouldn't retry or anything. - listenerManager.notifyError(newException("Error sending message. ${errorCodeName(result)}")) + listenerManager.notifyError( + newException("[${publication.sessionId()}] Error sending handshake message. $message (${errorCodeName(result)})")) return } } catch (e: Exception) { - listenerManager.notifyError(newException("Error serializing message $message", e)) + listenerManager.notifyError(newException("[${publication.sessionId()}] Error serializing handshake message $message", e)) } finally { sendIdleStrategy.reset() serialization.returnKryo(kryo) @@ -622,12 +645,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } // more critical error sending the message. we shouldn't retry or anything. - logger.error("Error sending message. ${errorCodeName(result)}") + logger.error("[${publication.sessionId()}] Error sending message. $message (${errorCodeName(result)})") return } } catch (e: Exception) { - logger.error("Error serializing message $message", e) + logger.error("[${publication.sessionId()}] Error serializing message $message", e) } finally { sendIdleStrategy.reset() serialization.returnKryo(kryo) @@ -671,8 +694,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A /** * Waits for this endpoint to be closed */ - fun waitForShutdown() { - shutdownLatch.await() + suspend fun waitForClose() { + shutdownLatch.doWait() } /** @@ -681,10 +704,11 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A * @return true if the client/server is active and running */ fun isRunning(): Boolean { - return mediaDriverContext?.isDriverActive(10_000, logger::debug) ?: false + // if the media driver is running, it will be a quick connection. Usually 100ms or so + return mediaDriverContext?.isDriverActive(1_000, logger::debug) ?: false } - override fun close() { + final override fun close() { if (shutdown.compareAndSet(expect = false, update = true)) { aeron?.close() mediaDriver?.close() @@ -700,7 +724,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } } - shutdownLatch.countDown() + close0() + + // if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now) + shutdownLatch.cancel() } } + + internal open fun close0() {} } diff --git a/src/dorkbox/network/connection/MediaDriverConnection.kt b/src/dorkbox/network/connection/MediaDriverConnection.kt index 6b95fe4a..8e588094 100644 --- a/src/dorkbox/network/connection/MediaDriverConnection.kt +++ b/src/dorkbox/network/connection/MediaDriverConnection.kt @@ -198,9 +198,10 @@ class IpcMediaDriverConnection(override val streamId: Int, val streamIdSubscription: Int, override val sessionId: Int, private val connectionTimeoutMS: Long = 30_000, - override val isReliable: Boolean = true) : MediaDriverConnection { + ) : MediaDriverConnection { - override val address = "" + override val isReliable = true + override val address = "ipc" override val subscriptionPort = 0 override val publicationPort = 0 @@ -209,10 +210,6 @@ class IpcMediaDriverConnection(override val streamId: Int, var success: Boolean = false - - init { - } - private fun uri(): ChannelUriStringBuilder { val builder = ChannelUriStringBuilder().media("ipc") if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { @@ -226,14 +223,10 @@ class IpcMediaDriverConnection(override val streamId: Int, override suspend fun buildClient(aeron: Aeron) { // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. val subscriptionUri = uri() -// .controlEndpoint("$address:$subscriptionPort") -// .controlMode("dynamic") - // Create a publication at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() -// .endpoint("$address:$publicationPort") // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe @@ -288,15 +281,10 @@ class IpcMediaDriverConnection(override val streamId: Int, override fun buildServer(aeron: Aeron) { // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. val subscriptionUri = uri() -// .endpoint("$address:$subscriptionPort") - // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() -// .controlEndpoint("$address:$publicationPort") -// .controlMode("dynamic") - // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. @@ -305,22 +293,29 @@ class IpcMediaDriverConnection(override val streamId: Int, } override fun clientInfo() : String { - return "" + return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { + "[$sessionId] aeron connection established to [$streamIdSubscription|$streamId]" + } else { + "Connecting IPC with handshake to [$streamIdSubscription|$streamId]" + } } override fun serverInfo() : String { - return "" - } - - fun connect() : Pair { - return Pair("","") + return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { + "[$sessionId] IPC listening on [$streamIdSubscription|$streamId] " + } else { + "IPC listening with handshake on [$streamIdSubscription|$streamId]" + } } override fun close() { - + if (success) { + subscription.close() + publication.close() + } } override fun toString(): String { - return "$address [$subscriptionPort|$publicationPort] [$streamId|$sessionId]" + return "[$streamIdSubscription|$streamId] [$sessionId]" } } diff --git a/src/dorkbox/network/handshake/ClientConnectionInfo.kt b/src/dorkbox/network/handshake/ClientConnectionInfo.kt index 2b5bc4be..10312aa7 100644 --- a/src/dorkbox/network/handshake/ClientConnectionInfo.kt +++ b/src/dorkbox/network/handshake/ClientConnectionInfo.kt @@ -15,10 +15,10 @@ */ package dorkbox.network.handshake -internal class ClientConnectionInfo(val subscriptionPort: Int, - val publicationPort: Int, +internal class ClientConnectionInfo(val subscriptionPort: Int = 0, + val publicationPort: Int = 0, val sessionId: Int, - val streamId: Int, - val publicKey: ByteArray, + val streamId: Int = 0, + val publicKey: ByteArray = ByteArray(0), val kryoIdsForRmi: IntArray) { } diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index ab2ec0ef..ea56b622 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -22,7 +22,6 @@ import dorkbox.network.connection.Connection import dorkbox.network.connection.CryptoManagement import dorkbox.network.connection.EndPoint import dorkbox.network.connection.MediaDriverConnection -import dorkbox.network.connection.UdpMediaDriverConnection import io.aeron.FragmentAssembler import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.Header @@ -82,7 +81,28 @@ internal class ClientHandshake(private val logger: KLogg // The message was intended for this client. Try to parse it as one of the available message types. // this message is ENCRYPTED! connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey) + } + HandshakeMessage.HELLO_ACK_IPC -> { + // The message was intended for this client. Try to parse it as one of the available message types. + // this message is ENCRYPTED! + val cryptInput = crypto.cryptInput + cryptInput.buffer = message.registrationData + val sessionId = cryptInput.readInt() + val streamSubId = cryptInput.readInt() + val streamPubId = cryptInput.readInt() + + val rmiIds = mutableListOf() + val rmiIdSize = cryptInput.readInt() + for (i in 0 until rmiIdSize) { + rmiIds.add(cryptInput.readInt()) + } + + // now read data off + connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId, + subscriptionPort = streamSubId, + publicationPort = streamPubId, + kryoIdsForRmi = rmiIds.toIntArray()) } HandshakeMessage.DONE_ACK -> { connectionDone = true @@ -124,7 +144,7 @@ internal class ClientHandshake(private val logger: KLogg while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { // NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment. // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` - pollCount = subscription.poll(handler, 2) + pollCount = subscription.poll(handler, 1) if (failed != null) { // no longer necessary to hold this connection open @@ -150,7 +170,7 @@ internal class ClientHandshake(private val logger: KLogg return connectionHelloInfo!! } - suspend fun handshakeDone(mediaConnection: UdpMediaDriverConnection, connectionTimeoutMS: Long): Boolean { + suspend fun handshakeDone(mediaConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { val registrationMessage = HandshakeMessage.doneFromClient() // Send the done message to the server. @@ -168,7 +188,7 @@ internal class ClientHandshake(private val logger: KLogg while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { // NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment. // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` - pollCount = subscription.poll(handler, 2) + pollCount = subscription.poll(handler, 1) if (failed != null) { // no longer necessary to hold this connection open diff --git a/src/dorkbox/network/handshake/HandshakeMessage.kt b/src/dorkbox/network/handshake/HandshakeMessage.kt index 708f022a..b8b79453 100644 --- a/src/dorkbox/network/handshake/HandshakeMessage.kt +++ b/src/dorkbox/network/handshake/HandshakeMessage.kt @@ -53,8 +53,9 @@ internal class HandshakeMessage private constructor() { const val INVALID = -1 const val HELLO = 0 const val HELLO_ACK = 1 - const val DONE = 2 - const val DONE_ACK = 3 + const val HELLO_ACK_IPC = 2 + const val DONE = 3 + const val DONE_ACK = 4 fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray, registrationRmiIdData: IntArray): HandshakeMessage { val hello = HandshakeMessage() @@ -73,6 +74,13 @@ internal class HandshakeMessage private constructor() { return hello } + fun helloAckIpcToClient(sessionId: Int): HandshakeMessage { + val hello = HandshakeMessage() + hello.state = HELLO_ACK_IPC + hello.sessionId = sessionId // has to be the same as before (the client expects this) + return hello + } + fun doneFromClient(): HandshakeMessage { val hello = HandshakeMessage() hello.state = DONE @@ -99,6 +107,7 @@ internal class HandshakeMessage private constructor() { INVALID -> "INVALID" HELLO -> "HELLO" HELLO_ACK -> "HELLO_ACK" + HELLO_ACK_IPC -> "HELLO_ACK_IPC" DONE -> "DONE" DONE_ACK -> "DONE_ACK" else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!" diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index d7cb88eb..e4a8517b 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -15,24 +15,31 @@ */ package dorkbox.network.handshake +import com.github.benmanes.caffeine.cache.Cache +import com.github.benmanes.caffeine.cache.Caffeine +import com.github.benmanes.caffeine.cache.RemovalCause +import com.github.benmanes.caffeine.cache.RemovalListener import dorkbox.network.Server import dorkbox.network.ServerConfiguration import dorkbox.network.aeron.client.ClientRejectedException +import dorkbox.network.aeron.client.ClientTimedOutException import dorkbox.network.aeron.server.AllocationException import dorkbox.network.aeron.server.RandomIdAllocator import dorkbox.network.aeron.server.ServerException import dorkbox.network.connection.Connection import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.EndPoint +import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.UdpMediaDriverConnection import io.aeron.Aeron import io.aeron.Publication import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import mu.KLogger import org.agrona.collections.Int2IntCounterMap -import org.agrona.collections.Int2ObjectHashMap +import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantReadWriteLock import kotlin.concurrent.write @@ -40,12 +47,25 @@ import kotlin.concurrent.write /** * @throws IllegalArgumentException If the port range is not valid */ +@Suppress("DuplicatedCode") internal class ServerHandshake(private val logger: KLogger, private val config: ServerConfiguration, private val listenerManager: ListenerManager) { private val pendingConnectionsLock = ReentrantReadWriteLock() - private val pendingConnections = Int2ObjectHashMap() + private val pendingConnections: Cache = Caffeine.newBuilder() + .expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS) + .removalListener(RemovalListener { _, value, cause -> + if (cause == RemovalCause.EXPIRED) { + @Suppress("UNCHECKED_CAST") + val connection = value as CONNECTION + + listenerManager.notifyError(ClientTimedOutException("[${connection.id}] Waiting for registration response from client")) + runBlocking { + connection.close() + } + } + }).build() private val connectionsPerIpCounts = Int2IntCounterMap(0) @@ -54,51 +74,244 @@ internal class ServerHandshake(private val logger: KLog EndPoint.RESERVED_SESSION_ID_HIGH) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) - // note: CANNOT be called in action dispatch - fun processHandshakeMessageServer(handshakePublication: Publication, - sessionId: Int, - clientAddressString: String, - clientAddress: Int, - message: Any?, - server: Server, - aeron: Aeron) { + + /** + * @return true if we should continue parsing the incoming message, false if we should abort + */ + private fun validateMessageTypeAndDoPending(server: Server, + handshakePublication: Publication, + message: Any?, + sessionId: Int, + connectionString: String): Boolean { // VALIDATE:: a Registration object is the only acceptable message during the connection phase if (message !is HandshakeMessage) { - listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request")) + listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request")) server.actionDispatch.launch { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) } - return + return false } - val clientPublicKeyBytes = message.publicKey - val validateRemoteAddress: PublicKeyValidationState - - // check to see if this is a pending connection if (message.state == HandshakeMessage.DONE) { - pendingConnectionsLock.write { - val pendingConnection = pendingConnections.remove(sessionId) - if (pendingConnection != null) { - logger.trace { "Connection from client $clientAddressString done with handshake." } + val pendingConnection = pendingConnectionsLock.write { + val con = pendingConnections.getIfPresent(sessionId) + pendingConnections.invalidate(sessionId) + con + } - // this enables the connection to start polling for messages - server.connections.add(pendingConnection) + if (pendingConnection == null) { + logger.error { "[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!" } + } else { + logger.trace { "[${pendingConnection.id}] Connection from client $connectionString done with handshake." } - server.actionDispatch.launch { - // now tell the client we are done - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) - listenerManager.notifyConnect(pendingConnection) - } + // this enables the connection to start polling for messages + server.connections.add(pendingConnection) - return + server.actionDispatch.launch { + // now tell the client we are done + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) + listenerManager.notifyConnect(pendingConnection) } } + + return false + } + + return true + } + + + // note: CANNOT be called in action dispatch + fun processHandshakeMessageServer(server: Server, + handshakePublication: Publication, + sessionId: Int, + message: Any?, + aeron: Aeron) { + + val connectionString = "IPC" + + if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, connectionString)) { + return + } + message as HandshakeMessage + + val serialization = config.serialization + + // VALIDATE:: make sure the serialization matches between the client/server! + if (!serialization.verifyKryoRegistration(message.registrationData!!)) { + listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Registration data mismatch.")) + return } + ///// + ///// + ///// DONE WITH VALIDATION + ///// + ///// + + + // allocate session/stream id's + val connectionSessionId: Int + try { + connectionSessionId = sessionIdAllocator.allocate() + } catch (e: AllocationException) { + listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) + } + return + } + + + val connectionStreamPubId: Int + try { + connectionStreamPubId = streamIdAllocator.allocate() + } catch (e: AllocationException) { + // have to unwind actions! + sessionIdAllocator.free(connectionSessionId) + + listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) + } + return + } + + val connectionStreamSubId: Int + try { + connectionStreamSubId = streamIdAllocator.allocate() + } catch (e: AllocationException) { + // have to unwind actions! + sessionIdAllocator.free(connectionSessionId) + sessionIdAllocator.free(connectionStreamPubId) + + listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) + } + return + } + + + // create a new connection. The session ID is encrypted. + try { + // connection timeout of 0 doesn't matter. it is not used by the server + val clientConnection = IpcMediaDriverConnection(streamId = connectionStreamPubId, + streamIdSubscription = connectionStreamSubId, + sessionId = connectionSessionId, + connectionTimeoutMS = 0) + + // we have to construct how the connection will communicate! + clientConnection.buildServer(aeron) + + logger.info { + "[${clientConnection.sessionId}] aeron IPC connection established to $clientConnection" + } + + val connection = server.newConnection(ConnectionParams(server, clientConnection, PublicKeyValidationState.VALID)) + + // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) + @Suppress("UNCHECKED_CAST") + val permitConnection = listenerManager.notifyFilter(connection) + if (!permitConnection) { + // have to unwind actions! + sessionIdAllocator.free(connectionSessionId) + streamIdAllocator.free(connectionStreamPubId) + + val exception = ClientRejectedException("Connection was not permitted!") + ListenerManager.cleanStackTrace(exception) + listenerManager.notifyError(connection, exception) + + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, + HandshakeMessage.error("Connection was not permitted!")) + } + + return + } + + + /////////////// + //// RMI + /////////////// + + // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information + // NOTE: This modifies the readKryo! This cannot be on a different thread! + serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage -> + listenerManager.notifyError(connection, + ClientRejectedException(errorMessage)) + } + + + + /////////////// + /// HANDSHAKE + /////////////// + + + + // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! + val successMessage = HandshakeMessage.helloAckIpcToClient(sessionId) + + + // if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint + + // now create the encrypted payload, using ECDH + val cryptOutput = server.crypto.cryptOutput + cryptOutput.reset() + cryptOutput.writeInt(connectionSessionId) + cryptOutput.writeInt(connectionStreamSubId) + cryptOutput.writeInt(connectionStreamPubId) + + val kryoRmiIds = serialization.getKryoRmiIds() + cryptOutput.writeInt(kryoRmiIds.size) + kryoRmiIds.forEach { + cryptOutput.writeInt(it) + } + + successMessage.registrationData = cryptOutput.toBytes() + + successMessage.publicKey = server.crypto.publicKeyBytes + + // before we notify connect, we have to wait for the client to tell us that they can receive data + pendingConnectionsLock.write { + pendingConnections.put(sessionId, connection) + } + + // this tells the client all of the info to connect. + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, successMessage) + } + } catch (e: Exception) { + // have to unwind actions! + sessionIdAllocator.free(connectionSessionId) + streamIdAllocator.free(connectionStreamPubId) + + listenerManager.notifyError(ServerException("Connection handshake from $connectionString crashed! Message $message", e)) + } + + } + + // note: CANNOT be called in action dispatch + fun processHandshakeMessageServer(server: Server, + handshakePublication: Publication, + sessionId: Int, + clientAddressString: String, + clientAddress: Int, + message: Any?, + aeron: Aeron) { + + if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) { + return + } + message as HandshakeMessage + + val clientPublicKeyBytes = message.publicKey + val validateRemoteAddress: PublicKeyValidationState val serialization = config.serialization try { @@ -272,7 +485,7 @@ internal class ServerHandshake(private val logger: KLog // before we notify connect, we have to wait for the client to tell us that they can receive data pendingConnectionsLock.write { - pendingConnections[sessionId] = connection + pendingConnections.put(sessionId, connection) } // this tells the client all of the info to connect. @@ -293,8 +506,7 @@ internal class ServerHandshake(private val logger: KLog * Free up resources from the closed connection */ fun cleanup(connection: CONNECTION) { - connectionsPerIpCounts.getAndDecrement(connection.remoteAddressInt) - sessionIdAllocator.free(connection.sessionId) - streamIdAllocator.free(connection.streamId) + connection.cleanup(connectionsPerIpCounts, sessionIdAllocator, streamIdAllocator) + pendingConnections.invalidateAll() } } diff --git a/src/dorkbox/network/other/coroutines/SuspendWaiter.kt b/src/dorkbox/network/other/coroutines/SuspendWaiter.kt new file mode 100644 index 00000000..c16b3db9 --- /dev/null +++ b/src/dorkbox/network/other/coroutines/SuspendWaiter.kt @@ -0,0 +1,29 @@ +package dorkbox.network.other.coroutines + +import kotlinx.coroutines.channels.Channel + +// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting +// https://kotlinlang.org/docs/reference/coroutines/channels.html +class SuspendWaiter(private val channel: Channel = Channel()) { + // "receive' suspends until another coroutine invokes "send" + // and + // "send" suspends until another coroutine invokes "receive". + suspend fun doWait() { + try { + channel.receive() + } catch (ignored: Exception) { + } + } + suspend fun doNotify() { + try { + channel.send(Unit) + } catch (ignored: Exception) { + } + } + fun cancel() { + try { + channel.cancel() + } catch (ignored: Exception) { + } + } +}