No longer using coroutines for adding publication/subscription and closing certain calsses. Better aeron error handling/reporting. Better aeron startup/shutdown. The pending connections cache no longer is ThreadSafe, and no longer is protected via RW lock.

This commit is contained in:
Robinson 2021-04-30 16:01:25 +02:00
parent 3f016672e6
commit cffca943f5
12 changed files with 297 additions and 281 deletions

View File

@ -25,14 +25,15 @@ import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientRejectedException import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.network.handshake.ClientHandshake import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.ping.Ping
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.TimeoutException import dorkbox.network.rmi.TimeoutException
import dorkbox.util.Sys import dorkbox.util.Sys
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import java.net.Inet4Address import java.net.Inet4Address
import java.net.Inet6Address import java.net.Inet6Address
import java.net.InetAddress import java.net.InetAddress
@ -78,9 +79,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper // This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper
// lock-stop ordering for how disconnect and connect work with each-other // lock-stop ordering for how disconnect and connect work with each-other
private val lockStepForReconnect = atomic<SuspendWaiter?>(null)
// GUARANTEE that the callbacks for 'onDisconnect' happens-before the 'onConnect'. // GUARANTEE that the callbacks for 'onDisconnect' happens-before the 'onConnect'.
private val lockStepForDispatch = atomic<SuspendWaiter?>(null) private val lockStepForConnect = atomic<SuspendWaiter?>(null)
final override fun newException(message: String, cause: Throwable?): Throwable { final override fun newException(message: String, cause: Throwable?): Throwable {
return ClientException(message, cause) return ClientException(message, cause)
@ -104,39 +104,54 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* - an InetAddress address * - an InetAddress address
* *
* ### For the IPC (Inter-Process-Communication) it must be: * ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY.
* - `connect()` * - `connect()`
* - `connect("")` * - `connect("")`
* - `connectIpc()`
* *
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * ### Case does not matter, and "localhost" is the default.
* *
* @param remoteAddress The network or if localhost, IPC address for the client to connect to * @param remoteAddress The network host or ip address
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable * @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?).
* *
* @throws IllegalArgumentException if the remote address is invalid * @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected * @throws ClientRejectedException if the client connection is rejected
*/ */
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
suspend fun connect(remoteAddress: String, fun connect(remoteAddress: String = "",
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { connectionTimeoutMS: Long = 30_000L,
reliable: Boolean = true) {
when { when {
// this is default IPC settings // this is default IPC settings
remoteAddress.isEmpty() -> connect(connectionTimeoutMS = connectionTimeoutMS) remoteAddress.isEmpty() -> {
connectIpc(connectionTimeoutMS = connectionTimeoutMS)
}
IPv4.isPreferred -> connect(remoteAddress = Inet4.toAddress(remoteAddress), IPv4.isPreferred -> {
connectionTimeoutMS = connectionTimeoutMS, connect(
reliable = reliable) remoteAddress = Inet4.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable
)
}
IPv6.isPreferred -> connect(remoteAddress = Inet6.toAddress(remoteAddress), IPv6.isPreferred -> {
connectionTimeoutMS = connectionTimeoutMS, connect(
reliable = reliable) remoteAddress = Inet6.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable
)
}
// if there is no preference, then try to connect via IPv4 // if there is no preference, then try to connect via IPv4
else -> connect(remoteAddress = Inet4.toAddress(remoteAddress), else -> {
connectionTimeoutMS = connectionTimeoutMS, connect(
reliable = reliable) remoteAddress = Inet4.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable
)
}
} }
} }
@ -151,22 +166,24 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* - an InetAddress address * - an InetAddress address
* *
* ### For the IPC (Inter-Process-Communication) it must be: * ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY.
* - `connect()` * - `connect()`
* - `connect("")` * - `connect("")`
* - `connectIpc()`
* *
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * ### Case does not matter, and "localhost" is the default.
* *
* @param remoteAddress The network or if localhost, IPC address for the client to connect to * @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable * @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?).
* *
* @throws IllegalArgumentException if the remote address is invalid * @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected * @throws ClientRejectedException if the client connection is rejected
*/ */
suspend fun connect(remoteAddress: InetAddress, fun connect(remoteAddress: InetAddress,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { connectionTimeoutMS: Long = 30_000L,
reliable: Boolean = true) {
// Default IPC ports are flipped because they are in the perspective of the SERVER // Default IPC ports are flipped because they are in the perspective of the SERVER
connect(remoteAddress = remoteAddress, connect(remoteAddress = remoteAddress,
ipcPublicationId = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, ipcPublicationId = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
@ -187,9 +204,9 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientRejectedException if the client connection is rejected * @throws ClientRejectedException if the client connection is rejected
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun connect(ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, fun connectIpc(ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB, ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L) { connectionTimeoutMS: Long = 30_000L) {
// Default IPC ports are flipped because they are in the perspective of the SERVER // Default IPC ports are flipped because they are in the perspective of the SERVER
require(ipcPublicationId != ipcSubscriptionId) { "IPC publication and subscription ports cannot be the same! The must match the server's configuration." } require(ipcPublicationId != ipcSubscriptionId) { "IPC publication and subscription ports cannot be the same! The must match the server's configuration." }
@ -208,42 +225,41 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* ### For a network address, it can be: * ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org") * - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1") * - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* - an InetAddress address
* *
* ### For the IPC (Inter-Process-Communication) it must be: * ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY. ie: just call `connect()` * - `connect()`
* - Specified EMPTY. ie: just call `connect()` * - `connect("")`
* - `connectIpc()`
* *
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * ### Case does not matter, and "localhost" is the default.
*
* ### Case does not matter, and "localhost" is the default.
* *
* @param remoteAddress The network or if localhost, IPC address for the client to connect to * @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param ipcPublicationId The IPC publication address for the client to connect to * @param ipcPublicationId The IPC publication address for the client to connect to
* @param ipcSubscriptionId The IPC subscription address for the client to connect to * @param ipcSubscriptionId The IPC subscription address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely. * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely.
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable * @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?).
* *
* @throws IllegalArgumentException if the remote address is invalid * @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected * @throws ClientRejectedException if the client connection is rejected
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
private suspend fun connect(remoteAddress: InetAddress? = null, private fun connect(remoteAddress: InetAddress? = null,
// Default IPC ports are flipped because they are in the perspective of the SERVER // Default IPC ports are flipped because they are in the perspective of the SERVER
ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB, ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB, ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { connectionTimeoutMS: Long = 30_000L,
reliable: Boolean = true) {
require(connectionTimeoutMS >= 0) { "connectionTimeoutMS '$connectionTimeoutMS' is invalid. It must be >=0" } require(connectionTimeoutMS >= 0) { "connectionTimeoutMS '$connectionTimeoutMS' is invalid. It must be >=0" }
// this will exist ONLY if we are reconnecting via a "disconnect" callback
lockStepForReconnect.value?.doWait()
if (isConnected) { if (isConnected) {
logger.error("Unable to connect when already connected!") logger.error("Unable to connect when already connected!")
return return
} }
lockStepForReconnect.lazySet(null)
// localhost/loopback IP might not always be 127.0.0.1 or ::1 // localhost/loopback IP might not always be 127.0.0.1 or ::1
this.remoteAddress0 = remoteAddress this.remoteAddress0 = remoteAddress
connection0 = null connection0 = null
@ -268,17 +284,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// only change LOCALHOST -> IPC if the media driver is ALREADY running LOCALLY! // only change LOCALHOST -> IPC if the media driver is ALREADY running LOCALLY!
var isUsingIPC = false var isUsingIPC = false
val canUseIPC = config.enableIpc && remoteAddress == null val canUseIPC = config.enableIpc
val autoChangeToIpc = canUseIPC && config.enableIpcForLoopback && val autoChangeToIpc = config.enableIpcForLoopback &&
remoteAddress != null && remoteAddress.isLoopbackAddress && aeronDriver.isRunning() (remoteAddress == null || remoteAddress.isLoopbackAddress) && aeronDriver.isRunning()
if (autoChangeToIpc) {
logger.info {"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" }
}
val handshake = ClientHandshake(crypto, this) val handshake = ClientHandshake(crypto, this)
val handshakeConnection = if (autoChangeToIpc || canUseIPC) { val handshakeConnection = if (autoChangeToIpc) {
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead logger.info {"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" }
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via network instead
val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId, val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId,
streamId = ipcPublicationId, streamId = ipcPublicationId,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID) sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID)
@ -288,8 +304,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
ipcConnection.buildClient(aeronDriver, logger) ipcConnection.buildClient(aeronDriver, logger)
isUsingIPC = true isUsingIPC = true
} catch (e: Exception) { } catch (e: Exception) {
// if we specified that we want to use IPC, then we have to throw the timeout exception, because there is no IPC // if we specified that we MUST use IPC, then we have to throw the exception, because there is no IPC
if (canUseIPC) { if (remoteAddress == null) {
throw e throw e
} }
} }
@ -301,7 +317,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// try a UDP connection instead // try a UDP connection instead
val udpConnection = UdpMediaDriverClientConnection( val udpConnection = UdpMediaDriverClientConnection(
address = this.remoteAddress0!!, address = remoteAddress!!,
publicationPort = config.subscriptionPort, publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort, subscriptionPort = config.publicationPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID, streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
@ -316,7 +332,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
else { else {
val test = UdpMediaDriverClientConnection( val test = UdpMediaDriverClientConnection(
address = this.remoteAddress0!!, address = remoteAddress!!,
publicationPort = config.subscriptionPort, publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort, subscriptionPort = config.publicationPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID, streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
@ -343,7 +359,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val validateRemoteAddress = if (isUsingIPC) { val validateRemoteAddress = if (isUsingIPC) {
PublicKeyValidationState.VALID PublicKeyValidationState.VALID
} else { } else {
crypto.validateRemoteAddress(this.remoteAddress0!!, connectionInfo.publicKey) crypto.validateRemoteAddress(remoteAddress!!, connectionInfo.publicKey)
} }
if (validateRemoteAddress == PublicKeyValidationState.INVALID) { if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -438,16 +454,14 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
////////////// //////////////
newConnection.preCloseAction = { newConnection.preCloseAction = {
// this is called whenever connection.close() is called by the framework or via client.close() // this is called whenever connection.close() is called by the framework or via client.close()
if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(connection, IllegalStateException("lockStep for reconnect was in the wrong state!"))
}
// on the client, we want to GUARANTEE that the disconnect happens-before the connect. // on the client, we want to GUARANTEE that the disconnect happens-before the connect.
if (!lockStepForDispatch.compareAndSet(null, SuspendWaiter())) { if (!lockStepForConnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(connection, IllegalStateException("lockStep for dispatch was in the wrong state!")) listenerManager.notifyError(connection, IllegalStateException("lockStep for onConnect was in the wrong state!"))
} }
} }
newConnection.postCloseAction = { newConnection.postCloseAction = {
isConnected = false
// this is called whenever connection.close() is called by the framework or via client.close() // this is called whenever connection.close() is called by the framework or via client.close()
// make sure to call our client.notifyDisconnect() callbacks // make sure to call our client.notifyDisconnect() callbacks
@ -462,10 +476,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
lockStepForDispatch.value?.cancel() lockStepForDispatch.value?.cancel()
} }
// in case notifyDisconnect called client.connect().... cancel them waiting
isConnected = false
lockStepForReconnect.value?.cancel()
} }
connection0 = newConnection connection0 = newConnection
@ -477,24 +487,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
if (canFinishConnecting) { if (canFinishConnecting) {
isConnected = true isConnected = true
// this forces the current thread to WAIT until poll system has started
val waiter = SuspendWaiter()
// have to make a new thread to listen for incoming data! // have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
@Suppress("EXPERIMENTAL_API_USAGE")
actionDispatch.launch(start = CoroutineStart.UNDISPATCHED) {
lockStepForDispatch.value?.doWait()
// NOTE: UNDISPATCHED means that this coroutine will start as an event loop, instead of concurrently
// we want this behavior INSTEAD OF automatically starting this on a new thread.
listenerManager.notifyConnect(newConnection)
lockStepForDispatch.lazySet(null)
}
// these have to be in two SEPARATE actionDispatch.launch commands.... otherwise... // these have to be in two SEPARATE actionDispatch.launch commands.... otherwise...
// if something inside of notifyConnect is blocking or suspends, then polling will never happen! // if something inside of notifyConnect is blocking or suspends, then polling will never happen!
actionDispatch.launch { actionDispatch.launch {
waiter.doNotify()
val pollIdleStrategy = config.pollIdleStrategy val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) { while (!isShutdown()) {
@ -502,8 +505,12 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug {"[${newConnection.id}] connection expired"} logger.debug {"[${newConnection.id}] connection expired"}
// NOTE: We do not shutdown the client!! The client is only closed by explicitly calling `client.close()` // eventloop is required, because we want to run this code AFTER the current coroutine has finished. This prevents
newConnection.close() // odd race conditions when a client is restarted
actionDispatch.eventLoop {
// NOTE: We do not shutdown the client!! The client is only closed by explicitly calling `client.close()`
newConnection.close()
}
return@launch return@launch
} }
else { else {
@ -515,6 +522,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
} }
} }
actionDispatch.eventLoop {
waiter.doWait()
lockStepForConnect.value?.doWait()
listenerManager.notifyConnect(newConnection)
lockStepForConnect.lazySet(null)
}
} else { } else {
close() close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")

View File

@ -36,7 +36,6 @@ import dorkbox.network.rmi.TimeoutException
import io.aeron.FragmentAssembler import io.aeron.FragmentAssembler
import io.aeron.Image import io.aeron.Image
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
@ -156,7 +155,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
return super.getRmiConnectionSupport() return super.getRmiConnectionSupport()
} }
private suspend fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { private fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIpc) { val poller = if (config.enableIpc) {
val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId, val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId,
streamId = config.ipcPublicationId, streamId = config.ipcPublicationId,
@ -180,9 +179,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request"))
actionDispatch.launch { writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler return@FragmentAssembler
} }
@ -190,7 +187,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
publication, publication,
sessionId, sessionId,
message, message,
aeronDriver) aeronDriver)
} }
override fun poll(): Int { return subscription.poll(handler, 1) } override fun poll(): Int { return subscription.poll(handler, 1) }
@ -210,7 +207,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
} }
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
private suspend fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { private fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (canUseIPv4) { val poller = if (canUseIPv4) {
val driver = UdpMediaDriverServerConnection( val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv4Address!!, listenAddress = listenIPv4Address!!,
@ -260,9 +257,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch { writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler return@FragmentAssembler
} }
@ -293,7 +288,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
} }
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
private suspend fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { private fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (canUseIPv6) { val poller = if (canUseIPv6) {
val driver = UdpMediaDriverServerConnection( val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv6Address!!, listenAddress = listenIPv6Address!!,
@ -343,9 +338,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch { writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler return@FragmentAssembler
} }
@ -376,7 +369,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
} }
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
private suspend fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller { private fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val driver = UdpMediaDriverServerConnection( val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv6Address!!, listenAddress = listenIPv6Address!!,
publicationPort = config.publicationPort, publicationPort = config.publicationPort,
@ -426,9 +419,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch { writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
return@FragmentAssembler return@FragmentAssembler
} }
@ -468,39 +459,39 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// we are done with initial configuration, now initialize aeron and the general state of this endpoint // we are done with initial configuration, now initialize aeron and the general state of this endpoint
bindAlreadyCalled = true bindAlreadyCalled = true
// this forces the current thread to WAIT until poll system has started
val waiter = SuspendWaiter() val waiter = SuspendWaiter()
actionDispatch.launch {
val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config)
// if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled! val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config)
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
val ipv4Poller: AeronPoller
val ipv6Poller: AeronPoller
if (isWildcard) { // if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled!
// IPv6 will bind to IPv4 wildcard as well!! val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
if (canUseIPv4 && canUseIPv6) { val ipv4Poller: AeronPoller
ipv4Poller = object : AeronPoller { val ipv6Poller: AeronPoller
override fun poll(): Int { return 0 }
override fun close() {} if (isWildcard) {
override fun serverInfo(): String { return "IPv4 Disabled" } // IPv6 will bind to IPv4 wildcard as well!!
} if (canUseIPv4 && canUseIPv6) {
ipv6Poller = getIpv6WildcardPoller(aeronDriver, config) ipv4Poller = object : AeronPoller {
} else { override fun poll(): Int { return 0 }
// only 1 will be a real poller override fun close() {}
ipv4Poller = getIpv4Poller(aeronDriver, config) override fun serverInfo(): String { return "IPv4 Disabled" }
ipv6Poller = getIpv6Poller(aeronDriver, config)
} }
ipv6Poller = getIpv6WildcardPoller(aeronDriver, config)
} else { } else {
// only 1 will be a real poller
ipv4Poller = getIpv4Poller(aeronDriver, config) ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config) ipv6Poller = getIpv6Poller(aeronDriver, config)
} }
} else {
ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config)
}
actionDispatch.launch {
waiter.doNotify() waiter.doNotify()
val pollIdleStrategy = config.pollIdleStrategy val pollIdleStrategy = config.pollIdleStrategy
try { try {
var pollCount: Int var pollCount: Int
@ -590,9 +581,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
jobs.add(job) jobs.add(job)
} }
// reset all of the handshake info
handshake.clear()
// when we close a client or a server, we want to make sure that ALL notifications are finished. // when we close a client or a server, we want to make sure that ALL notifications are finished.
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown // when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
jobs.forEach { it.join() } jobs.forEach { it.join() }
@ -604,6 +592,9 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
ipv6Poller.close() ipv6Poller.close()
ipcPoller.close() ipcPoller.close()
// clear all of the handshake info
handshake.clear()
// finish closing -- this lets us make sure that we don't run into race conditions on the thread that calls close() // finish closing -- this lets us make sure that we don't run into race conditions on the thread that calls close()
shutdownEventWaiter.doNotify() shutdownEventWaiter.doNotify()
} }

View File

@ -10,10 +10,9 @@ import io.aeron.Subscription
import io.aeron.driver.MediaDriver import io.aeron.driver.MediaDriver
import io.aeron.exceptions.DriverTimeoutException import io.aeron.exceptions.DriverTimeoutException
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
import mu.KotlinLogging import mu.KotlinLogging
import org.agrona.concurrent.AgentTerminationException
import org.agrona.concurrent.BackoffIdleStrategy import org.agrona.concurrent.BackoffIdleStrategy
import java.io.File import java.io.File
import java.lang.Thread.sleep import java.lang.Thread.sleep
@ -153,20 +152,9 @@ class AeronDriver(val config: Configuration,
// we DO NOT want to abort the JVM if there are errors. // we DO NOT want to abort the JVM if there are errors.
context.errorHandler { error -> context.errorHandler { error ->
if (error is DriverTimeoutException) { AeronDriver.manageError(logger, error)
// we suppress this because it is already handled
return@errorHandler
}
if (error.cause is BindException) {
// we suppress this because it is already handled
return@errorHandler
}
logger.error("Error in Aeron Media Driver", error)
} }
val aeronDir = File(context.aeronDirectoryName()).absoluteFile val aeronDir = File(context.aeronDirectoryName()).absoluteFile
context.aeronDirectoryName(aeronDir.path) context.aeronDirectoryName(aeronDir.path)
@ -238,9 +226,9 @@ class AeronDriver(val config: Configuration,
* *
* @return true if the media driver is active and running * @return true if the media driver is active and running
*/ */
fun isRunning(context: MediaDriver.Context, timeout: Long = context.driverTimeoutMs()): Boolean { fun isRunning(context: MediaDriver.Context): Boolean {
// if the media driver is running, it will be a quick connection. Usually 100ms or so // if the media driver is running, it will be a quick connection. Usually 100ms or so
return context.isDriverActive(timeout) { } return context.isDriverActive(context.driverTimeoutMs()) { }
} }
/** /**
@ -255,6 +243,26 @@ class AeronDriver(val config: Configuration,
require(config.context != null) { "Configuration context cannot be properly created. Unable to continue!" } require(config.context != null) { "Configuration context cannot be properly created. Unable to continue!" }
} }
fun manageError(logger: KLogger, error: Throwable) {
if (error is DriverTimeoutException) {
// we suppress this because it is already handled
return
}
if (error is AgentTerminationException) {
// we suppress this because it is already handled
return
}
if (error.cause is BindException) {
// we suppress this because it is already handled
return
}
ListenerManager.cleanStackTrace(error)
logger.error("Error in Aeron", error)
}
} }
private val closeRequested = atomic(false) private val closeRequested = atomic(false)
@ -272,6 +280,14 @@ class AeronDriver(val config: Configuration,
// did WE start the media driver, or did SOMEONE ELSE start it? // did WE start the media driver, or did SOMEONE ELSE start it?
private val mediaDriverWasAlreadyRunning: Boolean private val mediaDriverWasAlreadyRunning: Boolean
/**
* @return the configured driver timeout
*/
val driverTimeout: Long by lazy {
mediaDriverContext.driverTimeoutMs()
}
init { init {
mediaDriverContext mediaDriverContext
.conductorThreadFactory(threadFactory) .conductorThreadFactory(threadFactory)
@ -296,14 +312,12 @@ class AeronDriver(val config: Configuration,
// we DO NOT want to abort the JVM if there are errors. // we DO NOT want to abort the JVM if there are errors.
// this replaces the default handler with one that doesn't abort the JVM // this replaces the default handler with one that doesn't abort the JVM
aeronDriverContext.errorHandler { error -> aeronDriverContext.errorHandler { error ->
ListenerManager.cleanStackTrace(error) AeronDriver.manageError(logger, error)
logger.error("Error in Aeron", error)
} }
return aeronDriverContext return aeronDriverContext
} }
/** /**
* @return true if the media driver was started, false if it was not started * @return true if the media driver was started, false if it was not started
*/ */
@ -315,9 +329,18 @@ class AeronDriver(val config: Configuration,
if (mediaDriver == null) { if (mediaDriver == null) {
// only start if we didn't already start... There will be several checks. // only start if we didn't already start... There will be several checks.
if (!isRunning(mediaDriverContext)) { var running = isRunning(mediaDriverContext)
logger.debug("Starting Aeron Media driver in '${mediaDriverContext.aeronDirectory()}'") if (running) {
// wait for a bit, because we are running, but we ALSO issued a START, and expect it to start.
// SOMETIMES aeron is in the middle of shutting down, and this prevents us from trying to connect to
// that instance
logger.debug("Aeron Media driver already running. Double checking status...")
sleep(mediaDriverContext.driverTimeoutMs()/2)
running = isRunning(mediaDriverContext)
}
if (!running) {
logger.debug("Starting Aeron Media driver.")
// try to start. If we start/stop too quickly, it's a problem // try to start. If we start/stop too quickly, it's a problem
var count = 10 var count = 10
@ -327,13 +350,11 @@ class AeronDriver(val config: Configuration,
return true return true
} catch (e: Exception) { } catch (e: Exception) {
logger.warn(e) { "Unable to start the Aeron Media driver. Retrying $count more times..." } logger.warn(e) { "Unable to start the Aeron Media driver. Retrying $count more times..." }
runBlocking { sleep(mediaDriverContext.driverTimeoutMs())
delay(mediaDriverContext.driverTimeoutMs())
}
} }
} }
} else { } else {
logger.debug("Not starting Aeron Media driver. It was already running in '${mediaDriverContext.aeronDirectory()}'") logger.debug("Not starting Aeron Media driver. It was already running.")
} }
} }
@ -389,7 +410,7 @@ class AeronDriver(val config: Configuration,
} }
suspend fun addPublicationWithRetry(publicationUri: ChannelUriStringBuilder, streamId: Int): Publication { fun addPublicationWithRetry(publicationUri: ChannelUriStringBuilder, streamId: Int): Publication {
val uri = publicationUri.build() val uri = publicationUri.build()
// If we start/stop too quickly, we might have the address already in use! Retry a few times. // If we start/stop too quickly, we might have the address already in use! Retry a few times.
@ -404,13 +425,14 @@ class AeronDriver(val config: Configuration,
exception = e exception = e
logger.warn { "Unable to add a publication to Aeron. Retrying $count more times..." } logger.warn { "Unable to add a publication to Aeron. Retrying $count more times..." }
// if exceptions are added here, make sure to ALSO suppress them in the context error handler
if (e is DriverTimeoutException) { if (e is DriverTimeoutException) {
delay(mediaDriverContext.driverTimeoutMs()) sleep(mediaDriverContext.driverTimeoutMs())
} }
if (e.cause is BindException) { if (e.cause is BindException) {
// was starting too fast! // was starting too fast!
delay(mediaDriverContext.driverTimeoutMs()) sleep(mediaDriverContext.driverTimeoutMs())
} }
// reasons we cannot add a pub/sub to aeron // reasons we cannot add a pub/sub to aeron
@ -434,7 +456,7 @@ class AeronDriver(val config: Configuration,
throw exception!! throw exception!!
} }
suspend fun addSubscriptionWithRetry(subscriptionUri: ChannelUriStringBuilder, streamId: Int): Subscription { fun addSubscriptionWithRetry(subscriptionUri: ChannelUriStringBuilder, streamId: Int): Subscription {
val uri = subscriptionUri.build() val uri = subscriptionUri.build()
// If we start/stop too quickly, we might have the address already in use! Retry a few times. // If we start/stop too quickly, we might have the address already in use! Retry a few times.
@ -451,15 +473,16 @@ class AeronDriver(val config: Configuration,
} catch (e: Exception) { } catch (e: Exception) {
// NOTE: this error will be logged in the `aeronDriverContext` logger // NOTE: this error will be logged in the `aeronDriverContext` logger
exception = e exception = e
logger.warn { "Unable to add a sublication to Aeron. Retrying $count more times..." } logger.warn { "Unable to add a subscription to Aeron. Retrying $count more times..." }
// if exceptions are added here, make sure to ALSO suppress them in the context error handler
if (e is DriverTimeoutException) { if (e is DriverTimeoutException) {
delay(mediaDriverContext.driverTimeoutMs()) sleep(mediaDriverContext.driverTimeoutMs())
} }
if (e.cause is BindException) { if (e.cause is BindException) {
// was starting too fast! // was starting too fast!
delay(mediaDriverContext.driverTimeoutMs()) sleep(mediaDriverContext.driverTimeoutMs())
} }
// reasons we cannot add a pub/sub to aeron // reasons we cannot add a pub/sub to aeron
@ -545,6 +568,8 @@ class AeronDriver(val config: Configuration,
logger.warn { "Aeron Media driver at '${mediaDriverContext.aeronDirectory()}' is still running. Waiting for it to stop. Trying $count more times." } logger.warn { "Aeron Media driver at '${mediaDriverContext.aeronDirectory()}' is still running. Waiting for it to stop. Trying $count more times." }
sleep(mediaDriverContext.driverTimeoutMs()) sleep(mediaDriverContext.driverTimeoutMs())
} }
logger.debug { "Closed the media driver at '${mediaDriverContext.aeronDirectory()}'" }
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error closing the media driver at '${mediaDriverContext.aeronDirectory()}'", e) logger.error("Error closing the media driver at '${mediaDriverContext.aeronDirectory()}'", e)
} }
@ -552,7 +577,5 @@ class AeronDriver(val config: Configuration,
// Destroys this thread group and all of its subgroups. // Destroys this thread group and all of its subgroups.
// This thread group must be empty, indicating that all threads that had been in this thread group have since stopped. // This thread group must be empty, indicating that all threads that had been in this thread group have since stopped.
threadFactory.group.destroy() threadFactory.group.destroy()
logger.debug { "Closed the media driver at '${mediaDriverContext.aeronDirectory()}'" }
} }
} }

View File

@ -18,8 +18,8 @@ package dorkbox.network.aeron
import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.ChannelUriStringBuilder import io.aeron.ChannelUriStringBuilder
import kotlinx.coroutines.delay
import mu.KLogger import mu.KLogger
import java.lang.Thread.sleep
/** /**
* For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER * For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER
@ -47,7 +47,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
* *
* @throws ClientTimedOutException if we cannot connect to the server in the designated time * @throws ClientTimedOutException if we cannot connect to the server in the designated time
*/ */
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
// Create a publication at the given address and port, using the given stream ID. // Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri() val publicationUri = uri()
@ -64,23 +64,25 @@ internal open class IpcMediaDriverConnection(streamId: Int,
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client. // publication of any state to other threads and not be long running or re-entrant with the client.
var startTime = System.currentTimeMillis()
var success = false
// If we start/stop too quickly, we might have the aeron connectivity issues! Retry a few times. // If we start/stop too quickly, we might have the aeron connectivity issues! Retry a few times.
val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId) val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamIdSubscription) val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamIdSubscription)
var success = false
// this will wait for the server to acknowledge the connection (all via aeron) // this will wait for the server to acknowledge the connection (all via aeron)
var startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
if (subscription.isConnected && subscription.imageCount() > 0) { if (subscription.isConnected && subscription.imageCount() > 0) {
success = true success = true
break break
} }
delay(timeMillis = 100L) sleep(100L)
} }
if (!success) { if (!success) {
subscription.close() subscription.close()
throw ClientTimedOutException("Creating subscription connection to aeron") throw ClientTimedOutException("Creating subscription connection to aeron")
@ -97,7 +99,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
break break
} }
delay(timeMillis = 100L) sleep(100L)
} }
if (!success) { if (!success) {
@ -116,7 +118,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
* *
* serverAddress is ignored for IPC * serverAddress is ignored for IPC
*/ */
override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri() val publicationUri = uri()

View File

@ -32,8 +32,8 @@ abstract class MediaDriverConnection(
@Throws(ClientTimedOutException::class) @Throws(ClientTimedOutException::class)
abstract suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) abstract fun buildClient(aeronDriver: AeronDriver, logger: KLogger)
abstract suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false) abstract fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false)
abstract fun clientInfo() : String abstract fun clientInfo() : String
abstract fun serverInfo() : String abstract fun serverInfo() : String

View File

@ -21,8 +21,8 @@ import dorkbox.network.connection.ListenerManager
import dorkbox.network.exceptions.ClientException import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.ChannelUriStringBuilder import io.aeron.ChannelUriStringBuilder
import kotlinx.coroutines.delay
import mu.KLogger import mu.KLogger
import java.lang.Thread.sleep
import java.net.Inet4Address import java.net.Inet4Address
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -38,7 +38,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
sessionId: Int, sessionId: Int,
connectionTimeoutMS: Long = 0, connectionTimeoutMS: Long = 0,
isReliable: Boolean = true) : isReliable: Boolean = true) :
MediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
var success: Boolean = false var success: Boolean = false
@ -80,7 +80,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
@Throws(ClientException::class) @Throws(ClientException::class)
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
val aeronAddressString = aeronConnectionString(address) val aeronAddressString = aeronConnectionString(address)
// Create a publication at the given address and port, using the given stream ID. // Create a publication at the given address and port, using the given stream ID.
@ -100,6 +100,8 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
logger.trace("client sub URI: $ip ${subscriptionUri.build()}") logger.trace("client sub URI: $ip ${subscriptionUri.build()}")
} }
var success = false
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client. // publication of any state to other threads and not be long running or re-entrant with the client.
// on close, the publication CAN linger (in case a client goes away, and then comes back) // on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -107,7 +109,6 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId) val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamId) val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamId)
var success = false
// this will wait for the server to acknowledge the connection (all via aeron) // this will wait for the server to acknowledge the connection (all via aeron)
val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS) val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS)
@ -118,12 +119,12 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
break break
} }
delay(timeMillis = 100L) sleep(100L)
} }
if (!success) { if (!success) {
subscription.close() subscription.close()
val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()}") val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()} in ${timoutInNanos}ms")
ListenerManager.cleanStackTrace(ex) ListenerManager.cleanStackTrace(ex)
throw ex throw ex
} }
@ -139,19 +140,18 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
break break
} }
delay(timeMillis = 100L) sleep(100L)
} }
if (!success) { if (!success) {
subscription.close() subscription.close()
publication.close() publication.close()
val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()}") val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()} in ${timoutInNanos}ms")
// ListenerManager.cleanStackTrace(ex) ListenerManager.cleanStackTrace(ex)
throw ex throw ex
} }
this.success = true this.success = true
this.publication = publication this.publication = publication
this.subscription = subscription this.subscription = subscription
} }
@ -164,7 +164,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
} }
} }
override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
throw ClientException("Server info not implemented in Client MDC") throw ClientException("Server info not implemented in Client MDC")
} }
override fun serverInfo(): String { override fun serverInfo(): String {

View File

@ -0,0 +1,25 @@
/*
* Copyright 2021 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron
abstract class UdpMediaDriverConnection(publicationPort: Int, subscriptionPort: Int,
streamId: Int, sessionId: Int,
connectionTimeoutMS: Long, isReliable: Boolean) :
MediaDriverConnection(publicationPort, subscriptionPort,
streamId, sessionId,
connectionTimeoutMS, isReliable) {
}

View File

@ -36,11 +36,11 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
sessionId: Int, sessionId: Int,
connectionTimeoutMS: Long = 0, connectionTimeoutMS: Long = 0,
isReliable: Boolean = true) : isReliable: Boolean = true) :
MediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) { UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
var success: Boolean = false var success: Boolean = false
protected fun aeronConnectionString(ipAddress: InetAddress): String { private fun aeronConnectionString(ipAddress: InetAddress): String {
return if (ipAddress is Inet4Address) { return if (ipAddress is Inet4Address) {
ipAddress.hostAddress ipAddress.hostAddress
} else { } else {
@ -64,11 +64,11 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
} }
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) { override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
throw ServerException("Client info not implemented in Server MDC") throw ServerException("Client info not implemented in Server MDC")
} }
override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) { override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
val connectionString = aeronConnectionString(listenAddress) val connectionString = aeronConnectionString(listenAddress)
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.

View File

@ -17,8 +17,8 @@ package dorkbox.network.connection
import dorkbox.network.aeron.IpcMediaDriverConnection import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverClientConnection import dorkbox.network.aeron.UdpMediaDriverClientConnection
import dorkbox.network.aeron.UdpMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverPairedConnection import dorkbox.network.aeron.UdpMediaDriverPairedConnection
import dorkbox.network.aeron.UdpMediaDriverServerConnection
import dorkbox.network.handshake.ConnectionCounts import dorkbox.network.handshake.ConnectionCounts
import dorkbox.network.handshake.RandomIdAllocator import dorkbox.network.handshake.RandomIdAllocator
import dorkbox.network.ping.Ping import dorkbox.network.ping.Ping
@ -33,9 +33,7 @@ import io.aeron.Subscription
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.getAndUpdate
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.io.IOException import java.io.IOException
@ -84,7 +82,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/** /**
* @return true if this connection is a network connection * @return true if this connection is a network connection
*/ */
val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverServerConnection val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverConnection
/** /**
* the endpoint associated with this connection * the endpoint associated with this connection

View File

@ -87,7 +87,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
private val handshakeKryo: KryoExtra private val handshakeKryo: KryoExtra
private val sendIdleStrategy: CoroutineIdleStrategy private val sendIdleStrategy: CoroutineIdleStrategy
private val sendIdleStrategyHandshake: IdleStrategy private val sendIdleStrategyHandShake: IdleStrategy
val pollIdleStrategy: CoroutineIdleStrategy
val pollIdleStrategyHandShake: IdleStrategy
/** /**
* Crypto and signature management * Crypto and signature management
@ -126,7 +129,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// serialization stuff // serialization stuff
serialization = config.serialization serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy sendIdleStrategy = config.sendIdleStrategy
sendIdleStrategyHandshake = sendIdleStrategy.cloneToNormal() pollIdleStrategy = config.pollIdleStrategy
sendIdleStrategyHandShake = sendIdleStrategy.cloneToNormal()
pollIdleStrategyHandShake = pollIdleStrategy.cloneToNormal()
handshakeKryo = serialization.initHandshakeKryo() handshakeKryo = serialization.initHandshakeKryo()
@ -347,7 +353,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*/ */
if (result >= Publication.ADMIN_ACTION) { if (result >= Publication.ADMIN_ACTION) {
// we should retry. // we should retry.
sendIdleStrategyHandshake.idle() sendIdleStrategyHandShake.idle()
continue continue
} }
@ -362,7 +368,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
ListenerManager.cleanStackTrace(exception, 2) // 2 because we do not want to see the stack for the abstract `newException` ListenerManager.cleanStackTrace(exception, 2) // 2 because we do not want to see the stack for the abstract `newException`
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
} finally { } finally {
sendIdleStrategyHandshake.reset() sendIdleStrategyHandShake.reset()
} }
} }
@ -431,7 +437,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
when (message) { when (message) {
is PingMessage -> { is PingMessage -> {
// the ping listener (internal use only!) // the ping listener
actionDispatch.launch { actionDispatch.launch {
pingManager.manage(this@EndPoint, connection, message, logger) pingManager.manage(this@EndPoint, connection, message, logger)
} }
@ -636,22 +642,22 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
final override fun close() { final override fun close() {
if (shutdown.compareAndSet(expect = false, update = true)) { if (shutdown.compareAndSet(expect = false, update = true)) {
logger.info { "Shutting down..." } logger.info { "Shutting down..." }
aeronDriver.close()
runBlocking { runBlocking {
aeronDriver.close()
connections.forEach { connections.forEach {
it.close() it.close()
} }
// the storage is closed via this as well.
storage.close()
// Connections are closed first, because we want to make sure that no RMI messages can be received // Connections are closed first, because we want to make sure that no RMI messages can be received
// when we close the RMI support objects (in which case, weird - but harmless - errors show up) // when we close the RMI support objects (in which case, weird - but harmless - errors show up)
// this will wait for RMI timeouts if there are RMI in-progress. (this happens if we close via and RMI method)
rmiGlobalSupport.close() rmiGlobalSupport.close()
} }
// the storage is closed via this as well.
storage.close()
close0() close0()
// if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now) // if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now)

View File

@ -143,7 +143,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
} }
// called from the connect thread // called from the connect thread
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
failed = null failed = null
oneTimeKey = endPoint.crypto.secureRandom.nextInt() oneTimeKey = endPoint.crypto.secureRandom.nextInt()
val publicKey = endPoint.storage.getPublicKey()!! val publicKey = endPoint.storage.getPublicKey()!!
@ -151,8 +151,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
// Send the one-time pad to the server. // Send the one-time pad to the server.
val publication = handshakeConnection.publication val publication = handshakeConnection.publication
val subscription = handshakeConnection.subscription val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy val pollIdleStrategy = endPoint.pollIdleStrategyHandShake
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey)) endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
@ -191,7 +190,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
} }
// called from the connect thread // called from the connect thread
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey) val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
// Send the done message to the server. // Send the done message to the server.
@ -203,7 +202,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
failed = null failed = null
var pollCount: Int var pollCount: Int
val subscription = handshakeConnection.subscription val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy val pollIdleStrategy = endPoint.pollIdleStrategyHandShake
var startTime = System.currentTimeMillis() var startTime = System.currentTimeMillis()
while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) {

View File

@ -24,23 +24,15 @@ import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.AeronDriver import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.IpcMediaDriverConnection import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverPairedConnection import dorkbox.network.aeron.UdpMediaDriverPairedConnection
import dorkbox.network.connection.Connection import dorkbox.network.connection.*
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.exceptions.* import dorkbox.network.exceptions.*
import io.aeron.Publication import io.aeron.Publication
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
import java.net.Inet4Address import java.net.Inet4Address
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
/** /**
@ -51,13 +43,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val config: ServerConfiguration, private val config: ServerConfiguration,
private val listenerManager: ListenerManager<CONNECTION>) { private val listenerManager: ListenerManager<CONNECTION>) {
private val pendingConnectionsLock = ReentrantReadWriteLock() // note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close
private val pendingConnections: Cache<Int, CONNECTION> = Caffeine.newBuilder() private val pendingConnections: Cache<Int, CONNECTION> = Caffeine.newBuilder()
.expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS) .expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong() * 2, TimeUnit.SECONDS)
.removalListener(RemovalListener<Any?, Any?> { _, value, cause -> .removalListener(RemovalListener<Int, CONNECTION> { sessionId, connection, cause ->
if (cause == RemovalCause.EXPIRED) { if (cause == RemovalCause.EXPIRED) {
@Suppress("UNCHECKED_CAST") connection!!
val connection = value as CONNECTION
val exception = ClientTimedOutException("[${connection.id}] Waiting for registration response from client") val exception = ClientTimedOutException("[${connection.id}] Waiting for registration response from client")
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
@ -90,19 +81,15 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// check to see if this sessionId is ALREADY in use by another connection! // check to see if this sessionId is ALREADY in use by another connection!
// this can happen if there are multiple connections from the SAME ip address (ie: localhost) // this can happen if there are multiple connections from the SAME ip address (ie: localhost)
if (message.state == HandshakeMessage.HELLO) { if (message.state == HandshakeMessage.HELLO) {
val hasExistingSessionId = pendingConnectionsLock.read { // this should be null.
pendingConnections.getIfPresent(sessionId) != null val hasExistingSessionId = pendingConnections.getIfPresent(sessionId) != null
}
if (hasExistingSessionId) { if (hasExistingSessionId) {
// WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId // WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId
val exception = ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.") val exception = ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.")
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
actionDispatch.launch { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
}
return false return false
} }
@ -111,14 +98,11 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// check to see if this is a pending connection // check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) { if (message.state == HandshakeMessage.DONE) {
val pendingConnection = pendingConnectionsLock.write { val pendingConnection = pendingConnections.getIfPresent(sessionId)
val con = pendingConnections.getIfPresent(sessionId) pendingConnections.invalidate(sessionId)
pendingConnections.invalidate(sessionId)
con
}
if (pendingConnection == null) { if (pendingConnection == null) {
val exception = ClientException("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!") val exception = ServerException("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!")
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
} else { } else {
@ -127,7 +111,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// this enables the connection to start polling for messages // this enables the connection to start polling for messages
server.addConnection(pendingConnection) server.addConnection(pendingConnection)
// now tell the client we are done // now tell the client we are done
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
@ -165,9 +148,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
}
return false return false
} }
@ -182,9 +163,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
}
return false return false
} }
connectionsPerIpCounts.increment(clientAddress, currentCountForIp) connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
@ -193,9 +172,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
}
return false return false
} }
@ -234,9 +211,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return return
} }
@ -252,9 +227,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return return
} }
@ -270,9 +243,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return return
} }
@ -284,9 +255,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionId = connectionSessionId) sessionId = connectionSessionId)
// we have to construct how the connection will communicate! // we have to construct how the connection will communicate!
runBlocking { clientConnection.buildServer(aeronDriver, logger, true)
clientConnection.buildServer(aeronDriver, logger, true)
}
logger.info { logger.info {
"[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]" "[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]"
@ -326,14 +295,10 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
successMessage.publicKey = server.crypto.publicKeyBytes successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data // before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write { pendingConnections.put(sessionId, connection)
pendingConnections.put(sessionId, connection)
}
// this tells the client all of the info to connect. // this tells the client all of the info to connect.
runBlocking { server.writeHandshakeMessage(handshakePublication, successMessage)
server.writeHandshakeMessage(handshakePublication, successMessage)
}
} catch (e: Exception) { } catch (e: Exception) {
// have to unwind actions! // have to unwind actions!
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
@ -398,9 +363,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return return
} }
@ -417,9 +380,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception) ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
runBlocking { server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return return
} }
@ -450,9 +411,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
message.isReliable) message.isReliable)
// we have to construct how the connection will communicate! // we have to construct how the connection will communicate!
runBlocking { clientConnection.buildServer(aeronDriver, logger, true)
clientConnection.buildServer(aeronDriver, logger, true)
}
logger.info { logger.info {
// (reliable:$isReliable)" // (reliable:$isReliable)"
@ -462,21 +421,19 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val connection = server.newConnection(ConnectionParams(server, clientConnection, validateRemoteAddress)) val connection = server.newConnection(ConnectionParams(server, clientConnection, validateRemoteAddress))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
runBlocking { val permitConnection = listenerManager.notifyFilter(connection)
val permitConnection = listenerManager.notifyFilter(connection) if (!permitConnection) {
if (!permitConnection) { // have to unwind actions!
// have to unwind actions! connectionsPerIpCounts.decrementSlow(clientAddress)
connectionsPerIpCounts.decrementSlow(clientAddress) sessionIdAllocator.free(connectionSessionId)
sessionIdAllocator.free(connectionSessionId) streamIdAllocator.free(connectionStreamId)
streamIdAllocator.free(connectionStreamId)
val exception = ClientRejectedException("Connection $clientAddressString was not permitted!") val exception = ClientRejectedException("Connection $clientAddressString was not permitted!")
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception) listenerManager.notifyError(connection, exception)
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
return@runBlocking return
}
} }
@ -503,14 +460,10 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
successMessage.publicKey = server.crypto.publicKeyBytes successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data // before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write { pendingConnections.put(sessionId, connection)
pendingConnections.put(sessionId, connection)
}
// this tells the client all of the info to connect. // this tells the client all of the info to connect.
runBlocking { server.writeHandshakeMessage(handshakePublication, successMessage)
server.writeHandshakeMessage(handshakePublication, successMessage)
}
} catch (e: Exception) { } catch (e: Exception) {
// have to unwind actions! // have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress) connectionsPerIpCounts.decrementSlow(clientAddress)
@ -538,6 +491,8 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
sessionIdAllocator.clear() sessionIdAllocator.clear()
streamIdAllocator.clear() streamIdAllocator.clear()
pendingConnections.invalidateAll() pendingConnections.invalidateAll()
pendingConnections.cleanUp()
} }
} }