Fixed issues with connection handshake with the same computer, on multiple clients.

This commit is contained in:
nathan 2020-09-23 17:08:25 +02:00
parent 8605040819
commit 93e406289c
7 changed files with 196 additions and 120 deletions

View File

@ -115,6 +115,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("BlockingMethodInNonBlockingContext")
suspend fun connect(remoteAddress: String,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
when {
@ -270,14 +271,12 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
val handshake = ClientHandshake(config, crypto, this)
val handshake = ClientHandshake(crypto, this)
val handshakeConnection = if (autoChangeToIpc || canUseIPC) {
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead
val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId,
streamId = ipcPublicationId,
sessionId = AeronConfig.RESERVED_SESSION_ID_INVALID,
// "fast" connection timeout, since this is IPC
connectionTimeoutMS = 1000)
sessionId = AeronConfig.RESERVED_SESSION_ID_INVALID)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
try {
@ -356,14 +355,13 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
IpcMediaDriverConnection(sessionId = connectionInfo.sessionId,
// NOTE: pub/sub must be switched!
streamIdSubscription = connectionInfo.publicationPort,
streamId = connectionInfo.subscriptionPort,
connectionTimeoutMS = connectionTimeoutMS)
streamId = connectionInfo.subscriptionPort)
}
else {
UdpMediaDriverConnection(address = handshakeConnection.address!!,
// NOTE: pub/sub must be switched!
subscriptionPort = connectionInfo.publicationPort,
publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
@ -473,7 +471,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
if (newConnection.isClosed()) {
if (newConnection.isClosedViaAeron()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug {"[${newConnection.id}] connection expired"}

View File

@ -27,7 +27,9 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connectionType.ConnectionRule
import dorkbox.network.coroutines.SuspendWaiter
import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ServerException
import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
@ -167,6 +169,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val sessionId = header.sessionId()
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processIpcHandshakeMessageServer(this@Server,
publication,
sessionId,
@ -234,6 +247,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server,
publication,
sessionId,
@ -304,6 +328,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server,
publication,
sessionId,
@ -374,6 +409,17 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server,
publication,
sessionId,
@ -464,7 +510,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// this manages existing clients (for cleanup + connection polling). This has a concurrent iterator,
// so we can modify this as we go
connections.forEach { connection ->
if (connection.isClosed()) {
if (connection.isClosedViaAeron()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug { "[${connection.id}] connection expired" }

View File

@ -31,18 +31,14 @@ import java.net.Inet4Address
import java.net.InetAddress
import java.util.concurrent.TimeUnit
interface MediaDriverConnection : AutoCloseable {
val address: InetAddress?
val streamId: Int
val sessionId: Int
abstract class MediaDriverConnection(val address: InetAddress?,
val publicationPort: Int, val subscriptionPort: Int,
val streamId: Int, val sessionId: Int,
val connectionTimeoutMS: Long, val isReliable: Boolean) : AutoCloseable {
val subscriptionPort: Int
val publicationPort: Int
lateinit var subscription: Subscription
lateinit var publication: Publication
val subscription: Subscription
val publication: Publication
val isReliable: Boolean
suspend fun addSubscriptionWithRetry(aeron: Aeron, uri: String, streamId: Int, logger: KLogger): Subscription {
// If we start/stop too quickly, we might have the address already in use! Retry a few times.
@ -79,27 +75,25 @@ interface MediaDriverConnection : AutoCloseable {
}
@Throws(ClientTimedOutException::class)
suspend fun buildClient(aeron: Aeron, logger: KLogger)
suspend fun buildServer(aeron: Aeron, logger: KLogger)
abstract suspend fun buildClient(aeron: Aeron, logger: KLogger)
abstract suspend fun buildServer(aeron: Aeron, logger: KLogger)
fun clientInfo() : String
fun serverInfo() : String
abstract fun clientInfo() : String
abstract fun serverInfo() : String
}
/**
* For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER.
* A connection timeout of 0, means to wait forever
*/
class UdpMediaDriverConnection(override val address: InetAddress,
override val publicationPort: Int,
override val subscriptionPort: Int,
override val streamId: Int,
override val sessionId: Int,
private val connectionTimeoutMS: Long = 0,
override val isReliable: Boolean = true) : MediaDriverConnection {
override lateinit var subscription: Subscription
override lateinit var publication: Publication
class UdpMediaDriverConnection(address: InetAddress,
publicationPort: Int,
subscriptionPort: Int,
streamId: Int,
sessionId: Int,
connectionTimeoutMS: Long = 0,
isReliable: Boolean = true) :
MediaDriverConnection(address, publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
var success: Boolean = false
@ -163,7 +157,7 @@ class UdpMediaDriverConnection(override val address: InetAddress,
val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS)
var startTime = System.nanoTime()
while (timoutInNanos == 0L || System.nanoTime() - startTime < timoutInNanos) {
if (subscription.isConnected && subscription.imageCount() > 0) {
if (subscription.isConnected) {
success = true
break
}
@ -198,8 +192,8 @@ class UdpMediaDriverConnection(override val address: InetAddress,
this.success = true
this.subscription = subscription
this.publication = publication
this.subscription = subscription
}
override suspend fun buildServer(aeron: Aeron, logger: KLogger) {
@ -236,6 +230,8 @@ class UdpMediaDriverConnection(override val address: InetAddress,
}
override fun clientInfo(): String {
address!!
return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) {
"Connecting to ${IP.toString(address)} [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
} else {
@ -251,7 +247,7 @@ class UdpMediaDriverConnection(override val address: InetAddress,
IPv4.WILDCARD.hostAddress + "/" + address.hostAddress
}
} else {
IP.toString(address)
IP.toString(address!!)
}
return if (sessionId != AeronConfig.RESERVED_SESSION_ID_INVALID) {
@ -275,20 +271,13 @@ class UdpMediaDriverConnection(override val address: InetAddress,
/**
* For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER
* NOTE: IPC connection will ALWAYS have a timeout of 1 second to connect. This is IPC, it should connect fast
*/
class IpcMediaDriverConnection(override val streamId: Int,
class IpcMediaDriverConnection(streamId: Int,
val streamIdSubscription: Int,
override val sessionId: Int,
private val connectionTimeoutMS: Long = 30_000,
) : MediaDriverConnection {
override val address: InetAddress? = null
override val isReliable = true
override val subscriptionPort = 0
override val publicationPort = 0
override lateinit var subscription: Subscription
override lateinit var publication: Publication
sessionId: Int,
) :
MediaDriverConnection(null, 0, 0, streamId, sessionId, 1_000, true) {
var success: Boolean = false
@ -301,7 +290,11 @@ class IpcMediaDriverConnection(override val streamId: Int,
return builder
}
@Throws(ClientTimedOutException::class)
/**
* Set up the subscription + publication channels to the server
*
* @throws ClientTimedOutException if we cannot connect to the server in the designated time
*/
override suspend fun buildClient(aeron: Aeron, logger: KLogger) {
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
@ -366,6 +359,9 @@ class IpcMediaDriverConnection(override val streamId: Int,
this.subscription = subscription
}
/**
* Setup the subscription + publication channels on the server
*/
override suspend fun buildServer(aeron: Aeron, logger: KLogger) {
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.

View File

@ -293,16 +293,59 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
}
/**
* @return `true` if this connection has been closed
* Adds a function that will be called when a client/server "disconnects" with
* each other
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
fun isClosed(): Boolean {
val hasNoImages = subscription.hasNoImages()
if (hasNoImages) {
suspend fun onDisconnect(function: suspend (Connection) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onDisconnect(function)
}
/**
* Adds a function that will be called only for this connection, when a client/server receives a message
*/
suspend fun <MESSAGE> onMessage(function: suspend (Connection, MESSAGE) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onMessage(function)
}
/**
* Invoked when a message object was received from a remote peer.
*
* This is ALWAYS called on a new dispatch
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false
}
/**
* We must account for network blips. They blips will be recovered by aeron, but we want to make sure that we are actually
* disconnected for a set period of time before we start the close process for a connection
*
* @return `true` if this connection has been closed via aeron
*/
fun isClosedViaAeron(): Boolean {
val isNotConnected = !subscription.isConnected && !publication.isConnected
if (isNotConnected) {
// 1) connections take a little bit of time from polling -> connecting (because of how we poll connections before 'connecting' them).
return System.nanoTime() - connectionInitTime >= TimeUnit.SECONDS.toNanos(1)
}
return hasNoImages
return isNotConnected
}
/**
@ -370,47 +413,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
}
}
/**
* Adds a function that will be called when a client/server "disconnects" with
* each other
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
suspend fun onDisconnect(function: suspend (Connection) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onDisconnect(function)
}
/**
* Adds a function that will be called only for this connection, when a client/server receives a message
*/
suspend fun <MESSAGE> onMessage(function: suspend (Connection, MESSAGE) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager()
}
listenerManager.value!!.onMessage(function)
}
/**
* Invoked when a message object was received from a remote peer.
*
* This is ALWAYS called on a new dispatch
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false
}
//
//
// Generic object methods

View File

@ -42,6 +42,9 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
@Volatile
private var connectionDone = false
@Volatile
private var needToRetry = false
@Volatile
private var failed: Exception? = null
@ -64,11 +67,18 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
}
// this is an error message
if (message.sessionId == 0) {
if (message.state == HandshakeMessage.INVALID) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@FragmentAssembler
}
// this is an retry message
// this can happen if there are multiple connections from the SAME ip address (ie: localhost)
if (message.state == HandshakeMessage.RETRY) {
needToRetry = true
return@FragmentAssembler
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " +
@ -172,7 +182,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
val startTime = System.currentTimeMillis()
var startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
@ -184,6 +194,13 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
throw failed as Exception
}
if (needToRetry) {
needToRetry = false
// start over with the timeout!
startTime = System.currentTimeMillis()
}
if (connectionDone) {
break
}

View File

@ -50,7 +50,8 @@ internal class HandshakeMessage private constructor() {
var registrationRmiIdData: IntArray? = null
companion object {
const val INVALID = -1
const val INVALID = -2
const val RETRY = -1
const val HELLO = 0
const val HELLO_ACK = 1
const val HELLO_ACK_IPC = 2
@ -99,9 +100,17 @@ internal class HandshakeMessage private constructor() {
return error
}
fun retry(errorMessage: String): HandshakeMessage {
val error = HandshakeMessage()
error.state = RETRY
error.errorMessage = errorMessage
return error
}
fun toStateString(state: Int) : String {
return when(state) {
INVALID -> "INVALID"
RETRY -> "RETRY"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC"

View File

@ -30,11 +30,13 @@ import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.network.exceptions.ServerException
import io.aeron.Aeron
import io.aeron.Publication
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger
@ -42,6 +44,7 @@ import java.net.Inet4Address
import java.net.InetAddress
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
@ -78,21 +81,32 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
/**
* @return true if we should continue parsing the incoming message, false if we should abort
*/
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD. ONLY RESPONSES ARE ON ACTION DISPATCH!
private fun validateMessageTypeAndDoPending(server: Server<CONNECTION>,
actionDispatch: CoroutineScope,
handshakePublication: Publication,
message: Any?,
message: HandshakeMessage,
sessionId: Int,
connectionString: String): Boolean {
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request"))
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
// check to see if this sessionId is ALREADY in use by another connection!
// this can happen if there are multiple connections from the SAME ip address (ie: localhost)
if (message.state == HandshakeMessage.HELLO) {
val hasExistingSessionId = pendingConnectionsLock.read {
pendingConnections.getIfPresent(sessionId) != null
}
return false
if (hasExistingSessionId) {
// WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId
listenerManager.notifyError(ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry."))
actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
}
return false
}
return true
}
// check to see if this is a pending connection
@ -112,11 +126,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
server.addConnection(pendingConnection)
// now tell the client we are done
runBlocking {
actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
}
server.actionDispatch.launch {
// this must be THE ONLY THING in this class to use the action dispatch!
listenerManager.notifyConnect(pendingConnection)
}
}
@ -178,15 +190,14 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
fun processIpcHandshakeMessageServer(server: Server<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
message: Any?,
message: HandshakeMessage,
aeron: Aeron) {
val connectionString = "IPC"
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, connectionString)) {
if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, connectionString)) {
return
}
message as HandshakeMessage
val serialization = config.serialization
@ -242,11 +253,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// create a new connection. The session ID is encrypted.
try {
// connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = IpcMediaDriverConnection(streamId = connectionStreamPubId,
streamIdSubscription = connectionStreamSubId,
sessionId = connectionSessionId,
connectionTimeoutMS = 0)
sessionId = connectionSessionId)
// we have to construct how the connection will communicate!
runBlocking {
@ -332,14 +341,13 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionId: Int,
clientAddressString: String,
clientAddress: InetAddress,
message: Any?,
message: HandshakeMessage,
aeron: Aeron,
isIpv6Wildcard: Boolean) {
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) {
if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, clientAddressString)) {
return
}
message as HandshakeMessage
val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState