Fixed handshake race condition which resulted in empty messages. Cleaned up code, added more debug info

This commit is contained in:
nathan 2020-09-03 01:31:08 +02:00
parent b1e92be50b
commit 2a6c279692
5 changed files with 161 additions and 164 deletions

View File

@ -25,7 +25,6 @@ import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.handshake.ClientHandshake
@ -60,16 +59,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* For the IPC (Inter-Process-Communication) address. it must be:
* - the IPC integer ID, "0x1337c0de", "0x12312312", etc.
*/
private var remoteAddress = ""
private var remoteAddress0 = ""
@Volatile
private var isConnected = false
// is valid when there is a connection to the server, otherwise it is null
private var connection: CONNECTION? = null
private var connection0: CONNECTION? = null
@Volatile
protected var connectionTimeoutMS: Long = 5000 // default is 5 seconds
private var connectionTimeoutMS: Long = 5_000 // default is 5 seconds
private val previousClosedConnectionActivity: Long = 0
@ -133,11 +132,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
lockStepForReconnect.lazySet(null)
connection = null
connection0 = null
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState()
// only change LOCALHOST -> IPC if the media driver is ALREADY running!
val canAutoChangeToIpc = config.enableIpcForLoopback && isRunning()
if (canAutoChangeToIpc) {
logger.trace { "Media driver is running. Support for enable auto-switch from LOCALHOST -> IPC enabled" }
}
this.connectionTimeoutMS = connectionTimeoutMS
val isIpcConnection: Boolean
@ -149,56 +153,56 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
when (remoteAddress) {
"0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
"loopback", "localhost", "lo", "" -> {
if (config.enableIpcForLoopback) {
if (canAutoChangeToIpc) {
isIpcConnection = true
logger.info("Auto-changing network connection from $remoteAddress -> IPC")
this.remoteAddress = "ipc"
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = "ipc"
} else {
isIpcConnection = false
this.remoteAddress = IPv4.LOCALHOST.hostAddress
this.remoteAddress0 = IPv4.LOCALHOST.hostAddress
}
}
"0x" -> {
isIpcConnection = true
this.remoteAddress = "ipc"
this.remoteAddress0 = "ipc"
}
else -> when {
IPv4.isLoopback(remoteAddress) -> {
if (config.enableIpcForLoopback) {
if (canAutoChangeToIpc) {
isIpcConnection = true
logger.info("Auto-changing network connection from $remoteAddress -> IPC")
this.remoteAddress = "ipc"
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = "ipc"
} else {
isIpcConnection = false
this.remoteAddress = IPv4.LOCALHOST.hostAddress
this.remoteAddress0 = IPv4.LOCALHOST.hostAddress
}
}
IPv6.isLoopback(remoteAddress) -> {
if (config.enableIpcForLoopback) {
if (canAutoChangeToIpc) {
isIpcConnection = true
logger.info("Auto-changing network connection from $remoteAddress -> IPC")
this.remoteAddress = "ipc"
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = "ipc"
} else {
isIpcConnection = false
this.remoteAddress = IPv6.LOCALHOST.hostAddress
this.remoteAddress0 = IPv6.LOCALHOST.hostAddress
}
}
else -> {
isIpcConnection = false
this.remoteAddress = remoteAddress
this.remoteAddress0 = remoteAddress
}
}
}
if (IPv6.isValid(this.remoteAddress)) {
if (IPv6.isValid(this.remoteAddress0)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (this.remoteAddress.count { it == '[' } < 1 &&
this.remoteAddress.count { it == ']' } < 1) {
if (this.remoteAddress0.count { it == '[' } < 1 &&
this.remoteAddress0.count { it == ']' } < 1) {
this.remoteAddress = """[${this.remoteAddress}]"""
this.remoteAddress0 = """[${this.remoteAddress0}]"""
}
}
@ -212,7 +216,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
sessionId = RESERVED_SESSION_ID_INVALID)
}
else {
UdpMediaDriverConnection(address = this.remoteAddress,
UdpMediaDriverConnection(address = this.remoteAddress0,
publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
@ -237,7 +241,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val validateRemoteAddress = if (isIpcConnection) {
PublicKeyValidationState.VALID
} else {
crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress0), connectionInfo.publicKey)
}
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -320,7 +324,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
newConnection.preCloseAction = {
// this is called whenever connection.close() is called by the framework or via client.close()
if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(getConnection(), IllegalStateException("lockStep for reconnect was in the wrong state!"))
listenerManager.notifyError(connection, IllegalStateException("lockStep for reconnect was in the wrong state!"))
}
}
newConnection.postCloseAction = {
@ -331,7 +335,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// manually call it.
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
actionDispatch.launch {
listenerManager.notifyDisconnect(getConnection())
listenerManager.notifyDisconnect(connection)
}
// in case notifyDisconnect called client.connect().... cancel them waiting
@ -339,55 +343,52 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
lockStepForReconnect.value?.cancel()
}
connection = newConnection
connection0 = newConnection
connections.add(newConnection)
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.id}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.id}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Polls the AERON media driver subscription channel for incoming messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
}
// tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
if (canFinishConnecting) {
isConnected = true
// we poll for new messages AFTER `handshake.handshakeDone`, because the aeron media driver will queue up the messages for us.
// we want to make sure to call notify connect BEFORE processing new messages.
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
actionDispatch.launch {
listenerManager.notifyConnect(newConnection)
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.id}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.id}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Polls the AERON media driver subscription channel for incoming messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
}
} else {
close()
@ -399,52 +400,47 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
/**
* @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed.
* true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed.
*/
fun hasRemoteKeyChanged(): Boolean {
return getConnection().hasRemoteKeyChanged()
}
val remoteKeyHasChanged: Boolean
get() = connection.hasRemoteKeyChanged()
/**
* @return the remote address, as a string.
* the remote address, as a string.
*/
fun getRemoteHost(): String {
return this.remoteAddress
}
val remoteAddress: String
get() = remoteAddress0
/**
* @return true if this connection is an IPC connection
* true if this connection is an IPC connection
*/
fun isIPC(): Boolean {
return getConnection().isIpc
}
val isIPC: Boolean
get() = connection.isIpc
/**
* @return true if this connection is a network connection
*/
fun isNetwork(): Boolean {
return getConnection().isNetwork
}
val isNetwork: Boolean
get() = connection.isNetwork
/**
* @return the connection (TCP or IPC) id of this connection.
*/
fun id(): Int {
return getConnection().id
}
val id: Int
get() = connection.id
/**
* @return the connection used by the client, this is only valid after the client has connected
* the connection used by the client, this is only valid after the client has connected
*/
fun getConnection(): CONNECTION {
return connection as CONNECTION
}
val connection: CONNECTION
get() = connection0 as CONNECTION
/**
* @throws ClientException when a message cannot be sent
*/
suspend fun send(message: Any) {
val c = connection
val c = connection0
if (c != null) {
c.send(message)
} else {
@ -455,24 +451,25 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
/**
* @throws ClientException when a ping cannot be sent
*/
suspend fun ping(): Ping {
val c = connection
if (c != null) {
return c.ping()
} else {
throw ClientException("Cannot ping a connection when there is no connection!")
}
}
// suspend fun ping(): Ping {
// val c = connection
// if (c != null) {
// return c.ping()
// } else {
// throw ClientException("Cannot ping a connection when there is no connection!")
// }
// }
/**
* Removes the specified host address from the list of registered server keys.
*/
@Throws(SecurityException::class)
fun removeRegisteredServerKey(hostAddress: Int) {
val savedPublicKey = settingsStore.getRegisteredServerKey(hostAddress)
fun removeRegisteredServerKey(hostAddress: String) {
val address = IPv4.toInt(hostAddress)
val savedPublicKey = settingsStore.getRegisteredServerKey(address)
if (savedPublicKey != null) {
val logger2 = logger
if (logger2.isDebugEnabled) {
logger2.debug("Deleting remote IP address key ${IPv4.toString(hostAddress)}")
}
settingsStore.removeRegisteredServerKey(hostAddress)
logger.debug { "Deleting remote IP address key $hostAddress" }
settingsStore.removeRegisteredServerKey(address)
}
}
@ -585,7 +582,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java)
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
return rmiConnectionSupport.getProxyObject(getConnection(), kryoId, objectId, Iface::class.java)
return rmiConnectionSupport.getProxyObject(connection, kryoId, objectId, Iface::class.java)
}
/**
@ -615,7 +612,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
objectParameters as Array<Any?>
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, objectParameters, callback)
rmiConnectionSupport.createRemoteObject(connection, kryoId, objectParameters, callback)
}
/**
@ -642,7 +639,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java)
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, null, callback)
rmiConnectionSupport.createRemoteObject(connection, kryoId, null, callback)
}
//
@ -730,6 +727,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// NOTE: It's not possible to have reified inside a virtual function
// https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
return rmiGlobalSupport.getGlobalRemoteObject(getConnection(), objectId, Iface::class.java)
return rmiGlobalSupport.getGlobalRemoteObject(connection, objectId, Iface::class.java)
}
}

View File

@ -23,7 +23,6 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.connection.connectionType.ConnectionProperties
import dorkbox.network.connection.connectionType.ConnectionRule
import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.rmi.RemoteObject
@ -36,7 +35,6 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer
import java.net.InetSocketAddress
import java.util.concurrent.CopyOnWriteArrayList
/**
@ -68,14 +66,15 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
}
}
/**
* @return true if this server has successfully bound to an IP address and is running
*/
@Volatile
private var bindAlreadyCalled = false
/**
* Used for handshake connections
*/
private val handshake = ServerHandshake(logger, config, listenerManager)
/**
@ -403,32 +402,32 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
/**
* Only called by the server!
*
* If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic.
*/
// after the handshake, what sort of connection do we want (NONE, COMPRESS, ENCRYPT+COMPRESS)
fun getConnectionUpgradeType(remoteAddress: InetSocketAddress): Byte {
val address = remoteAddress.address
val size = connectionRules.size
// if it's unknown, then by default we encrypt the traffic
var connectionType = ConnectionProperties.COMPRESS_AND_ENCRYPT
if (size == 0 && address == IPv4.LOCALHOST) {
// if nothing is specified, then by default localhost is compression and everything else is encrypted
connectionType = ConnectionProperties.COMPRESS
}
for (i in 0 until size) {
val rule = connectionRules[i] ?: continue
if (rule.matches(remoteAddress)) {
connectionType = rule.ruleType()
break
}
}
logger.debug("Validating {} Permitted type is: {}", remoteAddress, connectionType)
return connectionType.type
}
// /**
// * Only called by the server!
// *
// * If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic.
// */
// // after the handshake, what sort of connection do we want (NONE, COMPRESS, ENCRYPT+COMPRESS)
// fun getConnectionUpgradeType(remoteAddress: InetSocketAddress): Byte {
// val address = remoteAddress.address
// val size = connectionRules.size
//
// // if it's unknown, then by default we encrypt the traffic
// var connectionType = ConnectionProperties.COMPRESS_AND_ENCRYPT
// if (size == 0 && address == IPv4.LOCALHOST) {
// // if nothing is specified, then by default localhost is compression and everything else is encrypted
// connectionType = ConnectionProperties.COMPRESS
// }
// for (i in 0 until size) {
// val rule = connectionRules[i] ?: continue
// if (rule.matches(remoteAddress)) {
// connectionType = rule.ruleType()
// break
// }
// }
// logger.debug("Validating {} Permitted type is: {}", remoteAddress, connectionType)
// return connectionType.type
// }
// RMI notes (in multiple places, copypasta, because this is confusing if not written down)

View File

@ -296,6 +296,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
if (type == Server::class.java || !isRunning()) {
// the server always creates a the media driver.
mediaDriver = try {
logger.debug { "Starting Aeron Media driver..."}
MediaDriver.launch(mediaDriverContext)
} catch (e: Exception) {
listenerManager.notifyError(e)
@ -510,7 +511,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*
* @return the message
*/
fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
try {
val message = serialization.readMessage(buffer, offset, length)
logger.trace {
@ -540,7 +541,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* @param header The aeron header information
* @param connection The connection this message happened on
*/
fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) {
internal fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) {
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
@Suppress("UNCHECKED_CAST")
connection as CONNECTION
@ -617,7 +618,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
}
// NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine!
suspend fun send(message: Any, publication: Publication, connection: Connection) {
internal suspend fun send(message: Any, publication: Publication, connection: Connection) {
// The sessionId is globally unique, and is assigned by the server.
logger.trace {
"[${publication.sessionId()}] send: $message"

View File

@ -88,14 +88,14 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
val cryptInput = crypto.cryptInput
cryptInput.buffer = message.registrationData
val sessionId = cryptInput.readInt()
val sessId = cryptInput.readInt()
val streamSubId = cryptInput.readInt()
val streamPubId = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val regDetails = cryptInput.readBytes(regDetailsSize)
// now read data off
connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId,
connectionHelloInfo = ClientConnectionInfo(sessionId = sessId,
subscriptionPort = streamSubId,
publicationPort = streamPubId,
kryoRegistrationDetails = regDetails)
@ -104,12 +104,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
connectionDone = true
}
else -> {
if (message.state != HandshakeMessage.HELLO_ACK) {
failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
}
else if (message.state != HandshakeMessage.DONE_ACK) {
failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
failed = ClientException("[$sessionId] ignored message that is ${HandshakeMessage.toStateString(message.state)}")
}
}
}
@ -162,18 +157,18 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
return connectionHelloInfo!!
}
suspend fun handshakeDone(mediaConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient()
// Send the done message to the server.
endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage)
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
// block until we receive the connection information from the server
failed = null
var pollCount: Int
val subscription = mediaConnection.subscription
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
val startTime = System.currentTimeMillis()
@ -184,7 +179,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
if (failed != null) {
// no longer necessary to hold this connection open
mediaConnection.close()
handshakeConnection.close()
throw failed as Exception
}
@ -196,9 +191,10 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
pollIdleStrategy.idle(pollCount)
}
// no longer necessary to hold this connection open
handshakeConnection.close()
if (!connectionDone) {
// no longer necessary to hold this connection open
mediaConnection.close()
throw ClientTimedOutException("Waiting for registration response from server")
}

View File

@ -98,18 +98,22 @@ internal class HandshakeMessage private constructor() {
error.errorMessage = errorMessage
return error
}
fun toStateString(state: Int) : String {
return when(state) {
INVALID -> "INVALID"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC"
DONE -> "DONE"
DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"
}
}
}
override fun toString(): String {
val stateStr = when(state) {
INVALID -> "INVALID"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC"
DONE -> "DONE"
DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"
}
val stateStr = toStateString(state)
val errorMsg = if (errorMessage == null) {
""