Added IPC support, filled out more methods. Better support for connect in a nested disconnect callback. Fixed issue with closing connections-with-handshake-errors on the server pending.

This commit is contained in:
nathan 2020-09-02 02:39:05 +02:00
parent e09fd43e37
commit e20f9b91de
11 changed files with 804 additions and 540 deletions

View File

@ -29,13 +29,14 @@ import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.handshake.ClientHandshake import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.other.coroutines.SuspendWaiter
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.TimeoutException import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException import dorkbox.util.exceptions.SecurityException
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/** /**
* The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's * The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's
@ -74,6 +75,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization) private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization)
private val lockStepForReconnect = atomic<SuspendWaiter?>(null)
init { init {
// have to do some basic validation of our configuration // have to do some basic validation of our configuration
if (config.publicationPort <= 0) { throw ClientException("configuration port must be > 0") } if (config.publicationPort <= 0) { throw ClientException("configuration port must be > 0") }
@ -106,12 +109,12 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* - a network name ("localhost", "loopback", "lo", "bob.example.org") * - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1") * - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* *
* ### For the IPC (Inter-Process-Communication) address. it must be: * ### For the IPC (Inter-Process-Communication) it must be:
* - the IPC integer ID, "0x1337c0de", "0x12312312", etc. * - EMPTY. ie: just call `connect()`
* *
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
* *
* @param remoteAddress The network or IPC address for the client to connect to * @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable * @param reliable true if we want to create a reliable connection. IPC connections are always reliable
* *
@ -121,29 +124,55 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
// this will exist ONLY if we are reconnecting via a "disconnect" callback
lockStepForReconnect.value?.doWait()
if (isConnected) { if (isConnected) {
logger.error("Unable to connect when already connected!") logger.error("Unable to connect when already connected!")
return return
} }
lockStepForReconnect.lazySet(null)
connection = null
// we are done with initial configuration, now initialize aeron and the general state of this endpoint // we are done with initial configuration, now initialize aeron and the general state of this endpoint
val aeron = initEndpointState() val aeron = initEndpointState()
this.connectionTimeoutMS = connectionTimeoutMS this.connectionTimeoutMS = connectionTimeoutMS
val isIpcConnection: Boolean
// NETWORK OR IPC ADDRESS
// if we connect to "loopback", then we substitute if for IPC (with log message)
// localhost/loopback IP might not always be 127.0.0.1 or ::1 // localhost/loopback IP might not always be 127.0.0.1 or ::1
when (remoteAddress) { when (remoteAddress) {
"loopback", "localhost", "lo", "" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress "0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
else -> when { "loopback", "localhost", "lo", "" -> {
IPv4.isLoopback(remoteAddress) -> this.remoteAddress = IPv4.LOCALHOST.hostAddress isIpcConnection = true
IPv6.isLoopback(remoteAddress) -> this.remoteAddress = IPv6.LOCALHOST.hostAddress logger.info("Auto-changing network connection from $remoteAddress -> IPC")
else -> this.remoteAddress = remoteAddress // might be IPC address! this.remoteAddress = "ipc"
}
"0x" -> {
isIpcConnection = true
this.remoteAddress = "ipc"
}
else -> when {
IPv4.isLoopback(remoteAddress) -> {
logger.info("Auto-changing network connection from $remoteAddress -> IPC")
isIpcConnection = true
this.remoteAddress = "ipc"
}
IPv6.isLoopback(remoteAddress) -> {
logger.info("Auto-changing network connection from $remoteAddress -> IPC")
isIpcConnection = true
this.remoteAddress = "ipc"
}
else -> {
isIpcConnection = false
this.remoteAddress = remoteAddress
}
} }
}
// if we are IPv4 wildcard
if (this.remoteAddress == "0.0.0.0") {
throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
} }
@ -158,234 +187,232 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
} }
val handshake = ClientHandshake(logger, config, crypto, this) val handshake = ClientHandshake(logger, config, crypto, this)
if (this.remoteAddress.isEmpty()) {
// this is an IPC address
// When conducting IPC transfers, we MUST use the same aeron configuration as the server! // initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER
// config.aeronLogDirectory val handshakeConnection = if (isIpcConnection) {
IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_PUB,
streamId = IPC_HANDSHAKE_STREAM_ID_SUB,
// stream IDs are flipped for a client because we operate from the perspective of the server sessionId = RESERVED_SESSION_ID_INVALID)
val handshakeConnection = IpcMediaDriverConnection(
streamId = IPC_HANDSHAKE_STREAM_ID_SUB,
streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_PUB,
sessionId = RESERVED_SESSION_ID_INVALID
)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
handshakeConnection.buildClient(aeron)
// logger.debug(handshakeConnection.clientInfo())
println("CONASD")
// this will block until the connection timeout, and throw an exception if we were unable to connect with the server
// @Throws(ConnectTimedOutException::class, ClientRejectedException::class)
val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
println("CO23232232323NASD")
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
} }
else { else {
// THIS IS A NETWORK ADDRESS UdpMediaDriverConnection(address = this.remoteAddress,
publicationPort = config.subscriptionPort,
// initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER subscriptionPort = config.publicationPort,
val handshakeConnection = UdpMediaDriverConnection(address = this.remoteAddress, streamId = UDP_HANDSHAKE_STREAM_ID,
publicationPort = config.subscriptionPort, sessionId = RESERVED_SESSION_ID_INVALID,
subscriptionPort = config.publicationPort, connectionTimeoutMS = connectionTimeoutMS,
streamId = UDP_HANDSHAKE_STREAM_ID, isReliable = reliable)
sessionId = RESERVED_SESSION_ID_INVALID, }
connectionTimeoutMS = connectionTimeoutMS,
isReliable = reliable)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
handshakeConnection.buildClient(aeron)
logger.info(handshakeConnection.clientInfo())
// this will block until the connection timeout, and throw an exception if we were unable to connect with the server // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
handshakeConnection.buildClient(aeron)
// @Throws(ConnectTimedOutException::class, ClientRejectedException::class) logger.info(handshakeConnection.clientInfo())
val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
// VALIDATE:: check to see if the remote connection's public key has changed! // this will block until the connection timeout, and throw an exception if we were unable to connect with the server
val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.")
listenerManager.notifyError(exception)
throw exception
}
// VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the // @Throws(ConnectTimedOutException::class, ClientRejectedException::class)
// client will timeout. SPECIFICALLY.... we do not give class serialization/registration info to the client (in case the client val connectionInfo = handshake.handshakeHello(handshakeConnection, connectionTimeoutMS)
// is rogue, we do not want to carelessly provide info.
// we are now connected, so we can connect to the NEW client-specific ports // VALIDATE:: check to see if the remote connection's public key has changed!
val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address, val validateRemoteAddress = if (isIpcConnection) {
// NOTE: pub/sub must be switched! PublicKeyValidationState.VALID
publicationPort = connectionInfo.subscriptionPort, } else {
subscriptionPort = connectionInfo.publicationPort, crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
streamId = connectionInfo.streamId, }
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
// only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object)
// does not need to do anything
//
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
logger.info(reliableClientConnection.clientInfo())
// we have to construct how the connection will communicate!
reliableClientConnection.buildClient(aeron)
logger.info {
"Creating new connection to $reliableClientConnection"
}
val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress))
// VALIDATE are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST")
val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) {
handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!")
listenerManager.notifyError(exception)
throw exception
}
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage ->
listenerManager.notifyError(newConnection,
ClientRejectedException(errorMessage))
}
connection = newConnection
connections.add(newConnection)
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.sessionId}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.sessionId}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Polls the AERON media driver subscription channel for incoming messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
}
// tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold the handshake connection open
handshakeConnection.close() handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.")
listenerManager.notifyError(exception)
throw exception
}
if (canFinishConnecting) {
isConnected = true
actionDispatch.launch { // VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the
listenerManager.notifyConnect(newConnection) // client will timeout. SPECIFICALLY.... we do not give class serialization/registration info to the client (in case the client
} // is rogue, we do not want to carelessly provide info.
} else {
close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") // we are now connected, so we can connect to the NEW client-specific ports
ListenerManager.cleanStackTrace(exception) val reliableClientConnection = if (isIpcConnection) {
listenerManager.notifyError(exception) IpcMediaDriverConnection(sessionId = connectionInfo.sessionId,
throw exception // NOTE: pub/sub must be switched!
streamIdSubscription = connectionInfo.publicationPort,
streamId = connectionInfo.subscriptionPort,
connectionTimeoutMS = connectionTimeoutMS)
}
else {
UdpMediaDriverConnection(address = handshakeConnection.address,
// NOTE: pub/sub must be switched!
subscriptionPort = connectionInfo.publicationPort,
publicationPort = connectionInfo.subscriptionPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
}
// we have to construct how the connection will communicate!
reliableClientConnection.buildClient(aeron)
// only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object)
// does not need to do anything
//
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
logger.info(reliableClientConnection.clientInfo())
val newConnection = if (isIpcConnection) {
newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID))
} else {
newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress))
}
// VALIDATE are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST")
val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) {
handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!")
listenerManager.notifyError(exception)
throw exception
}
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
serialization.updateKryoIdsForRmi(newConnection, connectionInfo.kryoIdsForRmi) { errorMessage ->
listenerManager.notifyError(newConnection,
ClientRejectedException(errorMessage))
}
//////////////
/// Extra Close action
//////////////
newConnection.preCloseAction = {
// this is called whenever connection.close() is called by the framework or via client.close()
if (!lockStepForReconnect.compareAndSet(null, SuspendWaiter())) {
listenerManager.notifyError(getConnection(), IllegalStateException("lockStep for reconnect was in the wrong state!"))
} }
} }
newConnection.postCloseAction = {
// this is called whenever connection.close() is called by the framework or via client.close()
// make sure to call our client.notifyDisconnect() callbacks
// manually call it.
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
actionDispatch.launch {
listenerManager.notifyDisconnect(getConnection())
}
// in case notifyDisconnect called client.connect().... cancel them waiting
isConnected = false
lockStepForReconnect.value?.cancel()
}
connection = newConnection
connections.add(newConnection)
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy
while (!isShutdown()) {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (newConnection.isExpired()) {
logger.debug {"[${newConnection.id}] connection expired"}
shouldCleanupConnection = true
}
else if (newConnection.isClosed()) {
logger.debug {"[${newConnection.id}] connection closed"}
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
close()
return@launch
}
else {
// Polls the AERON media driver subscription channel for incoming messages
val pollCount = newConnection.pollSubscriptions()
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
}
}
// tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold the handshake connection open
handshakeConnection.close()
if (canFinishConnecting) {
isConnected = true
actionDispatch.launch {
listenerManager.notifyConnect(newConnection)
}
} else {
close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
throw exception
}
} }
// override fun hasRemoteKeyChanged(): Boolean { /**
// return connection!!.hasRemoteKeyChanged() * @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed.
// } */
// fun hasRemoteKeyChanged(): Boolean {
// /** return getConnection().hasRemoteKeyChanged()
// * @return the remote address, as a string. }
// */
// override fun getRemoteHost(): String {
// return connection!!.remoteHost
// }
//
// /**
// * @return true if this connection is established on the loopback interface
// */
// override fun isLoopback(): Boolean {
// return connection!!.isLoopback
// }
//
// override fun isIPC(): Boolean {
// return false
// }
// /**
// * @return true if this connection is a network connection
// */
// override fun isNetwork(): Boolean {
// return false
// }
//
// /**
// * @return the connection (TCP or LOCAL) id of this connection.
// */
// override fun id(): Int {
// return connection!!.id()
// }
//
// /**
// * @return the connection (TCP or LOCAL) id of this connection as a HEX string.
// */
// override fun idAsHex(): String {
// return connection!!.idAsHex()
// }
/** /**
* Fetches the connection used by the client, this is only valid after the client has connected * @return the remote address, as a string.
*/
fun getRemoteHost(): String {
return this.remoteAddress
}
/**
* @return true if this connection is an IPC connection
*/
fun isIPC(): Boolean {
return getConnection().isIpc
}
/**
* @return true if this connection is a network connection
*/
fun isNetwork(): Boolean {
return getConnection().isNetwork
}
/**
* @return the connection (TCP or IPC) id of this connection.
*/
fun id(): Int {
return getConnection().id
}
/**
* @return the connection used by the client, this is only valid after the client has connected
*/ */
fun getConnection(): CONNECTION { fun getConnection(): CONNECTION {
return connection as CONNECTION return connection as CONNECTION
@ -427,32 +454,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
} }
override fun close() {
val con = connection
connection = null
isConnected = false
super.close()
// in the client, "client-notifyDisconnect" will NEVER be called, because it's only called on a connection!
// (meaning, 'connection-notifiyDisconnect' is what is called)
// manually call it.
if (con != null) {
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
val job = actionDispatch.launch {
listenerManager.notifyDisconnect(con)
}
// when we close a client or a server, we want to make sure that ALL notifications are finished.
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
// NOTE: this must be the LAST thing happening!
runBlocking {
job.join()
}
}
}
// RMI notes (in multiple places, copypasta, because this is confusing if not written down) // RMI notes (in multiple places, copypasta, because this is confusing if not written down)
// //
// only server can create a global object (in itself, via save) // only server can create a global object (in itself, via save)

View File

@ -20,6 +20,7 @@ import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.server.ServerException import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.connection.connectionType.ConnectionProperties import dorkbox.network.connection.connectionType.ConnectionProperties
@ -88,7 +89,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// localhost/loopback IP might not always be 127.0.0.1 or ::1 // localhost/loopback IP might not always be 127.0.0.1 or ::1
when (config.listenIpAddress) { when (config.listenIpAddress) {
"loopback", "localhost", "lo" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress "loopback", "localhost", "lo", "" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
else -> when { else -> when {
IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress
@ -132,13 +133,9 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
/** /**
* Binds the server to AERON configuration * Binds the server to AERON configuration
*
* @param blockUntilTerminate if true, will BLOCK until the server [close] method is called, and if you want to continue running code
* after this pass in false
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
@JvmOverloads fun bind() {
suspend fun bind(blockUntilTerminate: Boolean = true) {
if (bindAlreadyCalled) { if (bindAlreadyCalled) {
logger.error("Unable to bind when the server is already running!") logger.error("Unable to bind when the server is already running!")
return return
@ -150,32 +147,29 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
config as ServerConfiguration config as ServerConfiguration
// setup the "HANDSHAKE" ports, for initial clients to connect.
// The is how clients then get the new ports to connect to + other configuration options
val handshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress, val ipcHandshakeDriver = IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB,
publicationPort = config.publicationPort, streamId = IPC_HANDSHAKE_STREAM_ID_PUB,
subscriptionPort = config.subscriptionPort, sessionId = RESERVED_SESSION_ID_INVALID)
streamId = UDP_HANDSHAKE_STREAM_ID, ipcHandshakeDriver.buildServer(aeron)
sessionId = RESERVED_SESSION_ID_INVALID) val ipcHandshakePublication = ipcHandshakeDriver.publication
val ipcHandshakeSubscription = ipcHandshakeDriver.subscription
handshakeDriver.buildServer(aeron)
val handshakePublication = handshakeDriver.publication
val handshakeSubscription = handshakeDriver.subscription
logger.info(handshakeDriver.serverInfo())
// val ipcHandshakeDriver = IpcMediaDriverConnection(
// streamId = IPC_HANDSHAKE_STREAM_ID_PUB, val udpHandshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress,
// streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB, publicationPort = config.publicationPort,
// sessionId = RESERVED_SESSION_ID_INVALID subscriptionPort = config.subscriptionPort,
// ) streamId = UDP_HANDSHAKE_STREAM_ID,
// ipcHandshakeDriver.buildServer(aeron) sessionId = RESERVED_SESSION_ID_INVALID)
//
// val ipcHandshakePublication = ipcHandshakeDriver.publication udpHandshakeDriver.buildServer(aeron)
// val ipcHandshakeSubscription = ipcHandshakeDriver.subscription val handshakePublication = udpHandshakeDriver.publication
val handshakeSubscription = udpHandshakeDriver.subscription
logger.info(ipcHandshakeDriver.serverInfo())
logger.info(udpHandshakeDriver.serverInfo())
/** /**
@ -187,14 +181,14 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy. * properties from failure and streams with mechanical sympathy.
*/ */
val handshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> val udpHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId() val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo // note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity() val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split // split
@ -204,23 +198,28 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
val clientAddress = IPv4.toInt(clientAddressString) val clientAddress = IPv4.toInt(clientAddressString)
val message = readHandshakeMessage(buffer, offset, length, header) val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processHandshakeMessageServer(this@Server,
actionDispatch.launch { handshakePublication,
handshake.processHandshakeMessageServer(handshakePublication, sessionId,
sessionId, clientAddressString,
clientAddressString, clientAddress,
clientAddress, message,
message, aeron)
this@Server,
aeron)
}
} }
val ipcInitialConnectionHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
val ipcHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
actionDispatch.launch { // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
println("GOT MESSAGE!") // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
} val sessionId = header.sessionId()
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processHandshakeMessageServer(this@Server,
ipcHandshakePublication,
sessionId,
message,
aeron)
} }
actionDispatch.launch { actionDispatch.launch {
@ -236,10 +235,10 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
// this checks to see if there are NEW clients on the handshake ports // this checks to see if there are NEW clients on the handshake ports
pollCount += handshakeSubscription.poll(handshakeHandler, 2) pollCount += handshakeSubscription.poll(udpHandshakeHandler, 1)
// this checks to see if there are NEW clients via IPC // this checks to see if there are NEW clients via IPC
// pollCount += ipcHandshakeSubscription.poll(ipcInitialConnectionHandler, 100) pollCount += ipcHandshakeSubscription.poll(ipcHandshakeHandler, 1)
// this manages existing clients (for cleanup + connection polling) // this manages existing clients (for cleanup + connection polling)
@ -248,12 +247,12 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
var shouldCleanupConnection = false var shouldCleanupConnection = false
if (connection.isExpired()) { if (connection.isExpired()) {
logger.trace {"[${connection.sessionId}] connection expired"} logger.trace {"[${connection.id}] connection expired"}
shouldCleanupConnection = true shouldCleanupConnection = true
} }
else if (connection.isClosed()) { else if (connection.isClosed()) {
logger.trace {"[${connection.sessionId}] connection closed"} logger.trace {"[${connection.id}] connection closed"}
shouldCleanupConnection = true shouldCleanupConnection = true
} }
@ -268,7 +267,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
false false
} }
}, { connectionToClean -> }, { connectionToClean ->
logger.info {"[${connectionToClean.sessionId}] cleaned-up connection"} logger.info {"[${connectionToClean.id}] cleaned-up connection"}
// have to free up resources! // have to free up resources!
handshake.cleanup(connectionToClean) handshake.cleanup(connectionToClean)
@ -294,16 +293,10 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
handshakePublication.close() handshakePublication.close()
handshakeSubscription.close() handshakeSubscription.close()
// ipcHandshakePublication.close() ipcHandshakePublication.close()
// ipcHandshakeSubscription.close() ipcHandshakeSubscription.close()
} }
} }
// we now BLOCK until the stop method is called.
if (blockUntilTerminate) {
waitForShutdown();
}
} }
/** /**
@ -364,8 +357,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
/** /**
* Closes the server and all it's connections. After a close, you may call 'bind' again. * Closes the server and all it's connections. After a close, you may call 'bind' again.
*/ */
override fun close() { override fun close0() {
super.close()
bindAlreadyCalled = false bindAlreadyCalled = false
// when we call close, it will shutdown the polling mechanism, so we have to manually cleanup the connections and call server-notifyDisconnect // when we call close, it will shutdown the polling mechanism, so we have to manually cleanup the connections and call server-notifyDisconnect
@ -436,55 +428,6 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
} }
// enum class STATE {
// ERROR, WAIT, CONTINUE
// }
// fun verifyClassRegistration(metaChannel: MetaChannel, registration: Registration): STATE {
// if (registration.upgradeType == UpgradeType.FRAGMENTED) {
// val fragment = registration.payload!!
//
// // this means that the registrations are FRAGMENTED!
// // max size of ALL fragments is xxx * 127
// if (metaChannel.fragmentedRegistrationDetails == null) {
// metaChannel.remainingFragments = fragment[1]
// metaChannel.fragmentedRegistrationDetails = ByteArray(Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * fragment[1])
// }
// System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE, fragment.size - 2)
//
// metaChannel.remainingFragments--
//
// if (fragment[0] + 1 == fragment[1].toInt()) {
// // this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive)
// val correctSize = Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * (fragment[1] - 1) + (fragment.size - 2)
// val correctlySized = ByteArray(correctSize)
// System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize)
// metaChannel.fragmentedRegistrationDetails = correctlySized
// }
// if (metaChannel.remainingFragments.toInt() == 0) {
// // there are no more fragments available
// val details = metaChannel.fragmentedRegistrationDetails
// metaChannel.fragmentedRegistrationDetails = null
// if (!serialization.verifyKryoRegistration(details)) {
// // error
// return STATE.ERROR
// }
// } else {
// // wait for more fragments
// return STATE.WAIT
// }
// } else {
// if (!serialization.verifyKryoRegistration(registration.payload!!)) {
// return STATE.ERROR
// }
// }
// return STATE.CONTINUE
// }
// RMI notes (in multiple places, copypasta, because this is confusing if not written down) // RMI notes (in multiple places, copypasta, because this is confusing if not written down)
// //
// only server can create a global object (in itself, via save) // only server can create a global object (in itself, via save)
@ -532,7 +475,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveGlobalObject(`object`: Any): Int { fun saveGlobalObject(`object`: Any): Int {
val rmiId = rmiGlobalSupport.saveImplObject(`object`) val rmiId = rmiGlobalSupport.saveImplObject(`object`)
if (rmiId == RemoteObjectStorage.INVALID_RMI) { if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")
@ -561,7 +504,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveGlobalObject(`object`: Any, objectId: Int): Boolean { fun saveGlobalObject(`object`: Any, objectId: Int): Boolean {
val success = rmiGlobalSupport.saveImplObject(`object`, objectId) val success = rmiGlobalSupport.saveImplObject(`object`, objectId)
if (!success) { if (!success) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")

View File

@ -16,6 +16,7 @@
package dorkbox.network.connection package dorkbox.network.connection
import dorkbox.netUtil.IPv4 import dorkbox.netUtil.IPv4
import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.connection.ping.PingFuture import dorkbox.network.connection.ping.PingFuture
import dorkbox.network.connection.ping.PingMessage import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
@ -32,6 +33,7 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import org.agrona.collections.Int2IntCounterMap
import java.io.IOException import java.io.IOException
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -46,76 +48,68 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/** /**
* The publication port (used by aeron) for this connection. This is from the perspective of the server! * The publication port (used by aeron) for this connection. This is from the perspective of the server!
*/ */
internal val subscriptionPort: Int private val subscriptionPort: Int
internal val publicationPort: Int private val publicationPort: Int
/** /**
* the stream id of this connection. * the stream id of this connection. Can be 0 for IPC connections
*/ */
internal val streamId: Int private val streamId: Int
/** /**
* the session id of this connection. This value is UNIQUE * the session id of this connection. This value is UNIQUE
*/ */
internal val sessionId: Int
/**
* the id of this connection. This value is UNIQUE
*/
val id: Int val id: Int
get() = sessionId
/** /**
* the remote address, as a string. * the remote address, as a string. Will be "ipc" for IPC connections
*/ */
val remoteAddress: String val remoteAddress: String
/** /**
* the remote address, as an integer. * the remote address, as an integer. Can be 0 for IPC connections
*/ */
val remoteAddressInt: Int private val remoteAddressInt: Int
/** /**
* @return true if this connection is an IPC connection * @return true if this connection is an IPC connection
*/ */
val isIPC = connectionParameters.mediaDriverConnection is IpcMediaDriverConnection val isIpc = connectionParameters.mediaDriverConnection is IpcMediaDriverConnection
/** /**
* @return true if this connection is a network connection * @return true if this connection is a network connection
*/ */
val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverConnection val isNetwork = connectionParameters.mediaDriverConnection is UdpMediaDriverConnection
/**
* Returns the last calculated TCP return trip time, or -1 if or the [PingMessage] response has not yet been received.
*/
val lastRoundTripTime: Int
get() {
val pingFuture2 = pingFuture
return pingFuture2?.response ?: -1
}
/** /**
* the endpoint associated with this connection * the endpoint associated with this connection
*/ */
internal val endPoint = connectionParameters.endPoint internal val endPoint = connectionParameters.endPoint
private val listenerManager = atomic<ListenerManager<Connection>?>(null)
private val listenerManager = atomic<ListenerManager<Connection>?>(null)
val logger = endPoint.logger val logger = endPoint.logger
internal var preCloseAction: suspend () -> Unit = {}
internal var postCloseAction: suspend () -> Unit = {}
private val isClosed = atomic(false) private val isClosed = atomic(false)
/**
// * Returns the last calculated TCP return trip time, or -1 if or the [PingMessage] response has not yet been received.
// */
// val lastRoundTripTime: Int
// get() {
// val pingFuture2 = pingFuture
// return pingFuture2?.response ?: -1
// }
@Volatile @Volatile
private var pingFuture: PingFuture? = null private var pingFuture: PingFuture? = null
// while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error. // while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error.
private var remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED private val remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED
// The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 8 (external counter) + 4 (GCM counter) // The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 8 (external counter) + 4 (GCM counter)
// The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this // The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this
@ -128,6 +122,8 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// a record of how many messages are in progress of being sent. When closing the connection, this number must be 0 // a record of how many messages are in progress of being sent. When closing the connection, this number must be 0
private val messagesInProgress = atomic(0) private val messagesInProgress = atomic(0)
val toString0: () -> String
init { init {
val mediaDriverConnection = connectionParameters.mediaDriverConnection val mediaDriverConnection = connectionParameters.mediaDriverConnection
@ -135,12 +131,25 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
subscription = mediaDriverConnection.subscription subscription = mediaDriverConnection.subscription
publication = mediaDriverConnection.publication publication = mediaDriverConnection.publication
subscriptionPort = mediaDriverConnection.subscriptionPort remoteAddress = mediaDriverConnection.address // this can be the IP address or "ipc" word
publicationPort = mediaDriverConnection.publicationPort id = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server!
remoteAddress = mediaDriverConnection.address
remoteAddressInt = IPv4.toInt(remoteAddress) if (mediaDriverConnection is IpcMediaDriverConnection) {
streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server! streamId = 0 // this is because with IPC, we have stream sub/pub (which are replaced as port sub/pub)
sessionId = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server! subscriptionPort = mediaDriverConnection.streamIdSubscription
publicationPort = mediaDriverConnection.streamId
remoteAddressInt = 0
toString0 = { "[$id] IPC [$subscriptionPort|$publicationPort]" }
} else {
streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server!
subscriptionPort = mediaDriverConnection.subscriptionPort
publicationPort = mediaDriverConnection.publicationPort
remoteAddressInt = IPv4.toInt(mediaDriverConnection.address)
toString0 = { "[$id] $remoteAddress [$publicationPort|$subscriptionPort]" }
}
messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
@ -155,7 +164,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/** /**
* Has the remote ECC public key changed. This can be useful if specific actions are necessary when the key has changed. * @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed.
*/ */
fun hasRemoteKeyChanged(): Boolean { fun hasRemoteKeyChanged(): Boolean {
return remoteKeyChanged return remoteKeyChanged
@ -272,15 +281,14 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
return messagesInProgress.value return messagesInProgress.value
} }
/** /**
* @return `true` if this connection has no subscribers (which means this connection longer has a remote connection) * @return `true` if this connection has no subscribers (which means this connection does not have a remote connection)
*/ */
internal fun isExpired(): Boolean { internal fun isExpired(): Boolean {
return !subscription.isConnected // cannot use subscription.isConnected !!! images can be in a state of flux. We only care if there are NO images.
return subscription.hasNoImages()
} }
/** /**
* @return `true` if this connection has been closed * @return `true` if this connection has been closed
*/ */
@ -300,7 +308,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// the server 'handshake' connection info is cleaned up with the disconnect via timeout/expire. // the server 'handshake' connection info is cleaned up with the disconnect via timeout/expire.
if (isClosed.compareAndSet(expect = false, update = true)) { if (isClosed.compareAndSet(expect = false, update = true)) {
logger.info {"[${sessionId}] closed connection"} logger.info {"[$id] closed connection"}
subscription.close() subscription.close()
@ -332,11 +340,19 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
rmiConnectionSupport.clearProxyObjects() rmiConnectionSupport.clearProxyObjects()
// This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper
// lock-stop ordering for how disconnect and connect work with each-other
preCloseAction()
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
endPoint.actionDispatch.launch { endPoint.actionDispatch.launch {
// a connection might have also registered for disconnect events // a connection might have also registered for disconnect events
listenerManager.value?.notifyDisconnect(this@Connection) listenerManager.value?.notifyDisconnect(this@Connection)
} }
// This is set by the client so if there is a "connect()" call in the the disconnect callback, we can have proper
// lock-stop ordering for how disconnect and connect work with each-other
postCloseAction()
} }
} }
@ -387,11 +403,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// //
// //
override fun toString(): String { override fun toString(): String {
return "$remoteAddress $publicationPort/$subscriptionPort ID: $sessionId" return toString0()
} }
override fun hashCode(): Int { override fun hashCode(): Int {
return sessionId return id
} }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@ -406,9 +422,21 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
} }
val other1 = other as Connection val other1 = other as Connection
return sessionId == other1.sessionId return id == other1.id
} }
// cleans up the connection information
fun cleanup(connectionsPerIpCounts: Int2IntCounterMap, sessionIdAllocator: RandomIdAllocator, streamIdAllocator: RandomIdAllocator) {
if (isIpc) {
sessionIdAllocator.free(subscriptionPort)
sessionIdAllocator.free(publicationPort)
streamIdAllocator.free(streamId)
} else {
connectionsPerIpCounts.getAndDecrement(remoteAddressInt)
sessionIdAllocator.free(id)
streamIdAllocator.free(streamId)
}
}
// RMI notes (in multiple places, copypasta, because this is confusing if not written down) // RMI notes (in multiple places, copypasta, because this is confusing if not written down)
// //

View File

@ -69,6 +69,11 @@ internal class CryptoManagement(val logger: KLogger,
val secureRandom = SecureRandom(settingsStore.getSalt()) val secureRandom = SecureRandom(settingsStore.getSalt())
private val iv = ByteArray(GCM_IV_LENGTH)
private val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv)
val cryptOutput = AeronOutput()
val cryptInput = AeronInput()
private val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation private val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation
init { init {
@ -177,6 +182,7 @@ internal class CryptoManagement(val logger: KLogger,
return SecretKeySpec(hash.digest(), "AES") return SecretKeySpec(hash.digest(), "AES")
} }
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the server, mutually exclusive calls to decrypt)
fun encrypt(clientPublicKeyBytes: ByteArray, fun encrypt(clientPublicKeyBytes: ByteArray,
publicationPort: Int, publicationPort: Int,
subscriptionPort: Int, subscriptionPort: Int,
@ -185,29 +191,24 @@ internal class CryptoManagement(val logger: KLogger,
kryoRmiIds: IntArray): ByteArray { kryoRmiIds: IntArray): ByteArray {
val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes)
val iv = ByteArray(GCM_IV_LENGTH)
secureRandom.nextBytes(iv) secureRandom.nextBytes(iv)
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv)
aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec) aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec)
// now create the byte array that holds all our data // now create the byte array that holds all our data
val data = AeronOutput() cryptOutput.reset()
data.writeInt(connectionSessionId) cryptOutput.writeInt(connectionSessionId)
data.writeInt(connectionStreamId) cryptOutput.writeInt(connectionStreamId)
data.writeInt(publicationPort) cryptOutput.writeInt(publicationPort)
data.writeInt(subscriptionPort) cryptOutput.writeInt(subscriptionPort)
data.writeInt(kryoRmiIds.size) cryptOutput.writeInt(kryoRmiIds.size)
kryoRmiIds.forEach { kryoRmiIds.forEach {
data.writeInt(it) cryptOutput.writeInt(it)
} }
val bytes = data.toBytes() return iv + aesCipher.doFinal(cryptOutput.toBytes())
return iv + aesCipher.doFinal(bytes)
} }
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt)
fun decrypt(registrationData: ByteArray?, serverPublicKeyBytes: ByteArray?): ClientConnectionInfo? { fun decrypt(registrationData: ByteArray?, serverPublicKeyBytes: ByteArray?): ClientConnectionInfo? {
if (registrationData == null || serverPublicKeyBytes == null) { if (registrationData == null || serverPublicKeyBytes == null) {
return null return null
@ -216,7 +217,6 @@ internal class CryptoManagement(val logger: KLogger,
val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes) val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes)
// now read the encrypted data // now read the encrypted data
val iv = ByteArray(GCM_IV_LENGTH)
registrationData.copyInto(destination = iv, registrationData.copyInto(destination = iv,
endIndex = GCM_IV_LENGTH) endIndex = GCM_IV_LENGTH)
@ -226,21 +226,19 @@ internal class CryptoManagement(val logger: KLogger,
// now decrypt the data // now decrypt the data
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv)
aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec) aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec)
val data = AeronInput(aesCipher.doFinal(secretBytes)) cryptInput.buffer = aesCipher.doFinal(secretBytes)
val sessionId = cryptInput.readInt()
val sessionId = data.readInt() val streamId = cryptInput.readInt()
val streamId = data.readInt() val publicationPort = cryptInput.readInt()
val publicationPort = data.readInt() val subscriptionPort = cryptInput.readInt()
val subscriptionPort = data.readInt()
val rmiIds = mutableListOf<Int>() val rmiIds = mutableListOf<Int>()
val rmiIdSize = data.readInt() val rmiIdSize = cryptInput.readInt()
for (i in 0 until rmiIdSize) { for (i in 0 until rmiIdSize) {
rmiIds.add(data.readInt()) rmiIds.add(cryptInput.readInt())
} }
// now read data off // now read data off

View File

@ -22,6 +22,7 @@ import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.CoroutineIdleStrategy import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.connection.ping.PingMessage import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.ipFilter.IpFilterRule import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.other.coroutines.SuspendWaiter
import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.RmiManagerGlobal import dorkbox.network.rmi.RmiManagerGlobal
import dorkbox.network.rmi.messages.RmiMessage import dorkbox.network.rmi.messages.RmiMessage
@ -43,7 +44,6 @@ import mu.KLogger
import mu.KotlinLogging import mu.KotlinLogging
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.io.File import java.io.File
import java.util.concurrent.CountDownLatch
// If TCP and UDP both fill the pipe, THERE WILL BE FRAGMENTATION and dropped UDP packets! // If TCP and UDP both fill the pipe, THERE WILL BE FRAGMENTATION and dropped UDP packets!
@ -117,7 +117,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal val listenerManager = ListenerManager<CONNECTION>() internal val listenerManager = ListenerManager<CONNECTION>()
internal val connections = ConnectionManager<CONNECTION>() internal val connections = ConnectionManager<CONNECTION>()
private var mediaDriverContext: MediaDriver.Context? = null internal var mediaDriverContext: MediaDriver.Context? = null
private var mediaDriver: MediaDriver? = null private var mediaDriver: MediaDriver? = null
private var aeron: Aeron? = null private var aeron: Aeron? = null
@ -136,7 +136,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
private val shutdown = atomic(false) private val shutdown = atomic(false)
@Volatile @Volatile
private var shutdownLatch: CountDownLatch = CountDownLatch(1) private var shutdownLatch: SuspendWaiter = SuspendWaiter()
// we only want one instance of these created. These will be called appropriately // we only want one instance of these created. These will be called appropriately
val settingsStore: SettingsStore val settingsStore: SettingsStore
@ -219,7 +219,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
if (config.aeronLogDirectory == null) { if (config.aeronLogDirectory == null) {
val baseFileLocation = config.suggestAeronLogLocation(logger) val baseFileLocation = config.suggestAeronLogLocation(logger)
val aeronLogDirectory = File(baseFileLocation, "aeron-" + type.simpleName) // val aeronLogDirectory = File(baseFileLocation, "aeron-" + type.simpleName)
val aeronLogDirectory = File(baseFileLocation, "aeron")
aeronDirAlreadyExists = aeronLogDirectory.exists() aeronDirAlreadyExists = aeronLogDirectory.exists()
config.aeronLogDirectory = aeronLogDirectory config.aeronLogDirectory = aeronLogDirectory
} }
@ -229,6 +230,47 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
logger.warn("Aeron log directory already exists! This might not be what you want!") logger.warn("Aeron log directory already exists! This might not be what you want!")
} }
val threadFactory = NamedThreadFactory("Aeron", false)
// LOW-LATENCY SETTINGS
// .termBufferSparseFile(false)
// .useWindowsHighResTimer(true)
// .threadingMode(ThreadingMode.DEDICATED)
// .conductorIdleStrategy(BusySpinIdleStrategy.INSTANCE)
// .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE)
// .senderIdleStrategy(NoOpIdleStrategy.INSTANCE);
// setProperty(DISABLE_BOUNDS_CHECKS_PROP_NAME, "true");
// setProperty("aeron.mtu.length", "16384");
// setProperty("aeron.socket.so_sndbuf", "2097152");
// setProperty("aeron.socket.so_rcvbuf", "2097152");
// setProperty("aeron.rcv.initial.window.length", "2097152");
// driver context must happen in the initializer, because we have a Server.isRunning() method that uses the mediaDriverContext (without bind)
val mDrivercontext = MediaDriver.Context()
.publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH)
.dirDeleteOnStart(true)
.dirDeleteOnShutdown(true)
.conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory)
.senderThreadFactory(threadFactory)
.sharedNetworkThreadFactory(threadFactory)
.sharedThreadFactory(threadFactory)
.threadingMode(config.threadingMode)
.mtuLength(config.networkMtuSize)
.socketSndbufLength(config.sendBufferSize)
.socketRcvbufLength(config.receiveBufferSize)
mDrivercontext
.aeronDirectoryName(config.aeronLogDirectory!!.absolutePath)
.concludeAeronDirectory()
mDrivercontext.ipcTermBufferLength(16 * 1024 * 1024) // default: 64 megs each is HUGE
mDrivercontext.publicationTermBufferLength(4 * 1024 * 1024) // default: 16 megs each is HUGE (we run out of space in production w/ lots of clients)
mediaDriverContext = mDrivercontext
// serialization stuff // serialization stuff
serialization = config.serialization serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy sendIdleStrategy = config.sendIdleStrategy
@ -250,45 +292,26 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal fun initEndpointState(): Aeron { internal fun initEndpointState(): Aeron {
val aeronDirectory = config.aeronLogDirectory!!.absolutePath val aeronDirectory = config.aeronLogDirectory!!.absolutePath
val threadFactory = NamedThreadFactory("Aeron", false) if (!isRunning()) {
// the server always creates a media driver.
// LOW-LATENCY SETTINGS mediaDriver = try {
// .termBufferSparseFile(false) MediaDriver.launch(mediaDriverContext)
// .useWindowsHighResTimer(true) } catch (e: Exception) {
// .threadingMode(ThreadingMode.DEDICATED) listenerManager.notifyError(e)
// .conductorIdleStrategy(BusySpinIdleStrategy.INSTANCE) throw e
// .receiverIdleStrategy(NoOpIdleStrategy.INSTANCE) }
// .senderIdleStrategy(NoOpIdleStrategy.INSTANCE);
mediaDriverContext = MediaDriver.Context()
.publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH)
.dirDeleteOnStart(true)
.dirDeleteOnShutdown(true)
.conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory)
.senderThreadFactory(threadFactory)
.sharedNetworkThreadFactory(threadFactory)
.sharedThreadFactory(threadFactory)
.threadingMode(config.threadingMode)
.mtuLength(config.networkMtuSize)
.socketSndbufLength(config.sendBufferSize)
.socketRcvbufLength(config.receiveBufferSize)
.aeronDirectoryName(aeronDirectory)
val aeronContext = Aeron.Context().aeronDirectoryName(aeronDirectory)
mediaDriver = try {
MediaDriver.launch(mediaDriverContext)
} catch (e: Exception) {
listenerManager.notifyError(e)
throw e
} }
val aeronContext = Aeron.Context()
aeronContext
.aeronDirectoryName(aeronDirectory)
.concludeAeronDirectory()
try { try {
aeron = Aeron.connect(aeronContext) aeron = Aeron.connect(aeronContext)
} catch (e: Exception) { } catch (e: Exception) {
try { try {
mediaDriver!!.close() mediaDriver?.close()
} catch (secondaryException: Exception) { } catch (secondaryException: Exception) {
e.addSuppressed(secondaryException) e.addSuppressed(secondaryException)
} }
@ -299,8 +322,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
shutdown.getAndSet(false) shutdown.getAndSet(false)
shutdownLatch.countDown() shutdownLatch = SuspendWaiter()
shutdownLatch = CountDownLatch(1)
return aeron!! return aeron!!
} }
@ -466,11 +488,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
// more critical error sending the message. we shouldn't retry or anything. // more critical error sending the message. we shouldn't retry or anything.
listenerManager.notifyError(newException("Error sending message. ${errorCodeName(result)}")) listenerManager.notifyError(
newException("[${publication.sessionId()}] Error sending handshake message. $message (${errorCodeName(result)})"))
return return
} }
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(newException("Error serializing message $message", e)) listenerManager.notifyError(newException("[${publication.sessionId()}] Error serializing handshake message $message", e))
} finally { } finally {
sendIdleStrategy.reset() sendIdleStrategy.reset()
serialization.returnKryo(kryo) serialization.returnKryo(kryo)
@ -622,12 +645,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
// more critical error sending the message. we shouldn't retry or anything. // more critical error sending the message. we shouldn't retry or anything.
logger.error("Error sending message. ${errorCodeName(result)}") logger.error("[${publication.sessionId()}] Error sending message. $message (${errorCodeName(result)})")
return return
} }
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error serializing message $message", e) logger.error("[${publication.sessionId()}] Error serializing message $message", e)
} finally { } finally {
sendIdleStrategy.reset() sendIdleStrategy.reset()
serialization.returnKryo(kryo) serialization.returnKryo(kryo)
@ -671,8 +694,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
/** /**
* Waits for this endpoint to be closed * Waits for this endpoint to be closed
*/ */
fun waitForShutdown() { suspend fun waitForClose() {
shutdownLatch.await() shutdownLatch.doWait()
} }
/** /**
@ -681,10 +704,11 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* @return true if the client/server is active and running * @return true if the client/server is active and running
*/ */
fun isRunning(): Boolean { fun isRunning(): Boolean {
return mediaDriverContext?.isDriverActive(10_000, logger::debug) ?: false // if the media driver is running, it will be a quick connection. Usually 100ms or so
return mediaDriverContext?.isDriverActive(1_000, logger::debug) ?: false
} }
override fun close() { final override fun close() {
if (shutdown.compareAndSet(expect = false, update = true)) { if (shutdown.compareAndSet(expect = false, update = true)) {
aeron?.close() aeron?.close()
mediaDriver?.close() mediaDriver?.close()
@ -700,7 +724,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
} }
shutdownLatch.countDown() close0()
// if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now)
shutdownLatch.cancel()
} }
} }
internal open fun close0() {}
} }

View File

@ -198,9 +198,10 @@ class IpcMediaDriverConnection(override val streamId: Int,
val streamIdSubscription: Int, val streamIdSubscription: Int,
override val sessionId: Int, override val sessionId: Int,
private val connectionTimeoutMS: Long = 30_000, private val connectionTimeoutMS: Long = 30_000,
override val isReliable: Boolean = true) : MediaDriverConnection { ) : MediaDriverConnection {
override val address = "" override val isReliable = true
override val address = "ipc"
override val subscriptionPort = 0 override val subscriptionPort = 0
override val publicationPort = 0 override val publicationPort = 0
@ -209,10 +210,6 @@ class IpcMediaDriverConnection(override val streamId: Int,
var success: Boolean = false var success: Boolean = false
init {
}
private fun uri(): ChannelUriStringBuilder { private fun uri(): ChannelUriStringBuilder {
val builder = ChannelUriStringBuilder().media("ipc") val builder = ChannelUriStringBuilder().media("ipc")
if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
@ -226,14 +223,10 @@ class IpcMediaDriverConnection(override val streamId: Int,
override suspend fun buildClient(aeron: Aeron) { override suspend fun buildClient(aeron: Aeron) {
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri() val subscriptionUri = uri()
// .controlEndpoint("$address:$subscriptionPort")
// .controlMode("dynamic")
// Create a publication at the given address and port, using the given stream ID. // Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri() val publicationUri = uri()
// .endpoint("$address:$publicationPort")
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
@ -288,15 +281,10 @@ class IpcMediaDriverConnection(override val streamId: Int,
override fun buildServer(aeron: Aeron) { override fun buildServer(aeron: Aeron) {
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri() val subscriptionUri = uri()
// .endpoint("$address:$subscriptionPort")
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri() val publicationUri = uri()
// .controlEndpoint("$address:$publicationPort")
// .controlMode("dynamic")
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client. // publication of any state to other threads and not be long running or re-entrant with the client.
@ -305,22 +293,29 @@ class IpcMediaDriverConnection(override val streamId: Int,
} }
override fun clientInfo() : String { override fun clientInfo() : String {
return "" return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
"[$sessionId] aeron connection established to [$streamIdSubscription|$streamId]"
} else {
"Connecting IPC with handshake to [$streamIdSubscription|$streamId]"
}
} }
override fun serverInfo() : String { override fun serverInfo() : String {
return "" return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
} "[$sessionId] IPC listening on [$streamIdSubscription|$streamId] "
} else {
fun connect() : Pair<String, String> { "IPC listening with handshake on [$streamIdSubscription|$streamId]"
return Pair("","") }
} }
override fun close() { override fun close() {
if (success) {
subscription.close()
publication.close()
}
} }
override fun toString(): String { override fun toString(): String {
return "$address [$subscriptionPort|$publicationPort] [$streamId|$sessionId]" return "[$streamIdSubscription|$streamId] [$sessionId]"
} }
} }

View File

@ -15,10 +15,10 @@
*/ */
package dorkbox.network.handshake package dorkbox.network.handshake
internal class ClientConnectionInfo(val subscriptionPort: Int, internal class ClientConnectionInfo(val subscriptionPort: Int = 0,
val publicationPort: Int, val publicationPort: Int = 0,
val sessionId: Int, val sessionId: Int,
val streamId: Int, val streamId: Int = 0,
val publicKey: ByteArray, val publicKey: ByteArray = ByteArray(0),
val kryoIdsForRmi: IntArray) { val kryoIdsForRmi: IntArray) {
} }

View File

@ -22,7 +22,6 @@ import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.MediaDriverConnection import dorkbox.network.connection.MediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.FragmentAssembler import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
@ -82,7 +81,28 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
// The message was intended for this client. Try to parse it as one of the available message types. // The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED! // this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey) connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
}
HandshakeMessage.HELLO_ACK_IPC -> {
// The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED!
val cryptInput = crypto.cryptInput
cryptInput.buffer = message.registrationData
val sessionId = cryptInput.readInt()
val streamSubId = cryptInput.readInt()
val streamPubId = cryptInput.readInt()
val rmiIds = mutableListOf<Int>()
val rmiIdSize = cryptInput.readInt()
for (i in 0 until rmiIdSize) {
rmiIds.add(cryptInput.readInt())
}
// now read data off
connectionHelloInfo = ClientConnectionInfo(sessionId = sessionId,
subscriptionPort = streamSubId,
publicationPort = streamPubId,
kryoIdsForRmi = rmiIds.toIntArray())
} }
HandshakeMessage.DONE_ACK -> { HandshakeMessage.DONE_ACK -> {
connectionDone = true connectionDone = true
@ -124,7 +144,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment. // NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
pollCount = subscription.poll(handler, 2) pollCount = subscription.poll(handler, 1)
if (failed != null) { if (failed != null) {
// no longer necessary to hold this connection open // no longer necessary to hold this connection open
@ -150,7 +170,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
return connectionHelloInfo!! return connectionHelloInfo!!
} }
suspend fun handshakeDone(mediaConnection: UdpMediaDriverConnection, connectionTimeoutMS: Long): Boolean { suspend fun handshakeDone(mediaConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient() val registrationMessage = HandshakeMessage.doneFromClient()
// Send the done message to the server. // Send the done message to the server.
@ -168,7 +188,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment. // NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
pollCount = subscription.poll(handler, 2) pollCount = subscription.poll(handler, 1)
if (failed != null) { if (failed != null) {
// no longer necessary to hold this connection open // no longer necessary to hold this connection open

View File

@ -53,8 +53,9 @@ internal class HandshakeMessage private constructor() {
const val INVALID = -1 const val INVALID = -1
const val HELLO = 0 const val HELLO = 0
const val HELLO_ACK = 1 const val HELLO_ACK = 1
const val DONE = 2 const val HELLO_ACK_IPC = 2
const val DONE_ACK = 3 const val DONE = 3
const val DONE_ACK = 4
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray, registrationRmiIdData: IntArray): HandshakeMessage { fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray, registrationRmiIdData: IntArray): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
@ -73,6 +74,13 @@ internal class HandshakeMessage private constructor() {
return hello return hello
} }
fun helloAckIpcToClient(sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO_ACK_IPC
hello.sessionId = sessionId // has to be the same as before (the client expects this)
return hello
}
fun doneFromClient(): HandshakeMessage { fun doneFromClient(): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = DONE hello.state = DONE
@ -99,6 +107,7 @@ internal class HandshakeMessage private constructor() {
INVALID -> "INVALID" INVALID -> "INVALID"
HELLO -> "HELLO" HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK" HELLO_ACK -> "HELLO_ACK"
HELLO_ACK_IPC -> "HELLO_ACK_IPC"
DONE -> "DONE" DONE -> "DONE"
DONE_ACK -> "DONE_ACK" DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!" else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"

View File

@ -15,24 +15,31 @@
*/ */
package dorkbox.network.handshake package dorkbox.network.handshake
import com.github.benmanes.caffeine.cache.Cache
import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import dorkbox.network.Server import dorkbox.network.Server
import dorkbox.network.ServerConfiguration import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.client.ClientRejectedException import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.server.AllocationException import dorkbox.network.aeron.server.AllocationException
import dorkbox.network.aeron.server.RandomIdAllocator import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.aeron.server.ServerException import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.Aeron import io.aeron.Aeron
import io.aeron.Publication import io.aeron.Publication
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
import org.agrona.collections.Int2IntCounterMap import org.agrona.collections.Int2IntCounterMap
import org.agrona.collections.Int2ObjectHashMap import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.write import kotlin.concurrent.write
@ -40,12 +47,25 @@ import kotlin.concurrent.write
/** /**
* @throws IllegalArgumentException If the port range is not valid * @throws IllegalArgumentException If the port range is not valid
*/ */
@Suppress("DuplicatedCode")
internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger, internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger,
private val config: ServerConfiguration, private val config: ServerConfiguration,
private val listenerManager: ListenerManager<CONNECTION>) { private val listenerManager: ListenerManager<CONNECTION>) {
private val pendingConnectionsLock = ReentrantReadWriteLock() private val pendingConnectionsLock = ReentrantReadWriteLock()
private val pendingConnections = Int2ObjectHashMap<CONNECTION>() private val pendingConnections: Cache<Int,CONNECTION> = Caffeine.newBuilder()
.expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS)
.removalListener(RemovalListener<Any?, Any?> { _, value, cause ->
if (cause == RemovalCause.EXPIRED) {
@Suppress("UNCHECKED_CAST")
val connection = value as CONNECTION
listenerManager.notifyError(ClientTimedOutException("[${connection.id}] Waiting for registration response from client"))
runBlocking {
connection.close()
}
}
}).build()
private val connectionsPerIpCounts = Int2IntCounterMap(0) private val connectionsPerIpCounts = Int2IntCounterMap(0)
@ -54,51 +74,244 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
EndPoint.RESERVED_SESSION_ID_HIGH) EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
// note: CANNOT be called in action dispatch
fun processHandshakeMessageServer(handshakePublication: Publication, /**
sessionId: Int, * @return true if we should continue parsing the incoming message, false if we should abort
clientAddressString: String, */
clientAddress: Int, private fun validateMessageTypeAndDoPending(server: Server<CONNECTION>,
message: Any?, handshakePublication: Publication,
server: Server<CONNECTION>, message: Any?,
aeron: Aeron) { sessionId: Int,
connectionString: String): Boolean {
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request"))
server.actionDispatch.launch { server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
} }
return return false
} }
val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState
// check to see if this is a pending connection // check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) { if (message.state == HandshakeMessage.DONE) {
pendingConnectionsLock.write { val pendingConnection = pendingConnectionsLock.write {
val pendingConnection = pendingConnections.remove(sessionId) val con = pendingConnections.getIfPresent(sessionId)
if (pendingConnection != null) { pendingConnections.invalidate(sessionId)
logger.trace { "Connection from client $clientAddressString done with handshake." } con
}
// this enables the connection to start polling for messages if (pendingConnection == null) {
server.connections.add(pendingConnection) logger.error { "[$sessionId] Error! Connection from client $connectionString was null, and cannot complete handshake!" }
} else {
logger.trace { "[${pendingConnection.id}] Connection from client $connectionString done with handshake." }
server.actionDispatch.launch { // this enables the connection to start polling for messages
// now tell the client we are done server.connections.add(pendingConnection)
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
listenerManager.notifyConnect(pendingConnection)
}
return server.actionDispatch.launch {
// now tell the client we are done
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
listenerManager.notifyConnect(pendingConnection)
} }
} }
return false
}
return true
}
// note: CANNOT be called in action dispatch
fun processHandshakeMessageServer(server: Server<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
message: Any?,
aeron: Aeron) {
val connectionString = "IPC"
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, connectionString)) {
return
}
message as HandshakeMessage
val serialization = config.serialization
// VALIDATE:: make sure the serialization matches between the client/server!
if (!serialization.verifyKryoRegistration(message.registrationData!!)) {
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Registration data mismatch."))
return
} }
/////
/////
///// DONE WITH VALIDATION
/////
/////
// allocate session/stream id's
val connectionSessionId: Int
try {
connectionSessionId = sessionIdAllocator.allocate()
} catch (e: AllocationException) {
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return
}
val connectionStreamPubId: Int
try {
connectionStreamPubId = streamIdAllocator.allocate()
} catch (e: AllocationException) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return
}
val connectionStreamSubId: Int
try {
connectionStreamSubId = streamIdAllocator.allocate()
} catch (e: AllocationException) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
sessionIdAllocator.free(connectionStreamPubId)
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return
}
// create a new connection. The session ID is encrypted.
try {
// connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = IpcMediaDriverConnection(streamId = connectionStreamPubId,
streamIdSubscription = connectionStreamSubId,
sessionId = connectionSessionId,
connectionTimeoutMS = 0)
// we have to construct how the connection will communicate!
clientConnection.buildServer(aeron)
logger.info {
"[${clientConnection.sessionId}] aeron IPC connection established to $clientConnection"
}
val connection = server.newConnection(ConnectionParams(server, clientConnection, PublicKeyValidationState.VALID))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST")
val permitConnection = listenerManager.notifyFilter(connection)
if (!permitConnection) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamPubId)
val exception = ClientRejectedException("Connection was not permitted!")
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception)
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
}
return
}
///////////////
//// RMI
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
// NOTE: This modifies the readKryo! This cannot be on a different thread!
serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage ->
listenerManager.notifyError(connection,
ClientRejectedException(errorMessage))
}
///////////////
/// HANDSHAKE
///////////////
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckIpcToClient(sessionId)
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint
// now create the encrypted payload, using ECDH
val cryptOutput = server.crypto.cryptOutput
cryptOutput.reset()
cryptOutput.writeInt(connectionSessionId)
cryptOutput.writeInt(connectionStreamSubId)
cryptOutput.writeInt(connectionStreamPubId)
val kryoRmiIds = serialization.getKryoRmiIds()
cryptOutput.writeInt(kryoRmiIds.size)
kryoRmiIds.forEach {
cryptOutput.writeInt(it)
}
successMessage.registrationData = cryptOutput.toBytes()
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write {
pendingConnections.put(sessionId, connection)
}
// this tells the client all of the info to connect.
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, successMessage)
}
} catch (e: Exception) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamPubId)
listenerManager.notifyError(ServerException("Connection handshake from $connectionString crashed! Message $message", e))
}
}
// note: CANNOT be called in action dispatch
fun processHandshakeMessageServer(server: Server<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
clientAddressString: String,
clientAddress: Int,
message: Any?,
aeron: Aeron) {
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) {
return
}
message as HandshakeMessage
val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState
val serialization = config.serialization val serialization = config.serialization
try { try {
@ -272,7 +485,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// before we notify connect, we have to wait for the client to tell us that they can receive data // before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write { pendingConnectionsLock.write {
pendingConnections[sessionId] = connection pendingConnections.put(sessionId, connection)
} }
// this tells the client all of the info to connect. // this tells the client all of the info to connect.
@ -293,8 +506,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
* Free up resources from the closed connection * Free up resources from the closed connection
*/ */
fun cleanup(connection: CONNECTION) { fun cleanup(connection: CONNECTION) {
connectionsPerIpCounts.getAndDecrement(connection.remoteAddressInt) connection.cleanup(connectionsPerIpCounts, sessionIdAllocator, streamIdAllocator)
sessionIdAllocator.free(connection.sessionId) pendingConnections.invalidateAll()
streamIdAllocator.free(connection.streamId)
} }
} }

View File

@ -0,0 +1,29 @@
package dorkbox.network.other.coroutines
import kotlinx.coroutines.channels.Channel
// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting
// https://kotlinlang.org/docs/reference/coroutines/channels.html
class SuspendWaiter(private val channel: Channel<Unit> = Channel()) {
// "receive' suspends until another coroutine invokes "send"
// and
// "send" suspends until another coroutine invokes "receive".
suspend fun doWait() {
try {
channel.receive()
} catch (ignored: Exception) {
}
}
suspend fun doNotify() {
try {
channel.send(Unit)
} catch (ignored: Exception) {
}
}
fun cancel() {
try {
channel.cancel()
} catch (ignored: Exception) {
}
}
}