diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index 1240e1ed..7f926aed 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -18,17 +18,23 @@ package dorkbox.network import dorkbox.netUtil.IPv4 import dorkbox.netUtil.IPv6 import dorkbox.network.aeron.client.ClientException +import dorkbox.network.aeron.client.ClientRejectedException import dorkbox.network.aeron.client.ClientTimedOutException -import dorkbox.network.aeron.server.ClientRejectedException -import dorkbox.network.connection.* +import dorkbox.network.connection.Connection +import dorkbox.network.connection.ConnectionParams +import dorkbox.network.connection.EndPoint +import dorkbox.network.connection.IpcMediaDriverConnection +import dorkbox.network.connection.Ping +import dorkbox.network.connection.PublicKeyValidationState +import dorkbox.network.connection.UdpMediaDriverConnection import dorkbox.network.handshake.ClientHandshake import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RmiSupportConnection import dorkbox.network.rmi.TimeoutException import dorkbox.util.exceptions.SecurityException -import kotlinx.atomicfu.atomic import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking /** * The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's @@ -40,37 +46,6 @@ open class Client(config: Configuration = Configuration * Gets the version number. */ const val version = "5.0" - - /** - * Split array into chunks, max of 256 chunks. - * byte[0] = chunk ID - * byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte) - */ - private fun divideArray(source: ByteArray, chunksize: Int): Array? { - val fragments = Math.ceil(source.size / chunksize.toDouble()).toInt() - if (fragments > 127) { - // cannot allow more than 127 - return null - } - - // pre-allocate the memory - val splitArray = Array(fragments) { ByteArray(chunksize + 2) } - var start = 0 - for (i in splitArray.indices) { - var length: Int - length = if (start + chunksize > source.size) { - source.size - start - } else { - chunksize - } - splitArray[i] = ByteArray(length + 2) - splitArray[i][0] = i.toByte() // index - splitArray[i][1] = fragments.toByte() // total number of fragments - System.arraycopy(source, start, splitArray[i], 2, length) - start += chunksize - } - return splitArray - } } /** @@ -85,7 +60,8 @@ open class Client(config: Configuration = Configuration */ private var remoteAddress = "" - private val isConnected = atomic(false) + @Volatile + private var isConnected = false // is valid when there is a connection to the server, otherwise it is null private var connection: CONNECTION? = null @@ -95,7 +71,7 @@ open class Client(config: Configuration = Configuration private val previousClosedConnectionActivity: Long = 0 - override val handshake = ClientHandshake(logger, config, listenerManager, crypto) + private val handshake = ClientHandshake(logger, config, crypto, listenerManager) private val rmiConnectionSupport = RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch) init { @@ -108,8 +84,6 @@ open class Client(config: Configuration = Configuration if (config.networkMtuSize <= 0) { throw ClientException("configuration networkMtuSize must be > 0") } if (config.networkMtuSize >= 9 * 1024) { throw ClientException("configuration networkMtuSize must be < ${9 * 1024}") } - - autoClosableObjects.add(handshake) } override fun newException(message: String, cause: Throwable?): Throwable { @@ -124,26 +98,29 @@ open class Client(config: Configuration = Configuration } /** - * Will attempt to connect to the server, with a default 30 second connection timeout and will BLOCK until completed + * Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed. * - * For a network address, it can be: - * - a network name ("localhost", "loopback", "lo", "bob.example.org") - * - an IP address ("127.0.0.1", "123.123.123.123", "::1") + * Default connection is to localhost * - * For the IPC (Inter-Process-Communication) address. it must be: + * ### For a network address, it can be: + * - a network name ("localhost", "loopback", "lo", "bob.example.org") + * - an IP address ("127.0.0.1", "123.123.123.123", "::1") + * + * ### For the IPC (Inter-Process-Communication) address. it must be: * - the IPC integer ID, "0x1337c0de", "0x12312312", etc. * - * Note: Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') - * + * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * * @param remoteAddress The network or IPC address for the client to connect to - * @param connectionTimeout wait for x milliseconds. 0 will wait indefinitely + * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely * @param reliable true if we want to create a reliable connection. IPC connections are always reliable * + * @throws IllegalArgumentException if the remote address is invalid * @throws ClientTimedOutException if the client is unable to connect in x amount of time + * @throws ClientRejectedException if the client connection is rejected */ - suspend fun connect(remoteAddress: String = "", connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { - if (isConnected.value) { + suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + if (isConnected) { logger.error("Unable to connect when already connected!") return } @@ -151,26 +128,33 @@ open class Client(config: Configuration = Configuration this.connectionTimeoutMS = connectionTimeoutMS // localhost/loopback IP might not always be 127.0.0.1 or ::1 when (remoteAddress) { - "loopback", "localhost", "lo" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress + "loopback", "localhost", "lo", "" -> this.remoteAddress = IPv4.LOCALHOST.hostAddress else -> when { - remoteAddress.startsWith("127.") -> this.remoteAddress = IPv4.LOCALHOST.hostAddress - remoteAddress.startsWith("::1") -> this.remoteAddress = IPv6.LOCALHOST.hostAddress - else -> this.remoteAddress = remoteAddress + IPv4.isLoopback(remoteAddress) -> this.remoteAddress = IPv4.LOCALHOST.hostAddress + IPv6.isLoopback(remoteAddress) -> this.remoteAddress = IPv6.LOCALHOST.hostAddress + else -> this.remoteAddress = remoteAddress // might be IPC address! } } - // if we are IPv6, the IP must be in '[]' - if (this.remoteAddress.count { it == ':' } > 1 && - this.remoteAddress.count { it == '[' } < 1 && - this.remoteAddress.count { it == ']' } < 1) { - - this.remoteAddress = """[${this.remoteAddress}]""" - } + // if we are IPv4 wildcard if (this.remoteAddress == "0.0.0.0") { throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") } + + if (IPv6.isValid(this.remoteAddress)) { + // "[" and "]" are valid for ipv6 addresses... we want to make sure it is so + + // if we are IPv6, the IP must be in '[]' + if (this.remoteAddress.count { it == '[' } < 1 && + this.remoteAddress.count { it == ']' } < 1) { + + this.remoteAddress = """[${this.remoteAddress}]""" + } + } + + handshake.init(this) if (this.remoteAddress.isEmpty()) { @@ -180,8 +164,6 @@ open class Client(config: Configuration = Configuration // config.aeronLogDirectory - - // stream IDs are flipped for a client because we operate from the perspective of the server val handshakeConnection = IpcMediaDriverConnection( streamId = IPC_HANDSHAKE_STREAM_ID_SUB, @@ -209,20 +191,17 @@ open class Client(config: Configuration = Configuration // THIS IS A NETWORK ADDRESS // initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER - val handshakeConnection = UdpMediaDriverConnection( - address = this.remoteAddress, - subscriptionPort = config.publicationPort, - publicationPort = config.subscriptionPort, - streamId = UDP_HANDSHAKE_STREAM_ID, - sessionId = RESERVED_SESSION_ID_INVALID, - connectionTimeoutMS = connectionTimeoutMS, - isReliable = reliable) - - autoClosableObjects.add(handshakeConnection) + val handshakeConnection = UdpMediaDriverConnection(address = this.remoteAddress, + publicationPort = config.subscriptionPort, + subscriptionPort = config.publicationPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID, + connectionTimeoutMS = connectionTimeoutMS, + isReliable = reliable) // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports handshakeConnection.buildClient(aeron) - logger.debug(handshakeConnection.clientInfo()) + logger.info(handshakeConnection.clientInfo()) // this will block until the connection timeout, and throw an exception if we were unable to connect with the server @@ -232,19 +211,22 @@ open class Client(config: Configuration = Configuration // we are now connected, so we can connect to the NEW client-specific ports - val reliableClientConnection = UdpMediaDriverConnection( - address = handshakeConnection.address, - subscriptionPort = connectionInfo.subscriptionPort, - publicationPort = connectionInfo.publicationPort, - streamId = connectionInfo.streamId, - sessionId = connectionInfo.sessionId, - connectionTimeoutMS = connectionTimeoutMS, - isReliable = handshakeConnection.isReliable) + val reliableClientConnection = UdpMediaDriverConnection(address = handshakeConnection.address, + // NOTE: pub/sub must be switched! + publicationPort = connectionInfo.subscriptionPort, + subscriptionPort = connectionInfo.publicationPort, + streamId = connectionInfo.streamId, + sessionId = connectionInfo.sessionId, + connectionTimeoutMS = connectionTimeoutMS, + isReliable = handshakeConnection.isReliable) // VALIDATE:: check to see if the remote connection's public key has changed! - if (!crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey)) { - listenerManager.notifyError(ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.")) - return + val validateRemoteAddress = crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress), connectionInfo.publicKey) + if (validateRemoteAddress == PublicKeyValidationState.INVALID) { + handshakeConnection.close() + val exception = ClientRejectedException("Connection to $remoteAddress not allowed! Public key mismatch.") + listenerManager.notifyError(exception) + throw exception } // VALIDATE:: If the the serialization DOES NOT match between the client/server, then the server will emit a log, and the @@ -256,21 +238,22 @@ open class Client(config: Configuration = Configuration // does not need to do anything // // throws a ConnectTimedOutException if the client cannot connect for any reason to the server-assigned client ports - logger.debug(reliableClientConnection.clientInfo()) + logger.info(reliableClientConnection.clientInfo()) - val newConnection = newConnection(this, reliableClientConnection) - autoClosableObjects.add(newConnection) + val newConnection = newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress)) // VALIDATE are we allowed to connect to this server (now that we have the initial server information) @Suppress("UNCHECKED_CAST") val permitConnection = listenerManager.notifyFilter(newConnection) if (!permitConnection) { - listenerManager.notifyError(ClientRejectedException("Connection to $remoteAddress was not permitted!")) - return + handshakeConnection.close() + val exception = ClientRejectedException("Connection to $remoteAddress was not permitted!") + listenerManager.notifyError(exception) + throw exception } connection = newConnection - handshake.addConnection(newConnection) + connections.add(newConnection) // have to make a new thread to listen for incoming data! // SUBSCRIPTIONS ARE NOT THREAD SAFE! Only one thread at a time can poll them @@ -287,24 +270,25 @@ open class Client(config: Configuration = Configuration // tell the server our connection handshake is done, and the connection can now listen for data. val canFinishConnecting = handshake.handshakeDone(handshakeConnection, connectionTimeoutMS) + + // no longer necessary to hold this connection open + handshakeConnection.close() + if (canFinishConnecting) { - isConnected.lazySet(true) - listenerManager.notifyConnect(newConnection) + isConnected = true + + actionDispatch.launch { + listenerManager.notifyConnect(newConnection) + } } else { close() + val exception = ClientRejectedException("Unable to connect with server ${handshakeConnection.clientInfo()}") + listenerManager.notifyError(exception) + throw exception } } } - /** - * Checks to see if this client has connected yet or not. - * - * @return true if we are connected, false otherwise. - */ - override fun isConnected(): Boolean { - return isConnected.value - } - // override fun hasRemoteKeyChanged(): Boolean { // return connection!!.hasRemoteKeyChanged() // } @@ -373,18 +357,6 @@ open class Client(config: Configuration = Configuration } } - /** - * @throws ClientException when a message cannot be sent - */ - suspend fun send(message: Any, priority: Byte) { - val c = connection - if (c != null) { - c.send(message, priority) - } else { - throw ClientException("Cannot send a message when there is no connection!") - } - } - /** * @throws ClientException when a ping cannot be sent */ @@ -409,104 +381,24 @@ open class Client(config: Configuration = Configuration } } -// fun initClassRegistration(channel: Channel, registration: Registration): Boolean { -// val details = serialization.getKryoRegistrationDetails() -// val length = details.size -// if (length > Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) { -// // it is too large to send in a single packet -// -// // child arrays have index 0 also as their 'index' and 1 is the total number of fragments -// val fragments = divideArray(details, Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) -// if (fragments == null) { -// logger.error("Too many classes have been registered for Serialization. Please report this issue") -// return false -// } -// val allButLast = fragments.size - 1 -// for (i in 0 until allButLast) { -// val fragment = fragments[i] -// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey()) -// fragmentedRegistration.payload = fragment -// -// // tell the server we are fragmented -// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED -// -// // tell the server we are upgraded (it will bounce back telling us to connect) -// fragmentedRegistration.upgraded = true -// channel.writeAndFlush(fragmentedRegistration) -// } -// -// // now tell the server we are done with the fragments -// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey()) -// fragmentedRegistration.payload = fragments[allButLast] -// -// // tell the server we are fragmented -// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED -// -// // tell the server we are upgraded (it will bounce back telling us to connect) -// fragmentedRegistration.upgraded = true -// channel.writeAndFlush(fragmentedRegistration) -// } else { -// registration.payload = details -// -// // tell the server we are upgraded (it will bounce back telling us to connect) -// registration.upgraded = true -// channel.writeAndFlush(registration) -// } -// return true -// } - - // /** - // * Closes all connections ONLY (keeps the client running). To STOP the client, use stop(). - // *

- // * This is used, for example, when reconnecting to a server. - // */ - // protected - // void closeConnection() { - // if (isConnected.get()) { - // // make sure we're not waiting on registration - // stopRegistration(); - // - // // for the CLIENT only, we clear these connections! (the server only clears them on shutdown) - // - // // stop does the same as this + more. Only keep the listeners for connections IF we are the client. If we remove listeners as a client, - // // ALL of the client logic will be lost. The server is reactive, so listeners are added to connections as needed (instead of before startup) - // connectionManager.closeConnections(true); - // - // // Sometimes there might be "lingering" connections (ie, halfway though registration) that need to be closed. - // registrationWrapper.clearSessions(); - // - // - // closeConnections(true); - // shutdownAllChannels(); - // // shutdownEventLoops(); we don't do this here! - // - // connection = null; - // isConnected.set(false); - // - // previousClosedConnectionActivity = System.nanoTime(); - // } - // } -// /** -// * Internal call to abort registration if the shutdown command is issued during channel registration. -// */ -// @Suppress("unused") -// fun abortRegistration() { -// // make sure we're not waiting on registration -//// stopRegistration() -// } - override fun close() { val con = connection connection = null if (con != null) { - handshake.removeConnection(con) + connections.remove(con) + + runBlocking { + con.close() + listenerManager.notifyDisconnect(con) + } } super.close() + isConnected = false } - // RMI notes (in multiple places, copypasta, because this is confusing if not written down + // RMI notes (in multiple places, copypasta, because this is confusing if not written down) // // only server can create a global object (in itself, via save) // server diff --git a/src/dorkbox/network/Configuration.kt b/src/dorkbox/network/Configuration.kt index d88f07ae..acc5409f 100644 --- a/src/dorkbox/network/Configuration.kt +++ b/src/dorkbox/network/Configuration.kt @@ -20,12 +20,14 @@ import dorkbox.network.aeron.CoroutineIdleStrategy import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy import dorkbox.network.serialization.NetworkSerializationManager import dorkbox.network.serialization.Serialization -import dorkbox.network.store.PropertyStore -import dorkbox.network.store.SettingsStore +import dorkbox.network.storage.PropertyStore +import dorkbox.network.storage.SettingsStore +import dorkbox.os.OS import dorkbox.util.storage.StorageBuilder import dorkbox.util.storage.StorageSystem import io.aeron.driver.Configuration import io.aeron.driver.ThreadingMode +import mu.KLogger import java.io.File class ServerConfiguration : dorkbox.network.Configuration() { @@ -35,11 +37,6 @@ class ServerConfiguration : dorkbox.network.Configuration() { */ var listenIpAddress = "*" - /** - * The starting port for clients to use. The upper bound of this value is limited by the maximum number of clients allowed. - */ - var clientStartPort = 0 - /** * The maximum number of clients allowed for a server */ @@ -55,6 +52,8 @@ open class Configuration { /** * When connecting to a remote client/server, should connections be allowed if the remote machine signature has changed? + * + * Setting this to false is not recommended as it is a security risk */ var enableRemoteSignatureValidation: Boolean = true @@ -103,8 +102,8 @@ open class Configuration { * The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT. * * There are a couple strategies of importance to understand. - * * BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default. - * * BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less + * - BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default. + * - BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less * responsive to activity when idle for a little while. * * The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and @@ -116,8 +115,8 @@ open class Configuration { * The idle strategy used when polling the Media Driver for new messages. BackOffIdleStrategy is the DEFAULT. * * There are a couple strategies of importance to understand. - * * BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default. - * * BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less + * - BusySpinIdleStrategy uses a busy spin as an idle and will eat up CPU by default. + * - BackOffIdleStrategy uses a backoff strategy of spinning, yielding, and parking to be kinder to the CPU, but to be less * responsive to activity when idle for a little while. * * The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and @@ -126,26 +125,22 @@ open class Configuration { var sendIdleStrategy: CoroutineIdleStrategy = CoroutineSleepingMillisIdleStrategy(sleepPeriodMs = 100) /** - * A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation. + * ## A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation. * * * There are three main Agents in the driver: - * - * - * Conductor: Responsible for reacting to client requests and house keeping duties as well as detecting loss, sending NAKs, + * - Conductor: Responsible for reacting to client requests and house keeping duties as well as detecting loss, sending NAKs, * rotating buffers, etc. - * Sender: Responsible for shovelling messages from publishers to the network. - * Receiver: Responsible for shovelling messages from the network to subscribers. + * - Sender: Responsible for shovelling messages from publishers to the network. + * - Receiver: Responsible for shovelling messages from the network to subscribers. * * * This value can be one of: - * - * - * INVOKER: No threads. The client is responsible for using the MediaDriver.Context.driverAgentInvoker() to invoke the duty + * - INVOKER: No threads. The client is responsible for using the MediaDriver.Context.driverAgentInvoker() to invoke the duty * cycle directly. - * SHARED: All Agents share a single thread. 1 thread in total. - * SHARED_NETWORK: Sender and Receiver shares a thread, conductor has its own thread. 2 threads in total. - * DEDICATED: The default and dedicates one thread per Agent. 3 threads in total. + * - SHARED: All Agents share a single thread. 1 thread in total. + * - SHARED_NETWORK: Sender and Receiver shares a thread, conductor has its own thread. 2 threads in total. + * - DEDICATED: The default and dedicates one thread per Agent. 3 threads in total. * * * For performance, it is recommended to use DEDICATED as long as the number of busy threads is less than or equal to the number of @@ -217,4 +212,31 @@ open class Configuration { * A value of 0 will 'auto-configure' this setting. */ var receiveBufferSize = 0 + + /** + * Depending on the OS, different base locations for the Aeron log directory are preferred. + */ + fun suggestAeronLogLocation(logger: KLogger): File { + return when { + OS.isMacOsX() -> { + // does the recommended location exist?? + val suggestedLocation = File("/Volumes/DevShm") + if (suggestedLocation.exists()) { + suggestedLocation + } + else { + logger.info("It is recommended to create a RAM drive for best performance. For example\n" + "\$ diskutil erasevolume HFS+ \"DevShm\" `hdiutil attach -nomount ram://\$((2048 * 2048))`\n" + "\t After this, set config.aeronLogDirectory = \"/Volumes/DevShm\"") + + File(System.getProperty("java.io.tmpdir")) + } + } + OS.isLinux() -> { + // this is significantly faster for linux than using the temp dir + File("/dev/shm/") + } + else -> { + File(System.getProperty("java.io.tmpdir")) + } + } + } } diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index ec7c1823..699293fd 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -16,6 +16,7 @@ package dorkbox.network import dorkbox.netUtil.IPv4 +import dorkbox.netUtil.IPv6 import dorkbox.network.aeron.server.ServerException import dorkbox.network.connection.Connection import dorkbox.network.connection.EndPoint @@ -36,9 +37,6 @@ import java.net.InetSocketAddress import java.util.concurrent.CopyOnWriteArrayList /** - * NOTE: when using "server.publish(A)", this will go to ALL CLIENTS! add this to aeron via "publication.addDestination" so aeron manages it - * - * * The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the * server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections()) * @@ -75,7 +73,7 @@ open class Server(config: ServerConfiguration = ServerC private var bindAlreadyCalled = false - override val handshake = ServerHandshake(logger, config, listenerManager) + private val handshake = ServerHandshake(logger, config, listenerManager) /** * Maintains a thread-safe collection of rules used to define the connection type with this server. @@ -90,25 +88,29 @@ open class Server(config: ServerConfiguration = ServerC when (config.listenIpAddress) { "loopback", "localhost", "lo" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress else -> when { - config.listenIpAddress.startsWith("127.") -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress - config.listenIpAddress.startsWith("::1") -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress + IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress + IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress else -> config.listenIpAddress = "0.0.0.0" // we set this to "0.0.0.0" so that it is clear that we are trying to bind to that address. } } - // if we are IPv6, the IP must be in '[]' - if (config.listenIpAddress.count { it == ':' } > 1 && - config.listenIpAddress.count { it == '[' } < 1 && - config.listenIpAddress.count { it == ']' } < 1) { - - config.listenIpAddress = """[${config.listenIpAddress}]""" - } - + // if we are IPv4 wildcard if (config.listenIpAddress == "0.0.0.0") { // this will also fixup windows! config.listenIpAddress = IPv4.WILDCARD } + if (IPv6.isValid(config.listenIpAddress)) { + // "[" and "]" are valid for ipv6 addresses... we want to make sure it is so + + // if we are IPv6, the IP must be in '[]' + if (config.listenIpAddress.count { it == '[' } < 1 && + config.listenIpAddress.count { it == ']' } < 1) { + + config.listenIpAddress = """[${config.listenIpAddress}]""" + } + } + if (config.publicationPort <= 0) { throw ServerException("configuration port must be > 0") } if (config.publicationPort >= 65535) { throw ServerException("configuration port must be < 65535") } @@ -119,8 +121,9 @@ open class Server(config: ServerConfiguration = ServerC if (config.networkMtuSize <= 0) { throw ServerException("configuration networkMtuSize must be > 0") } if (config.networkMtuSize >= 9 * 1024) { throw ServerException("configuration networkMtuSize must be < ${9 * 1024}") } - autoClosableObjects.add(handshake) + if (config.maxConnectionsPerIpAddress == 0) { config.maxConnectionsPerIpAddress = config.maxClientCount} } + override fun newException(message: String, cause: Throwable?): Throwable { return ServerException(message, cause) } @@ -145,20 +148,18 @@ open class Server(config: ServerConfiguration = ServerC // setup the "HANDSHAKE" ports, for initial clients to connect. // The is how clients then get the new ports to connect to + other configuration options - val handshakeDriver = UdpMediaDriverConnection( - address = config.listenIpAddress, - subscriptionPort = config.subscriptionPort, - publicationPort = config.publicationPort, - streamId = UDP_HANDSHAKE_STREAM_ID, - sessionId = RESERVED_SESSION_ID_INVALID) + val handshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress, + publicationPort = config.publicationPort, + subscriptionPort = config.subscriptionPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID) handshakeDriver.buildServer(aeron) val handshakePublication = handshakeDriver.publication val handshakeSubscription = handshakeDriver.subscription - logger.debug(handshakeDriver.serverInfo()) - logger.debug("Server listening for incoming clients on ${handshakePublication.localSocketAddresses()}") + logger.info(handshakeDriver.serverInfo()) val ipcHandshakeDriver = IpcMediaDriverConnection( @@ -197,7 +198,11 @@ open class Server(config: ServerConfiguration = ServerC try { var pollCount: Int + while (!isShutdown()) { + // Get the current time, used to cleanup connections + val now = System.currentTimeMillis() + pollCount = 0 // this checks to see if there are NEW clients @@ -206,8 +211,38 @@ open class Server(config: ServerConfiguration = ServerC // this checks to see if there are NEW clients via IPC // pollCount += ipcHandshakeSubscription.poll(ipcInitialConnectionHandler, 100) + // this manages existing clients (for cleanup + connection polling) - pollCount += handshake.poll() + connections.forEachWithCleanup({ connection -> + // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. + var shouldCleanupConnection = false + + if (connection.isExpired(now)) { + logger.debug("[{}] connection expired", connection.sessionId) + shouldCleanupConnection = true + } + + if (connection.isClosed()) { + logger.debug("[{}] connection closed", connection.sessionId) + shouldCleanupConnection = true + } + if (shouldCleanupConnection) { + true + } + else { + // Otherwise, poll the duologue for activity. + pollCount += connection.pollSubscriptions() + false + } + }, { connectionToClean -> + logger.debug("[{}] deleted connection", connectionToClean.sessionId) + + // have to free up resources! + handshake.cleanup(connectionToClean) + + connectionToClean.close() + listenerManager.notifyDisconnect(connectionToClean) + }) // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) @@ -229,7 +264,14 @@ open class Server(config: ServerConfiguration = ServerC } } + internal suspend fun poll(): Int { + var pollCount = 0 + + + + return pollCount + } /** @@ -250,45 +292,11 @@ open class Server(config: ServerConfiguration = ServerC connectionRules.addAll(listOf(*rules)) } - - - - - - // verify the class ID registration details. - // the client will send their class registration data. VERIFY IT IS CORRECT! - - // verify the class ID registration details. - // the client will send their class registration data. VERIFY IT IS CORRECT! -// var state: dorkbox.network.connection.RegistrationWrapper.STATE = registrationWrapper.verifyClassRegistration(metaChannel, registration) -// if (state == RegistrationWrapper.STATE.ERROR) -// { -// // abort! There was an error -// shutdown(channel, 0) -// return -// } else if (state == RegistrationWrapper.STATE.WAIT) -// { -// return -// } - - - - - /** - * Checks to seeOnce a server has connected to ANY client, it will always return true until server.close() is called - * - * @return true if we are connected, false otherwise. - */ - override fun isConnected(): Boolean { - return handshake.connectionCount() > 0 - } - - /** * Safely sends objects to a destination */ suspend fun send(message: Any) { - handshake.send(message) + connections.send(message) } /** @@ -348,6 +356,7 @@ open class Server(config: ServerConfiguration = ServerC /** + * TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of) * Adds a custom connection to the server. * * This should only be used in situations where there can be DIFFERENT types of connections (such as a 'web-based' connection) and @@ -356,10 +365,11 @@ open class Server(config: ServerConfiguration = ServerC * @param connection the connection to add */ fun addConnection(connection: CONNECTION) { - handshake.addConnection(connection) + connections.add(connection) } /** + * TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of) * Removes a custom connection to the server. * * @@ -369,11 +379,10 @@ open class Server(config: ServerConfiguration = ServerC * @param connection the connection to remove */ fun removeConnection(connection: CONNECTION) { - handshake.removeConnection(connection) + connections.remove(connection) } - /** * Checks to see if a server (using the specified configuration) is running. * @@ -467,7 +476,7 @@ open class Server(config: ServerConfiguration = ServerC - // RMI notes (in multiple places, copypasta, because this is confusing if not written down + // RMI notes (in multiple places, copypasta, because this is confusing if not written down) // // only server can create a global object (in itself, via save) // server diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index 1a0aa911..5a16f34f 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -3,7 +3,12 @@ package dorkbox.network.handshake import dorkbox.network.Configuration import dorkbox.network.aeron.client.ClientException import dorkbox.network.aeron.client.ClientTimedOutException -import dorkbox.network.connection.* +import dorkbox.network.connection.Connection +import dorkbox.network.connection.CryptoManagement +import dorkbox.network.connection.EndPoint +import dorkbox.network.connection.ListenerManager +import dorkbox.network.connection.MediaDriverConnection +import dorkbox.network.connection.UdpMediaDriverConnection import io.aeron.FragmentAssembler import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.Header @@ -12,10 +17,10 @@ import mu.KLogger import org.agrona.DirectBuffer import java.security.SecureRandom -internal class ClientHandshake(logger: KLogger, - config: Configuration, - listenerManager: ListenerManager, - val crypto: CryptoManagement) : ConnectionManager(logger, config, listenerManager) { +internal class ClientHandshake(private val logger: KLogger, + private val config: Configuration, + private val crypto: CryptoManagement, + private val listenerManager: ListenerManager) { // a one-time key for connecting private val oneTimePad = SecureRandom().nextInt() @@ -25,8 +30,8 @@ internal class ClientHandshake(logger: KLogger, @Volatile var connectionDone = false - - private var failed = false + @Volatile + private var failed: Exception? = null lateinit var handler: FragmentHandler lateinit var endPoint: EndPoint<*> @@ -38,37 +43,48 @@ internal class ClientHandshake(logger: KLogger, // now we have a bi-directional connection with the server on the handshake "socket". handler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> endPoint.actionDispatch.launch { + val sessionId = header.sessionId() + val message = endPoint.readHandshakeMessage(buffer, offset, length, header) - logger.debug("[{}] handshake response: {}", sessionId, message) + logger.trace { + "[$sessionId] handshake response: $message" + } // it must be a registration message - if (message !is Message) { - logger.error("[{}] server returned unrecognized message: {}", sessionId, message) + if (message !is HandshakeMessage) { + failed = ClientException("[$sessionId] server returned unrecognized message: $message") return@launch } - if (message.sessionId != sessionId) { - logger.error("[{}] ignored message intended for another client", sessionId) + // this is an error message + if (message.sessionId == 0) { + failed = ClientException("[$sessionId] error: ${message.errorMessage}") + return@launch + } + + + if (this@ClientHandshake.sessionId != message.sessionId) { + failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: ${this@ClientHandshake.sessionId}") return@launch } // it must be the correct state when (message.state) { - Message.HELLO_ACK -> { + HandshakeMessage.HELLO_ACK -> { // The message was intended for this client. Try to parse it as one of the available message types. // this message is ENCRYPTED! connectionHelloInfo = crypto.decrypt(message.registrationData, message.publicKey) connectionHelloInfo!!.log(sessionId, logger) } - Message.DONE_ACK -> { + HandshakeMessage.DONE_ACK -> { connectionDone = true } else -> { - if (message.state != Message.HELLO_ACK) { - logger.error("[{}] ignored message that is not HELLO_ACK", sessionId) - } else if (message.state != Message.DONE_ACK) { - logger.error("[{}] ignored message that is not INIT_ACK", sessionId) + if (message.state != HandshakeMessage.HELLO_ACK) { + failed = ClientException("[$sessionId] ignored message that is not HELLO_ACK") + } else if (message.state != HandshakeMessage.DONE_ACK) { + failed = ClientException("[$sessionId] ignored message that is not DONE_ACK") } return@launch @@ -79,7 +95,7 @@ internal class ClientHandshake(logger: KLogger, } suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { - val registrationMessage = Message.helloFromClient( + val registrationMessage = HandshakeMessage.helloFromClient( oneTimePad = oneTimePad, publicKey = config.settingsStore.getPublicKey()!!, registrationData = config.serialization.getKryoRegistrationDetails() @@ -93,7 +109,7 @@ internal class ClientHandshake(logger: KLogger, // block until we receive the connection information from the server - failed = false + failed = null var pollCount: Int val subscription = handshakeConnection.subscription val pollIdleStrategy = endPoint.config.pollIdleStrategy @@ -102,10 +118,10 @@ internal class ClientHandshake(logger: KLogger, while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { pollCount = subscription.poll(handler, 1024) - if (failed) { + if (failed != null) { // no longer necessary to hold this connection open handshakeConnection.close() - throw ClientException("Server rejected this client") + throw failed as Exception } if (connectionHelloInfo != null) { @@ -127,7 +143,7 @@ internal class ClientHandshake(logger: KLogger, } suspend fun handshakeDone(mediaConnection: UdpMediaDriverConnection, connectionTimeoutMS: Long): Boolean { - val registrationMessage = Message.doneFromClient() + val registrationMessage = HandshakeMessage.doneFromClient() // Send the done message to the server. endPoint.writeHandshakeMessage(mediaConnection.publication, registrationMessage) @@ -135,7 +151,7 @@ internal class ClientHandshake(logger: KLogger, // block until we receive the connection information from the server - failed = false + failed = null var pollCount: Int val subscription = mediaConnection.subscription val pollIdleStrategy = endPoint.config.pollIdleStrategy @@ -144,10 +160,10 @@ internal class ClientHandshake(logger: KLogger, while (System.currentTimeMillis() - startTime < connectionTimeoutMS) { pollCount = subscription.poll(handler, 1024) - if (failed) { + if (failed != null) { // no longer necessary to hold this connection open mediaConnection.close() - throw ClientException("Server rejected this client") + throw failed as Exception } if (connectionDone) { @@ -164,9 +180,6 @@ internal class ClientHandshake(logger: KLogger, throw ClientTimedOutException("Waiting for registration response from server") } - // no longer necessary to hold this connection open - mediaConnection.close() - return connectionDone } } diff --git a/src/dorkbox/network/handshake/Message.kt b/src/dorkbox/network/handshake/HandshakeMessage.kt similarity index 74% rename from src/dorkbox/network/handshake/Message.kt rename to src/dorkbox/network/handshake/HandshakeMessage.kt index 1ee625d5..3de55bc6 100644 --- a/src/dorkbox/network/handshake/Message.kt +++ b/src/dorkbox/network/handshake/HandshakeMessage.kt @@ -18,7 +18,7 @@ package dorkbox.network.handshake /** * Internal message to handle the connection registration process */ -class Message private constructor() { +internal class HandshakeMessage private constructor() { // the public key is used to encrypt the data in the handshake var publicKey: ByteArray? = null @@ -67,8 +67,8 @@ class Message private constructor() { const val DONE = 2 const val DONE_ACK = 3 - fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray): Message { - val hello = Message() + fun helloFromClient(oneTimePad: Int, publicKey: ByteArray, registrationData: ByteArray): HandshakeMessage { + val hello = HandshakeMessage() hello.state = HELLO hello.oneTimePad = oneTimePad hello.publicKey = publicKey @@ -76,28 +76,28 @@ class Message private constructor() { return hello } - fun helloAckToClient(sessionId: Int): Message { - val hello = Message() + fun helloAckToClient(sessionId: Int): HandshakeMessage { + val hello = HandshakeMessage() hello.state = HELLO_ACK hello.sessionId = sessionId // has to be the same as before (the client expects this) return hello } - fun doneFromClient(): Message { - val hello = Message() + fun doneFromClient(): HandshakeMessage { + val hello = HandshakeMessage() hello.state = DONE return hello } - fun doneToClient(sessionId: Int): Message { - val hello = Message() + fun doneToClient(sessionId: Int): HandshakeMessage { + val hello = HandshakeMessage() hello.state = DONE_ACK hello.sessionId = sessionId return hello } - fun error(errorMessage: String?): Message { - val error = Message() + fun error(errorMessage: String): HandshakeMessage { + val error = HandshakeMessage() error.state = INVALID error.errorMessage = errorMessage return error @@ -105,6 +105,22 @@ class Message private constructor() { } override fun toString(): String { - return "Message(oneTimePad=$oneTimePad, state=$state)" + val stateStr = when(state) { + INVALID -> "INVALID" + HELLO -> "HELLO" + HELLO_ACK -> "HELLO_ACK" + DONE -> "DONE" + DONE_ACK -> "DONE_ACK" + else -> "ERROR. THIS SHOULD NEVER HAPPEN FOR STATE!" + } + + val errorMsg = if (errorMessage == null) { + "" + } else { + ", Error: $errorMessage" + } + + + return "HandshakeMessage(oneTimePad=$oneTimePad, sid= $sessionId $stateStr$errorMsg)" } } diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index aa73eb89..1dd51e46 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -2,8 +2,16 @@ package dorkbox.network.handshake import dorkbox.netUtil.IPv4 import dorkbox.network.ServerConfiguration -import dorkbox.network.aeron.server.* -import dorkbox.network.connection.* +import dorkbox.network.aeron.client.ClientRejectedException +import dorkbox.network.aeron.server.AllocationException +import dorkbox.network.aeron.server.RandomIdAllocator +import dorkbox.network.aeron.server.ServerException +import dorkbox.network.connection.Connection +import dorkbox.network.connection.ConnectionParams +import dorkbox.network.connection.EndPoint +import dorkbox.network.connection.ListenerManager +import dorkbox.network.connection.PublicKeyValidationState +import dorkbox.network.connection.UdpMediaDriverConnection import io.aeron.Image import io.aeron.Publication import io.aeron.logbuffer.Header @@ -16,27 +24,16 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import kotlin.concurrent.write - /** - * TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of) - * * @throws IllegalArgumentException If the port range is not valid */ -internal class ServerHandshake(logger: KLogger, - config: ServerConfiguration, - listenerManager: ListenerManager) : - ConnectionManager(logger, config, listenerManager) { - - companion object { - // this is the number of ports used per client. Depending on how a client is configured, this number can change - const val portsPerClient = 2 - } +internal class ServerHandshake(private val logger: KLogger, + private val config: ServerConfiguration, + private val listenerManager: ListenerManager) { private val pendingConnectionsLock = ReentrantReadWriteLock() private val pendingConnections = Int2ObjectHashMap() - - private val portAllocator: PortAllocator private val connectionsPerIpCounts = Int2IntCounterMap(0) // guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!) @@ -44,15 +41,6 @@ internal class ServerHandshake(logger: KLogger, EndPoint.RESERVED_SESSION_ID_HIGH) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) - init { - val minPort = config.clientStartPort - val maxPortCount = portsPerClient * config.maxClientCount - portAllocator = PortAllocator(minPort, maxPortCount) - - logger.info("Server connection port range [$minPort - ${minPort + maxPortCount}]") - } - - // note: this is called in action dispatch suspend fun receiveHandshakeMessageServer(handshakePublication: Publication, buffer: DirectBuffer, offset: Int, length: Int, header: Header, @@ -70,28 +58,27 @@ internal class ServerHandshake(logger: KLogger, // val port = remoteIpAndPort.substring(splitPoint+1) val clientAddress = IPv4.toInt(clientAddressString) - config as ServerConfiguration - val message = endPoint.readHandshakeMessage(buffer, offset, length, header) // VALIDATE:: a Registration object is the only acceptable message during the connection phase - if (message !is Message) { + if (message !is HandshakeMessage) { listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request")) - endPoint.writeHandshakeMessage(handshakePublication, Message.error("Invalid connection request")) + endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) return } val clientPublicKeyBytes = message.publicKey + val validateRemoteAddress: PublicKeyValidationState // check to see if this is a pending connection - if (message.state == Message.DONE) { + if (message.state == HandshakeMessage.DONE) { pendingConnectionsLock.write { val pendingConnection = pendingConnections.remove(sessionId) if (pendingConnection != null) { logger.debug("Connection from client $clientAddressString ready") // now tell the client we are done - endPoint.writeHandshakeMessage(handshakePublication, Message.doneToClient(sessionId)) + endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) endPoint.actionDispatch.launch { listenerManager.notifyConnect(pendingConnection) @@ -104,14 +91,16 @@ internal class ServerHandshake(logger: KLogger, try { // VALIDATE:: Check to see if there are already too many clients connected. - if (connectionCount() >= config.maxClientCount) { - listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full")) - endPoint.writeHandshakeMessage(handshakePublication, Message.error("Server full. Max allowed is ${config.maxClientCount}")) + if (endPoint.connections.connectionCount() >= config.maxClientCount) { + listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) + + endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) return } // VALIDATE:: check to see if the remote connection's public key has changed! - if (!endPoint.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes)) { + validateRemoteAddress = endPoint.crypto.validateRemoteAddress(clientAddress, clientPublicKeyBytes) + if (validateRemoteAddress == PublicKeyValidationState.INVALID) { listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Public key mismatch.")) return } @@ -125,17 +114,17 @@ internal class ServerHandshake(logger: KLogger, // VALIDATE:: we are now connected to the client and are going to create a new connection. val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) if (currentCountForIp >= config.maxConnectionsPerIpAddress) { - listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString")) + listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) - // decrement it now, since we aren't going to permit this connection (take the hit on failure, instead + // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always) connectionsPerIpCounts.getAndDecrement(clientAddress) - endPoint.writeHandshakeMessage(handshakePublication, Message.error("too many connections for IP address. Max allowed is ${config.maxConnectionsPerIpAddress}")) + endPoint.writeHandshakeMessage(handshakePublication, + HandshakeMessage.error("Too many connections for IP address")) return } } catch (e: Exception) { - listenerManager.notifyError(ClientRejectedException("could not validate client message", - e)) - endPoint.writeHandshakeMessage(handshakePublication, Message.error("Invalid connection")) + listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) + endPoint.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) return } @@ -149,19 +138,6 @@ internal class ServerHandshake(logger: KLogger, ///// - // allocate ports for the client - val connectionPorts: IntArray - - try { - // throws exception if this is not possible - connectionPorts = portAllocator.allocate(portsPerClient) - } catch (e: IllegalArgumentException) { - // have to unwind actions! - connectionsPerIpCounts.getAndDecrement(clientAddress) - listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate $portsPerClient ports for client connection!")) - return - } - // allocate session/stream id's val connectionSessionId: Int try { @@ -169,9 +145,11 @@ internal class ServerHandshake(logger: KLogger, } catch (e: AllocationException) { // have to unwind actions! connectionsPerIpCounts.getAndDecrement(clientAddress) - portAllocator.free(connectionPorts) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")) + + endPoint.writeHandshakeMessage(handshakePublication, + HandshakeMessage.error("Connection error!")) return } @@ -182,30 +160,34 @@ internal class ServerHandshake(logger: KLogger, } catch (e: AllocationException) { // have to unwind actions! connectionsPerIpCounts.getAndDecrement(clientAddress) - portAllocator.free(connectionPorts) sessionIdAllocator.free(connectionSessionId) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")) + + endPoint.writeHandshakeMessage(handshakePublication, + HandshakeMessage.error("Connection error!")) return } val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box? - val subscriptionPort = connectionPorts[0] - val publicationPort = connectionPorts[1] + + // the pub/sub do not necessarily have to be the same. The can be ANY port + val publicationPort = config.publicationPort + val subscriptionPort = config.subscriptionPort // create a new connection. The session ID is encrypted. try { // connection timeout of 0 doesn't matter. it is not used by the server val clientConnection = UdpMediaDriverConnection(serverAddress, - subscriptionPort, publicationPort, + subscriptionPort, connectionStreamId, connectionSessionId, 0, message.isReliable) - val connection: Connection = endPoint.newConnection(endPoint, clientConnection) + val connection: Connection = endPoint.newConnection(ConnectionParams(endPoint, clientConnection, validateRemoteAddress)) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) @Suppress("UNCHECKED_CAST") @@ -213,7 +195,6 @@ internal class ServerHandshake(logger: KLogger, if (!permitConnection) { // have to unwind actions! connectionsPerIpCounts.getAndDecrement(clientAddress) - portAllocator.free(connectionPorts) sessionIdAllocator.free(connectionSessionId) streamIdAllocator.free(connectionStreamId) @@ -221,17 +202,15 @@ internal class ServerHandshake(logger: KLogger, listenerManager.notifyError(connection, ClientRejectedException("Connection was not permitted!")) + + endPoint.writeHandshakeMessage(handshakePublication, + HandshakeMessage.error("Connection was not permitted!")) return } - logger.info { - "Client connected [$clientAddressString:$subscriptionPort|$publicationPort] (session: $sessionId)" - } - - logger.debug("Created new client connection sessionID {}", connectionSessionId) // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! - val successMessage = Message.helloAckToClient(sessionId) + val successMessage = HandshakeMessage.helloAckToClient(sessionId) // now create the encrypted payload, using ECDH successMessage.registrationData = endPoint.crypto.encrypt(publicationPort, @@ -242,19 +221,19 @@ internal class ServerHandshake(logger: KLogger, successMessage.publicKey = endPoint.crypto.publicKeyBytes - // this tells the client all of the info to connect. - endPoint.writeHandshakeMessage(handshakePublication, successMessage) - - addConnection(connection) + // this enables the connection to start polling for messages + endPoint.connections.add(connection) // before we notify connect, we have to wait for the client to tell us that they can receive data pendingConnectionsLock.write { pendingConnections[sessionId] = connection } + + // this tells the client all of the info to connect. + endPoint.writeHandshakeMessage(handshakePublication, successMessage) } catch (e: Exception) { // have to unwind actions! connectionsPerIpCounts.getAndDecrement(clientAddress) - portAllocator.free(connectionPorts) sessionIdAllocator.free(connectionSessionId) streamIdAllocator.free(connectionStreamId) @@ -262,49 +241,12 @@ internal class ServerHandshake(logger: KLogger, } } - - suspend fun poll(): Int { - // Get the current time, used to cleanup connections - val now = System.currentTimeMillis() - var pollCount = 0 - - forEachConnectionCleanup({ connection -> - // If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted. - var shouldCleanupConnection = false - - if (connection.isExpired(now)) { - logger.debug("[{}] connection expired", connection.sessionId) - shouldCleanupConnection = true - } - - if (connection.isClosed()) { - logger.debug("[{}] connection closed", connection.sessionId) - shouldCleanupConnection = true - } - if (shouldCleanupConnection) { - true - } - else { - // Otherwise, poll the duologue for activity. - pollCount += connection.pollSubscriptions() - false - } - }, { connectionToClean -> - logger.debug("[{}] deleted connection", connectionToClean.sessionId) - - removeConnection(connectionToClean) - - // have to free up resources! - connectionsPerIpCounts.getAndDecrement(connectionToClean.remoteAddressInt) - portAllocator.free(connectionToClean.subscriptionPort) - portAllocator.free(connectionToClean.publicationPort) - sessionIdAllocator.free(connectionToClean.sessionId) - streamIdAllocator.free(connectionToClean.streamId) - - listenerManager.notifyDisconnect(connectionToClean) - connectionToClean.close() - }) - - return pollCount + /** + * Free up resources from the closed connection + */ + fun cleanup(connection: CONNECTION) { + connectionsPerIpCounts.getAndDecrement(connection.remoteAddressInt) + sessionIdAllocator.free(connection.sessionId) + streamIdAllocator.free(connection.streamId) } }