From cffca943f5d278983ed6c64b627e69e62049d675 Mon Sep 17 00:00:00 2001 From: Robinson Date: Fri, 30 Apr 2021 16:01:25 +0200 Subject: [PATCH] No longer using coroutines for adding publication/subscription and closing certain calsses. Better aeron error handling/reporting. Better aeron startup/shutdown. The pending connections cache no longer is ThreadSafe, and no longer is protected via RW lock. --- src/dorkbox/network/Client.kt | 173 ++++++++++-------- src/dorkbox/network/Server.kt | 75 ++++---- src/dorkbox/network/aeron/AeronDriver.kt | 91 +++++---- .../network/aeron/IpcMediaDriverConnection.kt | 18 +- .../network/aeron/MediaDriverConnection.kt | 4 +- .../aeron/UdpMediaDriverClientConnection.kt | 22 +-- .../network/aeron/UdpMediaDriverConnection.kt | 25 +++ .../aeron/UdpMediaDriverServerConnection.kt | 8 +- src/dorkbox/network/connection/Connection.kt | 6 +- src/dorkbox/network/connection/EndPoint.kt | 26 ++- .../network/handshake/ClientHandshake.kt | 9 +- .../network/handshake/ServerHandshake.kt | 121 ++++-------- 12 files changed, 297 insertions(+), 281 deletions(-) create mode 100644 src/dorkbox/network/aeron/UdpMediaDriverConnection.kt diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index 85904d46..88410296 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -25,14 +25,15 @@ import dorkbox.network.exceptions.ClientException import dorkbox.network.exceptions.ClientRejectedException import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.handshake.ClientHandshake +import dorkbox.network.ping.Ping import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.TimeoutException import dorkbox.util.Sys import kotlinx.atomicfu.atomic -import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import java.net.Inet4Address import java.net.Inet6Address import java.net.InetAddress @@ -78,9 +79,8 @@ open class Client(config: Configuration = Configuration // This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper // lock-stop ordering for how disconnect and connect work with each-other - private val lockStepForReconnect = atomic(null) // GUARANTEE that the callbacks for 'onDisconnect' happens-before the 'onConnect'. - private val lockStepForDispatch = atomic(null) + private val lockStepForConnect = atomic(null) final override fun newException(message: String, cause: Throwable?): Throwable { return ClientException(message, cause) @@ -104,39 +104,54 @@ open class Client(config: Configuration = Configuration * - an InetAddress address * * ### For the IPC (Inter-Process-Communication) it must be: - * - EMPTY. * - `connect()` * - `connect("")` + * - `connectIpc()` * - * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') + * ### Case does not matter, and "localhost" is the default. * - * @param remoteAddress The network or if localhost, IPC address for the client to connect to + * @param remoteAddress The network host or ip address * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely - * @param reliable true if we want to create a reliable connection. IPC connections are always reliable + * @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?). * * @throws IllegalArgumentException if the remote address is invalid * @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientRejectedException if the client connection is rejected */ @Suppress("BlockingMethodInNonBlockingContext") - suspend fun connect(remoteAddress: String, - connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + fun connect(remoteAddress: String = "", + connectionTimeoutMS: Long = 30_000L, + reliable: Boolean = true) { when { // this is default IPC settings - remoteAddress.isEmpty() -> connect(connectionTimeoutMS = connectionTimeoutMS) + remoteAddress.isEmpty() -> { + connectIpc(connectionTimeoutMS = connectionTimeoutMS) + } - IPv4.isPreferred -> connect(remoteAddress = Inet4.toAddress(remoteAddress), - connectionTimeoutMS = connectionTimeoutMS, - reliable = reliable) + IPv4.isPreferred -> { + connect( + remoteAddress = Inet4.toAddress(remoteAddress), + connectionTimeoutMS = connectionTimeoutMS, + reliable = reliable + ) + } - IPv6.isPreferred -> connect(remoteAddress = Inet6.toAddress(remoteAddress), - connectionTimeoutMS = connectionTimeoutMS, - reliable = reliable) + IPv6.isPreferred -> { + connect( + remoteAddress = Inet6.toAddress(remoteAddress), + connectionTimeoutMS = connectionTimeoutMS, + reliable = reliable + ) + } // if there is no preference, then try to connect via IPv4 - else -> connect(remoteAddress = Inet4.toAddress(remoteAddress), - connectionTimeoutMS = connectionTimeoutMS, - reliable = reliable) + else -> { + connect( + remoteAddress = Inet4.toAddress(remoteAddress), + connectionTimeoutMS = connectionTimeoutMS, + reliable = reliable + ) + } } } @@ -151,22 +166,24 @@ open class Client(config: Configuration = Configuration * - an InetAddress address * * ### For the IPC (Inter-Process-Communication) it must be: - * - EMPTY. * - `connect()` * - `connect("")` + * - `connectIpc()` * - * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') + * ### Case does not matter, and "localhost" is the default. * * @param remoteAddress The network or if localhost, IPC address for the client to connect to * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely - * @param reliable true if we want to create a reliable connection. IPC connections are always reliable + * @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?). * * @throws IllegalArgumentException if the remote address is invalid * @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientRejectedException if the client connection is rejected */ - suspend fun connect(remoteAddress: InetAddress, - connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + fun connect(remoteAddress: InetAddress, + connectionTimeoutMS: Long = 30_000L, + reliable: Boolean = true) { + // Default IPC ports are flipped because they are in the perspective of the SERVER connect(remoteAddress = remoteAddress, ipcPublicationId = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, @@ -187,9 +204,9 @@ open class Client(config: Configuration = Configuration * @throws ClientRejectedException if the client connection is rejected */ @Suppress("DuplicatedCode") - suspend fun connect(ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, - ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB, - connectionTimeoutMS: Long = 30_000L) { + fun connectIpc(ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, + ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB, + connectionTimeoutMS: Long = 30_000L) { // Default IPC ports are flipped because they are in the perspective of the SERVER require(ipcPublicationId != ipcSubscriptionId) { "IPC publication and subscription ports cannot be the same! The must match the server's configuration." } @@ -208,42 +225,41 @@ open class Client(config: Configuration = Configuration * ### For a network address, it can be: * - a network name ("localhost", "loopback", "lo", "bob.example.org") * - an IP address ("127.0.0.1", "123.123.123.123", "::1") + * - an InetAddress address * * ### For the IPC (Inter-Process-Communication) it must be: - * - EMPTY. ie: just call `connect()` - * - Specified EMPTY. ie: just call `connect()` + * - `connect()` + * - `connect("")` + * - `connectIpc()` * - * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') + * ### Case does not matter, and "localhost" is the default. + * + * ### Case does not matter, and "localhost" is the default. * * @param remoteAddress The network or if localhost, IPC address for the client to connect to * @param ipcPublicationId The IPC publication address for the client to connect to * @param ipcSubscriptionId The IPC subscription address for the client to connect to * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely. - * @param reliable true if we want to create a reliable connection. IPC connections are always reliable + * @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?). * * @throws IllegalArgumentException if the remote address is invalid * @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientRejectedException if the client connection is rejected */ @Suppress("DuplicatedCode") - private suspend fun connect(remoteAddress: InetAddress? = null, - // Default IPC ports are flipped because they are in the perspective of the SERVER - ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, - ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB, - connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { - + private fun connect(remoteAddress: InetAddress? = null, + // Default IPC ports are flipped because they are in the perspective of the SERVER + ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, + ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB, + connectionTimeoutMS: Long = 30_000L, + reliable: Boolean = true) { require(connectionTimeoutMS >= 0) { "connectionTimeoutMS '$connectionTimeoutMS' is invalid. It must be >=0" } - // this will exist ONLY if we are reconnecting via a "disconnect" callback - lockStepForReconnect.value?.doWait() - if (isConnected) { logger.error("Unable to connect when already connected!") return } - lockStepForReconnect.lazySet(null) - // localhost/loopback IP might not always be 127.0.0.1 or ::1 this.remoteAddress0 = remoteAddress connection0 = null @@ -268,17 +284,17 @@ open class Client(config: Configuration = Configuration // only change LOCALHOST -> IPC if the media driver is ALREADY running LOCALLY! var isUsingIPC = false - val canUseIPC = config.enableIpc && remoteAddress == null - val autoChangeToIpc = canUseIPC && config.enableIpcForLoopback && - remoteAddress != null && remoteAddress.isLoopbackAddress && aeronDriver.isRunning() - if (autoChangeToIpc) { - logger.info {"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" } - } + val canUseIPC = config.enableIpc + val autoChangeToIpc = config.enableIpcForLoopback && + (remoteAddress == null || remoteAddress.isLoopbackAddress) && aeronDriver.isRunning() + val handshake = ClientHandshake(crypto, this) - val handshakeConnection = if (autoChangeToIpc || canUseIPC) { - // MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead + val handshakeConnection = if (autoChangeToIpc) { + logger.info {"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" } + + // MAYBE the server doesn't have IPC enabled? If no, we need to connect via network instead val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId, streamId = ipcPublicationId, sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID) @@ -288,8 +304,8 @@ open class Client(config: Configuration = Configuration ipcConnection.buildClient(aeronDriver, logger) isUsingIPC = true } catch (e: Exception) { - // if we specified that we want to use IPC, then we have to throw the timeout exception, because there is no IPC - if (canUseIPC) { + // if we specified that we MUST use IPC, then we have to throw the exception, because there is no IPC + if (remoteAddress == null) { throw e } } @@ -301,7 +317,7 @@ open class Client(config: Configuration = Configuration // try a UDP connection instead val udpConnection = UdpMediaDriverClientConnection( - address = this.remoteAddress0!!, + address = remoteAddress!!, publicationPort = config.subscriptionPort, subscriptionPort = config.publicationPort, streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID, @@ -316,7 +332,7 @@ open class Client(config: Configuration = Configuration } else { val test = UdpMediaDriverClientConnection( - address = this.remoteAddress0!!, + address = remoteAddress!!, publicationPort = config.subscriptionPort, subscriptionPort = config.publicationPort, streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID, @@ -343,7 +359,7 @@ open class Client(config: Configuration = Configuration val validateRemoteAddress = if (isUsingIPC) { PublicKeyValidationState.VALID } else { - crypto.validateRemoteAddress(this.remoteAddress0!!, connectionInfo.publicKey) + crypto.validateRemoteAddress(remoteAddress!!, connectionInfo.publicKey) } if (validateRemoteAddress == PublicKeyValidationState.INVALID) { @@ -438,16 +454,14 @@ open class Client(config: Configuration = Configuration ////////////// newConnection.preCloseAction = { // this is called whenever connection.close() is called by the framework or via client.close() - if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) { - listenerManager.notifyError(connection, IllegalStateException("lockStep for reconnect was in the wrong state!")) - } // on the client, we want to GUARANTEE that the disconnect happens-before the connect. - if (!lockStepForDispatch.compareAndSet(null, SuspendWaiter())) { - listenerManager.notifyError(connection, IllegalStateException("lockStep for dispatch was in the wrong state!")) + if (!lockStepForConnect.compareAndSet(null, SuspendWaiter())) { + listenerManager.notifyError(connection, IllegalStateException("lockStep for onConnect was in the wrong state!")) } } newConnection.postCloseAction = { + isConnected = false // this is called whenever connection.close() is called by the framework or via client.close() // make sure to call our client.notifyDisconnect() callbacks @@ -462,10 +476,6 @@ open class Client(config: Configuration = Configuration lockStepForDispatch.value?.cancel() } - - // in case notifyDisconnect called client.connect().... cancel them waiting - isConnected = false - lockStepForReconnect.value?.cancel() } connection0 = newConnection @@ -477,24 +487,17 @@ open class Client(config: Configuration = Configuration if (canFinishConnecting) { isConnected = true + // this forces the current thread to WAIT until poll system has started + val waiter = SuspendWaiter() + // have to make a new thread to listen for incoming data! // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them - // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback - @Suppress("EXPERIMENTAL_API_USAGE") - actionDispatch.launch(start = CoroutineStart.UNDISPATCHED) { - lockStepForDispatch.value?.doWait() - - // NOTE: UNDISPATCHED means that this coroutine will start as an event loop, instead of concurrently - // we want this behavior INSTEAD OF automatically starting this on a new thread. - listenerManager.notifyConnect(newConnection) - - lockStepForDispatch.lazySet(null) - } - // these have to be in two SEPARATE actionDispatch.launch commands.... otherwise... // if something inside of notifyConnect is blocking or suspends, then polling will never happen! actionDispatch.launch { + waiter.doNotify() + val pollIdleStrategy = config.pollIdleStrategy while (!isShutdown()) { @@ -502,8 +505,12 @@ open class Client(config: Configuration = Configuration // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. logger.debug {"[${newConnection.id}] connection expired"} - // NOTE: We do not shutdown the client!! The client is only closed by explicitly calling `client.close()` - newConnection.close() + // eventloop is required, because we want to run this code AFTER the current coroutine has finished. This prevents + // odd race conditions when a client is restarted + actionDispatch.eventLoop { + // NOTE: We do not shutdown the client!! The client is only closed by explicitly calling `client.close()` + newConnection.close() + } return@launch } else { @@ -515,6 +522,16 @@ open class Client(config: Configuration = Configuration } } } + + actionDispatch.eventLoop { + waiter.doWait() + + lockStepForConnect.value?.doWait() + + listenerManager.notifyConnect(newConnection) + + lockStepForConnect.lazySet(null) + } } else { close() val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index bd75caa0..8549f7a9 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -36,7 +36,6 @@ import dorkbox.network.rmi.TimeoutException import io.aeron.FragmentAssembler import io.aeron.Image import io.aeron.logbuffer.Header -import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Job import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking @@ -156,7 +155,7 @@ open class Server(config: ServerConfiguration = ServerC return super.getRmiConnectionSupport() } - private suspend fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { + private fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { val poller = if (config.enableIpc) { val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId, streamId = config.ipcPublicationId, @@ -180,9 +179,7 @@ open class Server(config: ServerConfiguration = ServerC if (message !is HandshakeMessage) { listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request")) - actionDispatch.launch { - writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) - } + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) return@FragmentAssembler } @@ -190,7 +187,7 @@ open class Server(config: ServerConfiguration = ServerC publication, sessionId, message, - aeronDriver) + aeronDriver) } override fun poll(): Int { return subscription.poll(handler, 1) } @@ -210,7 +207,7 @@ open class Server(config: ServerConfiguration = ServerC } @Suppress("DuplicatedCode") - private suspend fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { + private fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { val poller = if (canUseIPv4) { val driver = UdpMediaDriverServerConnection( listenAddress = listenIPv4Address!!, @@ -260,9 +257,7 @@ open class Server(config: ServerConfiguration = ServerC if (message !is HandshakeMessage) { listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) - actionDispatch.launch { - writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) - } + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) return@FragmentAssembler } @@ -293,7 +288,7 @@ open class Server(config: ServerConfiguration = ServerC } @Suppress("DuplicatedCode") - private suspend fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { + private fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { val poller = if (canUseIPv6) { val driver = UdpMediaDriverServerConnection( listenAddress = listenIPv6Address!!, @@ -343,9 +338,7 @@ open class Server(config: ServerConfiguration = ServerC if (message !is HandshakeMessage) { listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) - actionDispatch.launch { - writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) - } + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) return@FragmentAssembler } @@ -376,7 +369,7 @@ open class Server(config: ServerConfiguration = ServerC } @Suppress("DuplicatedCode") - private suspend fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { + private fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { val driver = UdpMediaDriverServerConnection( listenAddress = listenIPv6Address!!, publicationPort = config.publicationPort, @@ -426,9 +419,7 @@ open class Server(config: ServerConfiguration = ServerC if (message !is HandshakeMessage) { listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) - actionDispatch.launch { - writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) - } + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) return@FragmentAssembler } @@ -468,39 +459,39 @@ open class Server(config: ServerConfiguration = ServerC // we are done with initial configuration, now initialize aeron and the general state of this endpoint bindAlreadyCalled = true + // this forces the current thread to WAIT until poll system has started val waiter = SuspendWaiter() - actionDispatch.launch { - val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config) - // if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled! - val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD - val ipv4Poller: AeronPoller - val ipv6Poller: AeronPoller + val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config) - if (isWildcard) { - // IPv6 will bind to IPv4 wildcard as well!! - if (canUseIPv4 && canUseIPv6) { - ipv4Poller = object : AeronPoller { - override fun poll(): Int { return 0 } - override fun close() {} - override fun serverInfo(): String { return "IPv4 Disabled" } - } - ipv6Poller = getIpv6WildcardPoller(aeronDriver, config) - } else { - // only 1 will be a real poller - ipv4Poller = getIpv4Poller(aeronDriver, config) - ipv6Poller = getIpv6Poller(aeronDriver, config) + // if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled! + val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD + val ipv4Poller: AeronPoller + val ipv6Poller: AeronPoller + + if (isWildcard) { + // IPv6 will bind to IPv4 wildcard as well!! + if (canUseIPv4 && canUseIPv6) { + ipv4Poller = object : AeronPoller { + override fun poll(): Int { return 0 } + override fun close() {} + override fun serverInfo(): String { return "IPv4 Disabled" } } + ipv6Poller = getIpv6WildcardPoller(aeronDriver, config) } else { + // only 1 will be a real poller ipv4Poller = getIpv4Poller(aeronDriver, config) ipv6Poller = getIpv6Poller(aeronDriver, config) } + } else { + ipv4Poller = getIpv4Poller(aeronDriver, config) + ipv6Poller = getIpv6Poller(aeronDriver, config) + } + actionDispatch.launch { waiter.doNotify() - val pollIdleStrategy = config.pollIdleStrategy - try { var pollCount: Int @@ -590,9 +581,6 @@ open class Server(config: ServerConfiguration = ServerC jobs.add(job) } - // reset all of the handshake info - handshake.clear() - // when we close a client or a server, we want to make sure that ALL notifications are finished. // when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown jobs.forEach { it.join() } @@ -604,6 +592,9 @@ open class Server(config: ServerConfiguration = ServerC ipv6Poller.close() ipcPoller.close() + // clear all of the handshake info + handshake.clear() + // finish closing -- this lets us make sure that we don't run into race conditions on the thread that calls close() shutdownEventWaiter.doNotify() } diff --git a/src/dorkbox/network/aeron/AeronDriver.kt b/src/dorkbox/network/aeron/AeronDriver.kt index 75464a0d..06d4357c 100644 --- a/src/dorkbox/network/aeron/AeronDriver.kt +++ b/src/dorkbox/network/aeron/AeronDriver.kt @@ -10,10 +10,9 @@ import io.aeron.Subscription import io.aeron.driver.MediaDriver import io.aeron.exceptions.DriverTimeoutException import kotlinx.atomicfu.atomic -import kotlinx.coroutines.delay -import kotlinx.coroutines.runBlocking import mu.KLogger import mu.KotlinLogging +import org.agrona.concurrent.AgentTerminationException import org.agrona.concurrent.BackoffIdleStrategy import java.io.File import java.lang.Thread.sleep @@ -153,20 +152,9 @@ class AeronDriver(val config: Configuration, // we DO NOT want to abort the JVM if there are errors. context.errorHandler { error -> - if (error is DriverTimeoutException) { - // we suppress this because it is already handled - return@errorHandler - } - - if (error.cause is BindException) { - // we suppress this because it is already handled - return@errorHandler - } - - logger.error("Error in Aeron Media Driver", error) + AeronDriver.manageError(logger, error) } - val aeronDir = File(context.aeronDirectoryName()).absoluteFile context.aeronDirectoryName(aeronDir.path) @@ -238,9 +226,9 @@ class AeronDriver(val config: Configuration, * * @return true if the media driver is active and running */ - fun isRunning(context: MediaDriver.Context, timeout: Long = context.driverTimeoutMs()): Boolean { + fun isRunning(context: MediaDriver.Context): Boolean { // if the media driver is running, it will be a quick connection. Usually 100ms or so - return context.isDriverActive(timeout) { } + return context.isDriverActive(context.driverTimeoutMs()) { } } /** @@ -255,6 +243,26 @@ class AeronDriver(val config: Configuration, require(config.context != null) { "Configuration context cannot be properly created. Unable to continue!" } } + + fun manageError(logger: KLogger, error: Throwable) { + if (error is DriverTimeoutException) { + // we suppress this because it is already handled + return + } + + if (error is AgentTerminationException) { + // we suppress this because it is already handled + return + } + + if (error.cause is BindException) { + // we suppress this because it is already handled + return + } + + ListenerManager.cleanStackTrace(error) + logger.error("Error in Aeron", error) + } } private val closeRequested = atomic(false) @@ -272,6 +280,14 @@ class AeronDriver(val config: Configuration, // did WE start the media driver, or did SOMEONE ELSE start it? private val mediaDriverWasAlreadyRunning: Boolean + /** + * @return the configured driver timeout + */ + val driverTimeout: Long by lazy { + mediaDriverContext.driverTimeoutMs() + } + + init { mediaDriverContext .conductorThreadFactory(threadFactory) @@ -296,14 +312,12 @@ class AeronDriver(val config: Configuration, // we DO NOT want to abort the JVM if there are errors. // this replaces the default handler with one that doesn't abort the JVM aeronDriverContext.errorHandler { error -> - ListenerManager.cleanStackTrace(error) - logger.error("Error in Aeron", error) + AeronDriver.manageError(logger, error) } return aeronDriverContext } - /** * @return true if the media driver was started, false if it was not started */ @@ -315,9 +329,18 @@ class AeronDriver(val config: Configuration, if (mediaDriver == null) { // only start if we didn't already start... There will be several checks. - if (!isRunning(mediaDriverContext)) { - logger.debug("Starting Aeron Media driver in '${mediaDriverContext.aeronDirectory()}'") + var running = isRunning(mediaDriverContext) + if (running) { + // wait for a bit, because we are running, but we ALSO issued a START, and expect it to start. + // SOMETIMES aeron is in the middle of shutting down, and this prevents us from trying to connect to + // that instance + logger.debug("Aeron Media driver already running. Double checking status...") + sleep(mediaDriverContext.driverTimeoutMs()/2) + running = isRunning(mediaDriverContext) + } + if (!running) { + logger.debug("Starting Aeron Media driver.") // try to start. If we start/stop too quickly, it's a problem var count = 10 @@ -327,13 +350,11 @@ class AeronDriver(val config: Configuration, return true } catch (e: Exception) { logger.warn(e) { "Unable to start the Aeron Media driver. Retrying $count more times..." } - runBlocking { - delay(mediaDriverContext.driverTimeoutMs()) - } + sleep(mediaDriverContext.driverTimeoutMs()) } } } else { - logger.debug("Not starting Aeron Media driver. It was already running in '${mediaDriverContext.aeronDirectory()}'") + logger.debug("Not starting Aeron Media driver. It was already running.") } } @@ -389,7 +410,7 @@ class AeronDriver(val config: Configuration, } - suspend fun addPublicationWithRetry(publicationUri: ChannelUriStringBuilder, streamId: Int): Publication { + fun addPublicationWithRetry(publicationUri: ChannelUriStringBuilder, streamId: Int): Publication { val uri = publicationUri.build() // If we start/stop too quickly, we might have the address already in use! Retry a few times. @@ -404,13 +425,14 @@ class AeronDriver(val config: Configuration, exception = e logger.warn { "Unable to add a publication to Aeron. Retrying $count more times..." } + // if exceptions are added here, make sure to ALSO suppress them in the context error handler if (e is DriverTimeoutException) { - delay(mediaDriverContext.driverTimeoutMs()) + sleep(mediaDriverContext.driverTimeoutMs()) } if (e.cause is BindException) { // was starting too fast! - delay(mediaDriverContext.driverTimeoutMs()) + sleep(mediaDriverContext.driverTimeoutMs()) } // reasons we cannot add a pub/sub to aeron @@ -434,7 +456,7 @@ class AeronDriver(val config: Configuration, throw exception!! } - suspend fun addSubscriptionWithRetry(subscriptionUri: ChannelUriStringBuilder, streamId: Int): Subscription { + fun addSubscriptionWithRetry(subscriptionUri: ChannelUriStringBuilder, streamId: Int): Subscription { val uri = subscriptionUri.build() // If we start/stop too quickly, we might have the address already in use! Retry a few times. @@ -451,15 +473,16 @@ class AeronDriver(val config: Configuration, } catch (e: Exception) { // NOTE: this error will be logged in the `aeronDriverContext` logger exception = e - logger.warn { "Unable to add a sublication to Aeron. Retrying $count more times..." } + logger.warn { "Unable to add a subscription to Aeron. Retrying $count more times..." } + // if exceptions are added here, make sure to ALSO suppress them in the context error handler if (e is DriverTimeoutException) { - delay(mediaDriverContext.driverTimeoutMs()) + sleep(mediaDriverContext.driverTimeoutMs()) } if (e.cause is BindException) { // was starting too fast! - delay(mediaDriverContext.driverTimeoutMs()) + sleep(mediaDriverContext.driverTimeoutMs()) } // reasons we cannot add a pub/sub to aeron @@ -545,6 +568,8 @@ class AeronDriver(val config: Configuration, logger.warn { "Aeron Media driver at '${mediaDriverContext.aeronDirectory()}' is still running. Waiting for it to stop. Trying $count more times." } sleep(mediaDriverContext.driverTimeoutMs()) } + logger.debug { "Closed the media driver at '${mediaDriverContext.aeronDirectory()}'" } + } catch (e: Exception) { logger.error("Error closing the media driver at '${mediaDriverContext.aeronDirectory()}'", e) } @@ -552,7 +577,5 @@ class AeronDriver(val config: Configuration, // Destroys this thread group and all of its subgroups. // This thread group must be empty, indicating that all threads that had been in this thread group have since stopped. threadFactory.group.destroy() - - logger.debug { "Closed the media driver at '${mediaDriverContext.aeronDirectory()}'" } } } diff --git a/src/dorkbox/network/aeron/IpcMediaDriverConnection.kt b/src/dorkbox/network/aeron/IpcMediaDriverConnection.kt index 340a9efc..2d9830d6 100644 --- a/src/dorkbox/network/aeron/IpcMediaDriverConnection.kt +++ b/src/dorkbox/network/aeron/IpcMediaDriverConnection.kt @@ -18,8 +18,8 @@ package dorkbox.network.aeron import dorkbox.network.exceptions.ClientTimedOutException import io.aeron.ChannelUriStringBuilder -import kotlinx.coroutines.delay import mu.KLogger +import java.lang.Thread.sleep /** * For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER @@ -47,7 +47,7 @@ internal open class IpcMediaDriverConnection(streamId: Int, * * @throws ClientTimedOutException if we cannot connect to the server in the designated time */ - override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { + override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { // Create a publication at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() @@ -64,23 +64,25 @@ internal open class IpcMediaDriverConnection(streamId: Int, // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. + var startTime = System.currentTimeMillis() + + var success = false + // If we start/stop too quickly, we might have the aeron connectivity issues! Retry a few times. val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId) val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamIdSubscription) - var success = false - // this will wait for the server to acknowledge the connection (all via aeron) - var startTime = System.currentTimeMillis() while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { if (subscription.isConnected && subscription.imageCount() > 0) { success = true break } - delay(timeMillis = 100L) + sleep(100L) } + if (!success) { subscription.close() throw ClientTimedOutException("Creating subscription connection to aeron") @@ -97,7 +99,7 @@ internal open class IpcMediaDriverConnection(streamId: Int, break } - delay(timeMillis = 100L) + sleep(100L) } if (!success) { @@ -116,7 +118,7 @@ internal open class IpcMediaDriverConnection(streamId: Int, * * serverAddress is ignored for IPC */ - override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { + override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() diff --git a/src/dorkbox/network/aeron/MediaDriverConnection.kt b/src/dorkbox/network/aeron/MediaDriverConnection.kt index e0af5269..e44c0df4 100644 --- a/src/dorkbox/network/aeron/MediaDriverConnection.kt +++ b/src/dorkbox/network/aeron/MediaDriverConnection.kt @@ -32,8 +32,8 @@ abstract class MediaDriverConnection( @Throws(ClientTimedOutException::class) - abstract suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) - abstract suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false) + abstract fun buildClient(aeronDriver: AeronDriver, logger: KLogger) + abstract fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false) abstract fun clientInfo() : String abstract fun serverInfo() : String diff --git a/src/dorkbox/network/aeron/UdpMediaDriverClientConnection.kt b/src/dorkbox/network/aeron/UdpMediaDriverClientConnection.kt index e878bd93..0dce4303 100644 --- a/src/dorkbox/network/aeron/UdpMediaDriverClientConnection.kt +++ b/src/dorkbox/network/aeron/UdpMediaDriverClientConnection.kt @@ -21,8 +21,8 @@ import dorkbox.network.connection.ListenerManager import dorkbox.network.exceptions.ClientException import dorkbox.network.exceptions.ClientTimedOutException import io.aeron.ChannelUriStringBuilder -import kotlinx.coroutines.delay import mu.KLogger +import java.lang.Thread.sleep import java.net.Inet4Address import java.net.InetAddress import java.util.concurrent.TimeUnit @@ -38,7 +38,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, sessionId: Int, connectionTimeoutMS: Long = 0, isReliable: Boolean = true) : - MediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { + UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { var success: Boolean = false @@ -80,7 +80,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, @Suppress("DuplicatedCode") @Throws(ClientException::class) - override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { + override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { val aeronAddressString = aeronConnectionString(address) // Create a publication at the given address and port, using the given stream ID. @@ -100,6 +100,8 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, logger.trace("client sub URI: $ip ${subscriptionUri.build()}") } + var success = false + // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. // on close, the publication CAN linger (in case a client goes away, and then comes back) @@ -107,7 +109,6 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId) val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamId) - var success = false // this will wait for the server to acknowledge the connection (all via aeron) val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS) @@ -118,12 +119,12 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, break } - delay(timeMillis = 100L) + sleep(100L) } if (!success) { subscription.close() - val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()}") + val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()} in ${timoutInNanos}ms") ListenerManager.cleanStackTrace(ex) throw ex } @@ -139,19 +140,18 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, break } - delay(timeMillis = 100L) + sleep(100L) } if (!success) { subscription.close() publication.close() - val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()}") -// ListenerManager.cleanStackTrace(ex) + val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()} in ${timoutInNanos}ms") + ListenerManager.cleanStackTrace(ex) throw ex } this.success = true - this.publication = publication this.subscription = subscription } @@ -164,7 +164,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress, } } - override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { + override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { throw ClientException("Server info not implemented in Client MDC") } override fun serverInfo(): String { diff --git a/src/dorkbox/network/aeron/UdpMediaDriverConnection.kt b/src/dorkbox/network/aeron/UdpMediaDriverConnection.kt new file mode 100644 index 00000000..0ae4c37d --- /dev/null +++ b/src/dorkbox/network/aeron/UdpMediaDriverConnection.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2021 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dorkbox.network.aeron + +abstract class UdpMediaDriverConnection(publicationPort: Int, subscriptionPort: Int, + streamId: Int, sessionId: Int, + connectionTimeoutMS: Long, isReliable: Boolean) : + MediaDriverConnection(publicationPort, subscriptionPort, + streamId, sessionId, + connectionTimeoutMS, isReliable) { +} diff --git a/src/dorkbox/network/aeron/UdpMediaDriverServerConnection.kt b/src/dorkbox/network/aeron/UdpMediaDriverServerConnection.kt index da140684..b1602863 100644 --- a/src/dorkbox/network/aeron/UdpMediaDriverServerConnection.kt +++ b/src/dorkbox/network/aeron/UdpMediaDriverServerConnection.kt @@ -36,11 +36,11 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres sessionId: Int, connectionTimeoutMS: Long = 0, isReliable: Boolean = true) : - MediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { + UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { var success: Boolean = false - protected fun aeronConnectionString(ipAddress: InetAddress): String { + private fun aeronConnectionString(ipAddress: InetAddress): String { return if (ipAddress is Inet4Address) { ipAddress.hostAddress } else { @@ -64,11 +64,11 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres } @Suppress("DuplicatedCode") - override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { + override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { throw ServerException("Client info not implemented in Server MDC") } - override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { + override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { val connectionString = aeronConnectionString(listenAddress) // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index 3caa5a59..f06f6ecc 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -17,8 +17,8 @@ package dorkbox.network.connection import dorkbox.network.aeron.IpcMediaDriverConnection import dorkbox.network.aeron.UdpMediaDriverClientConnection +import dorkbox.network.aeron.UdpMediaDriverConnection import dorkbox.network.aeron.UdpMediaDriverPairedConnection -import dorkbox.network.aeron.UdpMediaDriverServerConnection import dorkbox.network.handshake.ConnectionCounts import dorkbox.network.handshake.RandomIdAllocator import dorkbox.network.ping.Ping @@ -33,9 +33,7 @@ import io.aeron.Subscription import io.aeron.logbuffer.Header import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate -import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.delay -import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.agrona.DirectBuffer import java.io.IOException @@ -84,7 +82,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { /** * @return true if this connection is a network connection */ - val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverServerConnection + val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverConnection /** * the endpoint associated with this connection diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index d3525eea..8eecc866 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -87,7 +87,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A private val handshakeKryo: KryoExtra private val sendIdleStrategy: CoroutineIdleStrategy - private val sendIdleStrategyHandshake: IdleStrategy + private val sendIdleStrategyHandShake: IdleStrategy + + val pollIdleStrategy: CoroutineIdleStrategy + val pollIdleStrategyHandShake: IdleStrategy /** * Crypto and signature management @@ -126,7 +129,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A // serialization stuff serialization = config.serialization sendIdleStrategy = config.sendIdleStrategy - sendIdleStrategyHandshake = sendIdleStrategy.cloneToNormal() + pollIdleStrategy = config.pollIdleStrategy + + sendIdleStrategyHandShake = sendIdleStrategy.cloneToNormal() + pollIdleStrategyHandShake = pollIdleStrategy.cloneToNormal() handshakeKryo = serialization.initHandshakeKryo() @@ -347,7 +353,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A */ if (result >= Publication.ADMIN_ACTION) { // we should retry. - sendIdleStrategyHandshake.idle() + sendIdleStrategyHandShake.idle() continue } @@ -362,7 +368,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A ListenerManager.cleanStackTrace(exception, 2) // 2 because we do not want to see the stack for the abstract `newException` listenerManager.notifyError(exception) } finally { - sendIdleStrategyHandshake.reset() + sendIdleStrategyHandShake.reset() } } @@ -431,7 +437,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A when (message) { is PingMessage -> { - // the ping listener (internal use only!) + // the ping listener actionDispatch.launch { pingManager.manage(this@EndPoint, connection, message, logger) } @@ -636,22 +642,22 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A final override fun close() { if (shutdown.compareAndSet(expect = false, update = true)) { logger.info { "Shutting down..." } + aeronDriver.close() runBlocking { - aeronDriver.close() - connections.forEach { it.close() } - // the storage is closed via this as well. - storage.close() - // Connections are closed first, because we want to make sure that no RMI messages can be received // when we close the RMI support objects (in which case, weird - but harmless - errors show up) + // this will wait for RMI timeouts if there are RMI in-progress. (this happens if we close via and RMI method) rmiGlobalSupport.close() } + // the storage is closed via this as well. + storage.close() + close0() // if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now) diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index a3f28586..7050025a 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -143,7 +143,7 @@ internal class ClientHandshake(private val crypto: Crypt } // called from the connect thread - suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { + fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { failed = null oneTimeKey = endPoint.crypto.secureRandom.nextInt() val publicKey = endPoint.storage.getPublicKey()!! @@ -151,8 +151,7 @@ internal class ClientHandshake(private val crypto: Crypt // Send the one-time pad to the server. val publication = handshakeConnection.publication val subscription = handshakeConnection.subscription - val pollIdleStrategy = endPoint.config.pollIdleStrategy - + val pollIdleStrategy = endPoint.pollIdleStrategyHandShake endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey)) @@ -191,7 +190,7 @@ internal class ClientHandshake(private val crypto: Crypt } // called from the connect thread - suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { + fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey) // Send the done message to the server. @@ -203,7 +202,7 @@ internal class ClientHandshake(private val crypto: Crypt failed = null var pollCount: Int val subscription = handshakeConnection.subscription - val pollIdleStrategy = endPoint.config.pollIdleStrategy + val pollIdleStrategy = endPoint.pollIdleStrategyHandShake var startTime = System.currentTimeMillis() while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) { diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index 0da74155..be6873be 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -24,23 +24,15 @@ import dorkbox.network.ServerConfiguration import dorkbox.network.aeron.AeronDriver import dorkbox.network.aeron.IpcMediaDriverConnection import dorkbox.network.aeron.UdpMediaDriverPairedConnection -import dorkbox.network.connection.Connection -import dorkbox.network.connection.ConnectionParams -import dorkbox.network.connection.ListenerManager -import dorkbox.network.connection.PublicKeyValidationState +import dorkbox.network.connection.* import dorkbox.network.exceptions.* import io.aeron.Publication import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.CoroutineStart -import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import mu.KLogger import java.net.Inet4Address import java.net.InetAddress import java.util.concurrent.TimeUnit -import java.util.concurrent.locks.ReentrantReadWriteLock -import kotlin.concurrent.read -import kotlin.concurrent.write /** @@ -51,13 +43,12 @@ internal class ServerHandshake(private val logger: KLog private val config: ServerConfiguration, private val listenerManager: ListenerManager) { - private val pendingConnectionsLock = ReentrantReadWriteLock() + // note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close private val pendingConnections: Cache = Caffeine.newBuilder() - .expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS) - .removalListener(RemovalListener { _, value, cause -> + .expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong() * 2, TimeUnit.SECONDS) + .removalListener(RemovalListener { sessionId, connection, cause -> if (cause == RemovalCause.EXPIRED) { - @Suppress("UNCHECKED_CAST") - val connection = value as CONNECTION + connection!! val exception = ClientTimedOutException("[${connection.id}] Waiting for registration response from client") ListenerManager.noStackTrace(exception) @@ -90,19 +81,15 @@ internal class ServerHandshake(private val logger: KLog // check to see if this sessionId is ALREADY in use by another connection! // this can happen if there are multiple connections from the SAME ip address (ie: localhost) if (message.state == HandshakeMessage.HELLO) { - val hasExistingSessionId = pendingConnectionsLock.read { - pendingConnections.getIfPresent(sessionId) != null - } - + // this should be null. + val hasExistingSessionId = pendingConnections.getIfPresent(sessionId) != null if (hasExistingSessionId) { // WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId val exception = ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.") ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!")) return false } @@ -111,14 +98,11 @@ internal class ServerHandshake(private val logger: KLog // check to see if this is a pending connection if (message.state == HandshakeMessage.DONE) { - val pendingConnection = pendingConnectionsLock.write { - val con = pendingConnections.getIfPresent(sessionId) - pendingConnections.invalidate(sessionId) - con - } + val pendingConnection = pendingConnections.getIfPresent(sessionId) + pendingConnections.invalidate(sessionId) if (pendingConnection == null) { - val exception = ClientException("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!") + val exception = ServerException("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!") ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) } else { @@ -127,7 +111,6 @@ internal class ServerHandshake(private val logger: KLog // this enables the connection to start polling for messages server.addConnection(pendingConnection) - // now tell the client we are done // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback @@ -165,9 +148,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) return false } @@ -182,9 +163,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address")) return false } connectionsPerIpCounts.increment(clientAddress, currentCountForIp) @@ -193,9 +172,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) return false } @@ -234,9 +211,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) return } @@ -252,9 +227,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) return } @@ -270,9 +243,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) return } @@ -284,9 +255,7 @@ internal class ServerHandshake(private val logger: KLog sessionId = connectionSessionId) // we have to construct how the connection will communicate! - runBlocking { - clientConnection.buildServer(aeronDriver, logger, true) - } + clientConnection.buildServer(aeronDriver, logger, true) logger.info { "[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]" @@ -326,14 +295,10 @@ internal class ServerHandshake(private val logger: KLog successMessage.publicKey = server.crypto.publicKeyBytes // before we notify connect, we have to wait for the client to tell us that they can receive data - pendingConnectionsLock.write { - pendingConnections.put(sessionId, connection) - } + pendingConnections.put(sessionId, connection) // this tells the client all of the info to connect. - runBlocking { - server.writeHandshakeMessage(handshakePublication, successMessage) - } + server.writeHandshakeMessage(handshakePublication, successMessage) } catch (e: Exception) { // have to unwind actions! sessionIdAllocator.free(connectionSessionId) @@ -398,9 +363,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) return } @@ -417,9 +380,7 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.noStackTrace(exception) listenerManager.notifyError(exception) - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) return } @@ -450,9 +411,7 @@ internal class ServerHandshake(private val logger: KLog message.isReliable) // we have to construct how the connection will communicate! - runBlocking { - clientConnection.buildServer(aeronDriver, logger, true) - } + clientConnection.buildServer(aeronDriver, logger, true) logger.info { // (reliable:$isReliable)" @@ -462,21 +421,19 @@ internal class ServerHandshake(private val logger: KLog val connection = server.newConnection(ConnectionParams(server, clientConnection, validateRemoteAddress)) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) - runBlocking { - val permitConnection = listenerManager.notifyFilter(connection) - if (!permitConnection) { - // have to unwind actions! - connectionsPerIpCounts.decrementSlow(clientAddress) - sessionIdAllocator.free(connectionSessionId) - streamIdAllocator.free(connectionStreamId) + val permitConnection = listenerManager.notifyFilter(connection) + if (!permitConnection) { + // have to unwind actions! + connectionsPerIpCounts.decrementSlow(clientAddress) + sessionIdAllocator.free(connectionSessionId) + streamIdAllocator.free(connectionStreamId) - val exception = ClientRejectedException("Connection $clientAddressString was not permitted!") - ListenerManager.cleanStackTrace(exception) - listenerManager.notifyError(connection, exception) + val exception = ClientRejectedException("Connection $clientAddressString was not permitted!") + ListenerManager.cleanStackTrace(exception) + listenerManager.notifyError(connection, exception) - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) - return@runBlocking - } + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) + return } @@ -503,14 +460,10 @@ internal class ServerHandshake(private val logger: KLog successMessage.publicKey = server.crypto.publicKeyBytes // before we notify connect, we have to wait for the client to tell us that they can receive data - pendingConnectionsLock.write { - pendingConnections.put(sessionId, connection) - } + pendingConnections.put(sessionId, connection) // this tells the client all of the info to connect. - runBlocking { - server.writeHandshakeMessage(handshakePublication, successMessage) - } + server.writeHandshakeMessage(handshakePublication, successMessage) } catch (e: Exception) { // have to unwind actions! connectionsPerIpCounts.decrementSlow(clientAddress) @@ -538,6 +491,8 @@ internal class ServerHandshake(private val logger: KLog // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD sessionIdAllocator.clear() streamIdAllocator.clear() + pendingConnections.invalidateAll() + pendingConnections.cleanUp() } }