From 93e406289cdfa48948e4cc1d56f1c293cf6d1e3f Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 23 Sep 2020 17:08:25 +0200 Subject: [PATCH] Fixed issues with connection handshake with the same computer, on multiple clients. --- src/dorkbox/network/Client.kt | 14 ++- src/dorkbox/network/Server.kt | 48 +++++++++- .../network/aeron/MediaDriverConnection.kt | 76 +++++++-------- src/dorkbox/network/connection/Connection.kt | 94 ++++++++++--------- .../network/handshake/ClientHandshake.kt | 21 ++++- .../network/handshake/HandshakeMessage.kt | 11 ++- .../network/handshake/ServerHandshake.kt | 52 +++++----- 7 files changed, 196 insertions(+), 120 deletions(-) diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index 70751ec4..3f0a54bd 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -115,6 +115,7 @@ open class Client(config: Configuration = Configuration * @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientRejectedException if the client connection is rejected */ + @Suppress("BlockingMethodInNonBlockingContext") suspend fun connect(remoteAddress: String, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { when { @@ -270,14 +271,12 @@ open class Client(config: Configuration = Configuration } - val handshake = ClientHandshake(config, crypto, this) + val handshake = ClientHandshake(crypto, this) val handshakeConnection = if (autoChangeToIpc || canUseIPC) { // MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId, streamId = ipcPublicationId, - sessionId = AeronConfig.RESERVED_SESSION_ID_INVALID, - // "fast" connection timeout, since this is IPC - connectionTimeoutMS = 1000) + sessionId = AeronConfig.RESERVED_SESSION_ID_INVALID) // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports try { @@ -356,14 +355,13 @@ open class Client(config: Configuration = Configuration IpcMediaDriverConnection(sessionId = connectionInfo.sessionId, // NOTE: pub/sub must be switched! streamIdSubscription = connectionInfo.publicationPort, - streamId = connectionInfo.subscriptionPort, - connectionTimeoutMS = connectionTimeoutMS) + streamId = connectionInfo.subscriptionPort) } else { UdpMediaDriverConnection(address = handshakeConnection.address!!, // NOTE: pub/sub must be switched! - subscriptionPort = connectionInfo.publicationPort, publicationPort = connectionInfo.subscriptionPort, + subscriptionPort = connectionInfo.publicationPort, streamId = connectionInfo.streamId, sessionId = connectionInfo.sessionId, connectionTimeoutMS = connectionTimeoutMS, @@ -473,7 +471,7 @@ open class Client(config: Configuration = Configuration val pollIdleStrategy = config.pollIdleStrategy while (!isShutdown()) { - if (newConnection.isClosed()) { + if (newConnection.isClosedViaAeron()) { // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. logger.debug {"[${newConnection.id}] connection expired"} diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index 4ca5c0a7..bb32123f 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -27,7 +27,9 @@ import dorkbox.network.connection.EndPoint import dorkbox.network.connection.ListenerManager import dorkbox.network.connectionType.ConnectionRule import dorkbox.network.coroutines.SuspendWaiter +import dorkbox.network.exceptions.ClientRejectedException import dorkbox.network.exceptions.ServerException +import dorkbox.network.handshake.HandshakeMessage import dorkbox.network.handshake.ServerHandshake import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage @@ -167,6 +169,17 @@ open class Server(config: ServerConfiguration = ServerC val sessionId = header.sessionId() val message = readHandshakeMessage(buffer, offset, length, header) + + // VALIDATE:: a Registration object is the only acceptable message during the connection phase + if (message !is HandshakeMessage) { + listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request")) + + actionDispatch.launch { + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) + } + return@FragmentAssembler + } + handshake.processIpcHandshakeMessageServer(this@Server, publication, sessionId, @@ -234,6 +247,17 @@ open class Server(config: ServerConfiguration = ServerC val message = readHandshakeMessage(buffer, offset, length, header) + + // VALIDATE:: a Registration object is the only acceptable message during the connection phase + if (message !is HandshakeMessage) { + listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) + + actionDispatch.launch { + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) + } + return@FragmentAssembler + } + handshake.processUdpHandshakeMessageServer(this@Server, publication, sessionId, @@ -304,6 +328,17 @@ open class Server(config: ServerConfiguration = ServerC val message = readHandshakeMessage(buffer, offset, length, header) + + // VALIDATE:: a Registration object is the only acceptable message during the connection phase + if (message !is HandshakeMessage) { + listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) + + actionDispatch.launch { + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) + } + return@FragmentAssembler + } + handshake.processUdpHandshakeMessageServer(this@Server, publication, sessionId, @@ -374,6 +409,17 @@ open class Server(config: ServerConfiguration = ServerC val message = readHandshakeMessage(buffer, offset, length, header) + + // VALIDATE:: a Registration object is the only acceptable message during the connection phase + if (message !is HandshakeMessage) { + listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) + + actionDispatch.launch { + writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request")) + } + return@FragmentAssembler + } + handshake.processUdpHandshakeMessageServer(this@Server, publication, sessionId, @@ -464,7 +510,7 @@ open class Server(config: ServerConfiguration = ServerC // this manages existing clients (for cleanup + connection polling). This has a concurrent iterator, // so we can modify this as we go connections.forEach { connection -> - if (connection.isClosed()) { + if (connection.isClosedViaAeron()) { // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. logger.debug { "[${connection.id}] connection expired" } diff --git a/src/dorkbox/network/aeron/MediaDriverConnection.kt b/src/dorkbox/network/aeron/MediaDriverConnection.kt index 3f72b13f..dd851a17 100644 --- a/src/dorkbox/network/aeron/MediaDriverConnection.kt +++ b/src/dorkbox/network/aeron/MediaDriverConnection.kt @@ -31,18 +31,14 @@ import java.net.Inet4Address import java.net.InetAddress import java.util.concurrent.TimeUnit -interface MediaDriverConnection : AutoCloseable { - val address: InetAddress? - val streamId: Int - val sessionId: Int +abstract class MediaDriverConnection(val address: InetAddress?, + val publicationPort: Int, val subscriptionPort: Int, + val streamId: Int, val sessionId: Int, + val connectionTimeoutMS: Long, val isReliable: Boolean) : AutoCloseable { - val subscriptionPort: Int - val publicationPort: Int + lateinit var subscription: Subscription + lateinit var publication: Publication - val subscription: Subscription - val publication: Publication - - val isReliable: Boolean suspend fun addSubscriptionWithRetry(aeron: Aeron, uri: String, streamId: Int, logger: KLogger): Subscription { // If we start/stop too quickly, we might have the address already in use! Retry a few times. @@ -79,27 +75,25 @@ interface MediaDriverConnection : AutoCloseable { } @Throws(ClientTimedOutException::class) - suspend fun buildClient(aeron: Aeron, logger: KLogger) - suspend fun buildServer(aeron: Aeron, logger: KLogger) + abstract suspend fun buildClient(aeron: Aeron, logger: KLogger) + abstract suspend fun buildServer(aeron: Aeron, logger: KLogger) - fun clientInfo() : String - fun serverInfo() : String + abstract fun clientInfo() : String + abstract fun serverInfo() : String } /** * For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER. * A connection timeout of 0, means to wait forever */ -class UdpMediaDriverConnection(override val address: InetAddress, - override val publicationPort: Int, - override val subscriptionPort: Int, - override val streamId: Int, - override val sessionId: Int, - private val connectionTimeoutMS: Long = 0, - override val isReliable: Boolean = true) : MediaDriverConnection { - - override lateinit var subscription: Subscription - override lateinit var publication: Publication +class UdpMediaDriverConnection(address: InetAddress, + publicationPort: Int, + subscriptionPort: Int, + streamId: Int, + sessionId: Int, + connectionTimeoutMS: Long = 0, + isReliable: Boolean = true) : + MediaDriverConnection(address, publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { var success: Boolean = false @@ -163,7 +157,7 @@ class UdpMediaDriverConnection(override val address: InetAddress, val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS) var startTime = System.nanoTime() while (timoutInNanos == 0L || System.nanoTime() - startTime < timoutInNanos) { - if (subscription.isConnected && subscription.imageCount() > 0) { + if (subscription.isConnected) { success = true break } @@ -198,8 +192,8 @@ class UdpMediaDriverConnection(override val address: InetAddress, this.success = true - this.subscription = subscription this.publication = publication + this.subscription = subscription } override suspend fun buildServer(aeron: Aeron, logger: KLogger) { @@ -236,6 +230,8 @@ class UdpMediaDriverConnection(override val address: InetAddress, } override fun clientInfo(): String { + address!! + return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) { "Connecting to ${IP.toString(address)} [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)" } else { @@ -251,7 +247,7 @@ class UdpMediaDriverConnection(override val address: InetAddress, IPv4.WILDCARD.hostAddress + "/" + address.hostAddress } } else { - IP.toString(address) + IP.toString(address!!) } return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) { @@ -275,20 +271,13 @@ class UdpMediaDriverConnection(override val address: InetAddress, /** * For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER + * NOTE: IPC connection will ALWAYS have a timeout of 1 second to connect. This is IPC, it should connect fast */ -class IpcMediaDriverConnection(override val streamId: Int, +class IpcMediaDriverConnection(streamId: Int, val streamIdSubscription: Int, - override val sessionId: Int, - private val connectionTimeoutMS: Long = 30_000, - ) : MediaDriverConnection { - - override val address: InetAddress? = null - override val isReliable = true - override val subscriptionPort = 0 - override val publicationPort = 0 - - override lateinit var subscription: Subscription - override lateinit var publication: Publication + sessionId: Int, + ) : + MediaDriverConnection(null, 0, 0, streamId, sessionId, 1_000, true) { var success: Boolean = false @@ -301,7 +290,11 @@ class IpcMediaDriverConnection(override val streamId: Int, return builder } - @Throws(ClientTimedOutException::class) + /** + * Set up the subscription + publication channels to the server + * + * @throws ClientTimedOutException if we cannot connect to the server in the designated time + */ override suspend fun buildClient(aeron: Aeron, logger: KLogger) { // Create a publication at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. @@ -366,6 +359,9 @@ class IpcMediaDriverConnection(override val streamId: Int, this.subscription = subscription } + /** + * Setup the subscription + publication channels on the server + */ override suspend fun buildServer(aeron: Aeron, logger: KLogger) { // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index 34d46f57..78d557de 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -293,16 +293,59 @@ open class Connection(connectionParameters: ConnectionParams<*>) { } /** - * @return `true` if this connection has been closed + * Adds a function that will be called when a client/server "disconnects" with + * each other + * + * For a server, this function will be called for ALL clients. + * + * It is POSSIBLE to add a server CONNECTION only (ie, not global) listener + * (via connection.addListener), meaning that ONLY that listener attached to + * the connection is notified on that event (ie, admin type listeners) */ - fun isClosed(): Boolean { - val hasNoImages = subscription.hasNoImages() - if (hasNoImages) { + suspend fun onDisconnect(function: suspend (Connection) -> Unit) { + // make sure we atomically create the listener manager, if necessary + listenerManager.getAndUpdate { origManager -> + origManager ?: ListenerManager() + } + + listenerManager.value!!.onDisconnect(function) + } + + /** + * Adds a function that will be called only for this connection, when a client/server receives a message + */ + suspend fun onMessage(function: suspend (Connection, MESSAGE) -> Unit) { + // make sure we atomically create the listener manager, if necessary + listenerManager.getAndUpdate { origManager -> + origManager ?: ListenerManager() + } + + listenerManager.value!!.onMessage(function) + } + + /** + * Invoked when a message object was received from a remote peer. + * + * This is ALWAYS called on a new dispatch + */ + internal suspend fun notifyOnMessage(message: Any): Boolean { + return listenerManager.value?.notifyOnMessage(this, message) ?: false + } + + /** + * We must account for network blips. They blips will be recovered by aeron, but we want to make sure that we are actually + * disconnected for a set period of time before we start the close process for a connection + * + * @return `true` if this connection has been closed via aeron + */ + fun isClosedViaAeron(): Boolean { + val isNotConnected = !subscription.isConnected && !publication.isConnected + if (isNotConnected) { // 1) connections take a little bit of time from polling -> connecting (because of how we poll connections before 'connecting' them). return System.nanoTime() - connectionInitTime >= TimeUnit.SECONDS.toNanos(1) } - return hasNoImages + return isNotConnected } /** @@ -370,47 +413,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) { } } - /** - * Adds a function that will be called when a client/server "disconnects" with - * each other - * - * For a server, this function will be called for ALL clients. - * - * It is POSSIBLE to add a server CONNECTION only (ie, not global) listener - * (via connection.addListener), meaning that ONLY that listener attached to - * the connection is notified on that event (ie, admin type listeners) - */ - suspend fun onDisconnect(function: suspend (Connection) -> Unit) { - // make sure we atomically create the listener manager, if necessary - listenerManager.getAndUpdate { origManager -> - origManager ?: ListenerManager() - } - - listenerManager.value!!.onDisconnect(function) - } - - /** - * Adds a function that will be called only for this connection, when a client/server receives a message - */ - suspend fun onMessage(function: suspend (Connection, MESSAGE) -> Unit) { - // make sure we atomically create the listener manager, if necessary - listenerManager.getAndUpdate { origManager -> - origManager ?: ListenerManager() - } - - listenerManager.value!!.onMessage(function) - } - - /** - * Invoked when a message object was received from a remote peer. - * - * This is ALWAYS called on a new dispatch - */ - internal suspend fun notifyOnMessage(message: Any): Boolean { - return listenerManager.value?.notifyOnMessage(this, message) ?: false - } - - // // // Generic object methods diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index 495898d0..243c89be 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -42,6 +42,9 @@ internal class ClientHandshake(private val crypto: Crypt @Volatile private var connectionDone = false + @Volatile + private var needToRetry = false + @Volatile private var failed: Exception? = null @@ -64,11 +67,18 @@ internal class ClientHandshake(private val crypto: Crypt } // this is an error message - if (message.sessionId == 0) { + if (message.state == HandshakeMessage.INVALID) { failed = ClientException("[$sessionId] error: ${message.errorMessage}") return@FragmentAssembler } + // this is an retry message + // this can happen if there are multiple connections from the SAME ip address (ie: localhost) + if (message.state == HandshakeMessage.RETRY) { + needToRetry = true + return@FragmentAssembler + } + if (this@ClientHandshake.sessionId != message.sessionId) { failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " + @@ -172,7 +182,7 @@ internal class ClientHandshake(private val crypto: Crypt val subscription = handshakeConnection.subscription val pollIdleStrategy = endPoint.config.pollIdleStrategy - val startTime = System.currentTimeMillis() + var startTime = System.currentTimeMillis() while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { // NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment. // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` @@ -184,6 +194,13 @@ internal class ClientHandshake(private val crypto: Crypt throw failed as Exception } + if (needToRetry) { + needToRetry = false + + // start over with the timeout! + startTime = System.currentTimeMillis() + } + if (connectionDone) { break } diff --git a/src/dorkbox/network/handshake/HandshakeMessage.kt b/src/dorkbox/network/handshake/HandshakeMessage.kt index 6452bd6e..d94df6d5 100644 --- a/src/dorkbox/network/handshake/HandshakeMessage.kt +++ b/src/dorkbox/network/handshake/HandshakeMessage.kt @@ -50,7 +50,8 @@ internal class HandshakeMessage private constructor() { var registrationRmiIdData: IntArray? = null companion object { - const val INVALID = -1 + const val INVALID = -2 + const val RETRY = -1 const val HELLO = 0 const val HELLO_ACK = 1 const val HELLO_ACK_IPC = 2 @@ -99,9 +100,17 @@ internal class HandshakeMessage private constructor() { return error } + fun retry(errorMessage: String): HandshakeMessage { + val error = HandshakeMessage() + error.state = RETRY + error.errorMessage = errorMessage + return error + } + fun toStateString(state: Int) : String { return when(state) { INVALID -> "INVALID" + RETRY -> "RETRY" HELLO -> "HELLO" HELLO_ACK -> "HELLO_ACK" HELLO_ACK_IPC -> "HELLO_ACK_IPC" diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index a5b1a358..4a25579c 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -30,11 +30,13 @@ import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.exceptions.AllocationException +import dorkbox.network.exceptions.ClientException import dorkbox.network.exceptions.ClientRejectedException import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.exceptions.ServerException import io.aeron.Aeron import io.aeron.Publication +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import mu.KLogger @@ -42,6 +44,7 @@ import java.net.Inet4Address import java.net.InetAddress import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantReadWriteLock +import kotlin.concurrent.read import kotlin.concurrent.write @@ -78,21 +81,32 @@ internal class ServerHandshake(private val logger: KLog /** * @return true if we should continue parsing the incoming message, false if we should abort */ - // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD. ONLY RESPONSES ARE ON ACTION DISPATCH! private fun validateMessageTypeAndDoPending(server: Server, + actionDispatch: CoroutineScope, handshakePublication: Publication, - message: Any?, + message: HandshakeMessage, sessionId: Int, connectionString: String): Boolean { - // VALIDATE:: a Registration object is the only acceptable message during the connection phase - if (message !is HandshakeMessage) { - listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request")) - - runBlocking { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) + // check to see if this sessionId is ALREADY in use by another connection! + // this can happen if there are multiple connections from the SAME ip address (ie: localhost) + if (message.state == HandshakeMessage.HELLO) { + val hasExistingSessionId = pendingConnectionsLock.read { + pendingConnections.getIfPresent(sessionId) != null } - return false + + if (hasExistingSessionId) { + // WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId + listenerManager.notifyError(ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.")) + + actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!")) + } + return false + } + + return true } // check to see if this is a pending connection @@ -112,11 +126,9 @@ internal class ServerHandshake(private val logger: KLog server.addConnection(pendingConnection) // now tell the client we are done - runBlocking { + actionDispatch.launch { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) - } - server.actionDispatch.launch { - // this must be THE ONLY THING in this class to use the action dispatch! + listenerManager.notifyConnect(pendingConnection) } } @@ -178,15 +190,14 @@ internal class ServerHandshake(private val logger: KLog fun processIpcHandshakeMessageServer(server: Server, handshakePublication: Publication, sessionId: Int, - message: Any?, + message: HandshakeMessage, aeron: Aeron) { val connectionString = "IPC" - if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, connectionString)) { + if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, connectionString)) { return } - message as HandshakeMessage val serialization = config.serialization @@ -242,11 +253,9 @@ internal class ServerHandshake(private val logger: KLog // create a new connection. The session ID is encrypted. try { - // connection timeout of 0 doesn't matter. it is not used by the server val clientConnection = IpcMediaDriverConnection(streamId = connectionStreamPubId, streamIdSubscription = connectionStreamSubId, - sessionId = connectionSessionId, - connectionTimeoutMS = 0) + sessionId = connectionSessionId) // we have to construct how the connection will communicate! runBlocking { @@ -332,14 +341,13 @@ internal class ServerHandshake(private val logger: KLog sessionId: Int, clientAddressString: String, clientAddress: InetAddress, - message: Any?, + message: HandshakeMessage, aeron: Aeron, isIpv6Wildcard: Boolean) { - if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) { + if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, clientAddressString)) { return } - message as HandshakeMessage val clientPublicKeyBytes = message.publicKey val validateRemoteAddress: PublicKeyValidationState