Fixed handshake race condition which resulted in empty messages. Cleaned up code, added more debug info

This commit is contained in:
nathan 2020-09-03 01:31:08 +02:00
parent b1e92be50b
commit 2a6c279692
5 changed files with 161 additions and 164 deletions

View File

@ -25,7 +25,6 @@ import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.handshake.ClientHandshake import dorkbox.network.handshake.ClientHandshake
@ -60,16 +59,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* For the IPC (Inter-Process-Communication) address. it must be: * For the IPC (Inter-Process-Communication) address. it must be:
* - the IPC integer ID, "0x1337c0de", "0x12312312", etc. * - the IPC integer ID, "0x1337c0de", "0x12312312", etc.
*/ */
private var remoteAddress = "" private var remoteAddress0 = ""
@Volatile @Volatile
private var isConnected = false private var isConnected = false
// is valid when there is a connection to the server, otherwise it is null // is valid when there is a connection to the server, otherwise it is null
private var connection: CONNECTION? = null private var connection0: CONNECTION? = null
@Volatile @Volatile
protected var connectionTimeoutMS: Long = 5000 // default is 5 seconds private var connectionTimeoutMS: Long = 5_000 // default is 5 seconds
private val previousClosedConnectionActivity: Long = 0 private val previousClosedConnectionActivity: Long = 0
@ -133,11 +132,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
lockStepForReconnect.lazySet(null) lockStepForReconnect.lazySet(null)
connection = null connection0 = null
// we are done with initial configuration, now initialize aeron and the general state of this endpoint // we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState() val aeron = initEndpointState()
// only change LOCALHOST -> IPC if the media driver is ALREADY running!
val canAutoChangeToIpc = config.enableIpcForLoopback && isRunning()
if (canAutoChangeToIpc) {
logger.trace { "Media driver is running. Support for enable auto-switch from LOCALHOST -> IPC enabled" }
}
this.connectionTimeoutMS = connectionTimeoutMS this.connectionTimeoutMS = connectionTimeoutMS
val isIpcConnection: Boolean val isIpcConnection: Boolean
@ -149,56 +153,56 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
when (remoteAddress) { when (remoteAddress) {
"0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") "0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
"loopback", "localhost", "lo", "" -> { "loopback", "localhost", "lo", "" -> {
if (config.enableIpcForLoopback) { if (canAutoChangeToIpc) {
isIpcConnection = true isIpcConnection = true
logger.info("Auto-changing network connection from $remoteAddress -> IPC") logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress = "ipc" this.remoteAddress0 = "ipc"
} else { } else {
isIpcConnection = false isIpcConnection = false
this.remoteAddress = IPv4.LOCALHOST.hostAddress this.remoteAddress0 = IPv4.LOCALHOST.hostAddress
} }
} }
"0x" -> { "0x" -> {
isIpcConnection = true isIpcConnection = true
this.remoteAddress = "ipc" this.remoteAddress0 = "ipc"
} }
else -> when { else -> when {
IPv4.isLoopback(remoteAddress) -> { IPv4.isLoopback(remoteAddress) -> {
if (config.enableIpcForLoopback) { if (canAutoChangeToIpc) {
isIpcConnection = true isIpcConnection = true
logger.info("Auto-changing network connection from $remoteAddress -> IPC") logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress = "ipc" this.remoteAddress0 = "ipc"
} else { } else {
isIpcConnection = false isIpcConnection = false
this.remoteAddress = IPv4.LOCALHOST.hostAddress this.remoteAddress0 = IPv4.LOCALHOST.hostAddress
} }
} }
IPv6.isLoopback(remoteAddress) -> { IPv6.isLoopback(remoteAddress) -> {
if (config.enableIpcForLoopback) { if (canAutoChangeToIpc) {
isIpcConnection = true isIpcConnection = true
logger.info("Auto-changing network connection from $remoteAddress -> IPC") logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress = "ipc" this.remoteAddress0 = "ipc"
} else { } else {
isIpcConnection = false isIpcConnection = false
this.remoteAddress = IPv6.LOCALHOST.hostAddress this.remoteAddress0 = IPv6.LOCALHOST.hostAddress
} }
} }
else -> { else -> {
isIpcConnection = false isIpcConnection = false
this.remoteAddress = remoteAddress this.remoteAddress0 = remoteAddress
} }
} }
} }
if (IPv6.isValid(this.remoteAddress)) { if (IPv6.isValid(this.remoteAddress0)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so // "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]' // if we are IPv6, the IP must be in '[]'
if (this.remoteAddress.count { it == '[' } < 1 && if (this.remoteAddress0.count { it == '[' } < 1 &&
this.remoteAddress.count { it == ']' } < 1) { this.remoteAddress0.count { it == ']' } < 1) {
this.remoteAddress = """[${this.remoteAddress}]""" this.remoteAddress0 = """[${this.remoteAddress0}]"""
} }
} }
@ -212,7 +216,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
sessionId = RESERVED_SESSION_ID_INVALID) sessionId = RESERVED_SESSION_ID_INVALID)
} }
else { else {
UdpMediaDriverConnection(address = this.remoteAddress, UdpMediaDriverConnection(address = this.remoteAddress0,
publicationPort = config.subscriptionPort, publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort, subscriptionPort = config.publicationPort,
streamId = UDP_HANDSHAKE_STREAM_ID, streamId = UDP_HANDSHAKE_STREAM_ID,
@ -237,7 +241,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val validateRemoteAddress = if (isIpcConnection) { val validateRemoteAddress = if (isIpcConnection) {
PublicKeyValidationState.VALID PublicKeyValidationState.VALID
} else { } else {
crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey) crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress0), connectionInfo.publicKey)
} }
if (validateRemoteAddress == PublicKeyValidationState.INVALID) { if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -320,7 +324,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
newConnection.preCloseAction = { newConnection.preCloseAction = {
// this is called whenever connection.close() is called by the framework or via client.close() // this is called whenever connection.close() is called by the framework or via client.close()
if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) { if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(getConnection(), IllegalStateException("lockStep for reconnect was in the wrong state!")) listenerManager.notifyError(connection, IllegalStateException("lockStep for reconnect was in the wrong state!"))
} }
} }
newConnection.postCloseAction = { newConnection.postCloseAction = {
@ -331,7 +335,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// manually call it. // manually call it.
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
actionDispatch.launch { actionDispatch.launch {
listenerManager.notifyDisconnect(getConnection()) listenerManager.notifyDisconnect(connection)
} }
// in case notifyDisconnect called client.connect().... cancel them waiting // in case notifyDisconnect called client.connect().... cancel them waiting
@ -339,55 +343,52 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
lockStepForReconnect.value?.cancel() lockStepForReconnect.value?.cancel()
} }
connection = newConnection connection0 = newConnection
connections.add(newConnection) connections.add(newConnection)
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.id}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.id}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Polls the AERON media driver subscription channel for incoming messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
}
// tell the server our connection handshake is done, and the connection can now listen for data. // tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
if (canFinishConnecting) { if (canFinishConnecting) {
isConnected = true isConnected = true
// we poll for new messages AFTER `handshake.handshakeDone`, because the aeron media driver will queue up the messages for us.
// we want to make sure to call notify connect BEFORE processing new messages.
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
actionDispatch.launch { actionDispatch.launch {
listenerManager.notifyConnect(newConnection) listenerManager.notifyConnect(newConnection)
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.id}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.id}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Polls the AERON media driver subscription channel for incoming messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
} }
} else { } else {
close() close()
@ -399,52 +400,47 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
/** /**
* @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed. * true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed.
*/ */
fun hasRemoteKeyChanged(): Boolean { val remoteKeyHasChanged: Boolean
return getConnection().hasRemoteKeyChanged() get() = connection.hasRemoteKeyChanged()
}
/** /**
* @return the remote address, as a string. * the remote address, as a string.
*/ */
fun getRemoteHost(): String { val remoteAddress: String
return this.remoteAddress get() = remoteAddress0
}
/** /**
* @return true if this connection is an IPC connection * true if this connection is an IPC connection
*/ */
fun isIPC(): Boolean { val isIPC: Boolean
return getConnection().isIpc get() = connection.isIpc
}
/** /**
* @return true if this connection is a network connection * @return true if this connection is a network connection
*/ */
fun isNetwork(): Boolean { val isNetwork: Boolean
return getConnection().isNetwork get() = connection.isNetwork
}
/** /**
* @return the connection (TCP or IPC) id of this connection. * @return the connection (TCP or IPC) id of this connection.
*/ */
fun id(): Int { val id: Int
return getConnection().id get() = connection.id
}
/** /**
* @return the connection used by the client, this is only valid after the client has connected * the connection used by the client, this is only valid after the client has connected
*/ */
fun getConnection(): CONNECTION { val connection: CONNECTION
return connection as CONNECTION get() = connection0 as CONNECTION
}
/** /**
* @throws ClientException when a message cannot be sent * @throws ClientException when a message cannot be sent
*/ */
suspend fun send(message: Any) { suspend fun send(message: Any) {
val c = connection val c = connection0
if (c != null) { if (c != null) {
c.send(message) c.send(message)
} else { } else {
@ -455,24 +451,25 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
/** /**
* @throws ClientException when a ping cannot be sent * @throws ClientException when a ping cannot be sent
*/ */
suspend fun ping(): Ping { // suspend fun ping(): Ping {
val c = connection // val c = connection
if (c != null) { // if (c != null) {
return c.ping() // return c.ping()
} else { // } else {
throw ClientException("Cannot ping a connection when there is no connection!") // throw ClientException("Cannot ping a connection when there is no connection!")
} // }
} // }
/**
* Removes the specified host address from the list of registered server keys.
*/
@Throws(SecurityException::class) @Throws(SecurityException::class)
fun removeRegisteredServerKey(hostAddress: Int) { fun removeRegisteredServerKey(hostAddress: String) {
val savedPublicKey = settingsStore.getRegisteredServerKey(hostAddress) val address = IPv4.toInt(hostAddress)
val savedPublicKey = settingsStore.getRegisteredServerKey(address)
if (savedPublicKey != null) { if (savedPublicKey != null) {
val logger2 = logger logger.debug { "Deleting remote IP address key $hostAddress" }
if (logger2.isDebugEnabled) { settingsStore.removeRegisteredServerKey(address)
logger2.debug("Deleting remote IP address key ${IPv4.toString(hostAddress)}")
}
settingsStore.removeRegisteredServerKey(hostAddress)
} }
} }
@ -585,7 +582,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java) val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java)
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
return rmiConnectionSupport.getProxyObject(getConnection(), kryoId, objectId, Iface::class.java) return rmiConnectionSupport.getProxyObject(connection, kryoId, objectId, Iface::class.java)
} }
/** /**
@ -615,7 +612,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
objectParameters as Array<Any?> objectParameters as Array<Any?>
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, objectParameters, callback) rmiConnectionSupport.createRemoteObject(connection, kryoId, objectParameters, callback)
} }
/** /**
@ -642,7 +639,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java) val kryoId = serialization.getKryoIdForRmiClient(Iface::class.java)
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
rmiConnectionSupport.createRemoteObject(getConnection(), kryoId, null, callback) rmiConnectionSupport.createRemoteObject(connection, kryoId, null, callback)
} }
// //
@ -730,6 +727,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// NOTE: It's not possible to have reified inside a virtual function // NOTE: It's not possible to have reified inside a virtual function
// https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function // https://stackoverflow.com/questions/60037849/kotlin-reified-generic-in-virtual-function
@Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE") @Suppress("NON_PUBLIC_CALL_FROM_PUBLIC_INLINE")
return rmiGlobalSupport.getGlobalRemoteObject(getConnection(), objectId, Iface::class.java) return rmiGlobalSupport.getGlobalRemoteObject(connection, objectId, Iface::class.java)
} }
} }

View File

@ -23,7 +23,6 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.connection.connectionType.ConnectionProperties
import dorkbox.network.connection.connectionType.ConnectionRule import dorkbox.network.connection.connectionType.ConnectionRule
import dorkbox.network.handshake.ServerHandshake import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
@ -36,7 +35,6 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.net.InetSocketAddress
import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.CopyOnWriteArrayList
/** /**
@ -68,14 +66,15 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
} }
} }
/** /**
* @return true if this server has successfully bound to an IP address and is running * @return true if this server has successfully bound to an IP address and is running
*/ */
@Volatile @Volatile
private var bindAlreadyCalled = false private var bindAlreadyCalled = false
/**
* Used for handshake connections
*/
private val handshake = ServerHandshake(logger, config, listenerManager) private val handshake = ServerHandshake(logger, config, listenerManager)
/** /**
@ -403,32 +402,32 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
/** // /**
* Only called by the server! // * Only called by the server!
* // *
* If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic. // * If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic.
*/ // */
// after the handshake, what sort of connection do we want (NONE, COMPRESS, ENCRYPT+COMPRESS) // // after the handshake, what sort of connection do we want (NONE, COMPRESS, ENCRYPT+COMPRESS)
fun getConnectionUpgradeType(remoteAddress: InetSocketAddress): Byte { // fun getConnectionUpgradeType(remoteAddress: InetSocketAddress): Byte {
val address = remoteAddress.address // val address = remoteAddress.address
val size = connectionRules.size // val size = connectionRules.size
//
// if it's unknown, then by default we encrypt the traffic // // if it's unknown, then by default we encrypt the traffic
var connectionType = ConnectionProperties.COMPRESS_AND_ENCRYPT // var connectionType = ConnectionProperties.COMPRESS_AND_ENCRYPT
if (size == 0 && address == IPv4.LOCALHOST) { // if (size == 0 && address == IPv4.LOCALHOST) {
// if nothing is specified, then by default localhost is compression and everything else is encrypted // // if nothing is specified, then by default localhost is compression and everything else is encrypted
connectionType = ConnectionProperties.COMPRESS // connectionType = ConnectionProperties.COMPRESS
} // }
for (i in 0 until size) { // for (i in 0 until size) {
val rule = connectionRules[i] ?: continue // val rule = connectionRules[i] ?: continue
if (rule.matches(remoteAddress)) { // if (rule.matches(remoteAddress)) {
connectionType = rule.ruleType() // connectionType = rule.ruleType()
break // break
} // }
} // }
logger.debug("Validating {} Permitted type is: {}", remoteAddress, connectionType) // logger.debug("Validating {} Permitted type is: {}", remoteAddress, connectionType)
return connectionType.type // return connectionType.type
} // }
// RMI notes (in multiple places, copypasta, because this is confusing if not written down) // RMI notes (in multiple places, copypasta, because this is confusing if not written down)

View File

@ -296,6 +296,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
if (type == Server::class.java || !isRunning()) { if (type == Server::class.java || !isRunning()) {
// the server always creates a the media driver. // the server always creates a the media driver.
mediaDriver = try { mediaDriver = try {
logger.debug { "Starting Aeron Media driver..."}
MediaDriver.launch(mediaDriverContext) MediaDriver.launch(mediaDriverContext)
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(e) listenerManager.notifyError(e)
@ -510,7 +511,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* *
* @return the message * @return the message
*/ */
fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
try { try {
val message = serialization.readMessage(buffer, offset, length) val message = serialization.readMessage(buffer, offset, length)
logger.trace { logger.trace {
@ -540,7 +541,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* @param header The aeron header information * @param header The aeron header information
* @param connection The connection this message happened on * @param connection The connection this message happened on
*/ */
fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) { internal fun processMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header, connection: Connection) {
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
connection as CONNECTION connection as CONNECTION
@ -617,7 +618,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
// NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine! // NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine!
suspend fun send(message: Any, publication: Publication, connection: Connection) { internal suspend fun send(message: Any, publication: Publication, connection: Connection) {
// The sessionId is globally unique, and is assigned by the server. // The sessionId is globally unique, and is assigned by the server.
logger.trace { logger.trace {
"[${publication.sessionId()}] send: $message" "[${publication.sessionId()}] send: $message"

View File

@ -88,14 +88,14 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
val cryptInput = crypto.cryptInput val cryptInput = crypto.cryptInput
cryptInput.buffer = message.registrationData cryptInput.buffer = message.registrationData
val sessionId = cryptInput.readInt() val sessId = cryptInput.readInt()
val streamSubId = cryptInput.readInt() val streamSubId = cryptInput.readInt()
val streamPubId = cryptInput.readInt() val streamPubId = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt() val regDetailsSize = cryptInput.readInt()
val regDetails = cryptInput.readBytes(regDetailsSize) val regDetails = cryptInput.readBytes(regDetailsSize)
// now read data off // now read data off
connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId, connectionHelloInfo = ClientConnectionInfo(sessionId = sessId,
subscriptionPort = streamSubId, subscriptionPort = streamSubId,
publicationPort = streamPubId, publicationPort = streamPubId,
kryoRegistrationDetails = regDetails) kryoRegistrationDetails = regDetails)
@ -104,12 +104,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
connectionDone = true connectionDone = true
} }
else -> { else -> {
if (message.state != HandshakeMessage.HELLO_ACK) { failed = ClientException("[$sessionId] ignored message that is ${HandshakeMessage.toStateString(message.state)}")
failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
}
else if (message.state != HandshakeMessage.DONE_ACK) {
failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
} }
} }
} }
@ -162,18 +157,18 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
return connectionHelloInfo!! return connectionHelloInfo!!
} }
suspend fun handshakeDone(mediaConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient() val registrationMessage = HandshakeMessage.doneFromClient()
// Send the done message to the server. // Send the done message to the server.
endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage) endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
// block until we receive the connection information from the server // block until we receive the connection information from the server
failed = null failed = null
var pollCount: Int var pollCount: Int
val subscription = mediaConnection.subscription val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy val pollIdleStrategy = endPoint.config.pollIdleStrategy
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
@ -184,7 +179,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
if (failed != null) { if (failed != null) {
// no longer necessary to hold this connection open // no longer necessary to hold this connection open
mediaConnection.close() handshakeConnection.close()
throw failed as Exception throw failed as Exception
} }
@ -196,9 +191,10 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
pollIdleStrategy.idle(pollCount) pollIdleStrategy.idle(pollCount)
} }
// no longer necessary to hold this connection open
handshakeConnection.close()
if (!connectionDone) { if (!connectionDone) {
// no longer necessary to hold this connection open
mediaConnection.close()
throw ClientTimedOutException("Waiting for registration response from server") throw ClientTimedOutException("Waiting for registration response from server")
} }

View File

@ -98,18 +98,22 @@ internal class HandshakeMessage private constructor() {
error.errorMessage = errorMessage error.errorMessage = errorMessage
return error return error
} }
fun toStateString(state: Int) : String {
return when(state) {
INVALID -> "INVALID"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC"
DONE -> "DONE"
DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"
}
}
} }
override fun toString(): String { override fun toString(): String {
val stateStr = when(state) { val stateStr = toStateString(state)
INVALID -> "INVALID"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC"
DONE -> "DONE"
DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"
}
val errorMsg = if (errorMessage == null) { val errorMsg = if (errorMessage == null) {
"" ""