From 2a6c2796926d896a54880d5f5891d4b3b4ccc066 Mon Sep 17 00:00:00 2001 From: nathan Date: Thu, 3 Sep 2020 01:31:08 +0200 Subject: [PATCH] Fixed handshake race condition which resulted in empty messages. Cleaned up code, added more debug info --- src/dorkbox/network/Client.kt | 213 +++++++++--------- src/dorkbox/network/Server.kt | 59 +++-- src/dorkbox/network/connection/EndPoint.kt | 7 +- .../network/handshake/ClientHandshake.kt | 24 +- .../network/handshake/HandshakeMessage.kt | 22 +- 5 files changed, 161 insertions(+), 164 deletions(-) diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index f50a8612..6776f590 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -25,7 +25,6 @@ import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.EndPoint import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager -import dorkbox.network.connection.Ping import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.handshake.ClientHandshake @@ -60,16 +59,16 @@ open class Client(config: Configuration = Configuration * For the IPC (Inter-Process-Communication) address. it must be: * - the IPC integer ID, "0x1337c0de", "0x12312312", etc. */ - private var remoteAddress = "" + private var remoteAddress0 = "" @Volatile private var isConnected = false // is valid when there is a connection to the server, otherwise it is null - private var connection: CONNECTION? = null + private var connection0: CONNECTION? = null @Volatile - protected var connectionTimeoutMS: Long = 5000 // default is 5 seconds + private var connectionTimeoutMS: Long = 5_000 // default is 5 seconds private val previousClosedConnectionActivity: Long = 0 @@ -133,11 +132,16 @@ open class Client(config: Configuration = Configuration } lockStepForReconnect.lazySet(null) - connection = null + connection0 = null // we are done with initial configuration, now initialize aeron and the general state of this endpoint val aeron = initEndpointState() + // only change LOCALHOST -> IPC if the media driver is ALREADY running! + val canAutoChangeToIpc = config.enableIpcForLoopback && isRunning() + if (canAutoChangeToIpc) { + logger.trace { "Media driver is running. Support for enable auto-switch from LOCALHOST -> IPC enabled" } + } this.connectionTimeoutMS = connectionTimeoutMS val isIpcConnection: Boolean @@ -149,56 +153,56 @@ open class Client(config: Configuration = Configuration when (remoteAddress) { "0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") "loopback", "localhost", "lo", "" -> { - if (config.enableIpcForLoopback) { + if (canAutoChangeToIpc) { isIpcConnection = true - logger.info("Auto-changing network connection from $remoteAddress -> IPC") - this.remoteAddress = "ipc" + logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } + this.remoteAddress0 = "ipc" } else { isIpcConnection = false - this.remoteAddress = IPv4.LOCALHOST.hostAddress + this.remoteAddress0 = IPv4.LOCALHOST.hostAddress } } "0x" -> { isIpcConnection = true - this.remoteAddress = "ipc" + this.remoteAddress0 = "ipc" } else -> when { IPv4.isLoopback(remoteAddress) -> { - if (config.enableIpcForLoopback) { + if (canAutoChangeToIpc) { isIpcConnection = true - logger.info("Auto-changing network connection from $remoteAddress -> IPC") - this.remoteAddress = "ipc" + logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } + this.remoteAddress0 = "ipc" } else { isIpcConnection = false - this.remoteAddress = IPv4.LOCALHOST.hostAddress + this.remoteAddress0 = IPv4.LOCALHOST.hostAddress } } IPv6.isLoopback(remoteAddress) -> { - if (config.enableIpcForLoopback) { + if (canAutoChangeToIpc) { isIpcConnection = true - logger.info("Auto-changing network connection from $remoteAddress -> IPC") - this.remoteAddress = "ipc" + logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } + this.remoteAddress0 = "ipc" } else { isIpcConnection = false - this.remoteAddress = IPv6.LOCALHOST.hostAddress + this.remoteAddress0 = IPv6.LOCALHOST.hostAddress } } else -> { isIpcConnection = false - this.remoteAddress = remoteAddress + this.remoteAddress0 = remoteAddress } } } - if (IPv6.isValid(this.remoteAddress)) { + if (IPv6.isValid(this.remoteAddress0)) { // "[" and "]" are valid for ipv6 addresses... we want to make sure it is so // if we are IPv6, the IP must be in '[]' - if (this.remoteAddress.count { it == '[' } < 1 && - this.remoteAddress.count { it == ']' } < 1) { + if (this.remoteAddress0.count { it == '[' } < 1 && + this.remoteAddress0.count { it == ']' } < 1) { - this.remoteAddress = """[${this.remoteAddress}]""" + this.remoteAddress0 = """[${this.remoteAddress0}]""" } } @@ -212,7 +216,7 @@ open class Client(config: Configuration = Configuration sessionId = RESERVED_SESSION_ID_INVALID) } else { - UdpMediaDriverConnection(address = this.remoteAddress, + UdpMediaDriverConnection(address = this.remoteAddress0, publicationPort = config.subscriptionPort, subscriptionPort = config.publicationPort, streamId = UDP_HANDSHAKE_STREAM_ID, @@ -237,7 +241,7 @@ open class Client(config: Configuration = Configuration val validateRemoteAddress = if (isIpcConnection) { PublicKeyValidationState.VALID } else { - crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey) + crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress0), connectionInfo.publicKey) } if (validateRemoteAddress == PublicKeyValidationState.INVALID) { @@ -320,7 +324,7 @@ open class Client(config: Configuration = Configuration newConnection.preCloseAction = { // this is called whenever connection.close() is called by the framework or via client.close() if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) { - listenerManager.notifyError(getConnection(), IllegalStateException("lockStep for reconnect was in the wrong state!")) + listenerManager.notifyError(connection, IllegalStateException("lockStep for reconnect was in the wrong state!")) } } newConnection.postCloseAction = { @@ -331,7 +335,7 @@ open class Client(config: Configuration = Configuration // manually call it. // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback actionDispatch.launch { - listenerManager.notifyDisconnect(getConnection()) + listenerManager.notifyDisconnect(connection) } // in case notifyDisconnect called client.connect().... cancel them waiting @@ -339,55 +343,52 @@ open class Client(config: Configuration = Configuration lockStepForReconnect.value?.cancel() } - connection = newConnection + connection0 = newConnection connections.add(newConnection) - // have to make a new thread to listen for incoming data! - // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them - actionDispatch.launch { - val pollIdleStrategy = config.pollIdleStrategy - - while (!isShutdown()) { - // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. - var shouldCleanupConnection = false - - if (newConnection.isExpired()) { - logger.debug {"[${newConnection.id}] connection expired"} - shouldCleanupConnection = true - } - - else if (newConnection.isClosed()) { - logger.debug {"[${newConnection.id}] connection closed"} - shouldCleanupConnection = true - } - - - if (shouldCleanupConnection) { - close() - return@launch - } - else { - // Polls the AERON media driver subscription channel for incoming messages - val pollCount = newConnection.pollSubscriptions() - - // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) - pollIdleStrategy.idle(pollCount) - } - } - } - - // tell the server our connection handshake is done, and the connection can now listen for data. val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) - // no longer necessary to hold the handshake connection open - handshakeConnection.close() - if (canFinishConnecting) { isConnected = true + // we poll for new messages AFTER `handshake.handshakeDone`, because the aeron media driver will queue up the messages for us. + // we want to make sure to call notify connect BEFORE processing new messages. + + // have to make a new thread to listen for incoming data! + // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them actionDispatch.launch { listenerManager.notifyConnect(newConnection) + + val pollIdleStrategy = config.pollIdleStrategy + + while (!isShutdown()) { + // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. + var shouldCleanupConnection = false + + if (newConnection.isExpired()) { + logger.debug {"[${newConnection.id}] connection expired"} + shouldCleanupConnection = true + } + + else if (newConnection.isClosed()) { + logger.debug {"[${newConnection.id}] connection closed"} + shouldCleanupConnection = true + } + + + if (shouldCleanupConnection) { + close() + return@launch + } + else { + // Polls the AERON media driver subscription channel for incoming messages + val pollCount = newConnection.pollSubscriptions() + + // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) + pollIdleStrategy.idle(pollCount) + } + } } } else { close() @@ -399,52 +400,47 @@ open class Client(config: Configuration = Configuration } /** - * @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed. + * true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed. */ - fun hasRemoteKeyChanged(): Boolean { - return getConnection().hasRemoteKeyChanged() - } + val remoteKeyHasChanged: Boolean + get() = connection.hasRemoteKeyChanged() /** - * @return the remote address, as a string. + * the remote address, as a string. */ - fun getRemoteHost(): String { - return this.remoteAddress - } + val remoteAddress: String + get() = remoteAddress0 /** - * @return true if this connection is an IPC connection + * true if this connection is an IPC connection */ - fun isIPC(): Boolean { - return getConnection().isIpc - } + val isIPC: Boolean + get() = connection.isIpc /** * @return true if this connection is a network connection */ - fun isNetwork(): Boolean { - return getConnection().isNetwork - } + val isNetwork: Boolean + get() = connection.isNetwork /** * @return the connection (TCP or IPC) id of this connection. */ - fun id(): Int { - return getConnection().id - } + val id: Int + get() = connection.id /** - * @return the connection used by the client, this is only valid after the client has connected + * the connection used by the client, this is only valid after the client has connected */ - fun getConnection(): CONNECTION { - return connection as CONNECTION - } + val connection: CONNECTION + get() = connection0 as CONNECTION + /** * @throws ClientException when a message cannot be sent */ suspend fun send(message: Any) { - val c = connection + val c = connection0 if (c != null) { c.send(message) } else { @@ -455,24 +451,25 @@ open class Client(config: Configuration = Configuration /** * @throws ClientException when a ping cannot be sent */ - suspend fun ping(): Ping { - val c = connection - if (c != null) { - return c.ping() - } else { - throw ClientException("Cannot ping a connection when there is no connection!") - } - } +// suspend fun ping(): Ping { +// val c = connection +// if (c != null) { +// return c.ping() +// } else { +// throw ClientException("Cannot ping a connection when there is no connection!") +// } +// } + /** + * Removes the specified host address from the list of registered server keys. + */ @Throws(SecurityException::class) - fun removeRegisteredServerKey(hostAddress: Int) { - val savedPublicKey = settingsStore.getRegisteredServerKey(hostAddress) + fun removeRegisteredServerKey(hostAddress: String) { + val address = IPv4.toInt(hostAddress) + val savedPublicKey = settingsStore.getRegisteredServerKey(address) if (savedPublicKey != null) { - val logger2 = logger - if (logger2.isDebugEnabled) { - logger2.debug("Deleting remote IP address key ${IPv4.toString(hostAddress)}") - } - settingsStore.removeRegisteredServerKey(hostAddress) + logger.debug { "Deleting remote IP address key $hostAddress" } + settingsStore.removeRegisteredServerKey(address) } } @@ -585,7 +582,7 @@ open class Client(config: Configuration = Configuration val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java) @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") - return rmiConnectionSupport.getProxyObject(getConnection(), kryoId, objectId, Iface::class.java) + return rmiConnectionSupport.getProxyObject(connection, kryoId, objectId, Iface::class.java) } /** @@ -615,7 +612,7 @@ open class Client(config: Configuration = Configuration objectParameters as Array @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") - rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, objectParameters, callback) + rmiConnectionSupport.createRemoteObject(connection, kryoId, objectParameters, callback) } /** @@ -642,7 +639,7 @@ open class Client(config: Configuration = Configuration val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java) @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") - rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, null, callback) + rmiConnectionSupport.createRemoteObject(connection, kryoId, null, callback) } // @@ -730,6 +727,6 @@ open class Client(config: Configuration = Configuration // NOTE: It's not possible to have reified inside a virtual function // https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") - return rmiGlobalSupport.getGlobalRemoteObject(getConnection(), objectId, Iface::class.java) + return rmiGlobalSupport.getGlobalRemoteObject(connection, objectId, Iface::class.java) } } diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index 5f5a4bc3..5ed2a760 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -23,7 +23,6 @@ import dorkbox.network.connection.EndPoint import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.UdpMediaDriverConnection -import dorkbox.network.connection.connectionType.ConnectionProperties import dorkbox.network.connection.connectionType.ConnectionRule import dorkbox.network.handshake.ServerHandshake import dorkbox.network.rmi.RemoteObject @@ -36,7 +35,6 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.agrona.DirectBuffer -import java.net.InetSocketAddress import java.util.concurrent.CopyOnWriteArrayList /** @@ -68,14 +66,15 @@ open class Server(config: ServerConfiguration = ServerC } } - /** * @return true if this server has successfully bound to an IP address and is running */ @Volatile private var bindAlreadyCalled = false - + /** + * Used for handshake connections + */ private val handshake = ServerHandshake(logger, config, listenerManager) /** @@ -403,32 +402,32 @@ open class Server(config: ServerConfiguration = ServerC - /** - * Only called by the server! - * - * If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic. - */ - // after the handshake, what sort of connection do we want (NONE, COMPRESS, ENCRYPT+COMPRESS) - fun getConnectionUpgradeType(remoteAddress: InetSocketAddress): Byte { - val address = remoteAddress.address - val size = connectionRules.size - - // if it's unknown, then by default we encrypt the traffic - var connectionType = ConnectionProperties.COMPRESS_AND_ENCRYPT - if (size == 0 && address == IPv4.LOCALHOST) { - // if nothing is specified, then by default localhost is compression and everything else is encrypted - connectionType = ConnectionProperties.COMPRESS - } - for (i in 0 until size) { - val rule = connectionRules[i] ?: continue - if (rule.matches(remoteAddress)) { - connectionType = rule.ruleType() - break - } - } - logger.debug("Validating {} Permitted type is: {}", remoteAddress, connectionType) - return connectionType.type - } +// /** +// * Only called by the server! +// * +// * If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic. +// */ +// // after the handshake, what sort of connection do we want (NONE, COMPRESS, ENCRYPT+COMPRESS) +// fun getConnectionUpgradeType(remoteAddress: InetSocketAddress): Byte { +// val address = remoteAddress.address +// val size = connectionRules.size +// +// // if it's unknown, then by default we encrypt the traffic +// var connectionType = ConnectionProperties.COMPRESS_AND_ENCRYPT +// if (size == 0 && address == IPv4.LOCALHOST) { +// // if nothing is specified, then by default localhost is compression and everything else is encrypted +// connectionType = ConnectionProperties.COMPRESS +// } +// for (i in 0 until size) { +// val rule = connectionRules[i] ?: continue +// if (rule.matches(remoteAddress)) { +// connectionType = rule.ruleType() +// break +// } +// } +// logger.debug("Validating {} Permitted type is: {}", remoteAddress, connectionType) +// return connectionType.type +// } // RMI notes (in multiple places, copypasta, because this is confusing if not written down) diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 242bb738..1a2ce96c 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -296,6 +296,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A if (type == Server::class.java || !isRunning()) { // the server always creates a the media driver. mediaDriver = try { + logger.debug { "Starting Aeron Media driver..."} MediaDriver.launch(mediaDriverContext) } catch (e: Exception) { listenerManager.notifyError(e) @@ -510,7 +511,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A * * @return the message */ - fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { + internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { try { val message = serialization.readMessage(buffer, offset, length) logger.trace { @@ -540,7 +541,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A * @param header The aeron header information * @param connection The connection this message happened on */ - fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) { + internal fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) { // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! @Suppress("UNCHECKED_CAST") connection as CONNECTION @@ -617,7 +618,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } // NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine! - suspend fun send(message: Any, publication: Publication, connection: Connection) { + internal suspend fun send(message: Any, publication: Publication, connection: Connection) { // The sessionId is globally unique, and is assigned by the server. logger.trace { "[${publication.sessionId()}] send: $message" diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index 1c05fe1a..c3c27b8e 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -88,14 +88,14 @@ internal class ClientHandshake(private val logger: KLogg val cryptInput = crypto.cryptInput cryptInput.buffer = message.registrationData - val sessionId = cryptInput.readInt() + val sessId = cryptInput.readInt() val streamSubId = cryptInput.readInt() val streamPubId = cryptInput.readInt() val regDetailsSize = cryptInput.readInt() val regDetails = cryptInput.readBytes(regDetailsSize) // now read data off - connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId, + connectionHelloInfo = ClientConnectionInfo(sessionId = sessId, subscriptionPort = streamSubId, publicationPort = streamPubId, kryoRegistrationDetails = regDetails) @@ -104,12 +104,7 @@ internal class ClientHandshake(private val logger: KLogg connectionDone = true } else -> { - if (message.state != HandshakeMessage.HELLO_ACK) { - failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK") - } - else if (message.state != HandshakeMessage.DONE_ACK) { - failed = ClientException("[$sessionId] ignored message that is not DONE_ACK") - } + failed = ClientException("[$sessionId] ignored message that is ${HandshakeMessage.toStateString(message.state)}") } } } @@ -162,18 +157,18 @@ internal class ClientHandshake(private val logger: KLogg return connectionHelloInfo!! } - suspend fun handshakeDone(mediaConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { + suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { val registrationMessage = HandshakeMessage.doneFromClient() // Send the done message to the server. - endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage) + endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) // block until we receive the connection information from the server failed = null var pollCount: Int - val subscription = mediaConnection.subscription + val subscription = handshakeConnection.subscription val pollIdleStrategy = endPoint.config.pollIdleStrategy val startTime = System.currentTimeMillis() @@ -184,7 +179,7 @@ internal class ClientHandshake(private val logger: KLogg if (failed != null) { // no longer necessary to hold this connection open - mediaConnection.close() + handshakeConnection.close() throw failed as Exception } @@ -196,9 +191,10 @@ internal class ClientHandshake(private val logger: KLogg pollIdleStrategy.idle(pollCount) } + // no longer necessary to hold this connection open + handshakeConnection.close() + if (!connectionDone) { - // no longer necessary to hold this connection open - mediaConnection.close() throw ClientTimedOutException("Waiting for registration response from server") } diff --git a/src/dorkbox/network/handshake/HandshakeMessage.kt b/src/dorkbox/network/handshake/HandshakeMessage.kt index 319612cc..ffea1c8d 100644 --- a/src/dorkbox/network/handshake/HandshakeMessage.kt +++ b/src/dorkbox/network/handshake/HandshakeMessage.kt @@ -98,18 +98,22 @@ internal class HandshakeMessage private constructor() { error.errorMessage = errorMessage return error } + + fun toStateString(state: Int) : String { + return when(state) { + INVALID -> "INVALID" + HELLO -> "HELLO" + HELLO_ACK -> "HELLO_ACK" + HELLO_ACK_IPC -> "HELLO_ACK_IPC" + DONE -> "DONE" + DONE_ACK -> "DONE_ACK" + else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!" + } + } } override fun toString(): String { - val stateStr = when(state) { - INVALID -> "INVALID" - HELLO -> "HELLO" - HELLO_ACK -> "HELLO_ACK" - HELLO_ACK_IPC -> "HELLO_ACK_IPC" - DONE -> "DONE" - DONE_ACK -> "DONE_ACK" - else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!" - } + val stateStr = toStateString(state) val errorMsg = if (errorMessage == null) { ""