No longer using coroutines for adding publication/subscription and closing certain calsses. Better aeron error handling/reporting. Better aeron startup/shutdown. The pending connections cache no longer is ThreadSafe, and no longer is protected via RW lock.

This commit is contained in:
Robinson 2021-04-30 16:01:25 +02:00
parent 3f016672e6
commit cffca943f5
12 changed files with 297 additions and 281 deletions

View File

@ -25,14 +25,15 @@ import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.ping.Ping
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.TimeoutException
import dorkbox.util.Sys
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
@ -78,9 +79,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper
// lock-stop ordering for how disconnect and connect work with each-other
private val lockStepForReconnect = atomic<SuspendWaiter?>(null)
// GUARANTEE that the callbacks for 'onDisconnect' happens-before the 'onConnect'.
private val lockStepForDispatch = atomic<SuspendWaiter?>(null)
private val lockStepForConnect = atomic<SuspendWaiter?>(null)
final override fun newException(message: String, cause: Throwable?): Throwable {
return ClientException(message, cause)
@ -104,39 +104,54 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* - an InetAddress address
*
* ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY.
* - `connect()`
* - `connect("")`
* - `connectIpc()`
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
* ### Case does not matter, and "localhost" is the default.
*
* @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param remoteAddress The network host or ip address
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable
* @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?).
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("BlockingMethodInNonBlockingContext")
suspend fun connect(remoteAddress: String,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
fun connect(remoteAddress: String = "",
connectionTimeoutMS: Long = 30_000L,
reliable: Boolean = true) {
when {
// this is default IPC settings
remoteAddress.isEmpty() -> connect(connectionTimeoutMS = connectionTimeoutMS)
remoteAddress.isEmpty() -> {
connectIpc(connectionTimeoutMS = connectionTimeoutMS)
}
IPv4.isPreferred -> connect(remoteAddress = Inet4.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable)
IPv4.isPreferred -> {
connect(
remoteAddress = Inet4.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable
)
}
IPv6.isPreferred -> connect(remoteAddress = Inet6.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable)
IPv6.isPreferred -> {
connect(
remoteAddress = Inet6.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable
)
}
// if there is no preference, then try to connect via IPv4
else -> connect(remoteAddress = Inet4.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable)
else -> {
connect(
remoteAddress = Inet4.toAddress(remoteAddress),
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable
)
}
}
}
@ -151,22 +166,24 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* - an InetAddress address
*
* ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY.
* - `connect()`
* - `connect("")`
* - `connectIpc()`
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
* ### Case does not matter, and "localhost" is the default.
*
* @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable
* @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?).
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
suspend fun connect(remoteAddress: InetAddress,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
fun connect(remoteAddress: InetAddress,
connectionTimeoutMS: Long = 30_000L,
reliable: Boolean = true) {
// Default IPC ports are flipped because they are in the perspective of the SERVER
connect(remoteAddress = remoteAddress,
ipcPublicationId = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
@ -187,9 +204,9 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("DuplicatedCode")
suspend fun connect(ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L) {
fun connectIpc(ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L) {
// Default IPC ports are flipped because they are in the perspective of the SERVER
require(ipcPublicationId != ipcSubscriptionId) { "IPC publication and subscription ports cannot be the same! The must match the server's configuration." }
@ -208,42 +225,41 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* - an InetAddress address
*
* ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY. ie: just call `connect()`
* - Specified EMPTY. ie: just call `connect()`
* - `connect()`
* - `connect("")`
* - `connectIpc()`
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
* ### Case does not matter, and "localhost" is the default.
*
* ### Case does not matter, and "localhost" is the default.
*
* @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param ipcPublicationId The IPC publication address for the client to connect to
* @param ipcSubscriptionId The IPC subscription address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely.
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable
* @param reliable true if we want to create a reliable connection (for UDP connections, is message loss acceptable?).
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("DuplicatedCode")
private suspend fun connect(remoteAddress: InetAddress? = null,
// Default IPC ports are flipped because they are in the perspective of the SERVER
ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
private fun connect(remoteAddress: InetAddress? = null,
// Default IPC ports are flipped because they are in the perspective of the SERVER
ipcPublicationId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = AeronDriver.IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L,
reliable: Boolean = true) {
require(connectionTimeoutMS >= 0) { "connectionTimeoutMS '$connectionTimeoutMS' is invalid. It must be >=0" }
// this will exist ONLY if we are reconnecting via a "disconnect" callback
lockStepForReconnect.value?.doWait()
if (isConnected) {
logger.error("Unable to connect when already connected!")
return
}
lockStepForReconnect.lazySet(null)
// localhost/loopback IP might not always be 127.0.0.1 or ::1
this.remoteAddress0 = remoteAddress
connection0 = null
@ -268,17 +284,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// only change LOCALHOST -> IPC if the media driver is ALREADY running LOCALLY!
var isUsingIPC = false
val canUseIPC = config.enableIpc && remoteAddress == null
val autoChangeToIpc = canUseIPC && config.enableIpcForLoopback &&
remoteAddress != null && remoteAddress.isLoopbackAddress && aeronDriver.isRunning()
if (autoChangeToIpc) {
logger.info {"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" }
}
val canUseIPC = config.enableIpc
val autoChangeToIpc = config.enableIpcForLoopback &&
(remoteAddress == null || remoteAddress.isLoopbackAddress) && aeronDriver.isRunning()
val handshake = ClientHandshake(crypto, this)
val handshakeConnection = if (autoChangeToIpc || canUseIPC) {
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via UDP instead
val handshakeConnection = if (autoChangeToIpc) {
logger.info {"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" }
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via network instead
val ipcConnection = IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId,
streamId = ipcPublicationId,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID)
@ -288,8 +304,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
ipcConnection.buildClient(aeronDriver, logger)
isUsingIPC = true
} catch (e: Exception) {
// if we specified that we want to use IPC, then we have to throw the timeout exception, because there is no IPC
if (canUseIPC) {
// if we specified that we MUST use IPC, then we have to throw the exception, because there is no IPC
if (remoteAddress == null) {
throw e
}
}
@ -301,7 +317,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// try a UDP connection instead
val udpConnection = UdpMediaDriverClientConnection(
address = this.remoteAddress0!!,
address = remoteAddress!!,
publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
@ -316,7 +332,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
else {
val test = UdpMediaDriverClientConnection(
address = this.remoteAddress0!!,
address = remoteAddress!!,
publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
@ -343,7 +359,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val validateRemoteAddress = if (isUsingIPC) {
PublicKeyValidationState.VALID
} else {
crypto.validateRemoteAddress(this.remoteAddress0!!, connectionInfo.publicKey)
crypto.validateRemoteAddress(remoteAddress!!, connectionInfo.publicKey)
}
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -438,16 +454,14 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
//////////////
newConnection.preCloseAction = {
// this is called whenever connection.close() is called by the framework or via client.close()
if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(connection, IllegalStateException("lockStep for reconnect was in the wrong state!"))
}
// on the client, we want to GUARANTEE that the disconnect happens-before the connect.
if (!lockStepForDispatch.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(connection, IllegalStateException("lockStep for dispatch was in the wrong state!"))
if (!lockStepForConnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(connection, IllegalStateException("lockStep for onConnect was in the wrong state!"))
}
}
newConnection.postCloseAction = {
isConnected = false
// this is called whenever connection.close() is called by the framework or via client.close()
// make sure to call our client.notifyDisconnect() callbacks
@ -462,10 +476,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
lockStepForDispatch.value?.cancel()
}
// in case notifyDisconnect called client.connect().... cancel them waiting
isConnected = false
lockStepForReconnect.value?.cancel()
}
connection0 = newConnection
@ -477,24 +487,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
if (canFinishConnecting) {
isConnected = true
// this forces the current thread to WAIT until poll system has started
val waiter = SuspendWaiter()
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
@Suppress("EXPERIMENTAL_API_USAGE")
actionDispatch.launch(start = CoroutineStart.UNDISPATCHED) {
lockStepForDispatch.value?.doWait()
// NOTE: UNDISPATCHED means that this coroutine will start as an event loop, instead of concurrently
// we want this behavior INSTEAD OF automatically starting this on a new thread.
listenerManager.notifyConnect(newConnection)
lockStepForDispatch.lazySet(null)
}
// these have to be in two SEPARATE actionDispatch.launch commands.... otherwise...
// if something inside of notifyConnect is blocking or suspends, then polling will never happen!
actionDispatch.launch {
waiter.doNotify()
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
@ -502,8 +505,12 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug {"[${newConnection.id}] connection expired"}
// NOTE: We do not shutdown the client!! The client is only closed by explicitly calling `client.close()`
newConnection.close()
// eventloop is required, because we want to run this code AFTER the current coroutine has finished. This prevents
// odd race conditions when a client is restarted
actionDispatch.eventLoop {
// NOTE: We do not shutdown the client!! The client is only closed by explicitly calling `client.close()`
newConnection.close()
}
return@launch
}
else {
@ -515,6 +522,16 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
}
}
actionDispatch.eventLoop {
waiter.doWait()
lockStepForConnect.value?.doWait()
listenerManager.notifyConnect(newConnection)
lockStepForConnect.lazySet(null)
}
} else {
close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")

View File

@ -36,7 +36,6 @@ import dorkbox.network.rmi.TimeoutException
import io.aeron.FragmentAssembler
import io.aeron.Image
import io.aeron.logbuffer.Header
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
@ -156,7 +155,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
return super.getRmiConnectionSupport()
}
private suspend fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
private fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIpc) {
val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId,
streamId = config.ipcPublicationId,
@ -180,9 +179,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from IPC not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
return@FragmentAssembler
}
@ -190,7 +187,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
publication,
sessionId,
message,
aeronDriver)
aeronDriver)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
@ -210,7 +207,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
}
@Suppress("DuplicatedCode")
private suspend fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
private fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (canUseIPv4) {
val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv4Address!!,
@ -260,9 +257,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
return@FragmentAssembler
}
@ -293,7 +288,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
}
@Suppress("DuplicatedCode")
private suspend fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
private fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (canUseIPv6) {
val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv6Address!!,
@ -343,9 +338,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
return@FragmentAssembler
}
@ -376,7 +369,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
}
@Suppress("DuplicatedCode")
private suspend fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
private fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv6Address!!,
publicationPort = config.publicationPort,
@ -426,9 +419,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"))
actionDispatch.launch {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
}
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
return@FragmentAssembler
}
@ -468,39 +459,39 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
bindAlreadyCalled = true
// this forces the current thread to WAIT until poll system has started
val waiter = SuspendWaiter()
actionDispatch.launch {
val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config)
// if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled!
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
val ipv4Poller: AeronPoller
val ipv6Poller: AeronPoller
val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config)
if (isWildcard) {
// IPv6 will bind to IPv4 wildcard as well!!
if (canUseIPv4 && canUseIPv6) {
ipv4Poller = object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv4 Disabled" }
}
ipv6Poller = getIpv6WildcardPoller(aeronDriver, config)
} else {
// only 1 will be a real poller
ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config)
// if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled!
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
val ipv4Poller: AeronPoller
val ipv6Poller: AeronPoller
if (isWildcard) {
// IPv6 will bind to IPv4 wildcard as well!!
if (canUseIPv4 && canUseIPv6) {
ipv4Poller = object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv4 Disabled" }
}
ipv6Poller = getIpv6WildcardPoller(aeronDriver, config)
} else {
// only 1 will be a real poller
ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config)
}
} else {
ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config)
}
actionDispatch.launch {
waiter.doNotify()
val pollIdleStrategy = config.pollIdleStrategy
try {
var pollCount: Int
@ -590,9 +581,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
jobs.add(job)
}
// reset all of the handshake info
handshake.clear()
// when we close a client or a server, we want to make sure that ALL notifications are finished.
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
jobs.forEach { it.join() }
@ -604,6 +592,9 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
ipv6Poller.close()
ipcPoller.close()
// clear all of the handshake info
handshake.clear()
// finish closing -- this lets us make sure that we don't run into race conditions on the thread that calls close()
shutdownEventWaiter.doNotify()
}

View File

@ -10,10 +10,9 @@ import io.aeron.Subscription
import io.aeron.driver.MediaDriver
import io.aeron.exceptions.DriverTimeoutException
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import mu.KLogger
import mu.KotlinLogging
import org.agrona.concurrent.AgentTerminationException
import org.agrona.concurrent.BackoffIdleStrategy
import java.io.File
import java.lang.Thread.sleep
@ -153,20 +152,9 @@ class AeronDriver(val config: Configuration,
// we DO NOT want to abort the JVM if there are errors.
context.errorHandler { error ->
if (error is DriverTimeoutException) {
// we suppress this because it is already handled
return@errorHandler
}
if (error.cause is BindException) {
// we suppress this because it is already handled
return@errorHandler
}
logger.error("Error in Aeron Media Driver", error)
AeronDriver.manageError(logger, error)
}
val aeronDir = File(context.aeronDirectoryName()).absoluteFile
context.aeronDirectoryName(aeronDir.path)
@ -238,9 +226,9 @@ class AeronDriver(val config: Configuration,
*
* @return true if the media driver is active and running
*/
fun isRunning(context: MediaDriver.Context, timeout: Long = context.driverTimeoutMs()): Boolean {
fun isRunning(context: MediaDriver.Context): Boolean {
// if the media driver is running, it will be a quick connection. Usually 100ms or so
return context.isDriverActive(timeout) { }
return context.isDriverActive(context.driverTimeoutMs()) { }
}
/**
@ -255,6 +243,26 @@ class AeronDriver(val config: Configuration,
require(config.context != null) { "Configuration context cannot be properly created. Unable to continue!" }
}
fun manageError(logger: KLogger, error: Throwable) {
if (error is DriverTimeoutException) {
// we suppress this because it is already handled
return
}
if (error is AgentTerminationException) {
// we suppress this because it is already handled
return
}
if (error.cause is BindException) {
// we suppress this because it is already handled
return
}
ListenerManager.cleanStackTrace(error)
logger.error("Error in Aeron", error)
}
}
private val closeRequested = atomic(false)
@ -272,6 +280,14 @@ class AeronDriver(val config: Configuration,
// did WE start the media driver, or did SOMEONE ELSE start it?
private val mediaDriverWasAlreadyRunning: Boolean
/**
* @return the configured driver timeout
*/
val driverTimeout: Long by lazy {
mediaDriverContext.driverTimeoutMs()
}
init {
mediaDriverContext
.conductorThreadFactory(threadFactory)
@ -296,14 +312,12 @@ class AeronDriver(val config: Configuration,
// we DO NOT want to abort the JVM if there are errors.
// this replaces the default handler with one that doesn't abort the JVM
aeronDriverContext.errorHandler { error ->
ListenerManager.cleanStackTrace(error)
logger.error("Error in Aeron", error)
AeronDriver.manageError(logger, error)
}
return aeronDriverContext
}
/**
* @return true if the media driver was started, false if it was not started
*/
@ -315,9 +329,18 @@ class AeronDriver(val config: Configuration,
if (mediaDriver == null) {
// only start if we didn't already start... There will be several checks.
if (!isRunning(mediaDriverContext)) {
logger.debug("Starting Aeron Media driver in '${mediaDriverContext.aeronDirectory()}'")
var running = isRunning(mediaDriverContext)
if (running) {
// wait for a bit, because we are running, but we ALSO issued a START, and expect it to start.
// SOMETIMES aeron is in the middle of shutting down, and this prevents us from trying to connect to
// that instance
logger.debug("Aeron Media driver already running. Double checking status...")
sleep(mediaDriverContext.driverTimeoutMs()/2)
running = isRunning(mediaDriverContext)
}
if (!running) {
logger.debug("Starting Aeron Media driver.")
// try to start. If we start/stop too quickly, it's a problem
var count = 10
@ -327,13 +350,11 @@ class AeronDriver(val config: Configuration,
return true
} catch (e: Exception) {
logger.warn(e) { "Unable to start the Aeron Media driver. Retrying $count more times..." }
runBlocking {
delay(mediaDriverContext.driverTimeoutMs())
}
sleep(mediaDriverContext.driverTimeoutMs())
}
}
} else {
logger.debug("Not starting Aeron Media driver. It was already running in '${mediaDriverContext.aeronDirectory()}'")
logger.debug("Not starting Aeron Media driver. It was already running.")
}
}
@ -389,7 +410,7 @@ class AeronDriver(val config: Configuration,
}
suspend fun addPublicationWithRetry(publicationUri: ChannelUriStringBuilder, streamId: Int): Publication {
fun addPublicationWithRetry(publicationUri: ChannelUriStringBuilder, streamId: Int): Publication {
val uri = publicationUri.build()
// If we start/stop too quickly, we might have the address already in use! Retry a few times.
@ -404,13 +425,14 @@ class AeronDriver(val config: Configuration,
exception = e
logger.warn { "Unable to add a publication to Aeron. Retrying $count more times..." }
// if exceptions are added here, make sure to ALSO suppress them in the context error handler
if (e is DriverTimeoutException) {
delay(mediaDriverContext.driverTimeoutMs())
sleep(mediaDriverContext.driverTimeoutMs())
}
if (e.cause is BindException) {
// was starting too fast!
delay(mediaDriverContext.driverTimeoutMs())
sleep(mediaDriverContext.driverTimeoutMs())
}
// reasons we cannot add a pub/sub to aeron
@ -434,7 +456,7 @@ class AeronDriver(val config: Configuration,
throw exception!!
}
suspend fun addSubscriptionWithRetry(subscriptionUri: ChannelUriStringBuilder, streamId: Int): Subscription {
fun addSubscriptionWithRetry(subscriptionUri: ChannelUriStringBuilder, streamId: Int): Subscription {
val uri = subscriptionUri.build()
// If we start/stop too quickly, we might have the address already in use! Retry a few times.
@ -451,15 +473,16 @@ class AeronDriver(val config: Configuration,
} catch (e: Exception) {
// NOTE: this error will be logged in the `aeronDriverContext` logger
exception = e
logger.warn { "Unable to add a sublication to Aeron. Retrying $count more times..." }
logger.warn { "Unable to add a subscription to Aeron. Retrying $count more times..." }
// if exceptions are added here, make sure to ALSO suppress them in the context error handler
if (e is DriverTimeoutException) {
delay(mediaDriverContext.driverTimeoutMs())
sleep(mediaDriverContext.driverTimeoutMs())
}
if (e.cause is BindException) {
// was starting too fast!
delay(mediaDriverContext.driverTimeoutMs())
sleep(mediaDriverContext.driverTimeoutMs())
}
// reasons we cannot add a pub/sub to aeron
@ -545,6 +568,8 @@ class AeronDriver(val config: Configuration,
logger.warn { "Aeron Media driver at '${mediaDriverContext.aeronDirectory()}' is still running. Waiting for it to stop. Trying $count more times." }
sleep(mediaDriverContext.driverTimeoutMs())
}
logger.debug { "Closed the media driver at '${mediaDriverContext.aeronDirectory()}'" }
} catch (e: Exception) {
logger.error("Error closing the media driver at '${mediaDriverContext.aeronDirectory()}'", e)
}
@ -552,7 +577,5 @@ class AeronDriver(val config: Configuration,
// Destroys this thread group and all of its subgroups.
// This thread group must be empty, indicating that all threads that had been in this thread group have since stopped.
threadFactory.group.destroy()
logger.debug { "Closed the media driver at '${mediaDriverContext.aeronDirectory()}'" }
}
}

View File

@ -18,8 +18,8 @@ package dorkbox.network.aeron
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.ChannelUriStringBuilder
import kotlinx.coroutines.delay
import mu.KLogger
import java.lang.Thread.sleep
/**
* For a client, the streamId specified here MUST be manually flipped because they are in the perspective of the SERVER
@ -47,7 +47,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
*
* @throws ClientTimedOutException if we cannot connect to the server in the designated time
*/
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri()
@ -64,23 +64,25 @@ internal open class IpcMediaDriverConnection(streamId: Int,
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
var startTime = System.currentTimeMillis()
var success = false
// If we start/stop too quickly, we might have the aeron connectivity issues! Retry a few times.
val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamIdSubscription)
var success = false
// this will wait for the server to acknowledge the connection (all via aeron)
var startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
if (subscription.isConnected && subscription.imageCount() > 0) {
success = true
break
}
delay(timeMillis = 100L)
sleep(100L)
}
if (!success) {
subscription.close()
throw ClientTimedOutException("Creating subscription connection to aeron")
@ -97,7 +99,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
break
}
delay(timeMillis = 100L)
sleep(100L)
}
if (!success) {
@ -116,7 +118,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
*
* serverAddress is ignored for IPC
*/
override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri()

View File

@ -32,8 +32,8 @@ abstract class MediaDriverConnection(
@Throws(ClientTimedOutException::class)
abstract suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger)
abstract suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false)
abstract fun buildClient(aeronDriver: AeronDriver, logger: KLogger)
abstract fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false)
abstract fun clientInfo() : String
abstract fun serverInfo() : String

View File

@ -21,8 +21,8 @@ import dorkbox.network.connection.ListenerManager
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.ChannelUriStringBuilder
import kotlinx.coroutines.delay
import mu.KLogger
import java.lang.Thread.sleep
import java.net.Inet4Address
import java.net.InetAddress
import java.util.concurrent.TimeUnit
@ -38,7 +38,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
sessionId: Int,
connectionTimeoutMS: Long = 0,
isReliable: Boolean = true) :
MediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
var success: Boolean = false
@ -80,7 +80,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
@Suppress("DuplicatedCode")
@Throws(ClientException::class)
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
val aeronAddressString = aeronConnectionString(address)
// Create a publication at the given address and port, using the given stream ID.
@ -100,6 +100,8 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
logger.trace("client sub URI: $ip ${subscriptionUri.build()}")
}
var success = false
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -107,7 +109,6 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
val publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
val subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamId)
var success = false
// this will wait for the server to acknowledge the connection (all via aeron)
val timoutInNanos = TimeUnit.MILLISECONDS.toNanos(connectionTimeoutMS)
@ -118,12 +119,12 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
break
}
delay(timeMillis = 100L)
sleep(100L)
}
if (!success) {
subscription.close()
val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()}")
val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()} in ${timoutInNanos}ms")
ListenerManager.cleanStackTrace(ex)
throw ex
}
@ -139,19 +140,18 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
break
}
delay(timeMillis = 100L)
sleep(100L)
}
if (!success) {
subscription.close()
publication.close()
val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()}")
// ListenerManager.cleanStackTrace(ex)
val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()} in ${timoutInNanos}ms")
ListenerManager.cleanStackTrace(ex)
throw ex
}
this.success = true
this.publication = publication
this.subscription = subscription
}
@ -164,7 +164,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
}
}
override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
throw ClientException("Server info not implemented in Client MDC")
}
override fun serverInfo(): String {

View File

@ -0,0 +1,25 @@
/*
* Copyright 2021 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron
abstract class UdpMediaDriverConnection(publicationPort: Int, subscriptionPort: Int,
streamId: Int, sessionId: Int,
connectionTimeoutMS: Long, isReliable: Boolean) :
MediaDriverConnection(publicationPort, subscriptionPort,
streamId, sessionId,
connectionTimeoutMS, isReliable) {
}

View File

@ -36,11 +36,11 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
sessionId: Int,
connectionTimeoutMS: Long = 0,
isReliable: Boolean = true) :
MediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutMS, isReliable) {
var success: Boolean = false
protected fun aeronConnectionString(ipAddress: InetAddress): String {
private fun aeronConnectionString(ipAddress: InetAddress): String {
return if (ipAddress is Inet4Address) {
ipAddress.hostAddress
} else {
@ -64,11 +64,11 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
}
@Suppress("DuplicatedCode")
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
override fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
throw ServerException("Client info not implemented in Server MDC")
}
override suspend fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
val connectionString = aeronConnectionString(listenAddress)
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.

View File

@ -17,8 +17,8 @@ package dorkbox.network.connection
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverClientConnection
import dorkbox.network.aeron.UdpMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverPairedConnection
import dorkbox.network.aeron.UdpMediaDriverServerConnection
import dorkbox.network.handshake.ConnectionCounts
import dorkbox.network.handshake.RandomIdAllocator
import dorkbox.network.ping.Ping
@ -33,9 +33,7 @@ import io.aeron.Subscription
import io.aeron.logbuffer.Header
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer
import java.io.IOException
@ -84,7 +82,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/**
* @return true if this connection is a network connection
*/
val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverServerConnection
val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverConnection
/**
* the endpoint associated with this connection

View File

@ -87,7 +87,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
private val handshakeKryo: KryoExtra
private val sendIdleStrategy: CoroutineIdleStrategy
private val sendIdleStrategyHandshake: IdleStrategy
private val sendIdleStrategyHandShake: IdleStrategy
val pollIdleStrategy: CoroutineIdleStrategy
val pollIdleStrategyHandShake: IdleStrategy
/**
* Crypto and signature management
@ -126,7 +129,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// serialization stuff
serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy
sendIdleStrategyHandshake = sendIdleStrategy.cloneToNormal()
pollIdleStrategy = config.pollIdleStrategy
sendIdleStrategyHandShake = sendIdleStrategy.cloneToNormal()
pollIdleStrategyHandShake = pollIdleStrategy.cloneToNormal()
handshakeKryo = serialization.initHandshakeKryo()
@ -347,7 +353,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*/
if (result >= Publication.ADMIN_ACTION) {
// we should retry.
sendIdleStrategyHandshake.idle()
sendIdleStrategyHandShake.idle()
continue
}
@ -362,7 +368,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
ListenerManager.cleanStackTrace(exception, 2) // 2 because we do not want to see the stack for the abstract `newException`
listenerManager.notifyError(exception)
} finally {
sendIdleStrategyHandshake.reset()
sendIdleStrategyHandShake.reset()
}
}
@ -431,7 +437,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
when (message) {
is PingMessage -> {
// the ping listener (internal use only!)
// the ping listener
actionDispatch.launch {
pingManager.manage(this@EndPoint, connection, message, logger)
}
@ -636,22 +642,22 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
final override fun close() {
if (shutdown.compareAndSet(expect = false, update = true)) {
logger.info { "Shutting down..." }
aeronDriver.close()
runBlocking {
aeronDriver.close()
connections.forEach {
it.close()
}
// the storage is closed via this as well.
storage.close()
// Connections are closed first, because we want to make sure that no RMI messages can be received
// when we close the RMI support objects (in which case, weird - but harmless - errors show up)
// this will wait for RMI timeouts if there are RMI in-progress. (this happens if we close via and RMI method)
rmiGlobalSupport.close()
}
// the storage is closed via this as well.
storage.close()
close0()
// if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now)

View File

@ -143,7 +143,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
}
// called from the connect thread
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
failed = null
oneTimeKey = endPoint.crypto.secureRandom.nextInt()
val publicKey = endPoint.storage.getPublicKey()!!
@ -151,8 +151,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
// Send the one-time pad to the server.
val publication = handshakeConnection.publication
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
val pollIdleStrategy = endPoint.pollIdleStrategyHandShake
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
@ -191,7 +190,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
}
// called from the connect thread
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
// Send the done message to the server.
@ -203,7 +202,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
failed = null
var pollCount: Int
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
val pollIdleStrategy = endPoint.pollIdleStrategyHandShake
var startTime = System.currentTimeMillis()
while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) {

View File

@ -24,23 +24,15 @@ import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverPairedConnection
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.*
import dorkbox.network.exceptions.*
import io.aeron.Publication
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger
import java.net.Inet4Address
import java.net.InetAddress
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
/**
@ -51,13 +43,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val config: ServerConfiguration,
private val listenerManager: ListenerManager<CONNECTION>) {
private val pendingConnectionsLock = ReentrantReadWriteLock()
// note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close
private val pendingConnections: Cache<Int, CONNECTION> = Caffeine.newBuilder()
.expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS)
.removalListener(RemovalListener<Any?, Any?> { _, value, cause ->
.expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong() * 2, TimeUnit.SECONDS)
.removalListener(RemovalListener<Int, CONNECTION> { sessionId, connection, cause ->
if (cause == RemovalCause.EXPIRED) {
@Suppress("UNCHECKED_CAST")
val connection = value as CONNECTION
connection!!
val exception = ClientTimedOutException("[${connection.id}] Waiting for registration response from client")
ListenerManager.noStackTrace(exception)
@ -90,19 +81,15 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// check to see if this sessionId is ALREADY in use by another connection!
// this can happen if there are multiple connections from the SAME ip address (ie: localhost)
if (message.state == HandshakeMessage.HELLO) {
val hasExistingSessionId = pendingConnectionsLock.read {
pendingConnections.getIfPresent(sessionId) != null
}
// this should be null.
val hasExistingSessionId = pendingConnections.getIfPresent(sessionId) != null
if (hasExistingSessionId) {
// WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId
val exception = ClientException("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.")
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
return false
}
@ -111,14 +98,11 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) {
val pendingConnection = pendingConnectionsLock.write {
val con = pendingConnections.getIfPresent(sessionId)
pendingConnections.invalidate(sessionId)
con
}
val pendingConnection = pendingConnections.getIfPresent(sessionId)
pendingConnections.invalidate(sessionId)
if (pendingConnection == null) {
val exception = ClientException("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!")
val exception = ServerException("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!")
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
} else {
@ -127,7 +111,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// this enables the connection to start polling for messages
server.addConnection(pendingConnection)
// now tell the client we are done
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
@ -165,9 +148,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
return false
}
@ -182,9 +163,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
return false
}
connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
@ -193,9 +172,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
return false
}
@ -234,9 +211,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
return
}
@ -252,9 +227,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
return
}
@ -270,9 +243,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
return
}
@ -284,9 +255,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionId = connectionSessionId)
// we have to construct how the connection will communicate!
runBlocking {
clientConnection.buildServer(aeronDriver, logger, true)
}
clientConnection.buildServer(aeronDriver, logger, true)
logger.info {
"[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]"
@ -326,14 +295,10 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write {
pendingConnections.put(sessionId, connection)
}
pendingConnections.put(sessionId, connection)
// this tells the client all of the info to connect.
runBlocking {
server.writeHandshakeMessage(handshakePublication, successMessage)
}
server.writeHandshakeMessage(handshakePublication, successMessage)
} catch (e: Exception) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
@ -398,9 +363,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
return
}
@ -417,9 +380,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.noStackTrace(exception)
listenerManager.notifyError(exception)
runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
return
}
@ -450,9 +411,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
message.isReliable)
// we have to construct how the connection will communicate!
runBlocking {
clientConnection.buildServer(aeronDriver, logger, true)
}
clientConnection.buildServer(aeronDriver, logger, true)
logger.info {
// (reliable:$isReliable)"
@ -462,21 +421,19 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val connection = server.newConnection(ConnectionParams(server, clientConnection, validateRemoteAddress))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
runBlocking {
val permitConnection = listenerManager.notifyFilter(connection)
if (!permitConnection) {
// have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
val permitConnection = listenerManager.notifyFilter(connection)
if (!permitConnection) {
// have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
val exception = ClientRejectedException("Connection $clientAddressString was not permitted!")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception)
val exception = ClientRejectedException("Connection $clientAddressString was not permitted!")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception)
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
return@runBlocking
}
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
return
}
@ -503,14 +460,10 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write {
pendingConnections.put(sessionId, connection)
}
pendingConnections.put(sessionId, connection)
// this tells the client all of the info to connect.
runBlocking {
server.writeHandshakeMessage(handshakePublication, successMessage)
}
server.writeHandshakeMessage(handshakePublication, successMessage)
} catch (e: Exception) {
// have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress)
@ -538,6 +491,8 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
sessionIdAllocator.clear()
streamIdAllocator.clear()
pendingConnections.invalidateAll()
pendingConnections.cleanUp()
}
}