separated handshake from connection management. cleaned up API

This commit is contained in:
nathan 2020-08-15 13:21:20 +02:00
parent 9eb3c122d7
commit 8d9bef0ccc
6 changed files with 337 additions and 443 deletions

View File

@ -18,17 +18,23 @@ package dorkbox.network
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.server.ClientRejectedException
import dorkbox.network.connection.*
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.Ping
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/**
* The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's
@ -40,37 +46,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* Gets the version number.
*/
const val version = "5.0"
/**
* Split array into chunks, max of 256 chunks.
* byte[0] = chunk ID
* byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte)
*/
private fun divideArray(source: ByteArray, chunksize: Int): Array<ByteArray>? {
val fragments = Math.ceil(source.size / chunksize.toDouble()).toInt()
if (fragments > 127) {
// cannot allow more than 127
return null
}
// pre-allocate the memory
val splitArray = Array(fragments) { ByteArray(chunksize + 2) }
var start = 0
for (i in splitArray.indices) {
var length: Int
length = if (start + chunksize > source.size) {
source.size - start
} else {
chunksize
}
splitArray[i] = ByteArray(length + 2)
splitArray[i][0] = i.toByte() // index
splitArray[i][1] = fragments.toByte() // total number of fragments
System.arraycopy(source, start, splitArray[i], 2, length)
start += chunksize
}
return splitArray
}
}
/**
@ -85,7 +60,8 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
*/
private var remoteAddress = ""
private val isConnected = atomic(false)
@Volatile
private var isConnected = false
// is valid when there is a connection to the server, otherwise it is null
private var connection: CONNECTION? = null
@ -95,7 +71,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
private val previousClosedConnectionActivity: Long = 0
override val handshake = ClientHandshake(logger, config, listenerManager, crypto)
private val handshake = ClientHandshake(logger, config, crypto, listenerManager)
private val rmiConnectionSupport = RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch)
init {
@ -108,8 +84,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
if (config.networkMtuSize <= 0) { throw ClientException("configuration networkMtuSize must be > 0") }
if (config.networkMtuSize >= 9 * 1024) { throw ClientException("configuration networkMtuSize must be < ${9 * 1024}") }
autoClosableObjects.add(handshake)
}
override fun newException(message: String, cause: Throwable?): Throwable {
@ -124,26 +98,29 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
/**
* Will attempt to connect to the server, with a default 30 second connection timeout and will BLOCK until completed
* Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed.
*
* For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* Default connection is to localhost
*
* For the IPC (Inter-Process-Communication) address. it must be:
* ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
*
* ### For the IPC (Inter-Process-Communication) address. it must be:
* - the IPC integer ID, "0x1337c0de", "0x12312312", etc.
*
* Note: Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
*
* @param remoteAddress The network or IPC address for the client to connect to
* @param connectionTimeout wait for x milliseconds. 0 will wait indefinitely
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
suspend fun connect(remoteAddress: String = "", connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
if (isConnected.value) {
suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
if (isConnected) {
logger.error("Unable to connect when already connected!")
return
}
@ -151,26 +128,33 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
this.connectionTimeoutMS = connectionTimeoutMS
// localhost/loopback IP might not always be 127.0.0.1 or ::1
when (remoteAddress) {
"loopback", "localhost", "lo" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress
"loopback", "localhost", "lo", "" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress
else -> when {
remoteAddress.startsWith("127.") -> this.remoteAddress = IPv4.LOCALHOST.hostAddress
remoteAddress.startsWith("::1") -> this.remoteAddress = IPv6.LOCALHOST.hostAddress
else -> this.remoteAddress = remoteAddress
IPv4.isLoopback(remoteAddress) -> this.remoteAddress = IPv4.LOCALHOST.hostAddress
IPv6.isLoopback(remoteAddress) -> this.remoteAddress = IPv6.LOCALHOST.hostAddress
else -> this.remoteAddress = remoteAddress // might be IPC address!
}
}
// if we are IPv6, the IP must be in '[]'
if (this.remoteAddress.count { it == ':' } > 1 &&
this.remoteAddress.count { it == '[' } < 1 &&
this.remoteAddress.count { it == ']' } < 1) {
this.remoteAddress = """[${this.remoteAddress}]"""
}
// if we are IPv4 wildcard
if (this.remoteAddress == "0.0.0.0") {
throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
}
if (IPv6.isValid(this.remoteAddress)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (this.remoteAddress.count { it == '[' } < 1 &&
this.remoteAddress.count { it == ']' } < 1) {
this.remoteAddress = """[${this.remoteAddress}]"""
}
}
handshake.init(this)
if (this.remoteAddress.isEmpty()) {
@ -180,8 +164,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// config.aeronLogDirectory
// stream IDs are flipped for a client because we operate from the perspective of the server
val handshakeConnection = IpcMediaDriverConnection(
streamId = IPC_HANDSHAKE_STREAM_ID_SUB,
@ -209,20 +191,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// THIS IS A NETWORK ADDRESS
// initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER
val handshakeConnection = UdpMediaDriverConnection(
address = this.remoteAddress,
subscriptionPort = config.publicationPort,
publicationPort = config.subscriptionPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = reliable)
autoClosableObjects.add(handshakeConnection)
val handshakeConnection = UdpMediaDriverConnection(address = this.remoteAddress,
publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = reliable)
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
handshakeConnection.buildClient(aeron)
logger.debug(handshakeConnection.clientInfo())
logger.info(handshakeConnection.clientInfo())
// this will block until the connection timeout, and throw an exception if we were unable to connect with the server
@ -232,19 +211,22 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = UdpMediaDriverConnection(
address = handshakeConnection.address,
subscriptionPort = connectionInfo.subscriptionPort,
publicationPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address,
// NOTE: pub/sub must be switched!
publicationPort = connectionInfo.subscriptionPort,
subscriptionPort = connectionInfo.publicationPort,
streamId = connectionInfo.streamId,
sessionId = connectionInfo.sessionId,
connectionTimeoutMS = connectionTimeoutMS,
isReliable = handshakeConnection.isReliable)
// VALIDATE:: check to see if the remote connection's public key has changed!
if (!crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)) {
listenerManager.notifyError(ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch."))
return
val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.")
listenerManager.notifyError(exception)
throw exception
}
// VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the
@ -256,21 +238,22 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// does not need to do anything
//
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports
logger.debug(reliableClientConnection.clientInfo())
logger.info(reliableClientConnection.clientInfo())
val newConnection = newConnection(this, reliableClientConnection)
autoClosableObjects.add(newConnection)
val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress))
// VALIDATE are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST")
val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) {
listenerManager.notifyError(ClientRejectedException("Connection to $remoteAddress was not permitted!"))
return
handshakeConnection.close()
val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!")
listenerManager.notifyError(exception)
throw exception
}
connection = newConnection
handshake.addConnection(newConnection)
connections.add(newConnection)
// have to make a new thread to listen for incoming data!
// SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them
@ -287,24 +270,25 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// tell the server our connection handshake is done, and the connection can now listen for data.
val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS)
// no longer necessary to hold this connection open
handshakeConnection.close()
if (canFinishConnecting) {
isConnected.lazySet(true)
listenerManager.notifyConnect(newConnection)
isConnected = true
actionDispatch.launch {
listenerManager.notifyConnect(newConnection)
}
} else {
close()
val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}")
listenerManager.notifyError(exception)
throw exception
}
}
}
/**
* Checks to see if this client has connected yet or not.
*
* @return true if we are connected, false otherwise.
*/
override fun isConnected(): Boolean {
return isConnected.value
}
// override fun hasRemoteKeyChanged(): Boolean {
// return connection!!.hasRemoteKeyChanged()
// }
@ -373,18 +357,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
}
/**
* @throws ClientException when a message cannot be sent
*/
suspend fun send(message: Any, priority: Byte) {
val c = connection
if (c != null) {
c.send(message, priority)
} else {
throw ClientException("Cannot send a message when there is no connection!")
}
}
/**
* @throws ClientException when a ping cannot be sent
*/
@ -409,104 +381,24 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
}
// fun initClassRegistration(channel: Channel, registration: Registration): Boolean {
// val details = serialization.getKryoRegistrationDetails()
// val length = details.size
// if (length > Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) {
// // it is too large to send in a single packet
//
// // child arrays have index 0 also as their 'index' and 1 is the total number of fragments
// val fragments = divideArray(details, Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE)
// if (fragments == null) {
// logger.error("Too many classes have been registered for Serialization. Please report this issue")
// return false
// }
// val allButLast = fragments.size - 1
// for (i in 0 until allButLast) {
// val fragment = fragments[i]
// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey())
// fragmentedRegistration.payload = fragment
//
// // tell the server we are fragmented
// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// fragmentedRegistration.upgraded = true
// channel.writeAndFlush(fragmentedRegistration)
// }
//
// // now tell the server we are done with the fragments
// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey())
// fragmentedRegistration.payload = fragments[allButLast]
//
// // tell the server we are fragmented
// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// fragmentedRegistration.upgraded = true
// channel.writeAndFlush(fragmentedRegistration)
// } else {
// registration.payload = details
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// registration.upgraded = true
// channel.writeAndFlush(registration)
// }
// return true
// }
// /**
// * Closes all connections ONLY (keeps the client running). To STOP the client, use stop().
// * <p/>
// * This is used, for example, when reconnecting to a server.
// */
// protected
// void closeConnection() {
// if (isConnected.get()) {
// // make sure we're not waiting on registration
// stopRegistration();
//
// // for the CLIENT only, we clear these connections! (the server only clears them on shutdown)
//
// // stop does the same as this + more. Only keep the listeners for connections IF we are the client. If we remove listeners as a client,
// // ALL of the client logic will be lost. The server is reactive, so listeners are added to connections as needed (instead of before startup)
// connectionManager.closeConnections(true);
//
// // Sometimes there might be "lingering" connections (ie, halfway though registration) that need to be closed.
// registrationWrapper.clearSessions();
//
//
// closeConnections(true);
// shutdownAllChannels();
// // shutdownEventLoops(); we don't do this here!
//
// connection = null;
// isConnected.set(false);
//
// previousClosedConnectionActivity = System.nanoTime();
// }
// }
// /**
// * Internal call to abort registration if the shutdown command is issued during channel registration.
// */
// @Suppress("unused")
// fun abortRegistration() {
// // make sure we're not waiting on registration
//// stopRegistration()
// }
override fun close() {
val con = connection
connection = null
if (con != null) {
handshake.removeConnection(con)
connections.remove(con)
runBlocking {
con.close()
listenerManager.notifyDisconnect(con)
}
}
super.close()
isConnected = false
}
// RMI notes (in multiple places, copypasta, because this is confusing if not written down
// RMI notes (in multiple places, copypasta, because this is confusing if not written down)
//
// only server can create a global object (in itself, via save)
// server

View File

@ -20,12 +20,14 @@ import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy
import dorkbox.network.serialization.NetworkSerializationManager
import dorkbox.network.serialization.Serialization
import dorkbox.network.store.PropertyStore
import dorkbox.network.store.SettingsStore
import dorkbox.network.storage.PropertyStore
import dorkbox.network.storage.SettingsStore
import dorkbox.os.OS
import dorkbox.util.storage.StorageBuilder
import dorkbox.util.storage.StorageSystem
import io.aeron.driver.Configuration
import io.aeron.driver.ThreadingMode
import mu.KLogger
import java.io.File
class ServerConfiguration : dorkbox.network.Configuration() {
@ -35,11 +37,6 @@ class ServerConfiguration : dorkbox.network.Configuration() {
*/
var listenIpAddress = "*"
/**
* The starting port for clients to use. The upper bound of this value is limited by the maximum number of clients allowed.
*/
var clientStartPort = 0
/**
* The maximum number of clients allowed for a server
*/
@ -55,6 +52,8 @@ open class Configuration {
/**
* When connecting to a remote client/server, should connections be allowed if the remote machine signature has changed?
*
* Setting this to false is not recommended as it is a security risk
*/
var enableRemoteSignatureValidation: Boolean = true
@ -103,8 +102,8 @@ open class Configuration {
* The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT.
*
* There are a couple strategies of importance to understand.
* * BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default.
* * BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less
* - BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default.
* - BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less
* responsive to activity when idle for a little while.
*
* The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and
@ -116,8 +115,8 @@ open class Configuration {
* The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT.
*
* There are a couple strategies of importance to understand.
* * BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default.
* * BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less
* - BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default.
* - BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less
* responsive to activity when idle for a little while.
*
* The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and
@ -126,26 +125,22 @@ open class Configuration {
var sendIdleStrategy: CoroutineIdleStrategy = CoroutineSleepingMillisIdleStrategy(sleepPeriodMs = 100)
/**
* A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation.
* ## A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation.
*
*
* There are three main Agents in the driver:
*
*
* Conductor: Responsible for reacting to client requests and house keeping duties as well as detecting loss, sending NAKs,
* - Conductor: Responsible for reacting to client requests and house keeping duties as well as detecting loss, sending NAKs,
* rotating buffers, etc.
* Sender: Responsible for shovelling messages from publishers to the network.
* Receiver: Responsible for shovelling messages from the network to subscribers.
* - Sender: Responsible for shovelling messages from publishers to the network.
* - Receiver: Responsible for shovelling messages from the network to subscribers.
*
*
* This value can be one of:
*
*
* INVOKER: No threads. The client is responsible for using the MediaDriver.Context.driverAgentInvoker() to invoke the duty
* - INVOKER: No threads. The client is responsible for using the MediaDriver.Context.driverAgentInvoker() to invoke the duty
* cycle directly.
* SHARED: All Agents share a single thread. 1 thread in total.
* SHARED_NETWORK: Sender and Receiver shares a thread, conductor has its own thread. 2 threads in total.
* DEDICATED: The default and dedicates one thread per Agent. 3 threads in total.
* - SHARED: All Agents share a single thread. 1 thread in total.
* - SHARED_NETWORK: Sender and Receiver shares a thread, conductor has its own thread. 2 threads in total.
* - DEDICATED: The default and dedicates one thread per Agent. 3 threads in total.
*
*
* For performance, it is recommended to use DEDICATED as long as the number of busy threads is less than or equal to the number of
@ -217,4 +212,31 @@ open class Configuration {
* A value of 0 will 'auto-configure' this setting.
*/
var receiveBufferSize = 0
/**
* Depending on the OS, different base locations for the Aeron log directory are preferred.
*/
fun suggestAeronLogLocation(logger: KLogger): File {
return when {
OS.isMacOsX() -> {
// does the recommended location exist??
val suggestedLocation = File("/Volumes/DevShm")
if (suggestedLocation.exists()) {
suggestedLocation
}
else {
logger.info("It is recommended to create a RAM drive for best performance. For example\n" + "\$ diskutil erasevolume HFS+ \"DevShm\" `hdiutil attach -nomount ram://\$((2048 * 2048))`\n" + "\t After this, set config.aeronLogDirectory = \"/Volumes/DevShm\"")
File(System.getProperty("java.io.tmpdir"))
}
}
OS.isLinux() -> {
// this is significantly faster for linux than using the temp dir
File("/dev/shm/")
}
else -> {
File(System.getProperty("java.io.tmpdir"))
}
}
}
}

View File

@ -16,6 +16,7 @@
package dorkbox.network
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
@ -36,9 +37,6 @@ import java.net.InetSocketAddress
import java.util.concurrent.CopyOnWriteArrayList
/**
* NOTE: when using "server.publish(A)", this will go to ALL CLIENTS! add this to aeron via "publication.addDestination" so aeron manages it
*
*
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
* server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections())
*
@ -75,7 +73,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
private var bindAlreadyCalled = false
override val handshake = ServerHandshake(logger, config, listenerManager)
private val handshake = ServerHandshake(logger, config, listenerManager)
/**
* Maintains a thread-safe collection of rules used to define the connection type with this server.
@ -90,25 +88,29 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
when (config.listenIpAddress) {
"loopback", "localhost", "lo" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
else -> when {
config.listenIpAddress.startsWith("127.") -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
config.listenIpAddress.startsWith("::1") -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress
else -> config.listenIpAddress = "0.0.0.0" // we set this to "0.0.0.0" so that it is clear that we are trying to bind to that address.
}
}
// if we are IPv6, the IP must be in '[]'
if (config.listenIpAddress.count { it == ':' } > 1 &&
config.listenIpAddress.count { it == '[' } < 1 &&
config.listenIpAddress.count { it == ']' } < 1) {
config.listenIpAddress = """[${config.listenIpAddress}]"""
}
// if we are IPv4 wildcard
if (config.listenIpAddress == "0.0.0.0") {
// this will also fixup windows!
config.listenIpAddress = IPv4.WILDCARD
}
if (IPv6.isValid(config.listenIpAddress)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (config.listenIpAddress.count { it == '[' } < 1 &&
config.listenIpAddress.count { it == ']' } < 1) {
config.listenIpAddress = """[${config.listenIpAddress}]"""
}
}
if (config.publicationPort <= 0) { throw ServerException("configuration port must be > 0") }
if (config.publicationPort >= 65535) { throw ServerException("configuration port must be < 65535") }
@ -119,8 +121,9 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
if (config.networkMtuSize <= 0) { throw ServerException("configuration networkMtuSize must be > 0") }
if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") }
autoClosableObjects.add(handshake)
if (config.maxConnectionsPerIpAddress == 0) { config.maxConnectionsPerIpAddress = config.maxClientCount}
}
override fun newException(message: String, cause: Throwable?): Throwable {
return ServerException(message, cause)
}
@ -145,20 +148,18 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// setup the "HANDSHAKE" ports, for initial clients to connect.
// The is how clients then get the new ports to connect to + other configuration options
val handshakeDriver = UdpMediaDriverConnection(
address = config.listenIpAddress,
subscriptionPort = config.subscriptionPort,
publicationPort = config.publicationPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID)
val handshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID)
handshakeDriver.buildServer(aeron)
val handshakePublication = handshakeDriver.publication
val handshakeSubscription = handshakeDriver.subscription
logger.debug(handshakeDriver.serverInfo())
logger.debug("Server listening for incoming clients on ${handshakePublication.localSocketAddresses()}")
logger.info(handshakeDriver.serverInfo())
val ipcHandshakeDriver = IpcMediaDriverConnection(
@ -197,7 +198,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
try {
var pollCount: Int
while (!isShutdown()) {
// Get the current time, used to cleanup connections
val now = System.currentTimeMillis()
pollCount = 0
// this checks to see if there are NEW clients
@ -206,8 +211,38 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// this checks to see if there are NEW clients via IPC
// pollCount += ipcHandshakeSubscription.poll(ipcInitialConnectionHandler, 100)
// this manages existing clients (for cleanup + connection polling)
pollCount += handshake.poll()
connections.forEachWithCleanup({ connection ->
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (connection.isExpired(now)) {
logger.debug("[{}] connection expired", connection.sessionId)
shouldCleanupConnection = true
}
if (connection.isClosed()) {
logger.debug("[{}] connection closed", connection.sessionId)
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
true
}
else {
// Otherwise, poll the duologue for activity.
pollCount += connection.pollSubscriptions()
false
}
}, { connectionToClean ->
logger.debug("[{}] deleted connection", connectionToClean.sessionId)
// have to free up resources!
handshake.cleanup(connectionToClean)
connectionToClean.close()
listenerManager.notifyDisconnect(connectionToClean)
})
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
@ -229,7 +264,14 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
}
}
internal suspend fun poll(): Int {
var pollCount = 0
return pollCount
}
/**
@ -250,45 +292,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
connectionRules.addAll(listOf(*rules))
}
// verify the class ID registration details.
// the client will send their class registration data. VERIFY IT IS CORRECT!
// verify the class ID registration details.
// the client will send their class registration data. VERIFY IT IS CORRECT!
// var state: dorkbox.network.connection.RegistrationWrapper.STATE = registrationWrapper.verifyClassRegistration(metaChannel, registration)
// if (state == RegistrationWrapper.STATE.ERROR)
// {
// // abort! There was an error
// shutdown(channel, 0)
// return
// } else if (state == RegistrationWrapper.STATE.WAIT)
// {
// return
// }
/**
* Checks to seeOnce a server has connected to ANY client, it will always return true until server.close() is called
*
* @return true if we are connected, false otherwise.
*/
override fun isConnected(): Boolean {
return handshake.connectionCount() > 0
}
/**
* Safely sends objects to a destination
*/
suspend fun send(message: Any) {
handshake.send(message)
connections.send(message)
}
/**
@ -348,6 +356,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
/**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
* Adds a custom connection to the server.
*
* This should only be used in situations where there can be DIFFERENT types of connections (such as a 'web-based' connection) and
@ -356,10 +365,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param connection the connection to add
*/
fun addConnection(connection: CONNECTION) {
handshake.addConnection(connection)
connections.add(connection)
}
/**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
* Removes a custom connection to the server.
*
*
@ -369,11 +379,10 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param connection the connection to remove
*/
fun removeConnection(connection: CONNECTION) {
handshake.removeConnection(connection)
connections.remove(connection)
}
/**
* Checks to see if a server (using the specified configuration) is running.
*
@ -467,7 +476,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// RMI notes (in multiple places, copypasta, because this is confusing if not written down
// RMI notes (in multiple places, copypasta, because this is confusing if not written down)
//
// only server can create a global object (in itself, via save)
// server

View File

@ -3,7 +3,12 @@ package dorkbox.network.handshake
import dorkbox.network.Configuration
import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.connection.*
import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.MediaDriverConnection
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header
@ -12,10 +17,10 @@ import mu.KLogger
import org.agrona.DirectBuffer
import java.security.SecureRandom
internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
config: Configuration,
listenerManager: ListenerManager<CONNECTION>,
val crypto: CryptoManagement) : ConnectionManager<CONNECTION>(logger, config, listenerManager) {
internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogger,
private val config: Configuration,
private val crypto: CryptoManagement,
private val listenerManager: ListenerManager<CONNECTION>) {
// a one-time key for connecting
private val oneTimePad = SecureRandom().nextInt()
@ -25,8 +30,8 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
@Volatile
var connectionDone = false
private var failed = false
@Volatile
private var failed: Exception? = null
lateinit var handler: FragmentHandler
lateinit var endPoint: EndPoint<*>
@ -38,37 +43,48 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
// now we have a bi-directional connection with the server on the handshake "socket".
handler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
endPoint.actionDispatch.launch {
val sessionId = header.sessionId()
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
logger.debug("[{}] handshake response: {}", sessionId, message)
logger.trace {
"[$sessionId] handshake response: $message"
}
// it must be a registration message
if (message !is Message) {
logger.error("[{}] server returned unrecognized message: {}", sessionId, message)
if (message !is HandshakeMessage) {
failed = ClientException("[$sessionId] server returned unrecognized message: $message")
return@launch
}
if (message.sessionId != sessionId) {
logger.error("[{}] ignored message intended for another client", sessionId)
// this is an error message
if (message.sessionId == 0) {
failed = ClientException("[$sessionId] error: ${message.errorMessage}")
return@launch
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}")
return@launch
}
// it must be the correct state
when (message.state) {
Message.HELLO_ACK -> {
HandshakeMessage.HELLO_ACK -> {
// The message was intended for this client. Try to parse it as one of the available message types.
// this message is ENCRYPTED!
connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey)
connectionHelloInfo!!.log(sessionId, logger)
}
Message.DONE_ACK -> {
HandshakeMessage.DONE_ACK -> {
connectionDone = true
}
else -> {
if (message.state != Message.HELLO_ACK) {
logger.error("[{}] ignored message that is not HELLO_ACK", sessionId)
} else if (message.state != Message.DONE_ACK) {
logger.error("[{}] ignored message that is not INIT_ACK", sessionId)
if (message.state != HandshakeMessage.HELLO_ACK) {
failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK")
} else if (message.state != HandshakeMessage.DONE_ACK) {
failed = ClientException("[$sessionId] ignored message that is not DONE_ACK")
}
return@launch
@ -79,7 +95,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
}
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
val registrationMessage = Message.helloFromClient(
val registrationMessage = HandshakeMessage.helloFromClient(
oneTimePad = oneTimePad,
publicKey = config.settingsStore.getPublicKey()!!,
registrationData = config.serialization.getKryoRegistrationDetails()
@ -93,7 +109,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
// block until we receive the connection information from the server
failed = false
failed = null
var pollCount: Int
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
@ -102,10 +118,10 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
pollCount = subscription.poll(handler, 1024)
if (failed) {
if (failed != null) {
// no longer necessary to hold this connection open
handshakeConnection.close()
throw ClientException("Server rejected this client")
throw failed as Exception
}
if (connectionHelloInfo != null) {
@ -127,7 +143,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
}
suspend fun handshakeDone(mediaConnection: UdpMediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = Message.doneFromClient()
val registrationMessage = HandshakeMessage.doneFromClient()
// Send the done message to the server.
endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage)
@ -135,7 +151,7 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
// block until we receive the connection information from the server
failed = false
failed = null
var pollCount: Int
val subscription = mediaConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
@ -144,10 +160,10 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
while (System.currentTimeMillis() - startTime < connectionTimeoutMS) {
pollCount = subscription.poll(handler, 1024)
if (failed) {
if (failed != null) {
// no longer necessary to hold this connection open
mediaConnection.close()
throw ClientException("Server rejected this client")
throw failed as Exception
}
if (connectionDone) {
@ -164,9 +180,6 @@ internal class ClientHandshake<CONNECTION: Connection>(logger: KLogger,
throw ClientTimedOutException("Waiting for registration response from server")
}
// no longer necessary to hold this connection open
mediaConnection.close()
return connectionDone
}
}

View File

@ -18,7 +18,7 @@ package dorkbox.network.handshake
/**
* Internal message to handle the connection registration process
*/
class Message private constructor() {
internal class HandshakeMessage private constructor() {
// the public key is used to encrypt the data in the handshake
var publicKey: ByteArray? = null
@ -67,8 +67,8 @@ class Message private constructor() {
const val DONE = 2
const val DONE_ACK = 3
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray): Message {
val hello = Message()
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.oneTimePad = oneTimePad
hello.publicKey = publicKey
@ -76,28 +76,28 @@ class Message private constructor() {
return hello
}
fun helloAckToClient(sessionId: Int): Message {
val hello = Message()
fun helloAckToClient(sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO_ACK
hello.sessionId = sessionId // has to be the same as before (the client expects this)
return hello
}
fun doneFromClient(): Message {
val hello = Message()
fun doneFromClient(): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE
return hello
}
fun doneToClient(sessionId: Int): Message {
val hello = Message()
fun doneToClient(sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE_ACK
hello.sessionId = sessionId
return hello
}
fun error(errorMessage: String?): Message {
val error = Message()
fun error(errorMessage: String): HandshakeMessage {
val error = HandshakeMessage()
error.state = INVALID
error.errorMessage = errorMessage
return error
@ -105,6 +105,22 @@ class Message private constructor() {
}
override fun toString(): String {
return "Message(oneTimePad=$oneTimePad, state=$state)"
val stateStr = when(state) {
INVALID -> "INVALID"
HELLO -> "HELLO"
HELLO_ACK -> "HELLO_ACK"
DONE -> "DONE"
DONE_ACK -> "DONE_ACK"
else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!"
}
val errorMsg = if (errorMessage == null) {
""
} else {
", Error: $errorMessage"
}
return "HandshakeMessage(oneTimePad=$oneTimePad, sid= $sessionId $stateStr$errorMsg)"
}
}

View File

@ -2,8 +2,16 @@ package dorkbox.network.handshake
import dorkbox.netUtil.IPv4
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.server.*
import dorkbox.network.connection.*
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.server.AllocationException
import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import io.aeron.Image
import io.aeron.Publication
import io.aeron.logbuffer.Header
@ -16,27 +24,16 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.write
/**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
*
* @throws IllegalArgumentException If the port range is not valid
*/
internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
config: ServerConfiguration,
listenerManager: ListenerManager<CONNECTION>) :
ConnectionManager<CONNECTION>(logger, config, listenerManager) {
companion object {
// this is the number of ports used per client. Depending on how a client is configured, this number can change
const val portsPerClient = 2
}
internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger,
private val config: ServerConfiguration,
private val listenerManager: ListenerManager<CONNECTION>) {
private val pendingConnectionsLock = ReentrantReadWriteLock()
private val pendingConnections = Int2ObjectHashMap<CONNECTION>()
private val portAllocator: PortAllocator
private val connectionsPerIpCounts = Int2IntCounterMap(0)
// guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!)
@ -44,15 +41,6 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
init {
val minPort = config.clientStartPort
val maxPortCount = portsPerClient * config.maxClientCount
portAllocator = PortAllocator(minPort, maxPortCount)
logger.info("Server connection port range [$minPort - ${minPort + maxPortCount}]")
}
// note: this is called in action dispatch
suspend fun receiveHandshakeMessageServer(handshakePublication: Publication,
buffer: DirectBuffer, offset: Int, length: Int, header: Header,
@ -70,28 +58,27 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString)
config as ServerConfiguration
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is Message) {
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request"))
endPoint.writeHandshakeMessage(handshakePublication, Message.error("Invalid connection request"))
endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
return
}
val clientPublicKeyBytes = message.publicKey
val validateRemoteAddress: PublicKeyValidationState
// check to see if this is a pending connection
if (message.state == Message.DONE) {
if (message.state == HandshakeMessage.DONE) {
pendingConnectionsLock.write {
val pendingConnection = pendingConnections.remove(sessionId)
if (pendingConnection != null) {
logger.debug("Connection from client $clientAddressString ready")
// now tell the client we are done
endPoint.writeHandshakeMessage(handshakePublication, Message.doneToClient(sessionId))
endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
endPoint.actionDispatch.launch {
listenerManager.notifyConnect(pendingConnection)
@ -104,14 +91,16 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
try {
// VALIDATE:: Check to see if there are already too many clients connected.
if (connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full"))
endPoint.writeHandshakeMessage(handshakePublication, Message.error("Server full. Max allowed is ${config.maxClientCount}"))
if (endPoint.connections.connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
return
}
// VALIDATE:: check to see if the remote connection's public key has changed!
if (!endPoint.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)) {
validateRemoteAddress = endPoint.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch."))
return
}
@ -125,17 +114,17 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString"))
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
// decrement it now, since we aren't going to permit this connection (take the hit on failure, instead
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress)
endPoint.writeHandshakeMessage(handshakePublication, Message.error("too many connections for IP address. Max allowed is ${config.maxConnectionsPerIpAddress}"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Too many connections for IP address"))
return
}
} catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message",
e))
endPoint.writeHandshakeMessage(handshakePublication, Message.error("Invalid connection"))
listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
return
}
@ -149,19 +138,6 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
/////
// allocate ports for the client
val connectionPorts: IntArray
try {
// throws exception if this is not possible
connectionPorts = portAllocator.allocate(portsPerClient)
} catch (e: IllegalArgumentException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate $portsPerClient ports for client connection!"))
return
}
// allocate session/stream id's
val connectionSessionId: Int
try {
@ -169,9 +145,11 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
} catch (e: AllocationException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection error!"))
return
}
@ -182,30 +160,34 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
} catch (e: AllocationException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection error!"))
return
}
val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box?
val subscriptionPort = connectionPorts[0]
val publicationPort = connectionPorts[1]
// the pub/sub do not necessarily have to be the same. The can be ANY port
val publicationPort = config.publicationPort
val subscriptionPort = config.subscriptionPort
// create a new connection. The session ID is encrypted.
try {
// connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = UdpMediaDriverConnection(serverAddress,
subscriptionPort,
publicationPort,
subscriptionPort,
connectionStreamId,
connectionSessionId,
0,
message.isReliable)
val connection: Connection = endPoint.newConnection(endPoint, clientConnection)
val connection: Connection = endPoint.newConnection(ConnectionParams(endPoint, clientConnection, validateRemoteAddress))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST")
@ -213,7 +195,6 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
if (!permitConnection) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
@ -221,17 +202,15 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
listenerManager.notifyError(connection,
ClientRejectedException("Connection was not permitted!"))
endPoint.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
return
}
logger.info {
"Client connected [$clientAddressString:$subscriptionPort|$publicationPort] (session: $sessionId)"
}
logger.debug("Created new client connection sessionID {}", connectionSessionId)
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = Message.helloAckToClient(sessionId)
val successMessage = HandshakeMessage.helloAckToClient(sessionId)
// now create the encrypted payload, using ECDH
successMessage.registrationData = endPoint.crypto.encrypt(publicationPort,
@ -242,19 +221,19 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
successMessage.publicKey = endPoint.crypto.publicKeyBytes
// this tells the client all of the info to connect.
endPoint.writeHandshakeMessage(handshakePublication, successMessage)
addConnection(connection)
// this enables the connection to start polling for messages
endPoint.connections.add(connection)
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnectionsLock.write {
pendingConnections[sessionId] = connection
}
// this tells the client all of the info to connect.
endPoint.writeHandshakeMessage(handshakePublication, successMessage)
} catch (e: Exception) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
@ -262,49 +241,12 @@ internal class ServerHandshake<CONNECTION : Connection>(logger: KLogger,
}
}
suspend fun poll(): Int {
// Get the current time, used to cleanup connections
val now = System.currentTimeMillis()
var pollCount = 0
forEachConnectionCleanup({ connection ->
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (connection.isExpired(now)) {
logger.debug("[{}] connection expired", connection.sessionId)
shouldCleanupConnection = true
}
if (connection.isClosed()) {
logger.debug("[{}] connection closed", connection.sessionId)
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
true
}
else {
// Otherwise, poll the duologue for activity.
pollCount += connection.pollSubscriptions()
false
}
}, { connectionToClean ->
logger.debug("[{}] deleted connection", connectionToClean.sessionId)
removeConnection(connectionToClean)
// have to free up resources!
connectionsPerIpCounts.getAndDecrement(connectionToClean.remoteAddressInt)
portAllocator.free(connectionToClean.subscriptionPort)
portAllocator.free(connectionToClean.publicationPort)
sessionIdAllocator.free(connectionToClean.sessionId)
streamIdAllocator.free(connectionToClean.streamId)
listenerManager.notifyDisconnect(connectionToClean)
connectionToClean.close()
})
return pollCount
/**
* Free up resources from the closed connection
*/
fun cleanup(connection: CONNECTION) {
connectionsPerIpCounts.getAndDecrement(connection.remoteAddressInt)
sessionIdAllocator.free(connection.sessionId)
streamIdAllocator.free(connection.streamId)
}
}