separated handshake from connection management. cleaned up API

This commit is contained in:
nathan 2020-08-15 13:21:20 +02:00
parent 9eb3c122d7
commit 8d9bef0ccc
6 changed files with 337 additions and 443 deletions

View File

@ -18,17 +18,23 @@ package dorkbox.network
import dorkbox.netUtil.IPv4 import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6 import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.client.ClientException import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.client.ClientTimedOutException import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.server.ClientRejectedException import dorkbox.network.connection.Connection
import dorkbox.network.connection.* import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.handshake.ClientHandshake import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiSupportConnection import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.rmi.TimeoutException import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException import dorkbox.util.exceptions.SecurityException
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/** /**
* The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's * The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's
@ -40,37 +46,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* Gets the version number. * Gets the version number.
*/ */
const val version = "5.0" const val version = "5.0"
/**
* Split array into chunks, max of 256 chunks.
* byte[0] = chunk ID
* byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte)
*/
private fun divideArray(source: ByteArray, chunksize: Int): Array<ByteArray>? {
val fragments = Math.ceil(source.size / chunksize.toDouble()).toInt()
if (fragments > 127) {
// cannot allow more than 127
return null
}
// pre-allocate the memory
val splitArray = Array(fragments) { ByteArray(chunksize + 2) }
var start = 0
for (i in splitArray.indices) {
var length: Int
length = if (start + chunksize > source.size) {
source.size - start
} else {
chunksize
}
splitArray[i] = ByteArray(length + 2)
splitArray[i][0] = i.toByte() // index
splitArray[i][1] = fragments.toByte() // total number of fragments
System.arraycopy(source, start, splitArray[i], 2, length)
start += chunksize
}
return splitArray
}
} }
/** /**
@ -85,7 +60,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
*/ */
private var remoteAddress = "" private var remoteAddress = ""
private val isConnected = atomic(false) @Volatile
private var isConnected = false
// is valid when there is a connection to the server, otherwise it is null // is valid when there is a connection to the server, otherwise it is null
private var connection: CONNECTION? = null private var connection: CONNECTION? = null
@ -95,7 +71,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
private val previousClosedConnectionActivity: Long = 0 private val previousClosedConnectionActivity: Long = 0
override val handshake = ClientHandshake(logger, config, listenerManager, crypto) private val handshake = ClientHandshake(logger, config, crypto, listenerManager)
private val rmiConnectionSupport = RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch) private val rmiConnectionSupport = RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch)
init { init {
@ -108,8 +84,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
if (config.networkMtuSize <= 0) { throw ClientException("configuration networkMtuSize must be > 0") } if (config.networkMtuSize <= 0) { throw ClientException("configuration networkMtuSize must be > 0") }
if (config.networkMtuSize >= 9 * 1024) { throw ClientException("configuration networkMtuSize must be < ${9 * 1024}") } if (config.networkMtuSize >= 9 * 1024) { throw ClientException("configuration networkMtuSize must be < ${9 * 1024}") }
autoClosableObjects.add(handshake)
} }
override fun newException(message: String, cause: Throwable?): Throwable { override fun newException(message: String, cause: Throwable?): Throwable {
@ -124,26 +98,29 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
/** /**
* Will attempt to connect to the server, with a default 30 second connection timeout and will BLOCK until completed * Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed.
* *
* For a network address, it can be: * Default connection is to localhost
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* *
* For the IPC (Inter-Process-Communication) address. it must be: * ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
*
* ### For the IPC (Inter-Process-Communication) address. it must be:
* - the IPC integer ID, "0x1337c0de", "0x12312312", etc. * - the IPC integer ID, "0x1337c0de", "0x12312312", etc.
* *
* Note: Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
*
* *
* @param remoteAddress The network or IPC address for the client to connect to * @param remoteAddress The network or IPC address for the client to connect to
* @param connectionTimeout wait for x milliseconds. 0 will wait indefinitely * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable * @param reliable true if we want to create a reliable connection. IPC connections are always reliable
* *
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/ */
suspend fun connect(remoteAddress: String = "", connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
if (isConnected.value) { if (isConnected) {
logger.error("Unable to connect when already connected!") logger.error("Unable to connect when already connected!")
return return
} }
@ -151,26 +128,33 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
this.connectionTimeoutMS = connectionTimeoutMS this.connectionTimeoutMS = connectionTimeoutMS
// localhost/loopback IP might not always be 127.0.0.1 or ::1 // localhost/loopback IP might not always be 127.0.0.1 or ::1
when (remoteAddress) { when (remoteAddress) {
"loopback", "localhost", "lo" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress "loopback", "localhost", "lo", "" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress
else -> when { else -> when {
remoteAddress.startsWith("127.") -> this.remoteAddress = IPv4.LOCALHOST.hostAddress IPv4.isLoopback(remoteAddress) -> this.remoteAddress = IPv4.LOCALHOST.hostAddress
remoteAddress.startsWith("::1") -> this.remoteAddress = IPv6.LOCALHOST.hostAddress IPv6.isLoopback(remoteAddress) -> this.remoteAddress = IPv6.LOCALHOST.hostAddress
else -> this.remoteAddress = remoteAddress else -> this.remoteAddress = remoteAddress // might be IPC address!
} }
} }
// if we are IPv6, the IP must be in '[]'
if (this.remoteAddress.count { it == ':' } > 1 &&
this.remoteAddress.count { it == '[' } < 1 &&
this.remoteAddress.count { it == ']' } < 1) {
this.remoteAddress = """[${this.remoteAddress}]"""
}
// if we are IPv4 wildcard
if (this.remoteAddress == "0.0.0.0") { if (this.remoteAddress == "0.0.0.0") {
throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
} }
if (IPv6.isValid(this.remoteAddress)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (this.remoteAddress.count { it == '[' } < 1 &&
this.remoteAddress.count { it == ']' } < 1) {
this.remoteAddress = """[${this.remoteAddress}]"""
}
}
handshake.init(this) handshake.init(this)
if (this.remoteAddress.isEmpty()) { if (this.remoteAddress.isEmpty()) {
@ -180,8 +164,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// config.aeronLogDirectory // config.aeronLogDirectory
// stream IDs are flipped for a client because we operate from the perspective of the server // stream IDs are flipped for a client because we operate from the perspective of the server
val handshakeConnection = IpcMediaDriverConnection( val handshakeConnection = IpcMediaDriverConnection(
streamId = IPC_HANDSHAKE_STREAM_ID_SUB, streamId = IPC_HANDSHAKE_STREAM_ID_SUB,
@ -209,20 +191,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// THIS IS A NETWORK ADDRESS // THIS IS A NETWORK ADDRESS
// initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER // initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER
val handshakeConnection = UdpMediaDriverConnection( val handshakeConnection = UdpMediaDriverConnection(address = this.remoteAddress,
address = this.remoteAddress, publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort, subscriptionPort = config.publicationPort,
publicationPort = config.subscriptionPort, streamId = UDP_HANDSHAKE_STREAM_ID,
streamId = UDP_HANDSHAKE_STREAM_ID, sessionId = RESERVED_SESSION_ID_INVALID,
sessionId = RESERVED_SESSION_ID_INVALID, connectionTimeoutMS = connectionTimeoutMS,
connectionTimeoutMS = connectionTimeoutMS, isReliable = reliable)
isReliable = reliable)
autoClosableObjects.add(handshakeConnection)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
handshakeConnection.buildClient(aeron) handshakeConnection.buildClient(aeron)
logger.debug(handshakeConnection.clientInfo()) logger.info(handshakeConnection.clientInfo())
// this will block until the connection timeout, and throw an exception if we were unable to connect with the server // this will block until the connection timeout, and throw an exception if we were unable to connect with the server
@ -232,19 +211,22 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// we are now connected, so we can connect to the NEW client-specific ports // we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = UdpMediaDriverConnection( val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address,
address = handshakeConnection.address, // NOTE: pub/sub must be switched!
subscriptionPort = connectionInfo.subscriptionPort, publicationPort = connectionInfo.subscriptionPort,
publicationPort = connectionInfo.publicationPort, subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId, streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId, sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS, connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable) isReliable = handshakeConnection.isReliable)
// VALIDATE:: check to see if the remote connection's public key has changed! // VALIDATE:: check to see if the remote connection's public key has changed!
if (!crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)) { val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
listenerManager.notifyError(ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.")) if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
return handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.")
listenerManager.notifyError(exception)
throw exception
} }
// VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the // VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the
@ -256,21 +238,22 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// does not need to do anything // does not need to do anything
// //
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports // throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
logger.debug(reliableClientConnection.clientInfo()) logger.info(reliableClientConnection.clientInfo())
val newConnection = newConnection(this, reliableClientConnection) val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress))
autoClosableObjects.add(newConnection)
// VALIDATE are we allowed to connect to this server (now that we have the initial server information) // VALIDATE are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val permitConnection = listenerManager.notifyFilter(newConnection) val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) { if (!permitConnection) {
listenerManager.notifyError(ClientRejectedException("Connection to $remoteAddress was not permitted!")) handshakeConnection.close()
return val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!")
listenerManager.notifyError(exception)
throw exception
} }
connection = newConnection connection = newConnection
handshake.addConnection(newConnection) connections.add(newConnection)
// have to make a new thread to listen for incoming data! // have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
@ -287,24 +270,25 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// tell the server our connection handshake is done, and the connection can now listen for data. // tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold this connection open
handshakeConnection.close()
if (canFinishConnecting) { if (canFinishConnecting) {
isConnected.lazySet(true) isConnected = true
listenerManager.notifyConnect(newConnection)
actionDispatch.launch {
listenerManager.notifyConnect(newConnection)
}
} else { } else {
close() close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
listenerManager.notifyError(exception)
throw exception
} }
} }
} }
/**
* Checks to see if this client has connected yet or not.
*
* @return true if we are connected, false otherwise.
*/
override fun isConnected(): Boolean {
return isConnected.value
}
// override fun hasRemoteKeyChanged(): Boolean { // override fun hasRemoteKeyChanged(): Boolean {
// return connection!!.hasRemoteKeyChanged() // return connection!!.hasRemoteKeyChanged()
// } // }
@ -373,18 +357,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
} }
/**
* @throws ClientException when a message cannot be sent
*/
suspend fun send(message: Any, priority: Byte) {
val c = connection
if (c != null) {
c.send(message, priority)
} else {
throw ClientException("Cannot send a message when there is no connection!")
}
}
/** /**
* @throws ClientException when a ping cannot be sent * @throws ClientException when a ping cannot be sent
*/ */
@ -409,104 +381,24 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
} }
} }
// fun initClassRegistration(channel: Channel, registration: Registration): Boolean {
// val details = serialization.getKryoRegistrationDetails()
// val length = details.size
// if (length > Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) {
// // it is too large to send in a single packet
//
// // child arrays have index 0 also as their 'index' and 1 is the total number of fragments
// val fragments = divideArray(details, Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE)
// if (fragments == null) {
// logger.error("Too many classes have been registered for Serialization. Please report this issue")
// return false
// }
// val allButLast = fragments.size - 1
// for (i in 0 until allButLast) {
// val fragment = fragments[i]
// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey())
// fragmentedRegistration.payload = fragment
//
// // tell the server we are fragmented
// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// fragmentedRegistration.upgraded = true
// channel.writeAndFlush(fragmentedRegistration)
// }
//
// // now tell the server we are done with the fragments
// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey())
// fragmentedRegistration.payload = fragments[allButLast]
//
// // tell the server we are fragmented
// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// fragmentedRegistration.upgraded = true
// channel.writeAndFlush(fragmentedRegistration)
// } else {
// registration.payload = details
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// registration.upgraded = true
// channel.writeAndFlush(registration)
// }
// return true
// }
// /**
// * Closes all connections ONLY (keeps the client running). To STOP the client, use stop().
// * <p/>
// * This is used, for example, when reconnecting to a server.
// */
// protected
// void closeConnection() {
// if (isConnected.get()) {
// // make sure we're not waiting on registration
// stopRegistration();
//
// // for the CLIENT only, we clear these connections! (the server only clears them on shutdown)
//
// // stop does the same as this + more. Only keep the listeners for connections IF we are the client. If we remove listeners as a client,
// // ALL of the client logic will be lost. The server is reactive, so listeners are added to connections as needed (instead of before startup)
// connectionManager.closeConnections(true);
//
// // Sometimes there might be "lingering" connections (ie, halfway though registration) that need to be closed.
// registrationWrapper.clearSessions();
//
//
// closeConnections(true);
// shutdownAllChannels();
// // shutdownEventLoops(); we don't do this here!
//
// connection = null;
// isConnected.set(false);
//
// previousClosedConnectionActivity = System.nanoTime();
// }
// }
// /**
// * Internal call to abort registration if the shutdown command is issued during channel registration.
// */
// @Suppress("unused")
// fun abortRegistration() {
// // make sure we're not waiting on registration
//// stopRegistration()
// }
override fun close() { override fun close() {
val con = connection val con = connection
connection = null connection = null
if (con != null) { if (con != null) {
handshake.removeConnection(con) connections.remove(con)
runBlocking {
con.close()
listenerManager.notifyDisconnect(con)
}
} }
super.close() super.close()
isConnected = false
} }
// RMI notes (in multiple places, copypasta, because this is confusing if not written down // RMI notes (in multiple places, copypasta, because this is confusing if not written down)
// //
// only server can create a global object (in itself, via save) // only server can create a global object (in itself, via save)
// server // server

View File

@ -20,12 +20,14 @@ import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy
import dorkbox.network.serialization.NetworkSerializationManager import dorkbox.network.serialization.NetworkSerializationManager
import dorkbox.network.serialization.Serialization import dorkbox.network.serialization.Serialization
import dorkbox.network.store.PropertyStore import dorkbox.network.storage.PropertyStore
import dorkbox.network.store.SettingsStore import dorkbox.network.storage.SettingsStore
import dorkbox.os.OS
import dorkbox.util.storage.StorageBuilder import dorkbox.util.storage.StorageBuilder
import dorkbox.util.storage.StorageSystem import dorkbox.util.storage.StorageSystem
import io.aeron.driver.Configuration import io.aeron.driver.Configuration
import io.aeron.driver.ThreadingMode import io.aeron.driver.ThreadingMode
import mu.KLogger
import java.io.File import java.io.File
class ServerConfiguration : dorkbox.network.Configuration() { class ServerConfiguration : dorkbox.network.Configuration() {
@ -35,11 +37,6 @@ class ServerConfiguration : dorkbox.network.Configuration() {
*/ */
var listenIpAddress = "*" var listenIpAddress = "*"
/**
* The starting port for clients to use. The upper bound of this value is limited by the maximum number of clients allowed.
*/
var clientStartPort = 0
/** /**
* The maximum number of clients allowed for a server * The maximum number of clients allowed for a server
*/ */
@ -55,6 +52,8 @@ open class Configuration {
/** /**
* When connecting to a remote client/server, should connections be allowed if the remote machine signature has changed? * When connecting to a remote client/server, should connections be allowed if the remote machine signature has changed?
*
* Setting this to false is not recommended as it is a security risk
*/ */
var enableRemoteSignatureValidation: Boolean = true var enableRemoteSignatureValidation: Boolean = true
@ -103,8 +102,8 @@ open class Configuration {
* The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT. * The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT.
* *
* There are a couple strategies of importance to understand. * There are a couple strategies of importance to understand.
* * BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default. * - BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default.
* * BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less * - BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less
* responsive to activity when idle for a little while. * responsive to activity when idle for a little while.
* *
* The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and * The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and
@ -116,8 +115,8 @@ open class Configuration {
* The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT. * The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT.
* *
* There are a couple strategies of importance to understand. * There are a couple strategies of importance to understand.
* * BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default. * - BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default.
* * BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less * - BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less
* responsive to activity when idle for a little while. * responsive to activity when idle for a little while.
* *
* The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and * The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and
@ -126,26 +125,22 @@ open class Configuration {
var sendIdleStrategy: CoroutineIdleStrategy = CoroutineSleepingMillisIdleStrategy(sleepPeriodMs = 100) var sendIdleStrategy: CoroutineIdleStrategy = CoroutineSleepingMillisIdleStrategy(sleepPeriodMs = 100)
/** /**
* A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation. * ## A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation.
* *
* *
* There are three main Agents in the driver: * There are three main Agents in the driver:
* * - Conductor: Responsible for reacting to client requests and house keeping duties as well as detecting loss, sending NAKs,
*
* Conductor: Responsible for reacting to client requests and house keeping duties as well as detecting loss, sending NAKs,
* rotating buffers, etc. * rotating buffers, etc.
* Sender: Responsible for shovelling messages from publishers to the network. * - Sender: Responsible for shovelling messages from publishers to the network.
* Receiver: Responsible for shovelling messages from the network to subscribers. * - Receiver: Responsible for shovelling messages from the network to subscribers.
* *
* *
* This value can be one of: * This value can be one of:
* * - INVOKER: No threads. The client is responsible for using the MediaDriver.Context.driverAgentInvoker() to invoke the duty
*
* INVOKER: No threads. The client is responsible for using the MediaDriver.Context.driverAgentInvoker() to invoke the duty
* cycle directly. * cycle directly.
* SHARED: All Agents share a single thread. 1 thread in total. * - SHARED: All Agents share a single thread. 1 thread in total.
* SHARED_NETWORK: Sender and Receiver shares a thread, conductor has its own thread. 2 threads in total. * - SHARED_NETWORK: Sender and Receiver shares a thread, conductor has its own thread. 2 threads in total.
* DEDICATED: The default and dedicates one thread per Agent. 3 threads in total. * - DEDICATED: The default and dedicates one thread per Agent. 3 threads in total.
* *
* *
* For performance, it is recommended to use DEDICATED as long as the number of busy threads is less than or equal to the number of * For performance, it is recommended to use DEDICATED as long as the number of busy threads is less than or equal to the number of
@ -217,4 +212,31 @@ open class Configuration {
* A value of 0 will 'auto-configure' this setting. * A value of 0 will 'auto-configure' this setting.
*/ */
var receiveBufferSize = 0 var receiveBufferSize = 0
/**
* Depending on the OS, different base locations for the Aeron log directory are preferred.
*/
fun suggestAeronLogLocation(logger: KLogger): File {
return when {
OS.isMacOsX() -> {
// does the recommended location exist??
val suggestedLocation = File("/Volumes/DevShm")
if (suggestedLocation.exists()) {
suggestedLocation
}
else {
logger.info("It is recommended to create a RAM drive for best performance. For example\n" + "\$ diskutil erasevolume HFS+ \"DevShm\" `hdiutil attach -nomount ram://\$((2048 * 2048))`\n" + "\t After this, set config.aeronLogDirectory = \"/Volumes/DevShm\"")
File(System.getProperty("java.io.tmpdir"))
}
}
OS.isLinux() -> {
// this is significantly faster for linux than using the temp dir
File("/dev/shm/")
}
else -> {
File(System.getProperty("java.io.tmpdir"))
}
}
}
} }

View File

@ -16,6 +16,7 @@
package dorkbox.network package dorkbox.network
import dorkbox.netUtil.IPv4 import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.server.ServerException import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
@ -36,9 +37,6 @@ import java.net.InetSocketAddress
import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.CopyOnWriteArrayList
/** /**
* NOTE: when using "server.publish(A)", this will go to ALL CLIENTS! add this to aeron via "publication.addDestination" so aeron manages it
*
*
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the * The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
* server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections()) * server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections())
* *
@ -75,7 +73,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
private var bindAlreadyCalled = false private var bindAlreadyCalled = false
override val handshake = ServerHandshake(logger, config, listenerManager) private val handshake = ServerHandshake(logger, config, listenerManager)
/** /**
* Maintains a thread-safe collection of rules used to define the connection type with this server. * Maintains a thread-safe collection of rules used to define the connection type with this server.
@ -90,25 +88,29 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
when (config.listenIpAddress) { when (config.listenIpAddress) {
"loopback", "localhost", "lo" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress "loopback", "localhost", "lo" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
else -> when { else -> when {
config.listenIpAddress.startsWith("127.") -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
config.listenIpAddress.startsWith("::1") -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress
else -> config.listenIpAddress = "0.0.0.0" // we set this to "0.0.0.0" so that it is clear that we are trying to bind to that address. else -> config.listenIpAddress = "0.0.0.0" // we set this to "0.0.0.0" so that it is clear that we are trying to bind to that address.
} }
} }
// if we are IPv6, the IP must be in '[]' // if we are IPv4 wildcard
if (config.listenIpAddress.count { it == ':' } > 1 &&
config.listenIpAddress.count { it == '[' } < 1 &&
config.listenIpAddress.count { it == ']' } < 1) {
config.listenIpAddress = """[${config.listenIpAddress}]"""
}
if (config.listenIpAddress == "0.0.0.0") { if (config.listenIpAddress == "0.0.0.0") {
// this will also fixup windows! // this will also fixup windows!
config.listenIpAddress = IPv4.WILDCARD config.listenIpAddress = IPv4.WILDCARD
} }
if (IPv6.isValid(config.listenIpAddress)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (config.listenIpAddress.count { it == '[' } < 1 &&
config.listenIpAddress.count { it == ']' } < 1) {
config.listenIpAddress = """[${config.listenIpAddress}]"""
}
}
if (config.publicationPort <= 0) { throw ServerException("configuration port must be > 0") } if (config.publicationPort <= 0) { throw ServerException("configuration port must be > 0") }
if (config.publicationPort >= 65535) { throw ServerException("configuration port must be < 65535") } if (config.publicationPort >= 65535) { throw ServerException("configuration port must be < 65535") }
@ -119,8 +121,9 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (config.networkMtuSize <= 0) { throw ServerException("configuration networkMtuSize must be > 0") } if (config.networkMtuSize <= 0) { throw ServerException("configuration networkMtuSize must be > 0") }
if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") } if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") }
autoClosableObjects.add(handshake) if (config.maxConnectionsPerIpAddress == 0) { config.maxConnectionsPerIpAddress = config.maxClientCount}
} }
override fun newException(message: String, cause: Throwable?): Throwable { override fun newException(message: String, cause: Throwable?): Throwable {
return ServerException(message, cause) return ServerException(message, cause)
} }
@ -145,20 +148,18 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// setup the "HANDSHAKE" ports, for initial clients to connect. // setup the "HANDSHAKE" ports, for initial clients to connect.
// The is how clients then get the new ports to connect to + other configuration options // The is how clients then get the new ports to connect to + other configuration options
val handshakeDriver = UdpMediaDriverConnection( val handshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress,
address = config.listenIpAddress, publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort, subscriptionPort = config.subscriptionPort,
publicationPort = config.publicationPort, streamId = UDP_HANDSHAKE_STREAM_ID,
streamId = UDP_HANDSHAKE_STREAM_ID, sessionId = RESERVED_SESSION_ID_INVALID)
sessionId = RESERVED_SESSION_ID_INVALID)
handshakeDriver.buildServer(aeron) handshakeDriver.buildServer(aeron)
val handshakePublication = handshakeDriver.publication val handshakePublication = handshakeDriver.publication
val handshakeSubscription = handshakeDriver.subscription val handshakeSubscription = handshakeDriver.subscription
logger.debug(handshakeDriver.serverInfo()) logger.info(handshakeDriver.serverInfo())
logger.debug("Server listening for incoming clients on ${handshakePublication.localSocketAddresses()}")
val ipcHandshakeDriver = IpcMediaDriverConnection( val ipcHandshakeDriver = IpcMediaDriverConnection(
@ -197,7 +198,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
try { try {
var pollCount: Int var pollCount: Int
while (!isShutdown()) { while (!isShutdown()) {
// Get the current time, used to cleanup connections
val now = System.currentTimeMillis()
pollCount = 0 pollCount = 0
// this checks to see if there are NEW clients // this checks to see if there are NEW clients
@ -206,8 +211,38 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// this checks to see if there are NEW clients via IPC // this checks to see if there are NEW clients via IPC
// pollCount += ipcHandshakeSubscription.poll(ipcInitialConnectionHandler, 100) // pollCount += ipcHandshakeSubscription.poll(ipcInitialConnectionHandler, 100)
// this manages existing clients (for cleanup + connection polling) // this manages existing clients (for cleanup + connection polling)
pollCount += handshake.poll() connections.forEachWithCleanup({ connection ->
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (connection.isExpired(now)) {
logger.debug("[{}] connection expired", connection.sessionId)
shouldCleanupConnection = true
}
if (connection.isClosed()) {
logger.debug("[{}] connection closed", connection.sessionId)
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
true
}
else {
// Otherwise, poll the duologue for activity.
pollCount += connection.pollSubscriptions()
false
}
}, { connectionToClean ->
logger.debug("[{}] deleted connection", connectionToClean.sessionId)
// have to free up resources!
handshake.cleanup(connectionToClean)
connectionToClean.close()
listenerManager.notifyDisconnect(connectionToClean)
})
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
@ -229,7 +264,14 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
} }
} }
internal suspend fun poll(): Int {
var pollCount = 0
return pollCount
}
/** /**
@ -250,45 +292,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
connectionRules.addAll(listOf(*rules)) connectionRules.addAll(listOf(*rules))
} }
// verify the class ID registration details.
// the client will send their class registration data. VERIFY IT IS CORRECT!
// verify the class ID registration details.
// the client will send their class registration data. VERIFY IT IS CORRECT!
// var state: dorkbox.network.connection.RegistrationWrapper.STATE = registrationWrapper.verifyClassRegistration(metaChannel, registration)
// if (state == RegistrationWrapper.STATE.ERROR)
// {
// // abort! There was an error
// shutdown(channel, 0)
// return
// } else if (state == RegistrationWrapper.STATE.WAIT)
// {
// return
// }
/**
* Checks to seeOnce a server has connected to ANY client, it will always return true until server.close() is called
*
* @return true if we are connected, false otherwise.
*/
override fun isConnected(): Boolean {
return handshake.connectionCount() > 0
}
/** /**
* Safely sends objects to a destination * Safely sends objects to a destination
*/ */
suspend fun send(message: Any) { suspend fun send(message: Any) {
handshake.send(message) connections.send(message)
} }
/** /**
@ -348,6 +356,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
/** /**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
* Adds a custom connection to the server. * Adds a custom connection to the server.
* *
* This should only be used in situations where there can be DIFFERENT types of connections (such as a 'web-based' connection) and * This should only be used in situations where there can be DIFFERENT types of connections (such as a 'web-based' connection) and
@ -356,10 +365,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param connection the connection to add * @param connection the connection to add
*/ */
fun addConnection(connection: CONNECTION) { fun addConnection(connection: CONNECTION) {
handshake.addConnection(connection) connections.add(connection)
} }
/** /**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
* Removes a custom connection to the server. * Removes a custom connection to the server.
* *
* *
@ -369,11 +379,10 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param connection the connection to remove * @param connection the connection to remove
*/ */
fun removeConnection(connection: CONNECTION) { fun removeConnection(connection: CONNECTION) {
handshake.removeConnection(connection) connections.remove(connection)
} }
/** /**
* Checks to see if a server (using the specified configuration) is running. * Checks to see if a server (using the specified configuration) is running.
* *
@ -467,7 +476,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// RMI notes (in multiple places, copypasta, because this is confusing if not written down // RMI notes (in multiple places, copypasta, because this is confusing if not written down)
// //
// only server can create a global object (in itself, via save) // only server can create a global object (in itself, via save)
// server // server

View File

@ -3,7 +3,12 @@ package dorkbox.network.handshake
import dorkbox.network.Configuration import dorkbox.network.Configuration
import dorkbox.network.aeron.client.ClientException import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientTimedOutException import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.connection.* import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.MediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.FragmentAssembler import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
@ -12,10 +17,10 @@ import mu.KLogger
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.security.SecureRandom import java.security.SecureRandom
internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger, internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogger,
config: Configuration, private val config: Configuration,
listenerManager: ListenerManager<CONNECTION>, private val crypto: CryptoManagement,
val crypto: CryptoManagement) : ConnectionManager<CONNECTION>(logger, config, listenerManager) { private val listenerManager: ListenerManager<CONNECTION>) {
// a one-time key for connecting // a one-time key for connecting
private val oneTimePad = SecureRandom().nextInt() private val oneTimePad = SecureRandom().nextInt()
@ -25,8 +30,8 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
@Volatile @Volatile
var connectionDone = false var connectionDone = false
@Volatile
private var failed = false private var failed: Exception? = null
lateinit var handler: FragmentHandler lateinit var handler: FragmentHandler
lateinit var endPoint: EndPoint<*> lateinit var endPoint: EndPoint<*>
@ -38,37 +43,48 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
// now we have a bi-directional connection with the server on the handshake "socket". // now we have a bi-directional connection with the server on the handshake "socket".
handler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> handler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
endPoint.actionDispatch.launch { endPoint.actionDispatch.launch {
val sessionId = header.sessionId()
val message = endPoint.readHandshakeMessage(buffer, offset, length, header) val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
logger.debug("[{}] handshake response: {}", sessionId, message) logger.trace {
"[$sessionId] handshake response: $message"
}
// it must be a registration message // it must be a registration message
if (message !is Message) { if (message !is HandshakeMessage) {
logger.error("[{}] server returned unrecognized message: {}", sessionId, message) failed = ClientException("[$sessionId] server returned unrecognized message: $message")
return@launch return@launch
} }
if (message.sessionId != sessionId) { // this is an error message
logger.error("[{}] ignored message intended for another client", sessionId) if (message.sessionId == 0) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@launch
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}")
return@launch return@launch
} }
// it must be the correct state // it must be the correct state
when (message.state) { when (message.state) {
Message.HELLO_ACK -> { HandshakeMessage.HELLO_ACK -> {
// The message was intended for this client. Try to parse it as one of the available message types. // The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED! // this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey) connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
connectionHelloInfo!!.log(sessionId, logger) connectionHelloInfo!!.log(sessionId, logger)
} }
Message.DONE_ACK -> { HandshakeMessage.DONE_ACK -> {
connectionDone = true connectionDone = true
} }
else -> { else -> {
if (message.state != Message.HELLO_ACK) { if (message.state != HandshakeMessage.HELLO_ACK) {
logger.error("[{}] ignored message that is not HELLO_ACK", sessionId) failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
} else if (message.state != Message.DONE_ACK) { } else if (message.state != HandshakeMessage.DONE_ACK) {
logger.error("[{}] ignored message that is not INIT_ACK", sessionId) failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
} }
return@launch return@launch
@ -79,7 +95,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
} }
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
val registrationMessage = Message.helloFromClient( val registrationMessage = HandshakeMessage.helloFromClient(
oneTimePad = oneTimePad, oneTimePad = oneTimePad,
publicKey = config.settingsStore.getPublicKey()!!, publicKey = config.settingsStore.getPublicKey()!!,
registrationData = config.serialization.getKryoRegistrationDetails() registrationData = config.serialization.getKryoRegistrationDetails()
@ -93,7 +109,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
// block until we receive the connection information from the server // block until we receive the connection information from the server
failed = false failed = null
var pollCount: Int var pollCount: Int
val subscription = handshakeConnection.subscription val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy val pollIdleStrategy = endPoint.config.pollIdleStrategy
@ -102,10 +118,10 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
pollCount = subscription.poll(handler, 1024) pollCount = subscription.poll(handler, 1024)
if (failed) { if (failed != null) {
// no longer necessary to hold this connection open // no longer necessary to hold this connection open
handshakeConnection.close() handshakeConnection.close()
throw ClientException("Server rejected this client") throw failed as Exception
} }
if (connectionHelloInfo != null) { if (connectionHelloInfo != null) {
@ -127,7 +143,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
} }
suspend fun handshakeDone(mediaConnection: UdpMediaDriverConnection, connectionTimeoutMS: Long): Boolean { suspend fun handshakeDone(mediaConnection: UdpMediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = Message.doneFromClient() val registrationMessage = HandshakeMessage.doneFromClient()
// Send the done message to the server. // Send the done message to the server.
endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage) endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage)
@ -135,7 +151,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
// block until we receive the connection information from the server // block until we receive the connection information from the server
failed = false failed = null
var pollCount: Int var pollCount: Int
val subscription = mediaConnection.subscription val subscription = mediaConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy val pollIdleStrategy = endPoint.config.pollIdleStrategy
@ -144,10 +160,10 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
pollCount = subscription.poll(handler, 1024) pollCount = subscription.poll(handler, 1024)
if (failed) { if (failed != null) {
// no longer necessary to hold this connection open // no longer necessary to hold this connection open
mediaConnection.close() mediaConnection.close()
throw ClientException("Server rejected this client") throw failed as Exception
} }
if (connectionDone) { if (connectionDone) {
@ -164,9 +180,6 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
throw ClientTimedOutException("Waiting for registration response from server") throw ClientTimedOutException("Waiting for registration response from server")
} }
// no longer necessary to hold this connection open
mediaConnection.close()
return connectionDone return connectionDone
} }
} }

View File

@ -18,7 +18,7 @@ package dorkbox.network.handshake
/** /**
* Internal message to handle the connection registration process * Internal message to handle the connection registration process
*/ */
class Message private constructor() { internal class HandshakeMessage private constructor() {
// the public key is used to encrypt the data in the handshake // the public key is used to encrypt the data in the handshake
var publicKey: ByteArray? = null var publicKey: ByteArray? = null
@ -67,8 +67,8 @@ class Message private constructor() {
const val DONE = 2 const val DONE = 2
const val DONE_ACK = 3 const val DONE_ACK = 3
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray): Message { fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray): HandshakeMessage {
val hello = Message() val hello = HandshakeMessage()
hello.state = HELLO hello.state = HELLO
hello.oneTimePad = oneTimePad hello.oneTimePad = oneTimePad
hello.publicKey = publicKey hello.publicKey = publicKey
@ -76,28 +76,28 @@ class Message private constructor() {
return hello return hello
} }
fun helloAckToClient(sessionId: Int): Message { fun helloAckToClient(sessionId: Int): HandshakeMessage {
val hello = Message() val hello = HandshakeMessage()
hello.state = HELLO_ACK hello.state = HELLO_ACK
hello.sessionId = sessionId // has to be the same as before (the client expects this) hello.sessionId = sessionId // has to be the same as before (the client expects this)
return hello return hello
} }
fun doneFromClient(): Message { fun doneFromClient(): HandshakeMessage {
val hello = Message() val hello = HandshakeMessage()
hello.state = DONE hello.state = DONE
return hello return hello
} }
fun doneToClient(sessionId: Int): Message { fun doneToClient(sessionId: Int): HandshakeMessage {
val hello = Message() val hello = HandshakeMessage()
hello.state = DONE_ACK hello.state = DONE_ACK
hello.sessionId = sessionId hello.sessionId = sessionId
return hello return hello
} }
fun error(errorMessage: String?): Message { fun error(errorMessage: String): HandshakeMessage {
val error = Message() val error = HandshakeMessage()
error.state = INVALID error.state = INVALID
error.errorMessage = errorMessage error.errorMessage = errorMessage
return error return error
@ -105,6 +105,22 @@ class Message private constructor() {
} }
override fun toString(): String { override fun toString(): String {
return "Message(oneTimePad=$oneTimePad, state=$state)" val stateStr = when(state) {
INVALID -> "INVALID"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
DONE -> "DONE"
DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"
}
val errorMsg = if (errorMessage == null) {
""
} else {
", Error: $errorMessage"
}
return "HandshakeMessage(oneTimePad=$oneTimePad, sid= $sessionId $stateStr$errorMsg)"
} }
} }

View File

@ -2,8 +2,16 @@ package dorkbox.network.handshake
import dorkbox.netUtil.IPv4 import dorkbox.netUtil.IPv4
import dorkbox.network.ServerConfiguration import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.server.* import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.connection.* import dorkbox.network.aeron.server.AllocationException
import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.Image import io.aeron.Image
import io.aeron.Publication import io.aeron.Publication
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
@ -16,27 +24,16 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.write import kotlin.concurrent.write
/** /**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
*
* @throws IllegalArgumentException If the port range is not valid * @throws IllegalArgumentException If the port range is not valid
*/ */
internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger, internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger,
config: ServerConfiguration, private val config: ServerConfiguration,
listenerManager: ListenerManager<CONNECTION>) : private val listenerManager: ListenerManager<CONNECTION>) {
ConnectionManager<CONNECTION>(logger, config, listenerManager) {
companion object {
// this is the number of ports used per client. Depending on how a client is configured, this number can change
const val portsPerClient = 2
}
private val pendingConnectionsLock = ReentrantReadWriteLock() private val pendingConnectionsLock = ReentrantReadWriteLock()
private val pendingConnections = Int2ObjectHashMap<CONNECTION>() private val pendingConnections = Int2ObjectHashMap<CONNECTION>()
private val portAllocator: PortAllocator
private val connectionsPerIpCounts = Int2IntCounterMap(0) private val connectionsPerIpCounts = Int2IntCounterMap(0)
// guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!) // guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!)
@ -44,15 +41,6 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
EndPoint.RESERVED_SESSION_ID_HIGH) EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
init {
val minPort = config.clientStartPort
val maxPortCount = portsPerClient * config.maxClientCount
portAllocator = PortAllocator(minPort, maxPortCount)
logger.info("Server connection port range [$minPort - ${minPort + maxPortCount}]")
}
// note: this is called in action dispatch // note: this is called in action dispatch
suspend fun receiveHandshakeMessageServer(handshakePublication: Publication, suspend fun receiveHandshakeMessageServer(handshakePublication: Publication,
buffer: DirectBuffer, offset: Int, length: Int, header: Header, buffer: DirectBuffer, offset: Int, length: Int, header: Header,
@ -70,28 +58,27 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
// val port = remoteIpAndPort.substring(splitPoint+1) // val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString) val clientAddress = IPv4.toInt(clientAddressString)
config as ServerConfiguration
val message = endPoint.readHandshakeMessage(buffer, offset, length, header) val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is Message) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request"))
endPoint.writeHandshakeMessage(handshakePublication, Message.error("Invalid connection request")) endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
return return
} }
val clientPublicKeyBytes = message.publicKey val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState
// check to see if this is a pending connection // check to see if this is a pending connection
if (message.state == Message.DONE) { if (message.state == HandshakeMessage.DONE) {
pendingConnectionsLock.write { pendingConnectionsLock.write {
val pendingConnection = pendingConnections.remove(sessionId) val pendingConnection = pendingConnections.remove(sessionId)
if (pendingConnection != null) { if (pendingConnection != null) {
logger.debug("Connection from client $clientAddressString ready") logger.debug("Connection from client $clientAddressString ready")
// now tell the client we are done // now tell the client we are done
endPoint.writeHandshakeMessage(handshakePublication, Message.doneToClient(sessionId)) endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
endPoint.actionDispatch.launch { endPoint.actionDispatch.launch {
listenerManager.notifyConnect(pendingConnection) listenerManager.notifyConnect(pendingConnection)
@ -104,14 +91,16 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
try { try {
// VALIDATE:: Check to see if there are already too many clients connected. // VALIDATE:: Check to see if there are already too many clients connected.
if (connectionCount() >= config.maxClientCount) { if (endPoint.connections.connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
endPoint.writeHandshakeMessage(handshakePublication, Message.error("Server full. Max allowed is ${config.maxClientCount}"))
endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
return return
} }
// VALIDATE:: check to see if the remote connection's public key has changed! // VALIDATE:: check to see if the remote connection's public key has changed!
if (!endPoint.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)) { validateRemoteAddress = endPoint.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch.")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch."))
return return
} }
@ -125,17 +114,17 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
// VALIDATE:: we are now connected to the client and are going to create a new connection. // VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) { if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString")) listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
// decrement it now, since we aren't going to permit this connection (take the hit on failure, instead // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
endPoint.writeHandshakeMessage(handshakePublication, Message.error("too many connections for IP address. Max allowed is ${config.maxConnectionsPerIpAddress}")) endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Too many connections for IP address"))
return return
} }
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
e)) endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
endPoint.writeHandshakeMessage(handshakePublication, Message.error("Invalid connection"))
return return
} }
@ -149,19 +138,6 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
///// /////
// allocate ports for the client
val connectionPorts: IntArray
try {
// throws exception if this is not possible
connectionPorts = portAllocator.allocate(portsPerClient)
} catch (e: IllegalArgumentException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate $portsPerClient ports for client connection!"))
return
}
// allocate session/stream id's // allocate session/stream id's
val connectionSessionId: Int val connectionSessionId: Int
try { try {
@ -169,9 +145,11 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
} catch (e: AllocationException) { } catch (e: AllocationException) {
// have to unwind actions! // have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection error!"))
return return
} }
@ -182,30 +160,34 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
} catch (e: AllocationException) { } catch (e: AllocationException) {
// have to unwind actions! // have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection error!"))
return return
} }
val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box? val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box?
val subscriptionPort = connectionPorts[0]
val publicationPort = connectionPorts[1] // the pub/sub do not necessarily have to be the same. The can be ANY port
val publicationPort = config.publicationPort
val subscriptionPort = config.subscriptionPort
// create a new connection. The session ID is encrypted. // create a new connection. The session ID is encrypted.
try { try {
// connection timeout of 0 doesn't matter. it is not used by the server // connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = UdpMediaDriverConnection(serverAddress, val clientConnection = UdpMediaDriverConnection(serverAddress,
subscriptionPort,
publicationPort, publicationPort,
subscriptionPort,
connectionStreamId, connectionStreamId,
connectionSessionId, connectionSessionId,
0, 0,
message.isReliable) message.isReliable)
val connection: Connection = endPoint.newConnection(endPoint, clientConnection) val connection: Connection = endPoint.newConnection(ConnectionParams(endPoint, clientConnection, validateRemoteAddress))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
@ -213,7 +195,6 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
if (!permitConnection) { if (!permitConnection) {
// have to unwind actions! // have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId) streamIdAllocator.free(connectionStreamId)
@ -221,17 +202,15 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
listenerManager.notifyError(connection, listenerManager.notifyError(connection,
ClientRejectedException("Connection was not permitted!")) ClientRejectedException("Connection was not permitted!"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
return return
} }
logger.info {
"Client connected [$clientAddressString:$subscriptionPort|$publicationPort] (session: $sessionId)"
}
logger.debug("Created new client connection sessionID {}", connectionSessionId)
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = Message.helloAckToClient(sessionId) val successMessage = HandshakeMessage.helloAckToClient(sessionId)
// now create the encrypted payload, using ECDH // now create the encrypted payload, using ECDH
successMessage.registrationData = endPoint.crypto.encrypt(publicationPort, successMessage.registrationData = endPoint.crypto.encrypt(publicationPort,
@ -242,19 +221,19 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
successMessage.publicKey = endPoint.crypto.publicKeyBytes successMessage.publicKey = endPoint.crypto.publicKeyBytes
// this tells the client all of the info to connect. // this enables the connection to start polling for messages
endPoint.writeHandshakeMessage(handshakePublication, successMessage) endPoint.connections.add(connection)
addConnection(connection)
// before we notify connect, we have to wait for the client to tell us that they can receive data // before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write { pendingConnectionsLock.write {
pendingConnections[sessionId] = connection pendingConnections[sessionId] = connection
} }
// this tells the client all of the info to connect.
endPoint.writeHandshakeMessage(handshakePublication, successMessage)
} catch (e: Exception) { } catch (e: Exception) {
// have to unwind actions! // have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId) streamIdAllocator.free(connectionStreamId)
@ -262,49 +241,12 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
} }
} }
/**
suspend fun poll(): Int { * Free up resources from the closed connection
// Get the current time, used to cleanup connections */
val now = System.currentTimeMillis() fun cleanup(connection: CONNECTION) {
var pollCount = 0 connectionsPerIpCounts.getAndDecrement(connection.remoteAddressInt)
sessionIdAllocator.free(connection.sessionId)
forEachConnectionCleanup({ connection -> streamIdAllocator.free(connection.streamId)
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (connection.isExpired(now)) {
logger.debug("[{}] connection expired", connection.sessionId)
shouldCleanupConnection = true
}
if (connection.isClosed()) {
logger.debug("[{}] connection closed", connection.sessionId)
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
true
}
else {
// Otherwise, poll the duologue for activity.
pollCount += connection.pollSubscriptions()
false
}
}, { connectionToClean ->
logger.debug("[{}] deleted connection", connectionToClean.sessionId)
removeConnection(connectionToClean)
// have to free up resources!
connectionsPerIpCounts.getAndDecrement(connectionToClean.remoteAddressInt)
portAllocator.free(connectionToClean.subscriptionPort)
portAllocator.free(connectionToClean.publicationPort)
sessionIdAllocator.free(connectionToClean.sessionId)
streamIdAllocator.free(connectionToClean.streamId)
listenerManager.notifyDisconnect(connectionToClean)
connectionToClean.close()
})
return pollCount
} }
} }