Cleaned up log messages, cleaned up handshakes, cleaned up polling code. Added unit test for concurrent client connections
parent
08ecaaf1c7
commit
e8f7c8d8d3
|
@ -18,7 +18,7 @@
|
|||
<logger name="ch.qos.logback" level="ERROR"/>
|
||||
|
||||
|
||||
<root level="INFO"> <!-- Release: ERROR -->
|
||||
<root level="TRACE"> <!-- Release: ERROR -->
|
||||
<appender-ref ref="STDOUT"/>
|
||||
</root>
|
||||
</configuration>
|
||||
|
|
|
@ -317,7 +317,7 @@ open class Client<CONNECTION : Connection>(
|
|||
require(connectionTimeoutSec >= 0) { "connectionTimeoutSec '$connectionTimeoutSec' is invalid. It must be >=0" }
|
||||
|
||||
if (isConnected) {
|
||||
logger.error("Unable to connect when already connected!")
|
||||
logger.error { "Unable to connect when already connected!" }
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -327,11 +327,9 @@ open class Client<CONNECTION : Connection>(
|
|||
|
||||
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
|
||||
try {
|
||||
runBlocking {
|
||||
initEndpointState()
|
||||
}
|
||||
initEndpointState()
|
||||
} catch (e: Exception) {
|
||||
logger.error("Unable to initialize the endpoint state", e)
|
||||
logger.error(e) { "Unable to initialize the endpoint state" }
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -349,7 +347,6 @@ open class Client<CONNECTION : Connection>(
|
|||
require(false) { "Cannot connect to ${IP.toString(remoteAddress)} It is an invalid address!" }
|
||||
}
|
||||
|
||||
|
||||
// IPC can be enabled TWO ways!
|
||||
// - config.enableIpc
|
||||
// - NULL remoteAddress
|
||||
|
@ -376,7 +373,7 @@ open class Client<CONNECTION : Connection>(
|
|||
buildUdpHandshake(connectionTimeoutSec = handshakeTimeout, reliable = reliable)
|
||||
}
|
||||
|
||||
logger.info(handshakeConnection.clientInfo())
|
||||
logger.info { handshakeConnection.clientInfo }
|
||||
|
||||
|
||||
connect0(handshake, handshakeConnection, handshakeTimeout)
|
||||
|
@ -389,13 +386,13 @@ open class Client<CONNECTION : Connection>(
|
|||
// short delay, since it failed we want to limit the retry rate to something slower than "as fast as the CPU can do it"
|
||||
delay(500)
|
||||
if (logger.isTraceEnabled) {
|
||||
logger.trace("Unable to connect, retrying", e)
|
||||
logger.trace(e) { "Unable to connect, retrying..." }
|
||||
} else {
|
||||
logger.error("Unable to connect, retrying ${e.message}")
|
||||
logger.info { "Unable to connect, retrying..." }
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
logger.error("Un-recoverable error during handshake. Aborting.", e)
|
||||
logger.error(e) { "Un-recoverable error during handshake. Aborting." }
|
||||
listenerManager.notifyError(e)
|
||||
throw e
|
||||
}
|
||||
|
@ -405,13 +402,9 @@ open class Client<CONNECTION : Connection>(
|
|||
|
||||
private suspend fun buildIpcHandshake(ipcSubscriptionId: Int, ipcPublicationId: Int, connectionTimeoutSec: Int, reliable: Boolean): MediaDriverConnection {
|
||||
if (remoteAddress == null) {
|
||||
logger.info {
|
||||
"IPC enabled."
|
||||
}
|
||||
logger.info { "IPC enabled." }
|
||||
} else {
|
||||
logger.info {
|
||||
"IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC"
|
||||
}
|
||||
logger.info { "IPC for loopback enabled and aeron is already running. Auto-changing network connection from ${IP.toString(remoteAddress!!)} -> IPC" }
|
||||
}
|
||||
|
||||
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via network instead
|
||||
|
@ -487,13 +480,13 @@ open class Client<CONNECTION : Connection>(
|
|||
val validateRemoteAddress = if (isUsingIPC) {
|
||||
PublicKeyValidationState.VALID
|
||||
} else {
|
||||
crypto.validateRemoteAddress(remoteAddress!!, connectionInfo.publicKey)
|
||||
crypto.validateRemoteAddress(remoteAddress!!, remoteAddressString, connectionInfo.publicKey)
|
||||
}
|
||||
|
||||
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
|
||||
handshakeConnection.close()
|
||||
val exception = ClientRejectedException("Connection to ${IP.toString(remoteAddress!!)} not allowed! Public key mismatch.")
|
||||
logger.error("Validation error", exception)
|
||||
logger.error(exception) { "Validation error" }
|
||||
throw exception
|
||||
}
|
||||
|
||||
|
@ -530,7 +523,7 @@ open class Client<CONNECTION : Connection>(
|
|||
// does not need to do anything
|
||||
//
|
||||
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
|
||||
logger.info(clientConnection.clientInfo())
|
||||
logger.info { clientConnection.clientInfo }
|
||||
|
||||
|
||||
///////////////
|
||||
|
@ -567,11 +560,11 @@ open class Client<CONNECTION : Connection>(
|
|||
handshakeConnection.close()
|
||||
val exception = ClientRejectedException("Connection to ${IP.toString(remoteAddress!!)} was not permitted!")
|
||||
ListenerManager.cleanStackTrace(exception)
|
||||
logger.error("Permission error", exception)
|
||||
logger.error(exception) { "Permission error" }
|
||||
throw exception
|
||||
}
|
||||
|
||||
logger.info("Adding new signature for ${IP.toString(remoteAddress!!)} : ${connectionInfo.publicKey.toHexString()}")
|
||||
logger.info { "Adding new signature for ${IP.toString(remoteAddress!!)} : ${connectionInfo.publicKey.toHexString()}" }
|
||||
storage.addRegisteredServerKey(remoteAddress!!, connectionInfo.publicKey)
|
||||
}
|
||||
|
||||
|
@ -584,7 +577,7 @@ open class Client<CONNECTION : Connection>(
|
|||
|
||||
// on the client, we want to GUARANTEE that the disconnect happens-before connect.
|
||||
if (!lockStepForConnect.compareAndSet(null, SuspendWaiter())) {
|
||||
logger.error("Connection ${newConnection.id}", "close lockStep for disconnect was in the wrong state!")
|
||||
logger.error { "Connection ${newConnection.id} : close lockStep for disconnect was in the wrong state!" }
|
||||
}
|
||||
}
|
||||
newConnection.postCloseAction = {
|
||||
|
@ -603,7 +596,7 @@ open class Client<CONNECTION : Connection>(
|
|||
connection0 = newConnection
|
||||
addConnection(newConnection)
|
||||
|
||||
logger.error { "Connection created, finishing handshake" }
|
||||
logger.error { "Connection created, finishing handshake: ${handshake.connectKey}" }
|
||||
|
||||
// tell the server our connection handshake is done, and the connection can now listen for data.
|
||||
// also closes the handshake (will also throw connect timeout exception)
|
||||
|
@ -614,7 +607,7 @@ open class Client<CONNECTION : Connection>(
|
|||
canFinishConnecting = try {
|
||||
handshake.done(handshakeConnection, successAttemptTimeout)
|
||||
} catch (e: ClientException) {
|
||||
logger.error("Error during handshake", e)
|
||||
logger.error(e) { "Error during handshake" }
|
||||
false
|
||||
}
|
||||
}
|
||||
|
@ -638,7 +631,7 @@ open class Client<CONNECTION : Connection>(
|
|||
while (!isShutdown()) {
|
||||
if (newConnection.isClosedViaAeron()) {
|
||||
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
|
||||
logger.debug {"[${newConnection.id}] connection expired"}
|
||||
logger.debug { "[${newConnection.id}] connection expired" }
|
||||
|
||||
// event-loop is required, because we want to run this code AFTER the current coroutine has finished. This prevents
|
||||
// odd race conditions when a client is restarted. Can only be run from inside another co-routine!
|
||||
|
@ -670,7 +663,7 @@ open class Client<CONNECTION : Connection>(
|
|||
} else {
|
||||
close()
|
||||
|
||||
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
|
||||
val exception = ClientRejectedException("Unable to connect with server: ${handshakeConnection.clientInfo}")
|
||||
ListenerManager.cleanStackTrace(exception)
|
||||
throw exception
|
||||
}
|
||||
|
@ -692,7 +685,13 @@ open class Client<CONNECTION : Connection>(
|
|||
* the remote address, as a string.
|
||||
*/
|
||||
val remoteAddressString: String
|
||||
get() = remoteAddress0?.hostAddress ?: "ipc"
|
||||
get() {
|
||||
return when (val address = remoteAddress) {
|
||||
is Inet4Address -> IPv4.toString(address)
|
||||
is Inet6Address -> IPv6.toString(address, true)
|
||||
else -> "ipc"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* true if this connection is an IPC connection
|
||||
|
@ -731,7 +730,7 @@ open class Client<CONNECTION : Connection>(
|
|||
c.send(message)
|
||||
} else {
|
||||
val exception = ClientException("Cannot send a message when there is no connection!")
|
||||
logger.error("No connection!", exception)
|
||||
logger.error(exception) { "No connection!" }
|
||||
false
|
||||
}
|
||||
}
|
||||
|
@ -760,7 +759,7 @@ open class Client<CONNECTION : Connection>(
|
|||
if (c != null) {
|
||||
return pingManager.ping(c, pingTimeoutSeconds, actionDispatch, responseManager, logger, function)
|
||||
} else {
|
||||
logger.error("No connection!", ClientException("Cannot send a ping when there is no connection!"))
|
||||
logger.error(ClientException("Cannot send a ping when there is no connection!")) { "No connection!" }
|
||||
}
|
||||
|
||||
return false
|
||||
|
|
|
@ -15,15 +15,12 @@
|
|||
*/
|
||||
package dorkbox.network
|
||||
|
||||
import dorkbox.netUtil.IP
|
||||
import dorkbox.netUtil.IPv4
|
||||
import dorkbox.netUtil.IPv6
|
||||
import dorkbox.netUtil.Inet4
|
||||
import dorkbox.netUtil.Inet6
|
||||
import dorkbox.network.aeron.AeronDriver
|
||||
import dorkbox.network.aeron.AeronPoller
|
||||
import dorkbox.network.aeron.IpcMediaDriverConnection
|
||||
import dorkbox.network.aeron.UdpMediaDriverServerConnection
|
||||
import dorkbox.network.connection.Connection
|
||||
import dorkbox.network.connection.ConnectionParams
|
||||
import dorkbox.network.connection.EndPoint
|
||||
|
@ -31,16 +28,13 @@ import dorkbox.network.connection.eventLoop
|
|||
import dorkbox.network.connectionType.ConnectionRule
|
||||
import dorkbox.network.coroutines.SuspendWaiter
|
||||
import dorkbox.network.exceptions.ServerException
|
||||
import dorkbox.network.handshake.HandshakeMessage
|
||||
import dorkbox.network.handshake.ServerHandshake
|
||||
import dorkbox.network.handshake.ServerHandshakePollers
|
||||
import dorkbox.network.rmi.RmiSupportServer
|
||||
import io.aeron.FragmentAssembler
|
||||
import io.aeron.Image
|
||||
import io.aeron.logbuffer.Header
|
||||
import kotlinx.coroutines.CoroutineStart
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.agrona.DirectBuffer
|
||||
import java.net.InetAddress
|
||||
import java.util.concurrent.*
|
||||
|
||||
|
@ -109,20 +103,6 @@ open class Server<CONNECTION : Connection>(
|
|||
Server::class.java.simpleName)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Gets the version number.
|
||||
|
@ -165,7 +145,7 @@ open class Server<CONNECTION : Connection>(
|
|||
/**
|
||||
* Used for handshake connections
|
||||
*/
|
||||
private val handshake = ServerHandshake(logger, config, listenerManager)
|
||||
internal val handshake = ServerHandshake(logger, config, listenerManager)
|
||||
|
||||
/**
|
||||
* Maintains a thread-safe collection of rules used to define the connection type with this server.
|
||||
|
@ -175,8 +155,8 @@ open class Server<CONNECTION : Connection>(
|
|||
/**
|
||||
* true if the following network stacks are available for use
|
||||
*/
|
||||
private val canUseIPv4 = config.enableIPv4 && IPv4.isAvailable
|
||||
private val canUseIPv6 = config.enableIPv6 && IPv6.isAvailable
|
||||
internal val canUseIPv4 = config.enableIPv4 && IPv4.isAvailable
|
||||
internal val canUseIPv6 = config.enableIPv6 && IPv6.isAvailable
|
||||
|
||||
|
||||
// localhost/loopback IP might not always be 127.0.0.1 or ::1
|
||||
|
@ -221,337 +201,20 @@ open class Server<CONNECTION : Connection>(
|
|||
return ServerException(message, cause)
|
||||
}
|
||||
|
||||
private fun getIpcPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
|
||||
val poller = if (config.enableIpc) {
|
||||
val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId,
|
||||
streamId = config.ipcPublicationId,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID)
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
object : AeronPoller {
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
val message = readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error("[$sessionId] Connection from IPC not allowed! Invalid connection request")
|
||||
|
||||
try {
|
||||
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processIpcHandshakeMessageServer(this@Server,
|
||||
rmiConnectionSupport,
|
||||
publication,
|
||||
sessionId,
|
||||
message,
|
||||
aeronDriver,
|
||||
connectionFunc,
|
||||
logger)
|
||||
}
|
||||
|
||||
override fun poll(): Int { return subscription.poll(handler, 1) }
|
||||
override fun close() { driver.close() }
|
||||
override fun serverInfo(): String { return driver.serverInfo() }
|
||||
}
|
||||
} else {
|
||||
object : AeronPoller {
|
||||
override fun poll(): Int { return 0 }
|
||||
override fun close() {}
|
||||
override fun serverInfo(): String { return "IPC Disabled" }
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(poller.serverInfo())
|
||||
return poller
|
||||
}
|
||||
|
||||
@Suppress("DuplicatedCode")
|
||||
private fun getIpv4Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
|
||||
val poller = if (canUseIPv4) {
|
||||
val driver = UdpMediaDriverServerConnection(
|
||||
listenAddress = listenIPv4Address!!,
|
||||
publicationPort = config.publicationPort,
|
||||
subscriptionPort = config.subscriptionPort,
|
||||
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
|
||||
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds)
|
||||
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
object : AeronPoller {
|
||||
/**
|
||||
* Note:
|
||||
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
|
||||
* desired, then limiting message sizes to MTU size is a good practice.
|
||||
*
|
||||
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
|
||||
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
|
||||
* properties from failure and streams with mechanical sympathy.
|
||||
*/
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
|
||||
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
|
||||
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
// val port = remoteIpAndPort.substring(splitPoint+1)
|
||||
|
||||
// this should never be null, because we are feeding it a valid IP address from aeron
|
||||
val clientAddress = IPv4.toAddressUnsafe(clientAddressString)
|
||||
|
||||
|
||||
val message = readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")
|
||||
|
||||
try {
|
||||
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processUdpHandshakeMessageServer(this@Server,
|
||||
rmiConnectionSupport,
|
||||
publication,
|
||||
sessionId,
|
||||
clientAddressString,
|
||||
clientAddress,
|
||||
message,
|
||||
aeronDriver,
|
||||
false,
|
||||
connectionFunc,
|
||||
logger)
|
||||
}
|
||||
|
||||
override fun poll(): Int { return subscription.poll(handler, 1) }
|
||||
override fun close() { driver.close() }
|
||||
override fun serverInfo(): String { return driver.serverInfo() }
|
||||
}
|
||||
} else {
|
||||
object : AeronPoller {
|
||||
override fun poll(): Int { return 0 }
|
||||
override fun close() {}
|
||||
override fun serverInfo(): String { return "IPv4 Disabled" }
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(poller.serverInfo())
|
||||
return poller
|
||||
}
|
||||
|
||||
@Suppress("DuplicatedCode")
|
||||
private fun getIpv6Poller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
|
||||
val poller = if (canUseIPv6) {
|
||||
val driver = UdpMediaDriverServerConnection(
|
||||
listenAddress = listenIPv6Address!!,
|
||||
publicationPort = config.publicationPort,
|
||||
subscriptionPort = config.subscriptionPort,
|
||||
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
|
||||
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds)
|
||||
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
object : AeronPoller {
|
||||
/**
|
||||
* Note:
|
||||
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
|
||||
* desired, then limiting message sizes to MTU size is a good practice.
|
||||
*
|
||||
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
|
||||
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
|
||||
* properties from failure and streams with mechanical sympathy.
|
||||
*/
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
|
||||
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
|
||||
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
// val port = remoteIpAndPort.substring(splitPoint+1)
|
||||
|
||||
// this should never be null, because we are feeding it a valid IP address from aeron
|
||||
val clientAddress = IPv6.toAddress(clientAddressString)!!
|
||||
|
||||
|
||||
val message = readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")
|
||||
|
||||
try {
|
||||
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processUdpHandshakeMessageServer(this@Server,
|
||||
rmiConnectionSupport,
|
||||
publication,
|
||||
sessionId,
|
||||
clientAddressString,
|
||||
clientAddress,
|
||||
message,
|
||||
aeronDriver,
|
||||
false,
|
||||
connectionFunc,
|
||||
logger)
|
||||
}
|
||||
|
||||
override fun poll(): Int { return subscription.poll(handler, 1) }
|
||||
override fun close() { driver.close() }
|
||||
override fun serverInfo(): String { return driver.serverInfo() }
|
||||
}
|
||||
} else {
|
||||
object : AeronPoller {
|
||||
override fun poll(): Int { return 0 }
|
||||
override fun close() {}
|
||||
override fun serverInfo(): String { return "IPv6 Disabled" }
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(poller.serverInfo())
|
||||
return poller
|
||||
}
|
||||
|
||||
@Suppress("DuplicatedCode")
|
||||
private fun getIpv6WildcardPoller(aeronDriver: AeronDriver, config: ServerConfiguration): AeronPoller {
|
||||
val driver = UdpMediaDriverServerConnection(
|
||||
listenAddress = listenIPv6Address!!,
|
||||
publicationPort = config.publicationPort,
|
||||
subscriptionPort = config.subscriptionPort,
|
||||
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
|
||||
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds)
|
||||
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
val poller = object : AeronPoller {
|
||||
/**
|
||||
* Note:
|
||||
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
|
||||
* desired, then limiting message sizes to MTU size is a good practice.
|
||||
*
|
||||
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
|
||||
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
|
||||
* properties from failure and streams with mechanical sympathy.
|
||||
*/
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
|
||||
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
|
||||
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
// val port = remoteIpAndPort.substring(splitPoint+1)
|
||||
|
||||
// this should never be null, because we are feeding it a valid IP address from aeron
|
||||
// maybe IPv4, maybe IPv6! This is slower than if we ALREADY know what it is.
|
||||
val clientAddress = IP.toAddress(clientAddressString)!!
|
||||
|
||||
|
||||
val message = readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error("[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request")
|
||||
|
||||
try {
|
||||
writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processUdpHandshakeMessageServer(this@Server,
|
||||
rmiConnectionSupport,
|
||||
publication,
|
||||
sessionId,
|
||||
clientAddressString,
|
||||
clientAddress,
|
||||
message,
|
||||
aeronDriver,
|
||||
true,
|
||||
connectionFunc,
|
||||
logger)
|
||||
}
|
||||
|
||||
override fun poll(): Int { return subscription.poll(handler, 1) }
|
||||
override fun close() { driver.close() }
|
||||
override fun serverInfo(): String { return driver.serverInfo() }
|
||||
}
|
||||
|
||||
logger.info(poller.serverInfo())
|
||||
return poller
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds the server to AERON configuration
|
||||
*/
|
||||
@Suppress("DuplicatedCode")
|
||||
fun bind() {
|
||||
if (bindAlreadyCalled) {
|
||||
logger.error("Unable to bind when the server is already running!")
|
||||
logger.error { "Unable to bind when the server is already running!" }
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
runBlocking {
|
||||
initEndpointState()
|
||||
}
|
||||
initEndpointState()
|
||||
} catch (e: Exception) {
|
||||
logger.error("Unable to initialize the endpoint state", e)
|
||||
logger.error(e) { "Unable to initialize the endpoint state" }
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -563,33 +226,30 @@ open class Server<CONNECTION : Connection>(
|
|||
// this forces the current thread to WAIT until poll system has started
|
||||
val waiter = SuspendWaiter()
|
||||
|
||||
val ipcPoller: AeronPoller = getIpcPoller(aeronDriver, config)
|
||||
val ipcPoller: AeronPoller = ServerHandshakePollers.IPC(aeronDriver, config, this)
|
||||
|
||||
// if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled!
|
||||
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
|
||||
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address == IPv6.WILDCARD
|
||||
val ipv4Poller: AeronPoller
|
||||
val ipv6Poller: AeronPoller
|
||||
|
||||
if (isWildcard) {
|
||||
// IPv6 will bind to IPv4 wildcard as well!!
|
||||
if (canUseIPv4 && canUseIPv6) {
|
||||
ipv4Poller = object : AeronPoller {
|
||||
override fun poll(): Int { return 0 }
|
||||
override fun close() {}
|
||||
override fun serverInfo(): String { return "IPv4 Disabled" }
|
||||
}
|
||||
ipv6Poller = getIpv6WildcardPoller(aeronDriver, config)
|
||||
// IPv6 will bind to IPv4 wildcard as well, so don't bind both!
|
||||
ipv4Poller = ServerHandshakePollers.disabled("IPv4 Disabled")
|
||||
ipv6Poller = ServerHandshakePollers.ip6Wildcard(aeronDriver, config, this)
|
||||
} else {
|
||||
// only 1 will be a real poller
|
||||
ipv4Poller = getIpv4Poller(aeronDriver, config)
|
||||
ipv6Poller = getIpv6Poller(aeronDriver, config)
|
||||
ipv4Poller = ServerHandshakePollers.ip4(aeronDriver, config, this)
|
||||
ipv6Poller = ServerHandshakePollers.ip6(aeronDriver, config, this)
|
||||
}
|
||||
} else {
|
||||
ipv4Poller = getIpv4Poller(aeronDriver, config)
|
||||
ipv6Poller = getIpv6Poller(aeronDriver, config)
|
||||
ipv4Poller = ServerHandshakePollers.ip4(aeronDriver, config, this)
|
||||
ipv6Poller = ServerHandshakePollers.ip6(aeronDriver, config, this)
|
||||
}
|
||||
|
||||
actionDispatch.launch {
|
||||
|
||||
actionDispatch.launch(start = CoroutineStart.ATOMIC) {
|
||||
waiter.doNotify()
|
||||
|
||||
val pollIdleStrategy = config.pollIdleStrategy
|
||||
|
@ -609,7 +269,6 @@ open class Server<CONNECTION : Connection>(
|
|||
// this checks to see if there are NEW clients via IPC
|
||||
pollCount += ipcPoller.poll()
|
||||
|
||||
|
||||
// this manages existing clients (for cleanup + connection polling). This has a concurrent iterator,
|
||||
// so we can modify this as we go
|
||||
connections.forEach { connection ->
|
||||
|
@ -656,7 +315,7 @@ open class Server<CONNECTION : Connection>(
|
|||
connections.clear()
|
||||
|
||||
cons.forEach { connection ->
|
||||
logger.error("${connection.id} cleanup")
|
||||
logger.error { "${connection.id} cleanup and close" }
|
||||
// have to free up resources!
|
||||
// NOTE: This can only occur on the polling dispatch thread!!
|
||||
handshake.cleanup(connection)
|
||||
|
@ -680,7 +339,7 @@ open class Server<CONNECTION : Connection>(
|
|||
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
|
||||
jobs.forEach { it.join() }
|
||||
} catch (e: Exception) {
|
||||
logger.error("Unexpected error during server message polling!", e)
|
||||
logger.error(e) { "Unexpected error during server message polling!" }
|
||||
} finally {
|
||||
ipv4Poller.close()
|
||||
ipv6Poller.close()
|
||||
|
@ -742,7 +401,6 @@ open class Server<CONNECTION : Connection>(
|
|||
shutdownEventWaiter.doWait()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -11,8 +11,6 @@ import io.aeron.Subscription
|
|||
import io.aeron.driver.MediaDriver
|
||||
import io.aeron.exceptions.DriverTimeoutException
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import mu.KLogger
|
||||
import mu.KotlinLogging
|
||||
import org.agrona.concurrent.BackoffIdleStrategy
|
||||
|
@ -20,6 +18,8 @@ import org.slf4j.LoggerFactory
|
|||
import java.io.File
|
||||
import java.lang.Thread.sleep
|
||||
import java.net.BindException
|
||||
import java.util.concurrent.locks.*
|
||||
import kotlin.concurrent.write
|
||||
|
||||
/**
|
||||
* Class for managing the Aeron+Media drivers
|
||||
|
@ -57,7 +57,7 @@ class AeronDriver(
|
|||
private const val AERON_PUBLICATION_LINGER_TIMEOUT = 5_000L // in MS
|
||||
|
||||
// prevents multiple instances, within the same JVM, from starting at the exact same time.
|
||||
private val startMutex = Mutex()
|
||||
private val lock = ReentrantReadWriteLock()
|
||||
|
||||
private fun setConfigDefaults(config: Configuration, logger: KLogger) {
|
||||
// explicitly don't set defaults if we already have the context defined!
|
||||
|
@ -322,8 +322,9 @@ class AeronDriver(
|
|||
*
|
||||
* @throws Exception if there is a problem starting the media driver
|
||||
*/
|
||||
suspend fun start() {
|
||||
startMutex.withLock {
|
||||
fun start() {
|
||||
// Note: A mutex doesn't work so well.
|
||||
lock.write {
|
||||
if (closeRequested.value) {
|
||||
logger.debug("Resetting media driver context")
|
||||
|
||||
|
|
|
@ -3,6 +3,6 @@ package dorkbox.network.aeron
|
|||
internal interface AeronPoller {
|
||||
fun poll(): Int
|
||||
fun close()
|
||||
fun serverInfo(): String
|
||||
}
|
||||
|
||||
val serverInfo: String
|
||||
}
|
||||
|
|
|
@ -33,8 +33,6 @@ internal open class IpcMediaDriverConnection(streamId: Int,
|
|||
) :
|
||||
MediaDriverConnection(0, 0, streamId, sessionId, 10, true) {
|
||||
|
||||
var success: Boolean = false
|
||||
|
||||
private fun uri(): ChannelUriStringBuilder {
|
||||
val builder = ChannelUriStringBuilder().media("ipc")
|
||||
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
|
@ -87,6 +85,7 @@ internal open class IpcMediaDriverConnection(streamId: Int,
|
|||
|
||||
if (!success) {
|
||||
subscription.close()
|
||||
|
||||
val clientTimedOutException = ClientTimedOutException("Creating subscription connection to aeron")
|
||||
ListenerManager.cleanStackTraceInternal(clientTimedOutException)
|
||||
throw clientTimedOutException
|
||||
|
@ -116,8 +115,8 @@ internal open class IpcMediaDriverConnection(streamId: Int,
|
|||
}
|
||||
|
||||
this.success = true
|
||||
this.publication = publication
|
||||
this.subscription = subscription
|
||||
this.publication = publication
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -146,34 +145,28 @@ internal open class IpcMediaDriverConnection(streamId: Int,
|
|||
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
|
||||
|
||||
// If we start/stop too quickly, we might have the aeron connectivity issues! Retry a few times.
|
||||
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
|
||||
success = true
|
||||
subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamIdSubscription)
|
||||
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
|
||||
}
|
||||
|
||||
override fun clientInfo() : String {
|
||||
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
override val clientInfo : String by lazy {
|
||||
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
"[$sessionId] IPC connection established to [$streamIdSubscription|$streamId]"
|
||||
} else {
|
||||
"Connecting handshake to IPC [$streamIdSubscription|$streamId]"
|
||||
}
|
||||
}
|
||||
|
||||
override fun serverInfo() : String {
|
||||
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
"[$sessionId] IPC listening on [$streamIdSubscription|$streamId] "
|
||||
override val serverInfo : String by lazy {
|
||||
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
"[$sessionId] IPC listening on [$streamIdSubscription|$streamId] [$sessionId]"
|
||||
} else {
|
||||
"Listening handshake on IPC [$streamIdSubscription|$streamId]"
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
if (success) {
|
||||
subscription.close()
|
||||
publication.close()
|
||||
"Listening handshake on IPC [$streamIdSubscription|$streamId] [$sessionId]"
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "[$streamIdSubscription|$streamId] [$sessionId]"
|
||||
return serverInfo
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ abstract class MediaDriverConnection(
|
|||
val streamId: Int, val sessionId: Int,
|
||||
val connectionTimeoutSec: Int, val isReliable: Boolean) : AutoCloseable {
|
||||
|
||||
var success: Boolean = false
|
||||
lateinit var subscription: Subscription
|
||||
lateinit var publication: Publication
|
||||
|
||||
|
@ -33,6 +34,13 @@ abstract class MediaDriverConnection(
|
|||
abstract suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger)
|
||||
abstract fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean = false)
|
||||
|
||||
abstract fun clientInfo() : String
|
||||
abstract fun serverInfo() : String
|
||||
abstract val clientInfo : String
|
||||
abstract val serverInfo : String
|
||||
|
||||
override fun close() {
|
||||
if (success) {
|
||||
publication.close()
|
||||
subscription.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,22 +40,6 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
|
|||
isReliable: Boolean = true) :
|
||||
UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutSec, isReliable) {
|
||||
|
||||
var success: Boolean = false
|
||||
|
||||
private fun aeronConnectionString(ipAddress: InetAddress): String {
|
||||
return if (ipAddress is Inet4Address) {
|
||||
ipAddress.hostAddress
|
||||
} else {
|
||||
// IPv6 requires the address to be bracketed by [...]
|
||||
val host = ipAddress.hostAddress
|
||||
if (host[0] == '[') {
|
||||
host
|
||||
} else {
|
||||
"[${ipAddress.hostAddress}]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val addressString: String by lazy {
|
||||
IP.toString(address)
|
||||
}
|
||||
|
@ -80,7 +64,18 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
|
|||
|
||||
@Suppress("DuplicatedCode")
|
||||
override suspend fun buildClient(aeronDriver: AeronDriver, logger: KLogger) {
|
||||
val aeronAddressString = aeronConnectionString(address)
|
||||
val aeronAddressString = if (address is Inet4Address) {
|
||||
address.hostAddress
|
||||
} else {
|
||||
// IPv6 requires the address to be bracketed by [...]
|
||||
val host = address.hostAddress
|
||||
if (host[0] == '[') {
|
||||
host
|
||||
} else {
|
||||
"[${address.hostAddress}]"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Create a publication at the given address and port, using the given stream ID.
|
||||
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
|
||||
|
@ -123,6 +118,7 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
|
|||
|
||||
if (!success) {
|
||||
subscription.close()
|
||||
|
||||
val ex = ClientTimedOutException("Cannot create subscription: $ip ${subscriptionUri.build()} in ${timoutInNanos}ms")
|
||||
ListenerManager.cleanStackTraceInternal(ex)
|
||||
throw ex
|
||||
|
@ -145,18 +141,19 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
|
|||
if (!success) {
|
||||
subscription.close()
|
||||
publication.close()
|
||||
|
||||
val ex = ClientTimedOutException("Cannot create publication: $ip ${publicationUri.build()} in ${timoutInNanos}ms")
|
||||
ListenerManager.cleanStackTrace(ex)
|
||||
throw ex
|
||||
}
|
||||
|
||||
this.success = true
|
||||
this.publication = publication
|
||||
this.subscription = subscription
|
||||
this.publication = publication
|
||||
}
|
||||
|
||||
override fun clientInfo(): String {
|
||||
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
override val clientInfo: String by lazy {
|
||||
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
"Connecting to $addressString [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
|
||||
} else {
|
||||
"Connecting handshake to $addressString [$subscriptionPort|$publicationPort] [$streamId|*] (reliable:$isReliable)"
|
||||
|
@ -166,18 +163,12 @@ internal class UdpMediaDriverClientConnection(val address: InetAddress,
|
|||
override fun buildServer(aeronDriver: AeronDriver, logger: KLogger, pairConnection: Boolean) {
|
||||
throw ClientException("Server info not implemented in Client MDC")
|
||||
}
|
||||
override fun serverInfo(): String {
|
||||
override val serverInfo: String
|
||||
get() {
|
||||
throw ClientException("Server info not implemented in Client MDC")
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
if (success) {
|
||||
subscription.close()
|
||||
publication.close()
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "$addressString [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
|
||||
return clientInfo
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,8 +38,6 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
|
|||
isReliable: Boolean = true) :
|
||||
UdpMediaDriverConnection(publicationPort, subscriptionPort, streamId, sessionId, connectionTimeoutSec, isReliable) {
|
||||
|
||||
var success: Boolean = false
|
||||
|
||||
private fun aeronConnectionString(ipAddress: InetAddress): String {
|
||||
return if (ipAddress is Inet4Address) {
|
||||
ipAddress.hostAddress
|
||||
|
@ -54,7 +52,7 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
|
|||
}
|
||||
}
|
||||
|
||||
protected fun uri(): ChannelUriStringBuilder {
|
||||
private fun uri(): ChannelUriStringBuilder {
|
||||
val builder = ChannelUriStringBuilder().reliable(isReliable).media("udp")
|
||||
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
builder.sessionId(sessionId)
|
||||
|
@ -99,15 +97,17 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
|
|||
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
|
||||
|
||||
// If we start/stop too quickly, we might have the address already in use! Retry a few times.
|
||||
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
|
||||
success = true
|
||||
subscription = aeronDriver.addSubscriptionWithRetry(subscriptionUri, streamId)
|
||||
publication = aeronDriver.addPublicationWithRetry(publicationUri, streamId)
|
||||
}
|
||||
|
||||
override fun clientInfo(): String {
|
||||
throw ServerException("Client info not implemented in Server MDC")
|
||||
}
|
||||
override val clientInfo: String
|
||||
get() {
|
||||
throw ServerException("Client info not implemented in Server MDC")
|
||||
}
|
||||
|
||||
override fun serverInfo(): String {
|
||||
override val serverInfo: String by lazy {
|
||||
val address = if (listenAddress == IPv4.WILDCARD || listenAddress == IPv6.WILDCARD) {
|
||||
if (listenAddress == IPv4.WILDCARD) {
|
||||
listenAddress.hostAddress
|
||||
|
@ -118,21 +118,14 @@ internal open class UdpMediaDriverServerConnection(val listenAddress: InetAddres
|
|||
IP.toString(listenAddress)
|
||||
}
|
||||
|
||||
return if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
if (sessionId != AeronDriver.RESERVED_SESSION_ID_INVALID) {
|
||||
"Listening on $address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
|
||||
} else {
|
||||
"Listening handshake on $address [$subscriptionPort|$publicationPort] [$streamId|*] (reliable:$isReliable)"
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
if (success) {
|
||||
subscription.close()
|
||||
publication.close()
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "$IP.toString(listenAddress) [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
|
||||
return serverInfo
|
||||
}
|
||||
}
|
||||
|
|
|
@ -346,7 +346,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
|
|||
|
||||
// the server 'handshake' connection info is cleaned up with the disconnect via timeout/expire.
|
||||
if (isClosed.compareAndSet(expect = false, update = true)) {
|
||||
logger.info {"[$id] connection closed"}
|
||||
logger.debug {"[$id] connection closing"}
|
||||
|
||||
subscription.close()
|
||||
|
||||
|
@ -393,6 +393,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
|
|||
// This is set by the client/server so if there is a "connect()" call in the the disconnect callback, we can have proper
|
||||
// lock-stop ordering for how disconnect and connect work with each-other
|
||||
postCloseAction()
|
||||
logger.debug {"[$id] connection closed"}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ package dorkbox.network.connection
|
|||
|
||||
import dorkbox.bytes.Hash
|
||||
import dorkbox.bytes.toHexString
|
||||
import dorkbox.netUtil.IP
|
||||
import dorkbox.network.handshake.ClientConnectionInfo
|
||||
import dorkbox.network.serialization.AeronInput
|
||||
import dorkbox.network.serialization.AeronOutput
|
||||
|
@ -137,9 +136,9 @@ internal class CryptoManagement(val logger: KLogger,
|
|||
* @return true if all is OK (the remote address public key matches the one saved or we disabled remote key validation.)
|
||||
* false if we should abort
|
||||
*/
|
||||
internal fun validateRemoteAddress(remoteAddress: InetAddress, publicKey: ByteArray?): PublicKeyValidationState {
|
||||
internal fun validateRemoteAddress(remoteAddress: InetAddress, remoteAddressString: String, publicKey: ByteArray?): PublicKeyValidationState {
|
||||
if (publicKey == null) {
|
||||
logger.error("Error validating public key for ${IP.toString(remoteAddress)}! It was null (and should not have been)")
|
||||
logger.error("Error validating public key for ${remoteAddressString}! It was null (and should not have been)")
|
||||
return PublicKeyValidationState.INVALID
|
||||
}
|
||||
|
||||
|
@ -150,18 +149,18 @@ internal class CryptoManagement(val logger: KLogger,
|
|||
if (!publicKey.contentEquals(savedPublicKey)) {
|
||||
return if (enableRemoteSignatureValidation) {
|
||||
// keys do not match, abort!
|
||||
logger.error("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Denying connection attempt")
|
||||
logger.error("The public key for remote connection $remoteAddressString does not match. Denying connection attempt")
|
||||
PublicKeyValidationState.INVALID
|
||||
}
|
||||
else {
|
||||
logger.warn("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Permitting connection attempt.")
|
||||
logger.warn("The public key for remote connection $remoteAddressString does not match. Permitting connection attempt.")
|
||||
PublicKeyValidationState.TAMPERED
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: SecurityException) {
|
||||
// keys do not match, abort!
|
||||
logger.error("Error validating public key for ${IP.toString(remoteAddress)}!", e)
|
||||
logger.error("Error validating public key for $remoteAddressString!", e)
|
||||
return PublicKeyValidationState.INVALID
|
||||
}
|
||||
|
||||
|
|
|
@ -215,7 +215,7 @@ internal constructor(val type: Class<*>,
|
|||
/**
|
||||
* @throws Exception if there is a problem starting the media driver
|
||||
*/
|
||||
internal suspend fun initEndpointState() {
|
||||
internal fun initEndpointState() {
|
||||
shutdown.getAndSet(false)
|
||||
shutdownWaiter = SuspendWaiter()
|
||||
|
||||
|
@ -363,9 +363,7 @@ internal constructor(val type: Class<*>,
|
|||
@Suppress("DuplicatedCode")
|
||||
internal fun writeHandshakeMessage(publication: Publication, message: HandshakeMessage) {
|
||||
// The handshake sessionId IS NOT globally unique
|
||||
logger.trace {
|
||||
"[${publication.sessionId()}] send HS: $message"
|
||||
}
|
||||
logger.trace { "[${message.connectKey}] send HS: $message" }
|
||||
|
||||
try {
|
||||
// we are not thread-safe!
|
||||
|
@ -431,11 +429,9 @@ internal constructor(val type: Class<*>,
|
|||
internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
|
||||
return try {
|
||||
// NOTE: This ABSOLUTELY MUST be done on the same thread! This cannot be done on a new one, because the buffer could change!
|
||||
val message = handshakeKryo.read(buffer, offset, length)
|
||||
val message = handshakeKryo.read(buffer, offset, length) as HandshakeMessage
|
||||
|
||||
logger.trace {
|
||||
"[${header.sessionId()}] received HS: $message"
|
||||
}
|
||||
logger.trace { "[${message.connectKey}] received HS: $message" }
|
||||
|
||||
message
|
||||
} catch (e: Exception) {
|
||||
|
@ -723,6 +719,7 @@ internal constructor(val type: Class<*>,
|
|||
|
||||
runBlocking {
|
||||
connections.forEach {
|
||||
logger.info { "Closing connection: ${it.id}" }
|
||||
it.close()
|
||||
}
|
||||
|
||||
|
@ -739,6 +736,8 @@ internal constructor(val type: Class<*>,
|
|||
|
||||
// if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now)
|
||||
shutdownWaiter.cancel()
|
||||
|
||||
logger.info { "Done shutting down..." }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,9 +42,9 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
|
||||
private val handler: FragmentHandler
|
||||
|
||||
// a one-time key for connecting
|
||||
// used to keep track and associate UDP/IPC handshakes between client/server
|
||||
@Volatile
|
||||
var oneTimeKey = 0
|
||||
var connectKey = 0L
|
||||
|
||||
@Volatile
|
||||
private var connectionHelloInfo: ClientConnectionInfo? = null
|
||||
|
@ -64,21 +64,20 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
failedException = null
|
||||
needToRetry = false
|
||||
|
||||
// it must be a registration message
|
||||
if (message !is HandshakeMessage) {
|
||||
failedException = ClientRejectedException("[$sessionId] cancelled handshake for unrecognized message: $message")
|
||||
failedException = ClientRejectedException("[${header.sessionId()}] cancelled handshake for unrecognized message: $message")
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
// this is an error message
|
||||
if (message.state == HandshakeMessage.INVALID) {
|
||||
val cause = ServerException(message.errorMessage ?: "Unknown").apply { stackTrace = stackTrace.copyOfRange(0, 1) }
|
||||
failedException = ClientRejectedException("[$sessionId] cancelled handshake", cause)
|
||||
failedException = ClientRejectedException("[${message.connectKey}] cancelled handshake", cause)
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
|
@ -89,8 +88,9 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
if (oneTimeKey != message.oneTimeKey) {
|
||||
logger.error("[$message] ignored message (one-time key: ${message.oneTimeKey}) intended for another client (mine is: ${oneTimeKey})")
|
||||
if (connectKey != message.connectKey) {
|
||||
logger.error("Ignored handshake (client connect key: ${message.connectKey}) intended for another client (mine is:" +
|
||||
" ${connectKey})")
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
|
@ -106,7 +106,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
if (registrationData != null && serverPublicKeyBytes != null) {
|
||||
connectionHelloInfo = crypto.decrypt(registrationData, serverPublicKeyBytes)
|
||||
} else {
|
||||
failedException = ClientRejectedException("[$message.sessionId] canceled handshake for message without registration and/or public key info")
|
||||
failedException = ClientRejectedException("[${message.connectKey}] canceled handshake for message without registration and/or public key info")
|
||||
}
|
||||
}
|
||||
HandshakeMessage.HELLO_ACK_IPC -> {
|
||||
|
@ -129,7 +129,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
publicationPort = streamPubId,
|
||||
kryoRegistrationDetails = regDetails)
|
||||
} else {
|
||||
failedException = ClientRejectedException("[$message.sessionId] canceled handshake for message without registration data")
|
||||
failedException = ClientRejectedException("[${message.connectKey}] canceled handshake for message without registration data")
|
||||
}
|
||||
}
|
||||
HandshakeMessage.DONE_ACK -> {
|
||||
|
@ -137,16 +137,28 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
}
|
||||
else -> {
|
||||
val stateString = HandshakeMessage.toStateString(message.state)
|
||||
failedException = ClientRejectedException("[$sessionId] cancelled handshake for message that is $stateString")
|
||||
failedException = ClientRejectedException("[${message.connectKey}] cancelled handshake for message that is $stateString")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure that NON-ZERO is returned
|
||||
*/
|
||||
private fun getSafeConnectKey(): Long {
|
||||
var key = endPoint.crypto.secureRandom.nextLong()
|
||||
while (key == 0L) {
|
||||
key = endPoint.crypto.secureRandom.nextLong()
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// called from the connect thread
|
||||
fun hello(handshakeConnection: MediaDriverConnection, connectionTimeoutSec: Int) : ClientConnectionInfo {
|
||||
failedException = null
|
||||
oneTimeKey = endPoint.crypto.secureRandom.nextInt()
|
||||
connectKey = getSafeConnectKey()
|
||||
val publicKey = endPoint.storage.getPublicKey()!!
|
||||
|
||||
// Send the one-time pad to the server.
|
||||
|
@ -155,8 +167,11 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
val pollIdleStrategy = endPoint.pollIdleStrategyHandShake
|
||||
|
||||
try {
|
||||
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
|
||||
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(connectKey, publicKey))
|
||||
} catch (e: Exception) {
|
||||
publication.close()
|
||||
subscription.close()
|
||||
|
||||
logger.error("Handshake error!", e)
|
||||
throw e
|
||||
}
|
||||
|
@ -181,6 +196,9 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
|
||||
val failedEx = failedException
|
||||
if (failedEx != null) {
|
||||
publication.close()
|
||||
subscription.close()
|
||||
|
||||
// no longer necessary to hold this connection open (if not a failure, we close the handshake after the DONE message)
|
||||
handshakeConnection.close()
|
||||
ListenerManager.cleanStackTraceInternal(failedEx)
|
||||
|
@ -188,8 +206,12 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
}
|
||||
|
||||
if (connectionHelloInfo == null) {
|
||||
publication.close()
|
||||
subscription.close()
|
||||
|
||||
// no longer necessary to hold this connection open (if not a failure, we close the handshake after the DONE message)
|
||||
handshakeConnection.close()
|
||||
|
||||
val exception = ClientTimedOutException("Waiting for registration response from server")
|
||||
ListenerManager.cleanStackTraceInternal(exception)
|
||||
throw exception
|
||||
|
@ -200,12 +222,14 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
|
||||
// called from the connect thread
|
||||
suspend fun done(handshakeConnection: MediaDriverConnection, connectionTimeoutSec: Int): Boolean {
|
||||
val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
|
||||
val registrationMessage = HandshakeMessage.doneFromClient(connectKey)
|
||||
|
||||
// Send the done message to the server.
|
||||
try {
|
||||
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
|
||||
} catch (e: Exception) {
|
||||
handshakeConnection.close()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -254,11 +278,13 @@ internal class ClientHandshake<CONNECTION: Connection>(
|
|||
throw exception
|
||||
}
|
||||
|
||||
logger.error{"[${subscription.streamId()}] handshake done"}
|
||||
|
||||
return connectionDone
|
||||
}
|
||||
|
||||
fun reset() {
|
||||
oneTimeKey = 0
|
||||
connectKey = 0L
|
||||
connectionHelloInfo = null
|
||||
connectionDone = false
|
||||
needToRetry = false
|
||||
|
|
|
@ -24,9 +24,10 @@ internal class HandshakeMessage private constructor() {
|
|||
var publicKey: ByteArray? = null
|
||||
|
||||
|
||||
// used to keep track and associate UDP/etc sessions. This is always defined by the server
|
||||
// a sessionId if '0', means we are still figuring it out.
|
||||
var oneTimeKey = 0
|
||||
// used to keep track and associate UDP/IPC handshakes between client/server
|
||||
// The connection info (session ID, etc) necessary to make a connection to the server are encrypted with the clients public key.
|
||||
// so EVEN IF you can guess someone's connectKey, you must also know their private key in order to connect as them.
|
||||
var connectKey = 0L
|
||||
|
||||
// -1 means there is an error
|
||||
var state = INVALID
|
||||
|
@ -35,9 +36,6 @@ internal class HandshakeMessage private constructor() {
|
|||
|
||||
var publicationPort = 0
|
||||
var subscriptionPort = 0
|
||||
var sessionId = 0
|
||||
var streamId = 0
|
||||
|
||||
|
||||
|
||||
// by default, this will be a reliable connection. When the client connects to the server, the client will specify if the new connection
|
||||
|
@ -45,9 +43,8 @@ internal class HandshakeMessage private constructor() {
|
|||
val isReliable = true
|
||||
|
||||
|
||||
// the client sends it's registration data to the server to make sure that the registered classes are the same between the client/server
|
||||
// the client sends its registration data to the server to make sure that the registered classes are the same between the client/server
|
||||
var registrationData: ByteArray? = null
|
||||
var registrationRmiIdData: IntArray? = null
|
||||
|
||||
companion object {
|
||||
const val INVALID = -2
|
||||
|
@ -58,42 +55,39 @@ internal class HandshakeMessage private constructor() {
|
|||
const val DONE = 3
|
||||
const val DONE_ACK = 4
|
||||
|
||||
fun helloFromClient(oneTimeKey: Int, publicKey: ByteArray): HandshakeMessage {
|
||||
fun helloFromClient(connectKey: Long, publicKey: ByteArray): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = HELLO
|
||||
hello.oneTimeKey = oneTimeKey
|
||||
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
|
||||
hello.publicKey = publicKey
|
||||
return hello
|
||||
}
|
||||
|
||||
fun helloAckToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
|
||||
fun helloAckToClient(connectKey: Long): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = HELLO_ACK
|
||||
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
|
||||
hello.sessionId = sessionId
|
||||
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
|
||||
return hello
|
||||
}
|
||||
|
||||
fun helloAckIpcToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
|
||||
fun helloAckIpcToClient(connectKey: Long): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = HELLO_ACK_IPC
|
||||
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
|
||||
hello.sessionId = sessionId
|
||||
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
|
||||
return hello
|
||||
}
|
||||
|
||||
fun doneFromClient(oneTimeKey: Int): HandshakeMessage {
|
||||
fun doneFromClient(connectKey: Long): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = DONE
|
||||
hello.oneTimeKey = oneTimeKey
|
||||
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
|
||||
return hello
|
||||
}
|
||||
|
||||
fun doneToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
|
||||
fun doneToClient(connectKey: Long): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = DONE_ACK
|
||||
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
|
||||
hello.sessionId = sessionId
|
||||
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
|
||||
return hello
|
||||
}
|
||||
|
||||
|
@ -134,7 +128,6 @@ internal class HandshakeMessage private constructor() {
|
|||
", Error: $errorMessage"
|
||||
}
|
||||
|
||||
|
||||
return "HandshakeMessage($sessionId : oneTimePad=$oneTimeKey $stateStr$errorMsg)"
|
||||
return "HandshakeMessage($stateStr$errorMsg)"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
package dorkbox.network.handshake
|
||||
|
||||
import dorkbox.netUtil.IP
|
||||
import dorkbox.network.Server
|
||||
import dorkbox.network.ServerConfiguration
|
||||
import dorkbox.network.aeron.AeronDriver
|
||||
|
@ -24,11 +25,11 @@ import dorkbox.network.connection.Connection
|
|||
import dorkbox.network.connection.ConnectionParams
|
||||
import dorkbox.network.connection.ListenerManager
|
||||
import dorkbox.network.connection.PublicKeyValidationState
|
||||
import dorkbox.network.connection.eventLoop
|
||||
import dorkbox.network.exceptions.AllocationException
|
||||
import dorkbox.network.rmi.RmiManagerConnections
|
||||
import io.aeron.Publication
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import mu.KLogger
|
||||
import net.jodah.expiringmap.ExpirationPolicy
|
||||
|
@ -40,8 +41,10 @@ import java.util.concurrent.*
|
|||
|
||||
/**
|
||||
* 'notifyConnect' must be THE ONLY THING in this class to use the action dispatch!
|
||||
*
|
||||
* NOTE: all methods in here are called by the SAME thread!
|
||||
*/
|
||||
@Suppress("DuplicatedCode")
|
||||
@Suppress("DuplicatedCode", "JoinDeclarationAndAssignment")
|
||||
internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger,
|
||||
private val config: ServerConfiguration,
|
||||
private val listenerManager: ListenerManager<CONNECTION>) {
|
||||
|
@ -50,10 +53,14 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
private val pendingConnections = ExpiringMap.builder()
|
||||
.expiration(config.connectionCloseTimeoutInSeconds.toLong() * 2, TimeUnit.SECONDS)
|
||||
.expirationPolicy(ExpirationPolicy.CREATED)
|
||||
.expirationListener<Int, CONNECTION> { sessionId, connection ->
|
||||
expirePendingConnections(sessionId, connection)
|
||||
.expirationListener<Long, CONNECTION> { clientConnectKey, connection ->
|
||||
// this blocks until it fully runs (which is ok. this is fast)
|
||||
logger.error { "[${clientConnectKey} Connection (${connection.id}) Timed out waiting for registration response from client" }
|
||||
runBlocking {
|
||||
connection.close()
|
||||
}
|
||||
}
|
||||
.build<Int, CONNECTION>()
|
||||
.build<Long, CONNECTION>()
|
||||
|
||||
|
||||
private val connectionsPerIpCounts = ConnectionCounts()
|
||||
|
@ -62,17 +69,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
private val sessionIdAllocator = RandomIdAllocator(AeronDriver.RESERVED_SESSION_ID_LOW, AeronDriver.RESERVED_SESSION_ID_HIGH)
|
||||
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
|
||||
|
||||
|
||||
private fun expirePendingConnections(sessionId: Int, connection: CONNECTION) {
|
||||
// this blocks until it fully runs (which is ok. this is fast)
|
||||
logger.error("[${connection.id}] Timed out waiting for registration response from client")
|
||||
|
||||
pendingConnections.remove(sessionId)
|
||||
runBlocking {
|
||||
connection.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return true if we should continue parsing the incoming message, false if we should abort
|
||||
*/
|
||||
|
@ -82,24 +78,22 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
actionDispatch: CoroutineScope,
|
||||
handshakePublication: Publication,
|
||||
message: HandshakeMessage,
|
||||
sessionId: Int,
|
||||
connectionString: String,
|
||||
logger: KLogger
|
||||
): Boolean {
|
||||
|
||||
// check to see if this sessionId is ALREADY in use by another connection!
|
||||
// this can happen if there are multiple connections from the SAME ip address (ie: localhost)
|
||||
if (message.state == HandshakeMessage.HELLO) {
|
||||
// this should be null.
|
||||
val hasExistingSessionId = pendingConnections[sessionId] != null
|
||||
if (hasExistingSessionId) {
|
||||
val hasExistingConnectionInProgress = pendingConnections[message.connectKey] != null
|
||||
if (hasExistingConnectionInProgress) {
|
||||
// WHOOPS! tell the client that it needs to retry, since a DIFFERENT client has a handshake in progress with the same sessionId
|
||||
logger.error("[$sessionId] Connection from $connectionString had an in-use session ID! Telling client to retry.")
|
||||
logger.error { "[${message.connectKey}] Connection from $connectionString had an in-use session ID! Telling client to retry." }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.retry("Handshake already in progress for sessionID!"))
|
||||
} catch (e: Error) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -109,19 +103,17 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
// check to see if this is a pending connection
|
||||
if (message.state == HandshakeMessage.DONE) {
|
||||
val pendingConnection = pendingConnections[sessionId]
|
||||
pendingConnections.remove(sessionId)
|
||||
|
||||
val pendingConnection = pendingConnections.remove(message.connectKey)
|
||||
if (pendingConnection == null) {
|
||||
logger.error("[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!")
|
||||
logger.error { "[${message.connectKey}] Error! Pending connection from client $connectionString was null, and cannot complete handshake!" }
|
||||
} else {
|
||||
logger.trace { "[${pendingConnection.id}] Connection from client $connectionString done with handshake." }
|
||||
logger.trace { "[${message.connectKey}] Connection (${pendingConnection.id}) from $connectionString done with handshake." }
|
||||
|
||||
pendingConnection.postCloseAction = {
|
||||
// called on connection.close()
|
||||
|
||||
// this always has to be on event dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
|
||||
actionDispatch.eventLoop {
|
||||
actionDispatch.launch {
|
||||
listenerManager.notifyDisconnect(pendingConnection)
|
||||
}
|
||||
}
|
||||
|
@ -132,17 +124,15 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
// now tell the client we are done
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.oneTimeKey, sessionId))
|
||||
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.connectKey))
|
||||
// this always has to be on event dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
|
||||
actionDispatch.eventLoop {
|
||||
actionDispatch.launch {
|
||||
listenerManager.notifyConnect(pendingConnection)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -170,7 +160,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -182,23 +172,23 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
|
||||
connectionsPerIpCounts.decrement(clientAddress, currentCountForIp)
|
||||
|
||||
logger.error("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")
|
||||
logger.error { "Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return false
|
||||
}
|
||||
connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
|
||||
} catch (e: Exception) {
|
||||
logger.error("could not validate client message", e)
|
||||
logger.error(e) { "could not validate client message" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,7 +201,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
server: Server<CONNECTION>,
|
||||
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
|
||||
handshakePublication: Publication,
|
||||
sessionId: Int,
|
||||
message: HandshakeMessage,
|
||||
aeronDriver: AeronDriver,
|
||||
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
|
||||
|
@ -220,7 +209,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
val connectionString = "IPC"
|
||||
|
||||
if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, sessionId, connectionString, logger)) {
|
||||
if (!validateMessageTypeAndDoPending(server, server.actionDispatch, handshakePublication, message, connectionString, logger)) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -238,12 +227,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
try {
|
||||
connectionSessionId = sessionIdAllocator.allocate()
|
||||
} catch (e: AllocationException) {
|
||||
logger.error("Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!")
|
||||
logger.error { "Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -256,12 +245,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
// have to unwind actions!
|
||||
sessionIdAllocator.free(connectionSessionId)
|
||||
|
||||
logger.error("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")
|
||||
logger.error { "Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -274,12 +263,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
sessionIdAllocator.free(connectionSessionId)
|
||||
sessionIdAllocator.free(connectionStreamPubId)
|
||||
|
||||
logger.error("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")
|
||||
logger.error { "Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -294,9 +283,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
// we have to construct how the connection will communicate!
|
||||
clientConnection.buildServer(aeronDriver, logger, true)
|
||||
|
||||
logger.info {
|
||||
"[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]"
|
||||
}
|
||||
logger.info { "[${clientConnection.sessionId}] IPC connection established to [${clientConnection.streamIdSubscription}|${clientConnection.streamId}]" }
|
||||
|
||||
val connection = connectionFunc(ConnectionParams(server, clientConnection, PublicKeyValidationState.VALID, rmiConnectionSupport))
|
||||
|
||||
|
@ -311,7 +298,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
|
||||
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
|
||||
val successMessage = HandshakeMessage.helloAckIpcToClient(message.oneTimeKey, sessionId)
|
||||
val successMessage = HandshakeMessage.helloAckIpcToClient(message.connectKey)
|
||||
|
||||
|
||||
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint
|
||||
|
@ -332,58 +319,70 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
successMessage.publicKey = server.crypto.publicKeyBytes
|
||||
|
||||
// before we notify connect, we have to wait for the client to tell us that they can receive data
|
||||
pendingConnections[sessionId] = connection
|
||||
pendingConnections[message.connectKey] = connection
|
||||
|
||||
// this tells the client all of the info to connect.
|
||||
// this tells the client all the info to connect.
|
||||
server.writeHandshakeMessage(handshakePublication, successMessage) // exception is already caught!
|
||||
} catch (e: Exception) {
|
||||
// have to unwind actions!
|
||||
sessionIdAllocator.free(connectionSessionId)
|
||||
streamIdAllocator.free(connectionStreamPubId)
|
||||
|
||||
logger.error("Connection handshake from $connectionString crashed! Message $message", e)
|
||||
logger.error(e) { "Connection handshake from $connectionString crashed! Message $message" }
|
||||
}
|
||||
}
|
||||
|
||||
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
|
||||
fun processUdpHandshakeMessageServer(server: Server<CONNECTION>,
|
||||
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
|
||||
handshakePublication: Publication,
|
||||
sessionId: Int,
|
||||
clientAddressString: String,
|
||||
clientAddress: InetAddress,
|
||||
message: HandshakeMessage,
|
||||
aeronDriver: AeronDriver,
|
||||
isIpv6Wildcard: Boolean,
|
||||
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
|
||||
logger: KLogger) {
|
||||
fun processUdpHandshakeMessageServer(
|
||||
server: Server<CONNECTION>,
|
||||
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
|
||||
handshakePublication: Publication,
|
||||
remoteIpAndPort: String,
|
||||
message: HandshakeMessage,
|
||||
aeronDriver: AeronDriver,
|
||||
isIpv6Wildcard: Boolean,
|
||||
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
|
||||
logger: KLogger
|
||||
) {
|
||||
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
// val port = remoteIpAndPort.substring(splitPoint+1)
|
||||
|
||||
// this should never be null, because we are feeding it a valid IP address from aeron
|
||||
val clientAddress = IP.toAddress(clientAddressString)
|
||||
if (clientAddress == null) {
|
||||
logger.error { "Connection from $clientAddressString not allowed! Invalid IP address!" }
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Manage the Handshake state
|
||||
if (!validateMessageTypeAndDoPending(
|
||||
server,
|
||||
server.actionDispatch,
|
||||
handshakePublication,
|
||||
message,
|
||||
sessionId,
|
||||
clientAddressString,
|
||||
logger
|
||||
server, server.actionDispatch, handshakePublication, message, clientAddressString, logger
|
||||
)) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
||||
val clientPublicKeyBytes = message.publicKey
|
||||
val validateRemoteAddress: PublicKeyValidationState
|
||||
val serialization = config.serialization
|
||||
|
||||
// VALIDATE:: check to see if the remote connection's public key has changed!
|
||||
validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)
|
||||
validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientAddressString, clientPublicKeyBytes)
|
||||
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
|
||||
logger.error("Connection from $clientAddressString not allowed! Public key mismatch.")
|
||||
logger.error { "Connection from $clientAddressString not allowed! Public key mismatch." }
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if (!clientAddress.isLoopbackAddress &&
|
||||
!validateUdpConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress, logger)) {
|
||||
// we do not want to limit loopback addresses!
|
||||
// we do not want to limit the loopback addresses!
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -403,12 +402,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
// have to unwind actions!
|
||||
connectionsPerIpCounts.decrementSlow(clientAddress)
|
||||
|
||||
logger.error("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")
|
||||
logger.error { "Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -422,12 +421,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
connectionsPerIpCounts.decrementSlow(clientAddress)
|
||||
sessionIdAllocator.free(connectionSessionId)
|
||||
|
||||
logger.error("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")
|
||||
logger.error { "Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -461,10 +460,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
// we have to construct how the connection will communicate!
|
||||
clientConnection.buildServer(aeronDriver, logger, true)
|
||||
|
||||
logger.info {
|
||||
// (reliable:$isReliable)"
|
||||
"Creating new connection from $clientAddressString [$subscriptionPort|$publicationPort] [$connectionStreamId|$connectionSessionId] (reliable:${message.isReliable})"
|
||||
}
|
||||
logger.info { "Creating new connection from $clientConnection" }
|
||||
|
||||
val connection = connectionFunc(ConnectionParams(server, clientConnection, validateRemoteAddress, rmiConnectionSupport))
|
||||
|
||||
|
@ -476,12 +472,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
sessionIdAllocator.free(connectionSessionId)
|
||||
streamIdAllocator.free(connectionStreamId)
|
||||
|
||||
logger.error("Connection $clientAddressString was not permitted!")
|
||||
logger.error { "Connection $clientAddressString was not permitted!" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
|
||||
} catch (e: Exception) {
|
||||
logger.error("Handshake error!", e)
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -492,9 +488,8 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
///////////////
|
||||
|
||||
|
||||
|
||||
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
|
||||
val successMessage = HandshakeMessage.helloAckToClient(message.oneTimeKey, sessionId)
|
||||
val successMessage = HandshakeMessage.helloAckToClient(message.connectKey)
|
||||
|
||||
|
||||
// Also send the RMI registration data to the client (so the client doesn't register anything)
|
||||
|
@ -510,7 +505,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
successMessage.publicKey = server.crypto.publicKeyBytes
|
||||
|
||||
// before we notify connect, we have to wait for the client to tell us that they can receive data
|
||||
pendingConnections[sessionId] = connection
|
||||
pendingConnections[message.connectKey] = connection
|
||||
|
||||
// this tells the client all the info to connect.
|
||||
server.writeHandshakeMessage(handshakePublication, successMessage) // exception is already caught
|
||||
|
@ -520,7 +515,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
sessionIdAllocator.free(connectionSessionId)
|
||||
streamIdAllocator.free(connectionStreamId)
|
||||
|
||||
logger.error("Connection handshake from $clientAddressString crashed! Message $message", e)
|
||||
logger.error(e) { "Connection handshake from $clientAddressString crashed! Message $message" }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,338 @@
|
|||
@file:Suppress("MemberVisibilityCanBePrivate", "DuplicatedCode")
|
||||
|
||||
package dorkbox.network.handshake
|
||||
|
||||
import dorkbox.network.Server
|
||||
import dorkbox.network.ServerConfiguration
|
||||
import dorkbox.network.aeron.AeronDriver
|
||||
import dorkbox.network.aeron.AeronPoller
|
||||
import dorkbox.network.aeron.IpcMediaDriverConnection
|
||||
import dorkbox.network.aeron.UdpMediaDriverServerConnection
|
||||
import dorkbox.network.connection.Connection
|
||||
import io.aeron.FragmentAssembler
|
||||
import io.aeron.Image
|
||||
import io.aeron.logbuffer.Header
|
||||
import org.agrona.DirectBuffer
|
||||
|
||||
internal object ServerHandshakePollers {
|
||||
fun disabled(serverInfo: String): AeronPoller {
|
||||
return object : AeronPoller {
|
||||
override fun poll(): Int { return 0 }
|
||||
override fun close() {}
|
||||
override val serverInfo = serverInfo
|
||||
}
|
||||
}
|
||||
|
||||
fun <CONNECTION : Connection> IPC(aeronDriver: AeronDriver, config: ServerConfiguration, server: Server<CONNECTION>): AeronPoller {
|
||||
val logger = server.logger
|
||||
val rmiConnectionSupport = server.rmiConnectionSupport
|
||||
val connectionFunc = server.connectionFunc
|
||||
val handshake = server.handshake
|
||||
|
||||
val poller = if (config.enableIpc) {
|
||||
val driver = IpcMediaDriverConnection(
|
||||
streamIdSubscription = config.ipcSubscriptionId,
|
||||
streamId = config.ipcPublicationId,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID
|
||||
)
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
object : AeronPoller {
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
val message = server.readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error { "[$sessionId] Connection from IPC not allowed! Invalid connection request" }
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processIpcHandshakeMessageServer(
|
||||
server, rmiConnectionSupport, publication, message, aeronDriver, connectionFunc, logger
|
||||
)
|
||||
}
|
||||
|
||||
override fun poll(): Int {
|
||||
return subscription.poll(handler, 1)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
driver.close()
|
||||
}
|
||||
|
||||
override val serverInfo = driver.serverInfo
|
||||
}
|
||||
} else {
|
||||
disabled("IPC Disabled")
|
||||
}
|
||||
|
||||
logger.info { poller.serverInfo }
|
||||
return poller
|
||||
}
|
||||
|
||||
|
||||
|
||||
fun <CONNECTION : Connection> ip4(aeronDriver: AeronDriver, config: ServerConfiguration, server: Server<CONNECTION>): AeronPoller {
|
||||
val logger = server.logger
|
||||
val rmiConnectionSupport = server.rmiConnectionSupport
|
||||
val connectionFunc = server.connectionFunc
|
||||
val handshake = server.handshake
|
||||
|
||||
val poller = if (server.canUseIPv4) {
|
||||
val driver = UdpMediaDriverServerConnection(
|
||||
listenAddress = server.listenIPv4Address!!,
|
||||
publicationPort = config.publicationPort,
|
||||
subscriptionPort = config.subscriptionPort,
|
||||
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
|
||||
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds
|
||||
)
|
||||
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
object : AeronPoller {
|
||||
/**
|
||||
* Note:
|
||||
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
|
||||
* desired, then limiting message sizes to MTU size is a good practice.
|
||||
*
|
||||
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
|
||||
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
|
||||
* properties from failure and streams with mechanical sympathy.
|
||||
*/
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
|
||||
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
|
||||
|
||||
val message = server.readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error {
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
|
||||
"[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"
|
||||
}
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processUdpHandshakeMessageServer(
|
||||
server, rmiConnectionSupport, publication, remoteIpAndPort, message, aeronDriver, false, connectionFunc, logger
|
||||
)
|
||||
}
|
||||
|
||||
override fun poll(): Int {
|
||||
return subscription.poll(handler, 1)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
driver.close()
|
||||
}
|
||||
|
||||
override val serverInfo = driver.serverInfo
|
||||
}
|
||||
} else {
|
||||
disabled("IPv4 Disabled")
|
||||
}
|
||||
|
||||
logger.info { poller.serverInfo }
|
||||
return poller
|
||||
}
|
||||
|
||||
fun <CONNECTION : Connection> ip6(aeronDriver: AeronDriver, config: ServerConfiguration, server: Server<CONNECTION>): AeronPoller {
|
||||
val logger = server.logger
|
||||
val rmiConnectionSupport = server.rmiConnectionSupport
|
||||
val connectionFunc = server.connectionFunc
|
||||
val handshake = server.handshake
|
||||
|
||||
val poller = if (server.canUseIPv6) {
|
||||
val driver = UdpMediaDriverServerConnection(
|
||||
listenAddress = server.listenIPv6Address!!,
|
||||
publicationPort = config.publicationPort,
|
||||
subscriptionPort = config.subscriptionPort,
|
||||
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
|
||||
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds
|
||||
)
|
||||
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
object : AeronPoller {
|
||||
/**
|
||||
* Note:
|
||||
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
|
||||
* desired, then limiting message sizes to MTU size is a good practice.
|
||||
*
|
||||
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
|
||||
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
|
||||
* properties from failure and streams with mechanical sympathy.
|
||||
*/
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
|
||||
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
|
||||
|
||||
val message = server.readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error {
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
"[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"
|
||||
}
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processUdpHandshakeMessageServer(
|
||||
server, rmiConnectionSupport, publication, remoteIpAndPort, message, aeronDriver, false, connectionFunc, logger
|
||||
)
|
||||
}
|
||||
|
||||
override fun poll(): Int {
|
||||
return subscription.poll(handler, 1)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
driver.close()
|
||||
}
|
||||
|
||||
override val serverInfo = driver.serverInfo
|
||||
}
|
||||
} else {
|
||||
disabled("IPv6 Disabled")
|
||||
}
|
||||
|
||||
logger.info { poller.serverInfo }
|
||||
return poller
|
||||
}
|
||||
|
||||
fun <CONNECTION : Connection> ip6Wildcard(
|
||||
aeronDriver: AeronDriver,
|
||||
config: ServerConfiguration,
|
||||
server: Server<CONNECTION>
|
||||
): AeronPoller {
|
||||
val logger = server.logger
|
||||
val rmiConnectionSupport = server.rmiConnectionSupport
|
||||
val connectionFunc = server.connectionFunc
|
||||
val handshake = server.handshake
|
||||
|
||||
val driver = UdpMediaDriverServerConnection(
|
||||
listenAddress = server.listenIPv6Address!!,
|
||||
publicationPort = config.publicationPort,
|
||||
subscriptionPort = config.subscriptionPort,
|
||||
streamId = AeronDriver.UDP_HANDSHAKE_STREAM_ID,
|
||||
sessionId = AeronDriver.RESERVED_SESSION_ID_INVALID,
|
||||
connectionTimeoutSec = config.connectionCloseTimeoutInSeconds
|
||||
)
|
||||
|
||||
driver.buildServer(aeronDriver, logger)
|
||||
|
||||
val publication = driver.publication
|
||||
val subscription = driver.subscription
|
||||
|
||||
val poller = object : AeronPoller {
|
||||
/**
|
||||
* Note:
|
||||
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
|
||||
* desired, then limiting message sizes to MTU size is a good practice.
|
||||
*
|
||||
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
|
||||
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
|
||||
* properties from failure and streams with mechanical sympathy.
|
||||
*/
|
||||
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
|
||||
|
||||
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
|
||||
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
|
||||
val sessionId = header.sessionId()
|
||||
|
||||
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
|
||||
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
|
||||
|
||||
val message = server.readHandshakeMessage(buffer, offset, length, header)
|
||||
|
||||
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
|
||||
if (message !is HandshakeMessage) {
|
||||
logger.error {
|
||||
// split
|
||||
val splitPoint = remoteIpAndPort.lastIndexOf(':')
|
||||
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
|
||||
"[$sessionId] Connection from $clientAddressString not allowed! Invalid connection request"
|
||||
}
|
||||
|
||||
try {
|
||||
server.writeHandshakeMessage(publication, HandshakeMessage.error("Invalid connection request"))
|
||||
} catch (e: Exception) {
|
||||
logger.error(e) { "Handshake error!" }
|
||||
}
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
handshake.processUdpHandshakeMessageServer(
|
||||
server, rmiConnectionSupport, publication, remoteIpAndPort, message, aeronDriver, true, connectionFunc, logger
|
||||
)
|
||||
}
|
||||
|
||||
override fun poll(): Int {
|
||||
return subscription.poll(handler, 1)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
driver.close()
|
||||
}
|
||||
|
||||
override val serverInfo = driver.serverInfo
|
||||
}
|
||||
|
||||
logger.info { poller.serverInfo }
|
||||
return poller
|
||||
}
|
||||
}
|
|
@ -36,10 +36,7 @@ package dorkboxTest.network
|
|||
|
||||
import ch.qos.logback.classic.Level
|
||||
import ch.qos.logback.classic.Logger
|
||||
import ch.qos.logback.classic.encoder.PatternLayoutEncoder
|
||||
import ch.qos.logback.classic.joran.JoranConfigurator
|
||||
import ch.qos.logback.classic.spi.ILoggingEvent
|
||||
import ch.qos.logback.core.ConsoleAppender
|
||||
import dorkbox.network.Client
|
||||
import dorkbox.network.Configuration
|
||||
import dorkbox.network.Server
|
||||
|
@ -55,20 +52,58 @@ import org.junit.After
|
|||
import org.junit.Assert
|
||||
import org.junit.Before
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.File
|
||||
import java.lang.Thread.sleep
|
||||
import java.lang.reflect.Field
|
||||
import java.lang.reflect.Method
|
||||
import java.util.concurrent.*
|
||||
|
||||
abstract class BaseTest {
|
||||
@Volatile
|
||||
private var latch = CountDownLatch(1)
|
||||
|
||||
@Volatile
|
||||
private var autoFailThread: Thread? = null
|
||||
|
||||
companion object {
|
||||
const val LOCALHOST = "localhost"
|
||||
|
||||
// wait minimum of 2 minutes before we automatically fail the unit test.
|
||||
var AUTO_FAIL_TIMEOUT: Long = 120L
|
||||
|
||||
init {
|
||||
if (OS.javaVersion >= 9) {
|
||||
// disableAccessWarnings
|
||||
try {
|
||||
val unsafeClass = Class.forName("sun.misc.Unsafe")
|
||||
val field: Field = unsafeClass.getDeclaredField("theUnsafe")
|
||||
field.isAccessible = true
|
||||
val unsafe: Any = field.get(null)
|
||||
val putObjectVolatile: Method = unsafeClass.getDeclaredMethod("putObjectVolatile", Any::class.java, Long::class.javaPrimitiveType, Any::class.java)
|
||||
val staticFieldOffset: Method = unsafeClass.getDeclaredMethod("staticFieldOffset", Field::class.java)
|
||||
val loggerClass = Class.forName("jdk.internal.module.IllegalAccessLogger")
|
||||
val loggerField: Field = loggerClass.getDeclaredField("logger")
|
||||
val offset = staticFieldOffset.invoke(unsafe, loggerField) as Long
|
||||
putObjectVolatile.invoke(unsafe, loggerClass, offset, null)
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
}
|
||||
|
||||
// if (System.getProperty("logback.configurationFile") == null) {
|
||||
// val file = File("logback.xml")
|
||||
// if (file.canRead()) {
|
||||
// System.setProperty("logback.configurationFile", file.toPath().toRealPath().toFile().toString())
|
||||
// } else {
|
||||
// System.setProperty("logback.configurationFile", "logback.xml")
|
||||
// }
|
||||
// }
|
||||
|
||||
// setLogLevel(Level.TRACE)
|
||||
// setLogLevel(Level.ERROR)
|
||||
setLogLevel(Level.DEBUG)
|
||||
|
||||
// we want our entropy generation to be simple (ie, no user interaction to generate)
|
||||
try {
|
||||
Entropy.init(SimpleEntropy::class.java)
|
||||
} catch (e: InitializationException) {
|
||||
e.printStackTrace()
|
||||
}
|
||||
}
|
||||
|
||||
fun clientConfig(block: Configuration.() -> Unit = {}): Configuration {
|
||||
|
||||
val configuration = Configuration()
|
||||
|
@ -102,66 +137,43 @@ abstract class BaseTest {
|
|||
|
||||
// assume SLF4J is bound to logback in the current environment
|
||||
val rootLogger = LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME) as Logger
|
||||
rootLogger.detachAndStopAllAppenders()
|
||||
rootLogger.level = level
|
||||
|
||||
val context = rootLogger.loggerContext
|
||||
val jc = JoranConfigurator()
|
||||
context.reset() // override default configuration
|
||||
|
||||
val jc = JoranConfigurator()
|
||||
jc.context = context
|
||||
|
||||
|
||||
context.getLogger(Server::class.simpleName).level = level
|
||||
context.getLogger(Client::class.simpleName).level = level
|
||||
jc.doConfigure(File("logback.xml").absoluteFile)
|
||||
|
||||
// we only want error messages
|
||||
val kryoLogger = LoggerFactory.getLogger("com.esotericsoftware") as Logger
|
||||
kryoLogger.level = Level.ERROR
|
||||
|
||||
val encoder = PatternLayoutEncoder()
|
||||
encoder.context = context
|
||||
encoder.pattern = "%date{HH:mm:ss.SSS} %-5level [%logger{35}] %msg%n"
|
||||
encoder.start()
|
||||
val consoleAppender = ConsoleAppender<ILoggingEvent>()
|
||||
consoleAppender.context = context
|
||||
consoleAppender.encoder = encoder
|
||||
consoleAppender.start()
|
||||
rootLogger.addAppender(consoleAppender)
|
||||
// val encoder = PatternLayoutEncoder()
|
||||
// encoder.context = context
|
||||
// encoder.pattern = "%date{HH:mm:ss.SSS} %-5level [%logger{35}] %msg%n"
|
||||
// encoder.start()
|
||||
//
|
||||
// val consoleAppender = ConsoleAppender<ILoggingEvent>()
|
||||
// consoleAppender.context = context
|
||||
// consoleAppender.encoder = encoder
|
||||
// consoleAppender.start()
|
||||
//
|
||||
// rootLogger.addAppender(consoleAppender)
|
||||
|
||||
// context.getLogger(Server::class.simpleName).trace("TESTING")
|
||||
// context.getLogger(Client::class.simpleName).trace("TESTING")
|
||||
}
|
||||
|
||||
// wait minimum of 2 minutes before we automatically fail the unit test.
|
||||
var AUTO_FAIL_TIMEOUT: Long = 120L
|
||||
|
||||
init {
|
||||
if (OS.javaVersion >= 9) {
|
||||
// disableAccessWarnings
|
||||
try {
|
||||
val unsafeClass = Class.forName("sun.misc.Unsafe")
|
||||
val field: Field = unsafeClass.getDeclaredField("theUnsafe")
|
||||
field.isAccessible = true
|
||||
val unsafe: Any = field.get(null)
|
||||
val putObjectVolatile: Method = unsafeClass.getDeclaredMethod("putObjectVolatile", Any::class.java, Long::class.javaPrimitiveType, Any::class.java)
|
||||
val staticFieldOffset: Method = unsafeClass.getDeclaredMethod("staticFieldOffset", Field::class.java)
|
||||
val loggerClass = Class.forName("jdk.internal.module.IllegalAccessLogger")
|
||||
val loggerField: Field = loggerClass.getDeclaredField("logger")
|
||||
val offset = staticFieldOffset.invoke(unsafe, loggerField) as Long
|
||||
putObjectVolatile.invoke(unsafe, loggerClass, offset, null)
|
||||
} catch (ignored: Exception) {
|
||||
}
|
||||
}
|
||||
|
||||
// we want our entropy generation to be simple (ie, no user interaction to generate)
|
||||
try {
|
||||
Entropy.init(SimpleEntropy::class.java)
|
||||
} catch (e: InitializationException) {
|
||||
e.printStackTrace()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Volatile
|
||||
private var latch = CountDownLatch(1)
|
||||
|
||||
@Volatile
|
||||
private var autoFailThread: Thread? = null
|
||||
|
||||
private val endPointConnections: MutableList<EndPoint<*>> = CopyOnWriteArrayList()
|
||||
|
||||
@Volatile
|
||||
|
@ -170,10 +182,6 @@ abstract class BaseTest {
|
|||
init {
|
||||
println("---- " + this.javaClass.simpleName)
|
||||
|
||||
setLogLevel(Level.TRACE)
|
||||
// setLogLevel(Level.ERROR)
|
||||
// setLogLevel(Level.DEBUG)
|
||||
|
||||
// we must always make sure that aeron is shut-down before starting again.
|
||||
while (Server.isRunning(serverConfig())) {
|
||||
println("Aeron was still running. Waiting for it to stop...")
|
||||
|
@ -211,7 +219,7 @@ abstract class BaseTest {
|
|||
if (endPoint is Client) {
|
||||
endPoint.close()
|
||||
latch.countDown()
|
||||
println("Done with ${endPoint.type.simpleName}")
|
||||
println("Done closing: ${endPoint.type.simpleName}")
|
||||
} else {
|
||||
remainingConnections.add(endPoint)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import dorkbox.network.aeron.AeronDriver
|
|||
import dorkbox.network.connection.Connection
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Assert
|
||||
import org.junit.Test
|
||||
import java.io.IOException
|
||||
|
@ -205,9 +204,7 @@ class DisconnectReconnectTest : BaseTest() {
|
|||
fun manualMediaDriverAndReconnectClient() {
|
||||
// NOTE: once a config is assigned to a driver, the config cannot be changed
|
||||
val aeronDriver = AeronDriver(serverConfig())
|
||||
runBlocking {
|
||||
aeronDriver.start()
|
||||
}
|
||||
aeronDriver.start()
|
||||
|
||||
run {
|
||||
val serverConfiguration = serverConfig()
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
package dorkboxTest.network
|
||||
|
||||
import ch.qos.logback.classic.Level
|
||||
import dorkbox.network.Client
|
||||
import dorkbox.network.Server
|
||||
import dorkbox.network.connection.Connection
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Assert
|
||||
import org.junit.Test
|
||||
import java.lang.Thread.sleep
|
||||
|
||||
class MultiClientTest : BaseTest() {
|
||||
private val totalCount = 8 // this number is dependent on the number of CPU cores on the box!
|
||||
private val clientConnectCount = atomic(0)
|
||||
private val serverConnectCount = atomic(0)
|
||||
private val disconnectCount = atomic(0)
|
||||
|
||||
@Test
|
||||
fun multiConnectClient() {
|
||||
setLogLevel(Level.TRACE)
|
||||
|
||||
|
||||
// clients first, so they try to connect to the server at (roughly) the same time
|
||||
val clients = mutableListOf<Client<Connection>>()
|
||||
for (i in 1..totalCount) {
|
||||
val client: Client<Connection> = Client(clientConfig(), "Client$i")
|
||||
client.onConnect {
|
||||
clientConnectCount.getAndIncrement()
|
||||
logger.error("${this.id} - Connected $i!")
|
||||
}
|
||||
client.onDisconnect {
|
||||
disconnectCount.getAndIncrement()
|
||||
logger.error("${this.id} - Disconnected $i!")
|
||||
}
|
||||
addEndPoint(client)
|
||||
clients += client
|
||||
}
|
||||
|
||||
GlobalScope.launch {
|
||||
clients.forEach {
|
||||
// long connection timeout, since the more that try to connect at the same time, the longer it takes to setup aeron (since it's all shared)
|
||||
launch { it.connect(LOCALHOST, 30*totalCount) }
|
||||
}
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
sleep(5000L)
|
||||
println("Starting server...")
|
||||
val configuration = serverConfig()
|
||||
|
||||
val server: Server<Connection> = Server(configuration)
|
||||
addEndPoint(server)
|
||||
server.onConnect {
|
||||
val count = serverConnectCount.incrementAndGet()
|
||||
|
||||
logger.error("${this.id} - Connecting $count ....")
|
||||
close()
|
||||
|
||||
if (count == totalCount) {
|
||||
logger.error { "Stopping endpoints!" }
|
||||
stopEndPoints(10000L)
|
||||
}
|
||||
}
|
||||
|
||||
server.bind()
|
||||
}
|
||||
|
||||
waitForThreads()
|
||||
|
||||
Assert.assertEquals(totalCount, clientConnectCount.value)
|
||||
Assert.assertEquals(totalCount, serverConnectCount.value)
|
||||
Assert.assertEquals(totalCount, disconnectCount.value)
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package dorkboxTest.network
|
||||
|
||||
import dorkbox.network.Client
|
||||
import dorkbox.network.Server
|
||||
import dorkbox.network.connection.Connection
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Assert
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
class MultiConnectTest : BaseTest() {
|
||||
private val reconnectCount = atomic(0)
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
fun multiConnectClient() {
|
||||
// clients first, so they try to connect to the server at (roughly) the same time
|
||||
val config = clientConfig()
|
||||
|
||||
val client1: Client<Connection> = Client(config)
|
||||
val client2: Client<Connection> = Client(config)
|
||||
|
||||
addEndPoint(client1)
|
||||
addEndPoint(client2)
|
||||
client1.onDisconnect {
|
||||
logger.error("Disconnected 1!")
|
||||
}
|
||||
client2.onDisconnect {
|
||||
logger.error("Disconnected 2!")
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
launch { client1.connect(LOCALHOST) }
|
||||
launch { client2.connect(LOCALHOST) }
|
||||
}
|
||||
// GlobalScope.launch {
|
||||
// client1.connect(LOCALHOST)
|
||||
// }
|
||||
//
|
||||
// GlobalScope.launch {
|
||||
// client2.connect(LOCALHOST)
|
||||
// }
|
||||
|
||||
println("Starting server...")
|
||||
|
||||
run {
|
||||
val configuration = serverConfig()
|
||||
|
||||
val server: Server<Connection> = Server(configuration)
|
||||
addEndPoint(server)
|
||||
server.bind()
|
||||
|
||||
server.onConnect {
|
||||
logger.error("Disconnecting after 10 seconds.")
|
||||
delay(10.seconds)
|
||||
|
||||
logger.error("Disconnecting....")
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
waitForThreads()
|
||||
|
||||
System.err.println("Connection count (after reconnecting) is: " + reconnectCount.value)
|
||||
Assert.assertEquals(4, reconnectCount.value)
|
||||
}
|
||||
|
||||
interface CloseIface {
|
||||
suspend fun close()
|
||||
}
|
||||
|
||||
class CloseImpl : CloseIface {
|
||||
override suspend fun close() {
|
||||
// the connection specific one is called instead
|
||||
}
|
||||
|
||||
suspend fun close(connection: Connection) {
|
||||
connection.close()
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue