Fixed issues with connection handshake with the same computer, on multiple clients.

This commit is contained in:
nathan 2020-09-23 17:08:25 +02:00
parent 8605040819
commit 93e406289c
7 changed files with 196 additions and 120 deletions

View File

@ -115,6 +115,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected * @throws ClientRejectedException if the client connection is rejected
*/ */
@Suppress("BlockingMethodInNonBlockingContext")
suspend fun connect(remoteAddress: String, suspend fun connect(remoteAddress: String,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
when { when {
@ -270,14 +271,12 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
val handshake = ClientHandshake(config, crypto, this) val handshake = ClientHandshake(crypto, this)
val handshakeConnection = if (autoChangeToIpc || canUseIPC) { val handshakeConnection = if (autoChangeToIpc || canUseIPC) {
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead // MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead
val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId, val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId,
streamId = ipcPublicationId, streamId = ipcPublicationId,
sessionId = AeronConfig.RESERVED_SESSION_ID_INVALID, sessionId = AeronConfig.RESERVED_SESSION_ID_INVALID)
// "fast" connection timeout, since this is IPC
connectionTimeoutMS = 1000)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
try { try {
@ -356,14 +355,13 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
IpcMediaDriverConnection(sessionId = connectionInfo.sessionId, IpcMediaDriverConnection(sessionId = connectionInfo.sessionId,
// NOTE: pub/sub must be switched! // NOTE: pub/sub must be switched!
streamIdSubscription = connectionInfo.publicationPort, streamIdSubscription = connectionInfo.publicationPort,
streamId = connectionInfo.subscriptionPort, streamId = connectionInfo.subscriptionPort)
connectionTimeoutMS = connectionTimeoutMS)
} }
else { else {
UdpMediaDriverConnection(address = handshakeConnection.address!!, UdpMediaDriverConnection(address = handshakeConnection.address!!,
// NOTE: pub/sub must be switched! // NOTE: pub/sub must be switched!
subscriptionPort = connectionInfo.publicationPort,
publicationPort = connectionInfo.subscriptionPort, publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId, streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId, sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS, connectionTimeoutMS = connectionTimeoutMS,
@ -473,7 +471,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val pollIdleStrategy = config.pollIdleStrategy val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) { while (!isShutdown()) {
if (newConnection.isClosed()) { if (newConnection.isClosedViaAeron()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug {"[${newConnection.id}] connection expired"} logger.debug {"[${newConnection.id}] connection expired"}

View File

@ -27,7 +27,9 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connectionType.ConnectionRule import dorkbox.network.connectionType.ConnectionRule
import dorkbox.network.coroutines.SuspendWaiter import dorkbox.network.coroutines.SuspendWaiter
import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ServerException import dorkbox.network.exceptions.ServerException
import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.handshake.ServerHandshake import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RemoteObjectStorage
@ -167,6 +169,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val sessionId = header.sessionId() val sessionId = header.sessionId()
val message = readHandshakeMessage(buffer, offset, length, header) val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processIpcHandshakeMessageServer(this@Server, handshake.processIpcHandshakeMessageServer(this@Server,
publication, publication,
sessionId, sessionId,
@ -234,6 +247,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val message = readHandshakeMessage(buffer, offset, length, header) val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server, handshake.processUdpHandshakeMessageServer(this@Server,
publication, publication,
sessionId, sessionId,
@ -304,6 +328,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val message = readHandshakeMessage(buffer, offset, length, header) val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server, handshake.processUdpHandshakeMessageServer(this@Server,
publication, publication,
sessionId, sessionId,
@ -374,6 +409,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val message = readHandshakeMessage(buffer, offset, length, header) val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server, handshake.processUdpHandshakeMessageServer(this@Server,
publication, publication,
sessionId, sessionId,
@ -464,7 +510,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// this manages existing clients (for cleanup + connection polling). This has a concurrent iterator, // this manages existing clients (for cleanup + connection polling). This has a concurrent iterator,
// so we can modify this as we go // so we can modify this as we go
connections.forEach { connection -> connections.forEach { connection ->
if (connection.isClosed()) { if (connection.isClosedViaAeron()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug { "[${connection.id}] connection expired" } logger.debug { "[${connection.id}] connection expired" }

View File

@ -31,18 +31,14 @@ import java.net.Inet4Address
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
interface MediaDriverConnection : AutoCloseable { abstract class MediaDriverConnection(val address: InetAddress?,
val address: InetAddress? val publicationPort: Int, val subscriptionPort: Int,
val streamId: Int val streamId: Int, val sessionId: Int,
val sessionId: Int val connectionTimeoutMS: Long, val isReliable: Boolean) : AutoCloseable {
val subscriptionPort: Int lateinit var subscription: Subscription
val publicationPort: Int lateinit var publication: Publication
val subscription: Subscription
val publication: Publication
val isReliable: Boolean
suspend fun addSubscriptionWithRetry(aeron: Aeron, uri: String, streamId: Int, logger: KLogger): Subscription { suspend fun addSubscriptionWithRetry(aeron: Aeron, uri: String, streamId: Int, logger: KLogger): Subscription {
// If we start/stop too quickly, we might have the address already in use! Retry a few times. // If we start/stop too quickly, we might have the address already in use! Retry a few times.
@ -79,27 +75,25 @@ interface MediaDriverConnection : AutoCloseable {
} }
@Throws(ClientTimedOutException::class) @Throws(ClientTimedOutException::class)
suspend fun buildClient(aeron: Aeron, logger: KLogger) abstract suspend fun buildClient(aeron: Aeron, logger: KLogger)
suspend fun buildServer(aeron: Aeron, logger: KLogger) abstract suspend fun buildServer(aeron: Aeron, logger: KLogger)
fun clientInfo() : String abstract fun clientInfo() : String
fun serverInfo() : String abstract fun serverInfo() : String
} }
/** /**
* For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER. * For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER.
* A connection timeout of 0, means to wait forever * A connection timeout of 0, means to wait forever
*/ */
class UdpMediaDriverConnection(override val address: InetAddress, class UdpMediaDriverConnection(address: InetAddress,
override val publicationPort: Int, publicationPort: Int,
override val subscriptionPort: Int, subscriptionPort: Int,
override val streamId: Int, streamId: Int,
override val sessionId: Int, sessionId: Int,
private val connectionTimeoutMS: Long = 0, connectionTimeoutMS: Long = 0,
override val isReliable: Boolean = true) : MediaDriverConnection { isReliable: Boolean = true) :
MediaDriverConnection(address, publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
override lateinit var subscription: Subscription
override lateinit var publication: Publication
var success: Boolean = false var success: Boolean = false
@ -163,7 +157,7 @@ class UdpMediaDriverConnection(override val address: InetAddress,
val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS) val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS)
var startTime = System.nanoTime() var startTime = System.nanoTime()
while (timoutInNanos == 0L || System.nanoTime() - startTime < timoutInNanos) { while (timoutInNanos == 0L || System.nanoTime() - startTime < timoutInNanos) {
if (subscription.isConnected && subscription.imageCount() > 0) { if (subscription.isConnected) {
success = true success = true
break break
} }
@ -198,8 +192,8 @@ class UdpMediaDriverConnection(override val address: InetAddress,
this.success = true this.success = true
this.subscription = subscription
this.publication = publication this.publication = publication
this.subscription = subscription
} }
override suspend fun buildServer(aeron: Aeron, logger: KLogger) { override suspend fun buildServer(aeron: Aeron, logger: KLogger) {
@ -236,6 +230,8 @@ class UdpMediaDriverConnection(override val address: InetAddress,
} }
override fun clientInfo(): String { override fun clientInfo(): String {
address!!
return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) { return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) {
"Connecting to ${IP.toString(address)} [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)" "Connecting to ${IP.toString(address)} [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
} else { } else {
@ -251,7 +247,7 @@ class UdpMediaDriverConnection(override val address: InetAddress,
IPv4.WILDCARD.hostAddress + "/" + address.hostAddress IPv4.WILDCARD.hostAddress + "/" + address.hostAddress
} }
} else { } else {
IP.toString(address) IP.toString(address!!)
} }
return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) { return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) {
@ -275,20 +271,13 @@ class UdpMediaDriverConnection(override val address: InetAddress,
/** /**
* For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER * For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER
* NOTE: IPC connection will ALWAYS have a timeout of 1 second to connect. This is IPC, it should connect fast
*/ */
class IpcMediaDriverConnection(override val streamId: Int, class IpcMediaDriverConnection(streamId: Int,
val streamIdSubscription: Int, val streamIdSubscription: Int,
override val sessionId: Int, sessionId: Int,
private val connectionTimeoutMS: Long = 30_000, ) :
) : MediaDriverConnection { MediaDriverConnection(null, 0, 0, streamId, sessionId, 1_000, true) {
override val address: InetAddress? = null
override val isReliable = true
override val subscriptionPort = 0
override val publicationPort = 0
override lateinit var subscription: Subscription
override lateinit var publication: Publication
var success: Boolean = false var success: Boolean = false
@ -301,7 +290,11 @@ class IpcMediaDriverConnection(override val streamId: Int,
return builder return builder
} }
@Throws(ClientTimedOutException::class) /**
* Set up the subscription + publication channels to the server
*
* @throws ClientTimedOutException if we cannot connect to the server in the designated time
*/
override suspend fun buildClient(aeron: Aeron, logger: KLogger) { override suspend fun buildClient(aeron: Aeron, logger: KLogger) {
// Create a publication at the given address and port, using the given stream ID. // Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
@ -366,6 +359,9 @@ class IpcMediaDriverConnection(override val streamId: Int,
this.subscription = subscription this.subscription = subscription
} }
/**
* Setup the subscription + publication channels on the server
*/
override suspend fun buildServer(aeron: Aeron, logger: KLogger) { override suspend fun buildServer(aeron: Aeron, logger: KLogger) {
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.

View File

@ -293,16 +293,59 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
} }
/** /**
* @return `true` if this connection has been closed * Adds a function that will be called when a client/server "disconnects" with
* each other
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/ */
fun isClosed(): Boolean { suspend fun onDisconnect(function: suspend (Connection) -> Unit) {
val hasNoImages = subscription.hasNoImages() // make sure we atomically create the listener manager, if necessary
if (hasNoImages) { listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onDisconnect(function)
}
/**
* Adds a function that will be called only for this connection, when a client/server receives a message
*/
suspend fun <MESSAGE> onMessage(function: suspend (Connection, MESSAGE) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onMessage(function)
}
/**
* Invoked when a message object was received from a remote peer.
*
* This is ALWAYS called on a new dispatch
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false
}
/**
* We must account for network blips. They blips will be recovered by aeron, but we want to make sure that we are actually
* disconnected for a set period of time before we start the close process for a connection
*
* @return `true` if this connection has been closed via aeron
*/
fun isClosedViaAeron(): Boolean {
val isNotConnected = !subscription.isConnected && !publication.isConnected
if (isNotConnected) {
// 1) connections take a little bit of time from polling -> connecting (because of how we poll connections before 'connecting' them). // 1) connections take a little bit of time from polling -> connecting (because of how we poll connections before 'connecting' them).
return System.nanoTime() - connectionInitTime >= TimeUnit.SECONDS.toNanos(1) return System.nanoTime() - connectionInitTime >= TimeUnit.SECONDS.toNanos(1)
} }
return hasNoImages return isNotConnected
} }
/** /**
@ -370,47 +413,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
} }
} }
/**
* Adds a function that will be called when a client/server "disconnects" with
* each other
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
suspend fun onDisconnect(function: suspend (Connection) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onDisconnect(function)
}
/**
* Adds a function that will be called only for this connection, when a client/server receives a message
*/
suspend fun <MESSAGE> onMessage(function: suspend (Connection, MESSAGE) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onMessage(function)
}
/**
* Invoked when a message object was received from a remote peer.
*
* This is ALWAYS called on a new dispatch
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false
}
// //
// //
// Generic object methods // Generic object methods

View File

@ -42,6 +42,9 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
@Volatile @Volatile
private var connectionDone = false private var connectionDone = false
@Volatile
private var needToRetry = false
@Volatile @Volatile
private var failed: Exception? = null private var failed: Exception? = null
@ -64,11 +67,18 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
} }
// this is an error message // this is an error message
if (message.sessionId == 0) { if (message.state == HandshakeMessage.INVALID) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}") failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@FragmentAssembler return@FragmentAssembler
} }
// this is an retry message
// this can happen if there are multiple connections from the SAME ip address (ie: localhost)
if (message.state == HandshakeMessage.RETRY) {
needToRetry = true
return@FragmentAssembler
}
if (this@ClientHandshake.sessionId != message.sessionId) { if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " + failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " +
@ -172,7 +182,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
val subscription = handshakeConnection.subscription val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy val pollIdleStrategy = endPoint.config.pollIdleStrategy
val startTime = System.currentTimeMillis() var startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment. // NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
@ -184,6 +194,13 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
throw failed as Exception throw failed as Exception
} }
if (needToRetry) {
needToRetry = false
// start over with the timeout!
startTime = System.currentTimeMillis()
}
if (connectionDone) { if (connectionDone) {
break break
} }

View File

@ -50,7 +50,8 @@ internal class HandshakeMessage private constructor() {
var registrationRmiIdData: IntArray? = null var registrationRmiIdData: IntArray? = null
companion object { companion object {
const val INVALID = -1 const val INVALID = -2
const val RETRY = -1
const val HELLO = 0 const val HELLO = 0
const val HELLO_ACK = 1 const val HELLO_ACK = 1
const val HELLO_ACK_IPC = 2 const val HELLO_ACK_IPC = 2
@ -99,9 +100,17 @@ internal class HandshakeMessage private constructor() {
return error return error
} }
fun retry(errorMessage: String): HandshakeMessage {
val error = HandshakeMessage()
error.state = RETRY
error.errorMessage = errorMessage
return error
}
fun toStateString(state: Int) : String { fun toStateString(state: Int) : String {
return when(state) { return when(state) {
INVALID -> "INVALID" INVALID -> "INVALID"
RETRY -> "RETRY"
HELLO -> "HELLO" HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK" HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC" HELLO_ACK_IPC -> "HELLO_ACK_IPC"

View File

@ -30,11 +30,13 @@ import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.exceptions.AllocationException import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientRejectedException import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.network.exceptions.ServerException import dorkbox.network.exceptions.ServerException
import io.aeron.Aeron import io.aeron.Aeron
import io.aeron.Publication import io.aeron.Publication
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
@ -42,6 +44,7 @@ import java.net.Inet4Address
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write import kotlin.concurrent.write
@ -78,21 +81,32 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
/** /**
* @return true if we should continue parsing the incoming message, false if we should abort * @return true if we should continue parsing the incoming message, false if we should abort
*/ */
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD. ONLY RESPONSES ARE ON ACTION DISPATCH!
private fun validateMessageTypeAndDoPending(server: Server<CONNECTION>, private fun validateMessageTypeAndDoPending(server: Server<CONNECTION>,
actionDispatch: CoroutineScope,
handshakePublication: Publication, handshakePublication: Publication,
message: Any?, message: HandshakeMessage,
sessionId: Int, sessionId: Int,
connectionString: String): Boolean { connectionString: String): Boolean {
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // check to see if this sessionId is ALREADY in use by another connection!
if (message !is HandshakeMessage) { // this can happen if there are multiple connections from the SAME ip address (ie: localhost)
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request")) if (message.state == HandshakeMessage.HELLO) {
val hasExistingSessionId = pendingConnectionsLock.read {
runBlocking { pendingConnections.getIfPresent(sessionId) != null
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
} }
return false
if (hasExistingSessionId) {
// WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId
listenerManager.notifyError(ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry."))
actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
}
return false
}
return true
} }
// check to see if this is a pending connection // check to see if this is a pending connection
@ -112,11 +126,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
server.addConnection(pendingConnection) server.addConnection(pendingConnection)
// now tell the client we are done // now tell the client we are done
runBlocking { actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
}
server.actionDispatch.launch {
// this must be THE ONLY THING in this class to use the action dispatch!
listenerManager.notifyConnect(pendingConnection) listenerManager.notifyConnect(pendingConnection)
} }
} }
@ -178,15 +190,14 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
fun processIpcHandshakeMessageServer(server: Server<CONNECTION>, fun processIpcHandshakeMessageServer(server: Server<CONNECTION>,
handshakePublication: Publication, handshakePublication: Publication,
sessionId: Int, sessionId: Int,
message: Any?, message: HandshakeMessage,
aeron: Aeron) { aeron: Aeron) {
val connectionString = "IPC" val connectionString = "IPC"
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, connectionString)) { if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, connectionString)) {
return return
} }
message as HandshakeMessage
val serialization = config.serialization val serialization = config.serialization
@ -242,11 +253,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// create a new connection. The session ID is encrypted. // create a new connection. The session ID is encrypted.
try { try {
// connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = IpcMediaDriverConnection(streamId = connectionStreamPubId, val clientConnection = IpcMediaDriverConnection(streamId = connectionStreamPubId,
streamIdSubscription = connectionStreamSubId, streamIdSubscription = connectionStreamSubId,
sessionId = connectionSessionId, sessionId = connectionSessionId)
connectionTimeoutMS = 0)
// we have to construct how the connection will communicate! // we have to construct how the connection will communicate!
runBlocking { runBlocking {
@ -332,14 +341,13 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionId: Int, sessionId: Int,
clientAddressString: String, clientAddressString: String,
clientAddress: InetAddress, clientAddress: InetAddress,
message: Any?, message: HandshakeMessage,
aeron: Aeron, aeron: Aeron,
isIpv6Wildcard: Boolean) { isIpv6Wildcard: Boolean) {
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) { if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, clientAddressString)) {
return return
} }
message as HandshakeMessage
val clientPublicKeyBytes = message.publicKey val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState val validateRemoteAddress: PublicKeyValidationState