Cleaned up log messages, cleaned up handshakes, cleaned up polling code. Added unit test for concurrent client connections

old_release
Robinson 2022-05-30 02:45:50 +02:00
parent 08ecaaf1c7
commit e8f7c8d8d3
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
20 changed files with 743 additions and 752 deletions

View File

@ -18,7 +18,7 @@
<logger name="ch.qos.logback" level="ERROR"/>
<root level="INFO"> <!-- Release: ERROR -->
<root level="TRACE"> <!-- Release: ERROR -->
<appender-ref ref="STDOUT"/>
</root>
</configuration>

View File

@ -317,7 +317,7 @@ open class Client<CONNECTION : Connection>(
require(connectionTimeoutSec >= 0) { "connectionTimeoutSec '$connectionTimeoutSec' is invalid. It must be >=0" }
if (isConnected) {
logger.error("Unable to connect when already connected!")
logger.error { "Unable to connect when already connected!" }
return
}
@ -327,11 +327,9 @@ open class Client<CONNECTION : Connection>(
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
try {
runBlocking {
initEndpointState()
}
initEndpointState()
} catch (e: Exception) {
logger.error("Unable to initialize the endpoint state", e)
logger.error(e) { "Unable to initialize the endpoint state" }
return
}
@ -349,7 +347,6 @@ open class Client<CONNECTION : Connection>(
require(false) { "Cannot connect to ${IP.toString(remoteAddress)} It is an invalid address!" }
}
// IPC can be enabled TWO ways!
// - config.enableIpc
// - NULL remoteAddress
@ -376,7 +373,7 @@ open class Client<CONNECTION : Connection>(
buildUdpHandshake(connectionTimeoutSec = handshakeTimeout, reliable = reliable)
}
logger.info(handshakeConnection.clientInfo())
logger.info { handshakeConnection.clientInfo }
connect0(handshake, handshakeConnection, handshakeTimeout)
@ -389,13 +386,13 @@ open class Client<CONNECTION : Connection>(
// short delay, since it failed we want to limit the retry rate to something slower than "as fast as the CPU can do it"
delay(500)
if (logger.isTraceEnabled) {
logger.trace("Unable to connect, retrying", e)
logger.trace(e) { "Unable to connect, retrying..." }
} else {
logger.error("Unable to connect, retrying ${e.message}")
logger.info { "Unable to connect, retrying..." }
}
} catch (e: Exception) {
logger.error("Un-recoverable error during handshake. Aborting.", e)
logger.error(e) { "Un-recoverable error during handshake. Aborting." }
listenerManager.notifyError(e)
throw e
}
@ -405,13 +402,9 @@ open class Client<CONNECTION : Connection>(
private suspend fun buildIpcHandshake(ipcSubscriptionId: Int, ipcPublicationId: Int, connectionTimeoutSec: Int, reliable: Boolean): MediaDriverConnection {
if (remoteAddress == null) {
logger.info {
"IPC enabled."
}
logger.info { "IPC enabled." }
} else {
logger.info {
"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC"
}
logger.info { "IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" }
}
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via network instead
@ -487,13 +480,13 @@ open class Client<CONNECTION : Connection>(
val validateRemoteAddress = if (isUsingIPC) {
PublicKeyValidationState.VALID
} else {
crypto.validateRemoteAddress(remoteAddress!!, connectionInfo.publicKey)
crypto.validateRemoteAddress(remoteAddress!!, remoteAddressString, connectionInfo.publicKey)
}
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
handshakeConnection.close()
val exception = ClientRejectedException("Connection to ${IP.toString(remoteAddress!!)} not allowed! Public key mismatch.")
logger.error("Validation error", exception)
logger.error(exception) { "Validation error" }
throw exception
}
@ -530,7 +523,7 @@ open class Client<CONNECTION : Connection>(
// does not need to do anything
//
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
logger.info(clientConnection.clientInfo())
logger.info { clientConnection.clientInfo }
///////////////
@ -567,11 +560,11 @@ open class Client<CONNECTION : Connection>(
handshakeConnection.close()
val exception = ClientRejectedException("Connection to ${IP.toString(remoteAddress!!)} was not permitted!")
ListenerManager.cleanStackTrace(exception)
logger.error("Permission error", exception)
logger.error(exception) { "Permission error" }
throw exception
}
logger.info("Adding new signature for ${IP.toString(remoteAddress!!)} : ${connectionInfo.publicKey.toHexString()}")
logger.info { "Adding new signature for ${IP.toString(remoteAddress!!)} : ${connectionInfo.publicKey.toHexString()}" }
storage.addRegisteredServerKey(remoteAddress!!, connectionInfo.publicKey)
}
@ -584,7 +577,7 @@ open class Client<CONNECTION : Connection>(
// on the client, we want to GUARANTEE that the disconnect happens-before connect.
if (!lockStepForConnect.compareAndSet(null, SuspendWaiter())) {
logger.error("Connection ${newConnection.id}", "close lockStep for disconnect was in the wrong state!")
logger.error { "Connection ${newConnection.id} : close lockStep for disconnect was in the wrong state!" }
}
}
newConnection.postCloseAction = {
@ -603,7 +596,7 @@ open class Client<CONNECTION : Connection>(
connection0 = newConnection
addConnection(newConnection)
logger.error { "Connection created, finishing handshake" }
logger.error { "Connection created, finishing handshake: ${handshake.connectKey}" }
// tell the server our connection handshake is done, and the connection can now listen for data.
// also closes the handshake (will also throw connect timeout exception)
@ -614,7 +607,7 @@ open class Client<CONNECTION : Connection>(
canFinishConnecting = try {
handshake.done(handshakeConnection, successAttemptTimeout)
} catch (e: ClientException) {
logger.error("Error during handshake", e)
logger.error(e) { "Error during handshake" }
false
}
}
@ -638,7 +631,7 @@ open class Client<CONNECTION : Connection>(
while (!isShutdown()) {
if (newConnection.isClosedViaAeron()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug {"[${newConnection.id}] connection expired"}
logger.debug { "[${newConnection.id}] connection expired" }
// event-loop is required, because we want to run this code AFTER the current coroutine has finished. This prevents
// odd race conditions when a client is restarted. Can only be run from inside another co-routine!
@ -670,7 +663,7 @@ open class Client<CONNECTION : Connection>(
} else {
close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
val exception = ClientRejectedException("Unable to connect with server: ${handshakeConnection.clientInfo}")
ListenerManager.cleanStackTrace(exception)
throw exception
}
@ -692,7 +685,13 @@ open class Client<CONNECTION : Connection>(
* the remote address, as a string.
*/
val remoteAddressString: String
get() = remoteAddress0?.hostAddress ?: "ipc"
get() {
return when (val address = remoteAddress) {
is Inet4Address -> IPv4.toString(address)
is Inet6Address -> IPv6.toString(address, true)
else -> "ipc"
}
}
/**
* true if this connection is an IPC connection
@ -731,7 +730,7 @@ open class Client<CONNECTION : Connection>(
c.send(message)
} else {
val exception = ClientException("Cannot send a message when there is no connection!")
logger.error("No connection!", exception)
logger.error(exception) { "No connection!" }
false
}
}
@ -760,7 +759,7 @@ open class Client<CONNECTION : Connection>(
if (c != null) {
return pingManager.ping(c, pingTimeoutSeconds, actionDispatch, responseManager, logger, function)
} else {
logger.error("No connection!", ClientException("Cannot send a ping when there is no connection!"))
logger.error(ClientException("Cannot send a ping when there is no connection!")) { "No connection!" }
}
return false

View File

@ -15,15 +15,12 @@
*/
package dorkbox.network
import dorkbox.netUtil.IP
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.netUtil.Inet4
import dorkbox.netUtil.Inet6
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronPoller
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverServerConnection
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
@ -31,16 +28,13 @@ import dorkbox.network.connection.eventLoop
import dorkbox.network.connectionType.ConnectionRule
import dorkbox.network.coroutines.SuspendWaiter
import dorkbox.network.exceptions.ServerException
import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.handshake.ServerHandshakePollers
import dorkbox.network.rmi.RmiSupportServer
import io.aeron.FragmentAssembler
import io.aeron.Image
import io.aeron.logbuffer.Header
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer
import java.net.InetAddress
import java.util.concurrent.*
@ -109,20 +103,6 @@ open class Server<CONNECTION : Connection>(
Server::class.java.simpleName)
companion object {
/**
* Gets the version number.
@ -165,7 +145,7 @@ open class Server<CONNECTION : Connection>(
/**
* Used for handshake connections
*/
private val handshake = ServerHandshake(logger, config, listenerManager)
internal val handshake = ServerHandshake(logger, config, listenerManager)
/**
* Maintains a thread-safe collection of rules used to define the connection type with this server.
@ -175,8 +155,8 @@ open class Server<CONNECTION : Connection>(
/**
* true if the following network stacks are available for use
*/
private val canUseIPv4 = config.enableIPv4 && IPv4.isAvailable
private val canUseIPv6 = config.enableIPv6 && IPv6.isAvailable
internal val canUseIPv4 = config.enableIPv4 && IPv4.isAvailable
internal val canUseIPv6 = config.enableIPv6 && IPv6.isAvailable
// localhost/loopback IP might not always be 127.0.0.1 or ::1
@ -221,337 +201,20 @@ open class Server<CONNECTION : Connection>(
return ServerException(message, cause)
}
private fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIpc) {
val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId,
streamId = config.ipcPublicationId,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error("[$sessionId] Connection from IPC not allowed! Invalid connection request")
try {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
}
return@FragmentAssembler
}
handshake.processIpcHandshakeMessageServer(this@Server,
rmiConnectionSupport,
publication,
sessionId,
message,
aeronDriver,
connectionFunc,
logger)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPC Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
@Suppress("DuplicatedCode")
private fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (canUseIPv4) {
val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv4Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
val clientAddress = IPv4.toAddressUnsafe(clientAddressString)
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")
try {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server,
rmiConnectionSupport,
publication,
sessionId,
clientAddressString,
clientAddress,
message,
aeronDriver,
false,
connectionFunc,
logger)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv4 Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
@Suppress("DuplicatedCode")
private fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val poller = if (canUseIPv6) {
val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv6Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
val clientAddress = IPv6.toAddress(clientAddressString)!!
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")
try {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server,
rmiConnectionSupport,
publication,
sessionId,
clientAddressString,
clientAddress,
message,
aeronDriver,
false,
connectionFunc,
logger)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv6 Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
@Suppress("DuplicatedCode")
private fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
val driver = UdpMediaDriverServerConnection(
listenAddress = listenIPv6Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
val poller = object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
// maybe IPv4, maybe IPv6! This is slower than if we ALREADY know what it is.
val clientAddress = IP.toAddress(clientAddressString)!!
val message = readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")
try {
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(this@Server,
rmiConnectionSupport,
publication,
sessionId,
clientAddressString,
clientAddress,
message,
aeronDriver,
true,
connectionFunc,
logger)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
logger.info(poller.serverInfo())
return poller
}
/**
* Binds the server to AERON configuration
*/
@Suppress("DuplicatedCode")
fun bind() {
if (bindAlreadyCalled) {
logger.error("Unable to bind when the server is already running!")
logger.error { "Unable to bind when the server is already running!" }
return
}
try {
runBlocking {
initEndpointState()
}
initEndpointState()
} catch (e: Exception) {
logger.error("Unable to initialize the endpoint state", e)
logger.error(e) { "Unable to initialize the endpoint state" }
return
}
@ -563,33 +226,30 @@ open class Server<CONNECTION : Connection>(
// this forces the current thread to WAIT until poll system has started
val waiter = SuspendWaiter()
val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config)
val ipcPoller: AeronPoller = ServerHandshakePollers.IPC(aeronDriver, config, this)
// if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled!
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address == IPv6.WILDCARD
val ipv4Poller: AeronPoller
val ipv6Poller: AeronPoller
if (isWildcard) {
// IPv6 will bind to IPv4 wildcard as well!!
if (canUseIPv4 && canUseIPv6) {
ipv4Poller = object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv4 Disabled" }
}
ipv6Poller = getIpv6WildcardPoller(aeronDriver, config)
// IPv6 will bind to IPv4 wildcard as well, so don't bind both!
ipv4Poller = ServerHandshakePollers.disabled("IPv4 Disabled")
ipv6Poller = ServerHandshakePollers.ip6Wildcard(aeronDriver, config, this)
} else {
// only 1 will be a real poller
ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config)
ipv4Poller = ServerHandshakePollers.ip4(aeronDriver, config, this)
ipv6Poller = ServerHandshakePollers.ip6(aeronDriver, config, this)
}
} else {
ipv4Poller = getIpv4Poller(aeronDriver, config)
ipv6Poller = getIpv6Poller(aeronDriver, config)
ipv4Poller = ServerHandshakePollers.ip4(aeronDriver, config, this)
ipv6Poller = ServerHandshakePollers.ip6(aeronDriver, config, this)
}
actionDispatch.launch {
actionDispatch.launch(start = CoroutineStart.ATOMIC) {
waiter.doNotify()
val pollIdleStrategy = config.pollIdleStrategy
@ -609,7 +269,6 @@ open class Server<CONNECTION : Connection>(
// this checks to see if there are NEW clients via IPC
pollCount += ipcPoller.poll()
// this manages existing clients (for cleanup + connection polling). This has a concurrent iterator,
// so we can modify this as we go
connections.forEach { connection ->
@ -656,7 +315,7 @@ open class Server<CONNECTION : Connection>(
connections.clear()
cons.forEach { connection ->
logger.error("${connection.id} cleanup")
logger.error { "${connection.id} cleanup and close" }
// have to free up resources!
// NOTE: This can only occur on the polling dispatch thread!!
handshake.cleanup(connection)
@ -680,7 +339,7 @@ open class Server<CONNECTION : Connection>(
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
jobs.forEach { it.join() }
} catch (e: Exception) {
logger.error("Unexpected error during server message polling!", e)
logger.error(e) { "Unexpected error during server message polling!" }
} finally {
ipv4Poller.close()
ipv6Poller.close()
@ -742,7 +401,6 @@ open class Server<CONNECTION : Connection>(
shutdownEventWaiter.doWait()
}
}
}

View File

@ -11,8 +11,6 @@ import io.aeron.Subscription
import io.aeron.driver.MediaDriver
import io.aeron.exceptions.DriverTimeoutException
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KLogger
import mu.KotlinLogging
import org.agrona.concurrent.BackoffIdleStrategy
@ -20,6 +18,8 @@ import org.slf4j.LoggerFactory
import java.io.File
import java.lang.Thread.sleep
import java.net.BindException
import java.util.concurrent.locks.*
import kotlin.concurrent.write
/**
* Class for managing the Aeron+Media drivers
@ -57,7 +57,7 @@ class AeronDriver(
private const val AERON_PUBLICATION_LINGER_TIMEOUT = 5_000L // in MS
// prevents multiple instances, within the same JVM, from starting at the exact same time.
private val startMutex = Mutex()
private val lock = ReentrantReadWriteLock()
private fun setConfigDefaults(config: Configuration, logger: KLogger) {
// explicitly don't set defaults if we already have the context defined!
@ -322,8 +322,9 @@ class AeronDriver(
*
* @throws Exception if there is a problem starting the media driver
*/
suspend fun start() {
startMutex.withLock {
fun start() {
// Note: A mutex doesn't work so well.
lock.write {
if (closeRequested.value) {
logger.debug("Resetting media driver context")

View File

@ -3,6 +3,6 @@ package dorkbox.network.aeron
internal interface AeronPoller {
fun poll(): Int
fun close()
fun serverInfo(): String
}
val serverInfo: String
}

View File

@ -33,8 +33,6 @@ internal open class IpcMediaDriverConnection(streamId: Int,
) :
MediaDriverConnection(0, 0, streamId, sessionId, 10, true) {
var success: Boolean = false
private fun uri(): ChannelUriStringBuilder {
val builder = ChannelUriStringBuilder().media("ipc")
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
@ -87,6 +85,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
if (!success) {
subscription.close()
val clientTimedOutException = ClientTimedOutException("Creating subscription connection to aeron")
ListenerManager.cleanStackTraceInternal(clientTimedOutException)
throw clientTimedOutException
@ -116,8 +115,8 @@ internal open class IpcMediaDriverConnection(streamId: Int,
}
this.success = true
this.publication = publication
this.subscription = subscription
this.publication = publication
}
/**
@ -146,34 +145,28 @@ internal open class IpcMediaDriverConnection(streamId: Int,
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
// If we start/stop too quickly, we might have the aeron connectivity issues! Retry a few times.
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
success = true
subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamIdSubscription)
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
}
override fun clientInfo() : String {
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
override val clientInfo : String by lazy {
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
"[$sessionId] IPC connection established to [$streamIdSubscription|$streamId]"
} else {
"Connecting handshake to IPC [$streamIdSubscription|$streamId]"
}
}
override fun serverInfo() : String {
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
"[$sessionId] IPC listening on [$streamIdSubscription|$streamId] "
override val serverInfo : String by lazy {
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
"[$sessionId] IPC listening on [$streamIdSubscription|$streamId] [$sessionId]"
} else {
"Listening handshake on IPC [$streamIdSubscription|$streamId]"
}
}
override fun close() {
if (success) {
subscription.close()
publication.close()
"Listening handshake on IPC [$streamIdSubscription|$streamId] [$sessionId]"
}
}
override fun toString(): String {
return "[$streamIdSubscription|$streamId] [$sessionId]"
return serverInfo
}
}

View File

@ -26,6 +26,7 @@ abstract class MediaDriverConnection(
val streamId: Int, val sessionId: Int,
val connectionTimeoutSec: Int, val isReliable: Boolean) : AutoCloseable {
var success: Boolean = false
lateinit var subscription: Subscription
lateinit var publication: Publication
@ -33,6 +34,13 @@ abstract class MediaDriverConnection(
abstract suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger)
abstract fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false)
abstract fun clientInfo() : String
abstract fun serverInfo() : String
abstract val clientInfo : String
abstract val serverInfo : String
override fun close() {
if (success) {
publication.close()
subscription.close()
}
}
}

View File

@ -40,22 +40,6 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
isReliable: Boolean = true) :
UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutSec, isReliable) {
var success: Boolean = false
private fun aeronConnectionString(ipAddress: InetAddress): String {
return if (ipAddress is Inet4Address) {
ipAddress.hostAddress
} else {
// IPv6 requires the address to be bracketed by [...]
val host = ipAddress.hostAddress
if (host[0] == '[') {
host
} else {
"[${ipAddress.hostAddress}]"
}
}
}
val addressString: String by lazy {
IP.toString(address)
}
@ -80,7 +64,18 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
@Suppress("DuplicatedCode")
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
val aeronAddressString = aeronConnectionString(address)
val aeronAddressString = if (address is Inet4Address) {
address.hostAddress
} else {
// IPv6 requires the address to be bracketed by [...]
val host = address.hostAddress
if (host[0] == '[') {
host
} else {
"[${address.hostAddress}]"
}
}
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
@ -123,6 +118,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
if (!success) {
subscription.close()
val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()} in ${timoutInNanos}ms")
ListenerManager.cleanStackTraceInternal(ex)
throw ex
@ -145,18 +141,19 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
if (!success) {
subscription.close()
publication.close()
val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()} in ${timoutInNanos}ms")
ListenerManager.cleanStackTrace(ex)
throw ex
}
this.success = true
this.publication = publication
this.subscription = subscription
this.publication = publication
}
override fun clientInfo(): String {
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
override val clientInfo: String by lazy {
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
"Connecting to $addressString [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
} else {
"Connecting handshake to $addressString [$subscriptionPort|$publicationPort] [$streamId|*] (reliable:$isReliable)"
@ -166,18 +163,12 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
throw ClientException("Server info not implemented in Client MDC")
}
override fun serverInfo(): String {
override val serverInfo: String
get() {
throw ClientException("Server info not implemented in Client MDC")
}
override fun close() {
if (success) {
subscription.close()
publication.close()
}
}
override fun toString(): String {
return "$addressString [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
return clientInfo
}
}

View File

@ -38,8 +38,6 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
isReliable: Boolean = true) :
UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutSec, isReliable) {
var success: Boolean = false
private fun aeronConnectionString(ipAddress: InetAddress): String {
return if (ipAddress is Inet4Address) {
ipAddress.hostAddress
@ -54,7 +52,7 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
}
}
protected fun uri(): ChannelUriStringBuilder {
private fun uri(): ChannelUriStringBuilder {
val builder = ChannelUriStringBuilder().reliable(isReliable).media("udp")
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
builder.sessionId(sessionId)
@ -99,15 +97,17 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
// If we start/stop too quickly, we might have the address already in use! Retry a few times.
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
success = true
subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamId)
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
}
override fun clientInfo(): String {
throw ServerException("Client info not implemented in Server MDC")
}
override val clientInfo: String
get() {
throw ServerException("Client info not implemented in Server MDC")
}
override fun serverInfo(): String {
override val serverInfo: String by lazy {
val address = if (listenAddress == IPv4.WILDCARD || listenAddress == IPv6.WILDCARD) {
if (listenAddress == IPv4.WILDCARD) {
listenAddress.hostAddress
@ -118,21 +118,14 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
IP.toString(listenAddress)
}
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
"Listening on $address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
} else {
"Listening handshake on $address [$subscriptionPort|$publicationPort] [$streamId|*] (reliable:$isReliable)"
}
}
override fun close() {
if (success) {
subscription.close()
publication.close()
}
}
override fun toString(): String {
return "$IP.toString(listenAddress) [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
return serverInfo
}
}

View File

@ -346,7 +346,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// the server 'handshake' connection info is cleaned up with the disconnect via timeout/expire.
if (isClosed.compareAndSet(expect = false, update = true)) {
logger.info {"[$id] connection closed"}
logger.debug {"[$id] connection closing"}
subscription.close()
@ -393,6 +393,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// This is set by the client/server so if there is a "connect()" call in the the disconnect callback, we can have proper
// lock-stop ordering for how disconnect and connect work with each-other
postCloseAction()
logger.debug {"[$id] connection closed"}
}
}

View File

@ -17,7 +17,6 @@ package dorkbox.network.connection
import dorkbox.bytes.Hash
import dorkbox.bytes.toHexString
import dorkbox.netUtil.IP
import dorkbox.network.handshake.ClientConnectionInfo
import dorkbox.network.serialization.AeronInput
import dorkbox.network.serialization.AeronOutput
@ -137,9 +136,9 @@ internal class CryptoManagement(val logger: KLogger,
* @return true if all is OK (the remote address public key matches the one saved or we disabled remote key validation.)
* false if we should abort
*/
internal fun validateRemoteAddress(remoteAddress: InetAddress, publicKey: ByteArray?): PublicKeyValidationState {
internal fun validateRemoteAddress(remoteAddress: InetAddress, remoteAddressString: String, publicKey: ByteArray?): PublicKeyValidationState {
if (publicKey == null) {
logger.error("Error validating public key for ${IP.toString(remoteAddress)}! It was null (and should not have been)")
logger.error("Error validating public key for ${remoteAddressString}! It was null (and should not have been)")
return PublicKeyValidationState.INVALID
}
@ -150,18 +149,18 @@ internal class CryptoManagement(val logger: KLogger,
if (!publicKey.contentEquals(savedPublicKey)) {
return if (enableRemoteSignatureValidation) {
// keys do not match, abort!
logger.error("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Denying connection attempt")
logger.error("The public key for remote connection $remoteAddressString does not match. Denying connection attempt")
PublicKeyValidationState.INVALID
}
else {
logger.warn("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Permitting connection attempt.")
logger.warn("The public key for remote connection $remoteAddressString does not match. Permitting connection attempt.")
PublicKeyValidationState.TAMPERED
}
}
}
} catch (e: SecurityException) {
// keys do not match, abort!
logger.error("Error validating public key for ${IP.toString(remoteAddress)}!", e)
logger.error("Error validating public key for $remoteAddressString!", e)
return PublicKeyValidationState.INVALID
}

View File

@ -215,7 +215,7 @@ internal constructor(val type: Class<*>,
/**
* @throws Exception if there is a problem starting the media driver
*/
internal suspend fun initEndpointState() {
internal fun initEndpointState() {
shutdown.getAndSet(false)
shutdownWaiter = SuspendWaiter()
@ -363,9 +363,7 @@ internal constructor(val type: Class<*>,
@Suppress("DuplicatedCode")
internal fun writeHandshakeMessage(publication: Publication, message: HandshakeMessage) {
// The handshake sessionId IS NOT globally unique
logger.trace {
"[${publication.sessionId()}] send HS: $message"
}
logger.trace { "[${message.connectKey}] send HS: $message" }
try {
// we are not thread-safe!
@ -431,11 +429,9 @@ internal constructor(val type: Class<*>,
internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
return try {
// NOTE: This ABSOLUTELY MUST be done on the same thread! This cannot be done on a new one, because the buffer could change!
val message = handshakeKryo.read(buffer, offset, length)
val message = handshakeKryo.read(buffer, offset, length) as HandshakeMessage
logger.trace {
"[${header.sessionId()}] received HS: $message"
}
logger.trace { "[${message.connectKey}] received HS: $message" }
message
} catch (e: Exception) {
@ -723,6 +719,7 @@ internal constructor(val type: Class<*>,
runBlocking {
connections.forEach {
logger.info { "Closing connection: ${it.id}" }
it.close()
}
@ -739,6 +736,8 @@ internal constructor(val type: Class<*>,
// if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now)
shutdownWaiter.cancel()
logger.info { "Done shutting down..." }
}
}

View File

@ -42,9 +42,9 @@ internal class ClientHandshake<CONNECTION: Connection>(
private val handler: FragmentHandler
// a one-time key for connecting
// used to keep track and associate UDP/IPC handshakes between client/server
@Volatile
var oneTimeKey = 0
var connectKey = 0L
@Volatile
private var connectionHelloInfo: ClientConnectionInfo? = null
@ -64,21 +64,20 @@ internal class ClientHandshake<CONNECTION: Connection>(
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
val sessionId = header.sessionId()
failedException = null
needToRetry = false
// it must be a registration message
if (message !is HandshakeMessage) {
failedException = ClientRejectedException("[$sessionId] cancelled handshake for unrecognized message: $message")
failedException = ClientRejectedException("[${header.sessionId()}] cancelled handshake for unrecognized message: $message")
return@FragmentAssembler
}
// this is an error message
if (message.state == HandshakeMessage.INVALID) {
val cause = ServerException(message.errorMessage ?: "Unknown").apply { stackTrace = stackTrace.copyOfRange(0, 1) }
failedException = ClientRejectedException("[$sessionId] cancelled handshake", cause)
failedException = ClientRejectedException("[${message.connectKey}] cancelled handshake", cause)
return@FragmentAssembler
}
@ -89,8 +88,9 @@ internal class ClientHandshake<CONNECTION: Connection>(
return@FragmentAssembler
}
if (oneTimeKey != message.oneTimeKey) {
logger.error("[$message] ignored message (one-time key: ${message.oneTimeKey}) intended for another client (mine is: ${oneTimeKey})")
if (connectKey != message.connectKey) {
logger.error("Ignored handshake (client connect key: ${message.connectKey}) intended for another client (mine is:" +
" ${connectKey})")
return@FragmentAssembler
}
@ -106,7 +106,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
if (registrationData != null && serverPublicKeyBytes != null) {
connectionHelloInfo = crypto.decrypt(registrationData, serverPublicKeyBytes)
} else {
failedException = ClientRejectedException("[$message.sessionId] canceled handshake for message without registration and/or public key info")
failedException = ClientRejectedException("[${message.connectKey}] canceled handshake for message without registration and/or public key info")
}
}
HandshakeMessage.HELLO_ACK_IPC -> {
@ -129,7 +129,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
publicationPort = streamPubId,
kryoRegistrationDetails = regDetails)
} else {
failedException = ClientRejectedException("[$message.sessionId] canceled handshake for message without registration data")
failedException = ClientRejectedException("[${message.connectKey}] canceled handshake for message without registration data")
}
}
HandshakeMessage.DONE_ACK -> {
@ -137,16 +137,28 @@ internal class ClientHandshake<CONNECTION: Connection>(
}
else -> {
val stateString = HandshakeMessage.toStateString(message.state)
failedException = ClientRejectedException("[$sessionId] cancelled handshake for message that is $stateString")
failedException = ClientRejectedException("[${message.connectKey}] cancelled handshake for message that is $stateString")
}
}
}
}
/**
* Make sure that NON-ZERO is returned
*/
private fun getSafeConnectKey(): Long {
var key = endPoint.crypto.secureRandom.nextLong()
while (key == 0L) {
key = endPoint.crypto.secureRandom.nextLong()
}
return key
}
// called from the connect thread
fun hello(handshakeConnection: MediaDriverConnection, connectionTimeoutSec: Int) : ClientConnectionInfo {
failedException = null
oneTimeKey = endPoint.crypto.secureRandom.nextInt()
connectKey = getSafeConnectKey()
val publicKey = endPoint.storage.getPublicKey()!!
// Send the one-time pad to the server.
@ -155,8 +167,11 @@ internal class ClientHandshake<CONNECTION: Connection>(
val pollIdleStrategy = endPoint.pollIdleStrategyHandShake
try {
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(connectKey, publicKey))
} catch (e: Exception) {
publication.close()
subscription.close()
logger.error("Handshake error!", e)
throw e
}
@ -181,6 +196,9 @@ internal class ClientHandshake<CONNECTION: Connection>(
val failedEx = failedException
if (failedEx != null) {
publication.close()
subscription.close()
// no longer necessary to hold this connection open (if not a failure, we close the handshake after the DONE message)
handshakeConnection.close()
ListenerManager.cleanStackTraceInternal(failedEx)
@ -188,8 +206,12 @@ internal class ClientHandshake<CONNECTION: Connection>(
}
if (connectionHelloInfo == null) {
publication.close()
subscription.close()
// no longer necessary to hold this connection open (if not a failure, we close the handshake after the DONE message)
handshakeConnection.close()
val exception = ClientTimedOutException("Waiting for registration response from server")
ListenerManager.cleanStackTraceInternal(exception)
throw exception
@ -200,12 +222,14 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread
suspend fun done(handshakeConnection: MediaDriverConnection, connectionTimeoutSec: Int): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
val registrationMessage = HandshakeMessage.doneFromClient(connectKey)
// Send the done message to the server.
try {
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
} catch (e: Exception) {
handshakeConnection.close()
return false
}
@ -254,11 +278,13 @@ internal class ClientHandshake<CONNECTION: Connection>(
throw exception
}
logger.error{"[${subscription.streamId()}] handshake done"}
return connectionDone
}
fun reset() {
oneTimeKey = 0
connectKey = 0L
connectionHelloInfo = null
connectionDone = false
needToRetry = false

View File

@ -24,9 +24,10 @@ internal class HandshakeMessage private constructor() {
var publicKey: ByteArray? = null
// used to keep track and associate UDP/etc sessions. This is always defined by the server
// a sessionId if '0', means we are still figuring it out.
var oneTimeKey = 0
// used to keep track and associate UDP/IPC handshakes between client/server
// The connection info (session ID, etc) necessary to make a connection to the server are encrypted with the clients public key.
// so EVEN IF you can guess someone's connectKey, you must also know their private key in order to connect as them.
var connectKey = 0L
// -1 means there is an error
var state = INVALID
@ -35,9 +36,6 @@ internal class HandshakeMessage private constructor() {
var publicationPort = 0
var subscriptionPort = 0
var sessionId = 0
var streamId = 0
// by default, this will be a reliable connection. When the client connects to the server, the client will specify if the new connection
@ -45,9 +43,8 @@ internal class HandshakeMessage private constructor() {
val isReliable = true
// the client sends it's registration data to the server to make sure that the registered classes are the same between the client/server
// the client sends its registration data to the server to make sure that the registered classes are the same between the client/server
var registrationData: ByteArray? = null
var registrationRmiIdData: IntArray? = null
companion object {
const val INVALID = -2
@ -58,42 +55,39 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3
const val DONE_ACK = 4
fun helloFromClient(oneTimeKey: Int, publicKey: ByteArray): HandshakeMessage {
fun helloFromClient(connectKey: Long, publicKey: ByteArray): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.oneTimeKey = oneTimeKey
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
hello.publicKey = publicKey
return hello
}
fun helloAckToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
fun helloAckToClient(connectKey: Long): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO_ACK
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
return hello
}
fun helloAckIpcToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
fun helloAckIpcToClient(connectKey: Long): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO_ACK_IPC
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
return hello
}
fun doneFromClient(oneTimeKey: Int): HandshakeMessage {
fun doneFromClient(connectKey: Long): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE
hello.oneTimeKey = oneTimeKey
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
return hello
}
fun doneToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
fun doneToClient(connectKey: Long): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE_ACK
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
return hello
}
@ -134,7 +128,6 @@ internal class HandshakeMessage private constructor() {
", Error: $errorMessage"
}
return "HandshakeMessage($sessionId : oneTimePad=$oneTimeKey $stateStr$errorMsg)"
return "HandshakeMessage($stateStr$errorMsg)"
}
}

View File

@ -15,6 +15,7 @@
*/
package dorkbox.network.handshake
import dorkbox.netUtil.IP
import dorkbox.network.Server
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.AeronDriver
@ -24,11 +25,11 @@ import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.eventLoop
import dorkbox.network.exceptions.AllocationException
import dorkbox.network.rmi.RmiManagerConnections
import io.aeron.Publication
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger
import net.jodah.expiringmap.ExpirationPolicy
@ -40,8 +41,10 @@ import java.util.concurrent.*
/**
* 'notifyConnect' must be THE ONLY THING in this class to use the action dispatch!
*
* NOTE: all methods in here are called by the SAME thread!
*/
@Suppress("DuplicatedCode")
@Suppress("DuplicatedCode", "JoinDeclarationAndAssignment")
internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger,
private val config: ServerConfiguration,
private val listenerManager: ListenerManager<CONNECTION>) {
@ -50,10 +53,14 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val pendingConnections = ExpiringMap.builder()
.expiration(config.connectionCloseTimeoutInSeconds.toLong() * 2, TimeUnit.SECONDS)
.expirationPolicy(ExpirationPolicy.CREATED)
.expirationListener<Int, CONNECTION> { sessionId, connection ->
expirePendingConnections(sessionId, connection)
.expirationListener<Long, CONNECTION> { clientConnectKey, connection ->
// this blocks until it fully runs (which is ok. this is fast)
logger.error { "[${clientConnectKey} Connection (${connection.id}) Timed out waiting for registration response from client" }
runBlocking {
connection.close()
}
}
.build<Int, CONNECTION>()
.build<Long, CONNECTION>()
private val connectionsPerIpCounts = ConnectionCounts()
@ -62,17 +69,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val sessionIdAllocator = RandomIdAllocator(AeronDriver.RESERVED_SESSION_ID_LOW, AeronDriver.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
private fun expirePendingConnections(sessionId: Int, connection: CONNECTION) {
// this blocks until it fully runs (which is ok. this is fast)
logger.error("[${connection.id}] Timed out waiting for registration response from client")
pendingConnections.remove(sessionId)
runBlocking {
connection.close()
}
}
/**
* @return true if we should continue parsing the incoming message, false if we should abort
*/
@ -82,24 +78,22 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
actionDispatch: CoroutineScope,
handshakePublication: Publication,
message: HandshakeMessage,
sessionId: Int,
connectionString: String,
logger: KLogger
): Boolean {
// check to see if this sessionId is ALREADY in use by another connection!
// this can happen if there are multiple connections from the SAME ip address (ie: localhost)
if (message.state == HandshakeMessage.HELLO) {
// this should be null.
val hasExistingSessionId = pendingConnections[sessionId] != null
if (hasExistingSessionId) {
val hasExistingConnectionInProgress = pendingConnections[message.connectKey] != null
if (hasExistingConnectionInProgress) {
// WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId
logger.error("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.")
logger.error { "[${message.connectKey}] Connection from $connectionString had an in-use session ID! Telling client to retry." }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
} catch (e: Error) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return false
}
@ -109,19 +103,17 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) {
val pendingConnection = pendingConnections[sessionId]
pendingConnections.remove(sessionId)
val pendingConnection = pendingConnections.remove(message.connectKey)
if (pendingConnection == null) {
logger.error("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!")
logger.error { "[${message.connectKey}] Error! Pending connection from client $connectionString was null, and cannot complete handshake!" }
} else {
logger.trace { "[${pendingConnection.id}] Connection from client $connectionString done with handshake." }
logger.trace { "[${message.connectKey}] Connection (${pendingConnection.id}) from $connectionString done with handshake." }
pendingConnection.postCloseAction = {
// called on connection.close()
// this always has to be on event dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
actionDispatch.eventLoop {
actionDispatch.launch {
listenerManager.notifyDisconnect(pendingConnection)
}
}
@ -132,17 +124,15 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// now tell the client we are done
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.oneTimeKey, sessionId))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.connectKey))
// this always has to be on event dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
actionDispatch.eventLoop {
actionDispatch.launch {
listenerManager.notifyConnect(pendingConnection)
}
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
}
return false
}
@ -170,7 +160,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return false
}
@ -182,23 +172,23 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.decrement(clientAddress, currentCountForIp)
logger.error("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")
logger.error { "Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return false
}
connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
} catch (e: Exception) {
logger.error("could not validate client message", e)
logger.error(e) { "could not validate client message" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
}
@ -211,7 +201,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
server: Server<CONNECTION>,
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
message: HandshakeMessage,
aeronDriver: AeronDriver,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
@ -220,7 +209,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val connectionString = "IPC"
if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, connectionString, logger)) {
if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, connectionString, logger)) {
return
}
@ -238,12 +227,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
try {
connectionSessionId = sessionIdAllocator.allocate()
} catch (e: AllocationException) {
logger.error("Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!")
logger.error { "Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return
}
@ -256,12 +245,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
logger.error("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")
logger.error { "Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return
}
@ -274,12 +263,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId)
sessionIdAllocator.free(connectionStreamPubId)
logger.error("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")
logger.error { "Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return
}
@ -294,9 +283,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// we have to construct how the connection will communicate!
clientConnection.buildServer(aeronDriver, logger, true)
logger.info {
"[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]"
}
logger.info { "[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]" }
val connection = connectionFunc(ConnectionParams(server, clientConnection, PublicKeyValidationState.VALID, rmiConnectionSupport))
@ -311,7 +298,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckIpcToClient(message.oneTimeKey, sessionId)
val successMessage = HandshakeMessage.helloAckIpcToClient(message.connectKey)
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint
@ -332,58 +319,70 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[sessionId] = connection
pendingConnections[message.connectKey] = connection
// this tells the client all of the info to connect.
// this tells the client all the info to connect.
server.writeHandshakeMessage(handshakePublication, successMessage) // exception is already caught!
} catch (e: Exception) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamPubId)
logger.error("Connection handshake from $connectionString crashed! Message $message", e)
logger.error(e) { "Connection handshake from $connectionString crashed! Message $message" }
}
}
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
fun processUdpHandshakeMessageServer(server: Server<CONNECTION>,
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
clientAddressString: String,
clientAddress: InetAddress,
message: HandshakeMessage,
aeronDriver: AeronDriver,
isIpv6Wildcard: Boolean,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
logger: KLogger) {
fun processUdpHandshakeMessageServer(
server: Server<CONNECTION>,
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
handshakePublication: Publication,
remoteIpAndPort: String,
message: HandshakeMessage,
aeronDriver: AeronDriver,
isIpv6Wildcard: Boolean,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
logger: KLogger
) {
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
val clientAddress = IP.toAddress(clientAddressString)
if (clientAddress == null) {
logger.error { "Connection from $clientAddressString not allowed! Invalid IP address!" }
return
}
// Manage the Handshake state
if (!validateMessageTypeAndDoPending(
server,
server.actionDispatch,
handshakePublication,
message,
sessionId,
clientAddressString,
logger
server, server.actionDispatch, handshakePublication, message, clientAddressString, logger
)) {
return
}
val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState
val serialization = config.serialization
// VALIDATE:: check to see if the remote connection's public key has changed!
validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)
validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientAddressString, clientPublicKeyBytes)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
logger.error("Connection from $clientAddressString not allowed! Public key mismatch.")
logger.error { "Connection from $clientAddressString not allowed! Public key mismatch." }
return
}
if (!clientAddress.isLoopbackAddress &&
!validateUdpConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress, logger)) {
// we do not want to limit loopback addresses!
// we do not want to limit the loopback addresses!
return
}
@ -403,12 +402,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress)
logger.error("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")
logger.error { "Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return
}
@ -422,12 +421,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionId)
logger.error("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")
logger.error { "Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return
}
@ -461,10 +460,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// we have to construct how the connection will communicate!
clientConnection.buildServer(aeronDriver, logger, true)
logger.info {
// (reliable:$isReliable)"
"Creating new connection from $clientAddressString [$subscriptionPort|$publicationPort] [$connectionStreamId|$connectionSessionId] (reliable:${message.isReliable})"
}
logger.info { "Creating new connection from $clientConnection" }
val connection = connectionFunc(ConnectionParams(server, clientConnection, validateRemoteAddress, rmiConnectionSupport))
@ -476,12 +472,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
logger.error("Connection $clientAddressString was not permitted!")
logger.error { "Connection $clientAddressString was not permitted!" }
try {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
logger.error("Handshake error!", e)
logger.error(e) { "Handshake error!" }
}
return
}
@ -492,9 +488,8 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
///////////////
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(message.oneTimeKey, sessionId)
val successMessage = HandshakeMessage.helloAckToClient(message.connectKey)
// Also send the RMI registration data to the client (so the client doesn't register anything)
@ -510,7 +505,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[sessionId] = connection
pendingConnections[message.connectKey] = connection
// this tells the client all the info to connect.
server.writeHandshakeMessage(handshakePublication, successMessage) // exception is already caught
@ -520,7 +515,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
logger.error("Connection handshake from $clientAddressString crashed! Message $message", e)
logger.error(e) { "Connection handshake from $clientAddressString crashed! Message $message" }
}
}

View File

@ -0,0 +1,338 @@
@file:Suppress("MemberVisibilityCanBePrivate", "DuplicatedCode")
package dorkbox.network.handshake
import dorkbox.network.Server
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronPoller
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverServerConnection
import dorkbox.network.connection.Connection
import io.aeron.FragmentAssembler
import io.aeron.Image
import io.aeron.logbuffer.Header
import org.agrona.DirectBuffer
internal object ServerHandshakePollers {
fun disabled(serverInfo: String): AeronPoller {
return object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override val serverInfo = serverInfo
}
}
fun <CONNECTION : Connection> IPC(aeronDriver: AeronDriver, config: ServerConfiguration, server: Server<CONNECTION>): AeronPoller {
val logger = server.logger
val rmiConnectionSupport = server.rmiConnectionSupport
val connectionFunc = server.connectionFunc
val handshake = server.handshake
val poller = if (config.enableIpc) {
val driver = IpcMediaDriverConnection(
streamIdSubscription = config.ipcSubscriptionId,
streamId = config.ipcPublicationId,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID
)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
val message = server.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error { "[$sessionId] Connection from IPC not allowed! Invalid connection request" }
try {
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error(e) { "Handshake error!" }
}
return@FragmentAssembler
}
handshake.processIpcHandshakeMessageServer(
server, rmiConnectionSupport, publication, message, aeronDriver, connectionFunc, logger
)
}
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override fun close() {
driver.close()
}
override val serverInfo = driver.serverInfo
}
} else {
disabled("IPC Disabled")
}
logger.info { poller.serverInfo }
return poller
}
fun <CONNECTION : Connection> ip4(aeronDriver: AeronDriver, config: ServerConfiguration, server: Server<CONNECTION>): AeronPoller {
val logger = server.logger
val rmiConnectionSupport = server.rmiConnectionSupport
val connectionFunc = server.connectionFunc
val handshake = server.handshake
val poller = if (server.canUseIPv4) {
val driver = UdpMediaDriverServerConnection(
listenAddress = server.listenIPv4Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds
)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
val message = server.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error {
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
"[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"
}
try {
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error(e) { "Handshake error!" }
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(
server, rmiConnectionSupport, publication, remoteIpAndPort, message, aeronDriver, false, connectionFunc, logger
)
}
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override fun close() {
driver.close()
}
override val serverInfo = driver.serverInfo
}
} else {
disabled("IPv4 Disabled")
}
logger.info { poller.serverInfo }
return poller
}
fun <CONNECTION : Connection> ip6(aeronDriver: AeronDriver, config: ServerConfiguration, server: Server<CONNECTION>): AeronPoller {
val logger = server.logger
val rmiConnectionSupport = server.rmiConnectionSupport
val connectionFunc = server.connectionFunc
val handshake = server.handshake
val poller = if (server.canUseIPv6) {
val driver = UdpMediaDriverServerConnection(
listenAddress = server.listenIPv6Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds
)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
val message = server.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error {
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
"[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"
}
try {
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error(e) { "Handshake error!" }
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(
server, rmiConnectionSupport, publication, remoteIpAndPort, message, aeronDriver, false, connectionFunc, logger
)
}
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override fun close() {
driver.close()
}
override val serverInfo = driver.serverInfo
}
} else {
disabled("IPv6 Disabled")
}
logger.info { poller.serverInfo }
return poller
}
fun <CONNECTION : Connection> ip6Wildcard(
aeronDriver: AeronDriver,
config: ServerConfiguration,
server: Server<CONNECTION>
): AeronPoller {
val logger = server.logger
val rmiConnectionSupport = server.rmiConnectionSupport
val connectionFunc = server.connectionFunc
val handshake = server.handshake
val driver = UdpMediaDriverServerConnection(
listenAddress = server.listenIPv6Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds
)
driver.buildServer(aeronDriver, logger)
val publication = driver.publication
val subscription = driver.subscription
val poller = object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
val message = server.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
logger.error {
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
"[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"
}
try {
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
} catch (e: Exception) {
logger.error(e) { "Handshake error!" }
}
return@FragmentAssembler
}
handshake.processUdpHandshakeMessageServer(
server, rmiConnectionSupport, publication, remoteIpAndPort, message, aeronDriver, true, connectionFunc, logger
)
}
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override fun close() {
driver.close()
}
override val serverInfo = driver.serverInfo
}
logger.info { poller.serverInfo }
return poller
}
}

View File

@ -36,10 +36,7 @@ package dorkboxTest.network
import ch.qos.logback.classic.Level
import ch.qos.logback.classic.Logger
import ch.qos.logback.classic.encoder.PatternLayoutEncoder
import ch.qos.logback.classic.joran.JoranConfigurator
import ch.qos.logback.classic.spi.ILoggingEvent
import ch.qos.logback.core.ConsoleAppender
import dorkbox.network.Client
import dorkbox.network.Configuration
import dorkbox.network.Server
@ -55,20 +52,58 @@ import org.junit.After
import org.junit.Assert
import org.junit.Before
import org.slf4j.LoggerFactory
import java.io.File
import java.lang.Thread.sleep
import java.lang.reflect.Field
import java.lang.reflect.Method
import java.util.concurrent.*
abstract class BaseTest {
@Volatile
private var latch = CountDownLatch(1)
@Volatile
private var autoFailThread: Thread? = null
companion object {
const val LOCALHOST = "localhost"
// wait minimum of 2 minutes before we automatically fail the unit test.
var AUTO_FAIL_TIMEOUT: Long = 120L
init {
if (OS.javaVersion >= 9) {
// disableAccessWarnings
try {
val unsafeClass = Class.forName("sun.misc.Unsafe")
val field: Field = unsafeClass.getDeclaredField("theUnsafe")
field.isAccessible = true
val unsafe: Any = field.get(null)
val putObjectVolatile: Method = unsafeClass.getDeclaredMethod("putObjectVolatile", Any::class.java, Long::class.javaPrimitiveType, Any::class.java)
val staticFieldOffset: Method = unsafeClass.getDeclaredMethod("staticFieldOffset", Field::class.java)
val loggerClass = Class.forName("jdk.internal.module.IllegalAccessLogger")
val loggerField: Field = loggerClass.getDeclaredField("logger")
val offset = staticFieldOffset.invoke(unsafe, loggerField) as Long
putObjectVolatile.invoke(unsafe, loggerClass, offset, null)
} catch (ignored: Exception) {
}
}
// if (System.getProperty("logback.configurationFile") == null) {
// val file = File("logback.xml")
// if (file.canRead()) {
// System.setProperty("logback.configurationFile", file.toPath().toRealPath().toFile().toString())
// } else {
// System.setProperty("logback.configurationFile", "logback.xml")
// }
// }
// setLogLevel(Level.TRACE)
// setLogLevel(Level.ERROR)
setLogLevel(Level.DEBUG)
// we want our entropy generation to be simple (ie, no user interaction to generate)
try {
Entropy.init(SimpleEntropy::class.java)
} catch (e: InitializationException) {
e.printStackTrace()
}
}
fun clientConfig(block: Configuration.() -> Unit = {}): Configuration {
val configuration = Configuration()
@ -102,66 +137,43 @@ abstract class BaseTest {
// assume SLF4J is bound to logback in the current environment
val rootLogger = LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME) as Logger
rootLogger.detachAndStopAllAppenders()
rootLogger.level = level
val context = rootLogger.loggerContext
val jc = JoranConfigurator()
context.reset() // override default configuration
val jc = JoranConfigurator()
jc.context = context
context.getLogger(Server::class.simpleName).level = level
context.getLogger(Client::class.simpleName).level = level
jc.doConfigure(File("logback.xml").absoluteFile)
// we only want error messages
val kryoLogger = LoggerFactory.getLogger("com.esotericsoftware") as Logger
kryoLogger.level = Level.ERROR
val encoder = PatternLayoutEncoder()
encoder.context = context
encoder.pattern = "%date{HH:mm:ss.SSS} %-5level [%logger{35}] %msg%n"
encoder.start()
val consoleAppender = ConsoleAppender<ILoggingEvent>()
consoleAppender.context = context
consoleAppender.encoder = encoder
consoleAppender.start()
rootLogger.addAppender(consoleAppender)
// val encoder = PatternLayoutEncoder()
// encoder.context = context
// encoder.pattern = "%date{HH:mm:ss.SSS} %-5level [%logger{35}] %msg%n"
// encoder.start()
//
// val consoleAppender = ConsoleAppender<ILoggingEvent>()
// consoleAppender.context = context
// consoleAppender.encoder = encoder
// consoleAppender.start()
//
// rootLogger.addAppender(consoleAppender)
// context.getLogger(Server::class.simpleName).trace("TESTING")
// context.getLogger(Client::class.simpleName).trace("TESTING")
}
// wait minimum of 2 minutes before we automatically fail the unit test.
var AUTO_FAIL_TIMEOUT: Long = 120L
init {
if (OS.javaVersion >= 9) {
// disableAccessWarnings
try {
val unsafeClass = Class.forName("sun.misc.Unsafe")
val field: Field = unsafeClass.getDeclaredField("theUnsafe")
field.isAccessible = true
val unsafe: Any = field.get(null)
val putObjectVolatile: Method = unsafeClass.getDeclaredMethod("putObjectVolatile", Any::class.java, Long::class.javaPrimitiveType, Any::class.java)
val staticFieldOffset: Method = unsafeClass.getDeclaredMethod("staticFieldOffset", Field::class.java)
val loggerClass = Class.forName("jdk.internal.module.IllegalAccessLogger")
val loggerField: Field = loggerClass.getDeclaredField("logger")
val offset = staticFieldOffset.invoke(unsafe, loggerField) as Long
putObjectVolatile.invoke(unsafe, loggerClass, offset, null)
} catch (ignored: Exception) {
}
}
// we want our entropy generation to be simple (ie, no user interaction to generate)
try {
Entropy.init(SimpleEntropy::class.java)
} catch (e: InitializationException) {
e.printStackTrace()
}
}
}
@Volatile
private var latch = CountDownLatch(1)
@Volatile
private var autoFailThread: Thread? = null
private val endPointConnections: MutableList<EndPoint<*>> = CopyOnWriteArrayList()
@Volatile
@ -170,10 +182,6 @@ abstract class BaseTest {
init {
println("---- " + this.javaClass.simpleName)
setLogLevel(Level.TRACE)
// setLogLevel(Level.ERROR)
// setLogLevel(Level.DEBUG)
// we must always make sure that aeron is shut-down before starting again.
while (Server.isRunning(serverConfig())) {
println("Aeron was still running. Waiting for it to stop...")
@ -211,7 +219,7 @@ abstract class BaseTest {
if (endPoint is Client) {
endPoint.close()
latch.countDown()
println("Done with ${endPoint.type.simpleName}")
println("Done closing: ${endPoint.type.simpleName}")
} else {
remainingConnections.add(endPoint)
}

View File

@ -6,7 +6,6 @@ import dorkbox.network.aeron.AeronDriver
import dorkbox.network.connection.Connection
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.junit.Assert
import org.junit.Test
import java.io.IOException
@ -205,9 +204,7 @@ class DisconnectReconnectTest : BaseTest() {
fun manualMediaDriverAndReconnectClient() {
// NOTE: once a config is assigned to a driver, the config cannot be changed
val aeronDriver = AeronDriver(serverConfig())
runBlocking {
aeronDriver.start()
}
aeronDriver.start()
run {
val serverConfiguration = serverConfig()

View File

@ -0,0 +1,77 @@
package dorkboxTest.network
import ch.qos.logback.classic.Level
import dorkbox.network.Client
import dorkbox.network.Server
import dorkbox.network.connection.Connection
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.junit.Assert
import org.junit.Test
import java.lang.Thread.sleep
class MultiClientTest : BaseTest() {
private val totalCount = 8 // this number is dependent on the number of CPU cores on the box!
private val clientConnectCount = atomic(0)
private val serverConnectCount = atomic(0)
private val disconnectCount = atomic(0)
@Test
fun multiConnectClient() {
setLogLevel(Level.TRACE)
// clients first, so they try to connect to the server at (roughly) the same time
val clients = mutableListOf<Client<Connection>>()
for (i in 1..totalCount) {
val client: Client<Connection> = Client(clientConfig(), "Client$i")
client.onConnect {
clientConnectCount.getAndIncrement()
logger.error("${this.id} - Connected $i!")
}
client.onDisconnect {
disconnectCount.getAndIncrement()
logger.error("${this.id} - Disconnected $i!")
}
addEndPoint(client)
clients += client
}
GlobalScope.launch {
clients.forEach {
// long connection timeout, since the more that try to connect at the same time, the longer it takes to setup aeron (since it's all shared)
launch { it.connect(LOCALHOST, 30*totalCount) }
}
}
runBlocking {
sleep(5000L)
println("Starting server...")
val configuration = serverConfig()
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.onConnect {
val count = serverConnectCount.incrementAndGet()
logger.error("${this.id} - Connecting $count ....")
close()
if (count == totalCount) {
logger.error { "Stopping endpoints!" }
stopEndPoints(10000L)
}
}
server.bind()
}
waitForThreads()
Assert.assertEquals(totalCount, clientConnectCount.value)
Assert.assertEquals(totalCount, serverConnectCount.value)
Assert.assertEquals(totalCount, disconnectCount.value)
}
}

View File

@ -1,85 +0,0 @@
package dorkboxTest.network
import dorkbox.network.Client
import dorkbox.network.Server
import dorkbox.network.connection.Connection
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.junit.Assert
import org.junit.Ignore
import org.junit.Test
import kotlin.time.Duration.Companion.seconds
class MultiConnectTest : BaseTest() {
private val reconnectCount = atomic(0)
@Ignore
@Test
fun multiConnectClient() {
// clients first, so they try to connect to the server at (roughly) the same time
val config = clientConfig()
val client1: Client<Connection> = Client(config)
val client2: Client<Connection> = Client(config)
addEndPoint(client1)
addEndPoint(client2)
client1.onDisconnect {
logger.error("Disconnected 1!")
}
client2.onDisconnect {
logger.error("Disconnected 2!")
}
runBlocking {
launch { client1.connect(LOCALHOST) }
launch { client2.connect(LOCALHOST) }
}
// GlobalScope.launch {
// client1.connect(LOCALHOST)
// }
//
// GlobalScope.launch {
// client2.connect(LOCALHOST)
// }
println("Starting server...")
run {
val configuration = serverConfig()
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.bind()
server.onConnect {
logger.error("Disconnecting after 10 seconds.")
delay(10.seconds)
logger.error("Disconnecting....")
close()
}
}
waitForThreads()
System.err.println("Connection count (after reconnecting) is: " + reconnectCount.value)
Assert.assertEquals(4, reconnectCount.value)
}
interface CloseIface {
suspend fun close()
}
class CloseImpl : CloseIface {
override suspend fun close() {
// the connection specific one is called instead
}
suspend fun close(connection: Connection) {
connection.close()
}
}
}