diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index 97d65e68..8cc39098 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -16,26 +16,26 @@ package dorkbox.network import dorkbox.netUtil.IPv4 -import dorkbox.netUtil.IPv6 -import dorkbox.network.aeron.client.ClientException -import dorkbox.network.aeron.client.ClientRejectedException -import dorkbox.network.aeron.client.ClientTimedOutException +import dorkbox.network.aeron.IpcMediaDriverConnection +import dorkbox.network.aeron.UdpMediaDriverConnection import dorkbox.network.connection.Connection import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.EndPoint -import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.PublicKeyValidationState -import dorkbox.network.connection.UdpMediaDriverConnection +import dorkbox.network.exceptions.ClientException +import dorkbox.network.exceptions.ClientRejectedException +import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.handshake.ClientHandshake import dorkbox.network.other.coroutines.SuspendWaiter import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.TimeoutException -import dorkbox.util.exceptions.SecurityException import kotlinx.atomicfu.atomic import kotlinx.coroutines.launch +import java.net.Inet4Address +import java.net.InetAddress /** * The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's @@ -59,7 +59,7 @@ open class Client(config: Configuration = Configuration * For the IPC (Inter-Process-Communication) address. it must be: * - the IPC integer ID, "0x1337c0de", "0x12312312", etc. */ - private var remoteAddress0 = "" + private var remoteAddress0: InetAddress? = IPv4.LOCALHOST @Volatile private var isConnected = false @@ -67,9 +67,6 @@ open class Client(config: Configuration = Configuration // is valid when there is a connection to the server, otherwise it is null private var connection0: CONNECTION? = null - @Volatile - private var connectionTimeoutMS: Long = 5_000 // default is 5 seconds - private val previousClosedConnectionActivity: Long = 0 private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization) @@ -107,9 +104,12 @@ open class Client(config: Configuration = Configuration * ### For a network address, it can be: * - a network name ("localhost", "loopback", "lo", "bob.example.org") * - an IP address ("127.0.0.1", "123.123.123.123", "::1") + * - an InetAddress address * * ### For the IPC (Inter-Process-Communication) it must be: - * - EMPTY. ie: just call `connect()` + * - EMPTY. + * - `connect()` + * - `connect("")` * * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') * @@ -121,8 +121,113 @@ open class Client(config: Configuration = Configuration * @throws ClientTimedOutException if the client is unable to connect in x amount of time * @throws ClientRejectedException if the client connection is rejected */ + suspend fun connect(remoteAddress: String, + connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + when { + // this is default IPC settings + remoteAddress.isEmpty() -> connect(connectionTimeoutMS = connectionTimeoutMS) + + IPv4.isPreferred -> connect(remoteAddress = Inet4Address.getAllByName(remoteAddress)[0], + connectionTimeoutMS = connectionTimeoutMS, + reliable = reliable) + + else -> connect(remoteAddress = Inet4Address.getAllByName(remoteAddress)[0], + connectionTimeoutMS = connectionTimeoutMS, + reliable = reliable) + } + } + + /** + * Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed. + * + * Default connection is to localhost + * + * ### For a network address, it can be: + * - a network name ("localhost", "loopback", "lo", "bob.example.org") + * - an IP address ("127.0.0.1", "123.123.123.123", "::1") + * - an InetAddress address + * + * ### For the IPC (Inter-Process-Communication) it must be: + * - EMPTY. + * - `connect()` + * - `connect("")` + * + * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') + * + * @param remoteAddress The network or if localhost, IPC address for the client to connect to + * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely + * @param reliable true if we want to create a reliable connection. IPC connections are always reliable + * + * @throws IllegalArgumentException if the remote address is invalid + * @throws ClientTimedOutException if the client is unable to connect in x amount of time + * @throws ClientRejectedException if the client connection is rejected + */ + suspend fun connect(remoteAddress: InetAddress, + connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + // Default IPC ports are flipped because they are in the perspective of the SERVER + connect(remoteAddress = remoteAddress, + ipcPublicationId = IPC_HANDSHAKE_STREAM_ID_SUB, + ipcSubscriptionId = IPC_HANDSHAKE_STREAM_ID_PUB, + connectionTimeoutMS = connectionTimeoutMS, + reliable = reliable) + } + + /** + * Will attempt to connect to the server via IPC, with a default 30 second connection timeout and will block until completed. + * + * @param ipcPublicationId The IPC publication address for the client to connect to + * @param ipcSubscriptionId The IPC subscription address for the client to connect to + * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely. + * + * @throws IllegalArgumentException if the remote address is invalid + * @throws ClientTimedOutException if the client is unable to connect in x amount of time + * @throws ClientRejectedException if the client connection is rejected + */ @Suppress("DuplicatedCode") - suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { + suspend fun connect(ipcPublicationId: Int = IPC_HANDSHAKE_STREAM_ID_SUB, + ipcSubscriptionId: Int = IPC_HANDSHAKE_STREAM_ID_PUB, + connectionTimeoutMS: Long = 30_000L) { + // Default IPC ports are flipped because they are in the perspective of the SERVER + + require(ipcPublicationId != ipcSubscriptionId) { "IPC publication and subscription ports cannot be the same! The must match the server's configuration." } + + connect(remoteAddress = null, // required! + ipcPublicationId = ipcPublicationId, + ipcSubscriptionId = ipcSubscriptionId, + connectionTimeoutMS = connectionTimeoutMS) + } + + /** + * Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed. + * + * Default connection is to localhost + * + * ### For a network address, it can be: + * - a network name ("localhost", "loopback", "lo", "bob.example.org") + * - an IP address ("127.0.0.1", "123.123.123.123", "::1") + * + * ### For the IPC (Inter-Process-Communication) it must be: + * - EMPTY. ie: just call `connect()` + * - Specified EMPTY. ie: just call `connect()` + * + * ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x') + * + * @param remoteAddress The network or if localhost, IPC address for the client to connect to + * @param ipcPublicationId The IPC publication address for the client to connect to + * @param ipcSubscriptionId The IPC subscription address for the client to connect to + * @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely. + * @param reliable true if we want to create a reliable connection. IPC connections are always reliable + * + * @throws IllegalArgumentException if the remote address is invalid + * @throws ClientTimedOutException if the client is unable to connect in x amount of time + * @throws ClientRejectedException if the client connection is rejected + */ + @Suppress("DuplicatedCode") + private suspend fun connect(remoteAddress: InetAddress? = null, + // Default IPC ports are flipped because they are in the perspective of the SERVER + ipcPublicationId: Int = IPC_HANDSHAKE_STREAM_ID_SUB, + ipcSubscriptionId: Int = IPC_HANDSHAKE_STREAM_ID_PUB, + connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) { // this will exist ONLY if we are reconnecting via a "disconnect" callback lockStepForReconnect.value?.doWait() @@ -143,80 +248,31 @@ open class Client(config: Configuration = Configuration logger.info("Media driver is running. Support for enable auto-switch from LOCALHOST -> IPC enabled") } - this.connectionTimeoutMS = connectionTimeoutMS - val isIpcConnection: Boolean - // NETWORK OR IPC ADDRESS - // if we connect to "loopback", then we substitute if for IPC (with log message) + // if we connect to "loopback", then MAYBE we substitute if for IPC (with log message) // localhost/loopback IP might not always be 127.0.0.1 or ::1 - when (remoteAddress) { - "0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") - "loopback", "localhost", "lo", "" -> { - if (canAutoChangeToIpc) { - isIpcConnection = true - logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } - this.remoteAddress0 = "ipc" - } else { - isIpcConnection = false - this.remoteAddress0 = IPv4.LOCALHOST.hostAddress - } - } - "0x" -> { - isIpcConnection = true - this.remoteAddress0 = "ipc" - } - else -> when { - IPv4.isLoopback(remoteAddress) -> { - if (canAutoChangeToIpc) { - isIpcConnection = true - logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } - this.remoteAddress0 = "ipc" - } else { - isIpcConnection = false - this.remoteAddress0 = IPv4.LOCALHOST.hostAddress - } - } - IPv6.isLoopback(remoteAddress) -> { - if (canAutoChangeToIpc) { - isIpcConnection = true - logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } - this.remoteAddress0 = "ipc" - } else { - isIpcConnection = false - this.remoteAddress0 = IPv6.LOCALHOST.hostAddress - } - } - else -> { - isIpcConnection = false - this.remoteAddress0 = remoteAddress - } + when { + remoteAddress == null -> this.remoteAddress0 = null + remoteAddress.isAnyLocalAddress -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!") + canAutoChangeToIpc && remoteAddress.isLoopbackAddress -> { + logger.info { "Auto-changing network connection from $remoteAddress -> IPC" } + this.remoteAddress0 = null } + else -> this.remoteAddress0 = remoteAddress } - if (IPv6.isValid(this.remoteAddress0)) { - // "[" and "]" are valid for ipv6 addresses... we want to make sure it is so - - // if we are IPv6, the IP must be in '[]' - if (this.remoteAddress0.count { it == '[' } < 1 && - this.remoteAddress0.count { it == ']' } < 1) { - - this.remoteAddress0 = """[${this.remoteAddress0}]""" - } - } - val handshake = ClientHandshake(logger, config, crypto, this) - // initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER - val handshakeConnection = if (isIpcConnection) { - IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_PUB, - streamId = IPC_HANDSHAKE_STREAM_ID_SUB, + val handshakeConnection = if (this.remoteAddress0 == null) { + IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId, + streamId = ipcPublicationId, sessionId = RESERVED_SESSION_ID_INVALID) } else { - UdpMediaDriverConnection(address = this.remoteAddress0, + UdpMediaDriverConnection(address = this.remoteAddress0!!, publicationPort = config.subscriptionPort, subscriptionPort = config.publicationPort, streamId = UDP_HANDSHAKE_STREAM_ID, @@ -227,7 +283,7 @@ open class Client(config: Configuration = Configuration // throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports - handshakeConnection.buildClient(aeron) + handshakeConnection.buildClient(aeron, logger) logger.info(handshakeConnection.clientInfo()) @@ -238,10 +294,10 @@ open class Client(config: Configuration = Configuration // VALIDATE:: check to see if the remote connection's public key has changed! - val validateRemoteAddress = if (isIpcConnection) { + val validateRemoteAddress = if (this.remoteAddress0 == null) { PublicKeyValidationState.VALID } else { - crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress0), connectionInfo.publicKey) + crypto.validateRemoteAddress(this.remoteAddress0!!, connectionInfo.publicKey) } if (validateRemoteAddress == PublicKeyValidationState.INVALID) { @@ -258,7 +314,7 @@ open class Client(config: Configuration = Configuration // we are now connected, so we can connect to the NEW client-specific ports - val reliableClientConnection = if (isIpcConnection) { + val reliableClientConnection = if (this.remoteAddress0 == null) { IpcMediaDriverConnection(sessionId = connectionInfo.sessionId, // NOTE: pub/sub must be switched! streamIdSubscription = connectionInfo.publicationPort, @@ -266,7 +322,7 @@ open class Client(config: Configuration = Configuration connectionTimeoutMS = connectionTimeoutMS) } else { - UdpMediaDriverConnection(address = handshakeConnection.address, + UdpMediaDriverConnection(address = handshakeConnection.address!!, // NOTE: pub/sub must be switched! subscriptionPort = connectionInfo.publicationPort, publicationPort = connectionInfo.subscriptionPort, @@ -277,7 +333,7 @@ open class Client(config: Configuration = Configuration } // we have to construct how the connection will communicate! - reliableClientConnection.buildClient(aeron) + reliableClientConnection.buildClient(aeron, logger) // only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object) // does not need to do anything @@ -302,7 +358,7 @@ open class Client(config: Configuration = Configuration } - val newConnection = if (isIpcConnection) { + val newConnection = if (this.remoteAddress0 == null) { newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID)) } else { newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress)) @@ -406,11 +462,17 @@ open class Client(config: Configuration = Configuration val remoteKeyHasChanged: Boolean get() = connection.hasRemoteKeyChanged() + /** + * the remote address + */ + val remoteAddress: InetAddress? + get() = remoteAddress0 + /** * the remote address, as a string. */ - val remoteAddress: String - get() = remoteAddress0 + val remoteAddressString: String + get() = remoteAddress0?.hostAddress ?: "ipc" /** * true if this connection is an IPC connection @@ -464,12 +526,10 @@ open class Client(config: Configuration = Configuration /** * Removes the specified host address from the list of registered server keys. */ - @Throws(SecurityException::class) - fun removeRegisteredServerKey(hostAddress: String) { - val address = IPv4.toInt(hostAddress) + fun removeRegisteredServerKey(address: InetAddress) { val savedPublicKey = settingsStore.getRegisteredServerKey(address) if (savedPublicKey != null) { - logger.debug { "Deleting remote IP address key $hostAddress" } + logger.debug { "Deleting remote IP address key $address" } settingsStore.removeRegisteredServerKey(address) } } diff --git a/src/dorkbox/network/Configuration.kt b/src/dorkbox/network/Configuration.kt index ebe6212e..bc9a6e24 100644 --- a/src/dorkbox/network/Configuration.kt +++ b/src/dorkbox/network/Configuration.kt @@ -18,6 +18,7 @@ package dorkbox.network import dorkbox.network.aeron.CoroutineBackoffIdleStrategy import dorkbox.network.aeron.CoroutineIdleStrategy import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy +import dorkbox.network.connection.EndPoint import dorkbox.network.serialization.Serialization import dorkbox.network.storage.PropertyStore import dorkbox.network.storage.SettingsStore @@ -30,6 +31,21 @@ import mu.KLogger import java.io.File class ServerConfiguration : dorkbox.network.Configuration() { + /** + * Enables the ability to use the IPv4 network stack. + */ + var enableIPv4 = true + + /** + * Enables the ability to use the IPv6 network stack. + */ + var enableIPv6 = true + + /** + * Enables the ability use IPC (Inter Process Communication) + */ + var enableIPC = true + /** * The address for the server to listen on. "*" will accept connections from all interfaces, otherwise specify * the hostname (or IP) to bind to. @@ -37,14 +53,24 @@ class ServerConfiguration : dorkbox.network.Configuration() { var listenIpAddress = "*" /** - * The maximum number of clients allowed for a server + * The maximum number of clients allowed for a server. IPC is unlimited */ var maxClientCount = 0 /** - * The maximum number of client connection allowed per IP address + * The maximum number of client connection allowed per IP address. IPC is unlimited */ var maxConnectionsPerIpAddress = 0 + + /** + * The IPC Publication ID is used to define what ID the server will send data on. The client IPC subscription ID must match this value. + */ + var ipcPublicationId = EndPoint.IPC_HANDSHAKE_STREAM_ID_PUB + + /** + * The IPC Subscription ID is used to define what ID the server will receive data on. The client IPC publication ID must match this value. + */ + var ipcSubscriptionId = EndPoint.IPC_HANDSHAKE_STREAM_ID_SUB } open class Configuration { diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index d3a706f0..2babddac 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -15,19 +15,23 @@ */ package dorkbox.network +import dorkbox.netUtil.IP import dorkbox.netUtil.IPv4 import dorkbox.netUtil.IPv6 -import dorkbox.network.aeron.server.ServerException +import dorkbox.network.aeron.AeronPoller +import dorkbox.network.aeron.IpcMediaDriverConnection +import dorkbox.network.aeron.UdpMediaDriverConnection import dorkbox.network.connection.Connection import dorkbox.network.connection.EndPoint -import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager -import dorkbox.network.connection.UdpMediaDriverConnection -import dorkbox.network.connection.connectionType.ConnectionRule +import dorkbox.network.connectionType.ConnectionRule +import dorkbox.network.exceptions.ServerException import dorkbox.network.handshake.ServerHandshake +import dorkbox.network.other.coroutines.SuspendWaiter import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.TimeoutException +import io.aeron.Aeron import io.aeron.FragmentAssembler import io.aeron.Image import io.aeron.logbuffer.Header @@ -35,6 +39,9 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.agrona.DirectBuffer +import java.net.Inet4Address +import java.net.Inet6Address +import java.net.InetAddress import java.util.concurrent.CopyOnWriteArrayList /** @@ -72,6 +79,13 @@ open class Server(config: ServerConfiguration = ServerC @Volatile private var bindAlreadyCalled = false + /** + * These are run in lock-step to shutdown/close the server. Afterwards, bind() can be called again + */ + private val shutdownPollWaiter = SuspendWaiter() + private val shutdownEventWaiter = SuspendWaiter() + + /** * Used for handshake connections */ @@ -82,38 +96,43 @@ open class Server(config: ServerConfiguration = ServerC */ private val connectionRules = CopyOnWriteArrayList() + internal val listenIPv4Address: InetAddress? + internal val listenIPv6Address: InetAddress? + init { // have to do some basic validation of our configuration config.listenIpAddress = config.listenIpAddress.toLowerCase() + require(config.listenIpAddress.isNotBlank()) { "Blank listen IP address, cannot continue"} + // localhost/loopback IP might not always be 127.0.0.1 or ::1 - when (config.listenIpAddress) { - "loopback", "localhost", "lo", "" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress - else -> when { - IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress - IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress - else -> config.listenIpAddress = "0.0.0.0" // we set this to "0.0.0.0" so that it is clear that we are trying to bind to that address. + // We want to listen on BOTH IPv4 and IPv6 (config option lets us configure this) + listenIPv4Address = if (!config.enableIPv4) { + null + } else { + when (config.listenIpAddress) { + "loopback", "localhost", "lo" -> IPv4.LOCALHOST + "0", "::", "0.0.0.0", "*" -> { + // this is the "wildcard" address. Windows has problems with this. + InetAddress.getByAddress("", byteArrayOf(0, 0, 0, 0)) + } + else -> Inet4Address.getAllByName(config.listenIpAddress)[0] } } - // if we are IPv4 wildcard - if (config.listenIpAddress == "0.0.0.0") { - // this will also fixup windows! - config.listenIpAddress = IPv4.WILDCARD - } - - if (IPv6.isValid(config.listenIpAddress)) { - // "[" and "]" are valid for ipv6 addresses... we want to make sure it is so - - // if we are IPv6, the IP must be in '[]' - if (config.listenIpAddress.count { it == '[' } < 1 && - config.listenIpAddress.count { it == ']' } < 1) { - - config.listenIpAddress = """[${config.listenIpAddress}]""" + listenIPv6Address = if (!config.enableIPv6) { + null + } else { + when (config.listenIpAddress) { + "loopback", "localhost", "lo" -> IPv6.LOCALHOST + "0", "::", "0.0.0.0", "*" -> { + // this is the "wildcard" address. Windows has problems with this. + InetAddress.getByAddress("", byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + } + else -> Inet6Address.getAllByName(config.listenIpAddress)[0] } } - if (config.publicationPort <= 0) { throw ServerException("configuration port must be > 0") } if (config.publicationPort >= 65535) { throw ServerException("configuration port must be < 65535") } @@ -133,6 +152,253 @@ open class Server(config: ServerConfiguration = ServerC return ServerException(message, cause) } + + private fun getIpcPoller(aeron: Aeron, config: ServerConfiguration): AeronPoller { + val poller = if (config.enableIPC) { + val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId, + streamId = config.ipcPublicationId, + sessionId = RESERVED_SESSION_ID_INVALID) + driver.buildServer(aeron, logger) + val publication = driver.publication + val subscription = driver.subscription + + object : AeronPoller { + val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> + // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! + + // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. + // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE + val sessionId = header.sessionId() + + val message = readHandshakeMessage(buffer, offset, length, header) + handshake.processIpcHandshakeMessageServer(this@Server, + publication, + sessionId, + message, + aeron) + } + + override fun poll(): Int { return subscription.poll(handler, 1) } + override fun close() { driver.close() } + override fun serverInfo(): String { return driver.serverInfo() } + } + } else { + object : AeronPoller { + override fun poll(): Int { return 0 } + override fun close() {} + override fun serverInfo(): String { return "IPC Disabled" } + } + } + + logger.info(poller.serverInfo()) + return poller + } + + private fun getIpv4Poller(aeron: Aeron, config: ServerConfiguration): AeronPoller { + val poller = if (config.enableIPv4) { + val driver = UdpMediaDriverConnection(address = listenIPv4Address!!, + publicationPort = config.publicationPort, + subscriptionPort = config.subscriptionPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID) + + driver.buildServer(aeron, logger) + val publication = driver.publication + val subscription = driver.subscription + + object : AeronPoller { + /** + * Note: + * Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is + * desired, then limiting message sizes to MTU size is a good practice. + * + * There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB. + * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery + * properties from failure and streams with mechanical sympathy. + */ + val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> + // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! + + // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. + // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE + val sessionId = header.sessionId() + + // note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!) + val remoteIpAndPort = (header.context() as Image).sourceIdentity() + + // split + val splitPoint = remoteIpAndPort.lastIndexOf(':') + val clientAddressString = remoteIpAndPort.substring(0, splitPoint) + // val port = remoteIpAndPort.substring(splitPoint+1) + + // this should never be null, because we are feeding it a valid IP address from aeron + val clientAddress = IPv4.getByNameUnsafe(clientAddressString) + + + val message = readHandshakeMessage(buffer, offset, length, header) + handshake.processUdpHandshakeMessageServer(this@Server, + publication, + sessionId, + clientAddressString, + clientAddress, + message, + aeron, + false) + } + + override fun poll(): Int { return subscription.poll(handler, 1) } + override fun close() { driver.close() } + override fun serverInfo(): String { return driver.serverInfo() } + } + } else { + object : AeronPoller { + override fun poll(): Int { return 0 } + override fun close() {} + override fun serverInfo(): String { return "IPv4 Disabled" } + } + } + + logger.info(poller.serverInfo()) + return poller + } + + private fun getIpv6Poller(aeron: Aeron, config: ServerConfiguration): AeronPoller { + val poller = if (config.enableIPv6) { + val driver = UdpMediaDriverConnection(address = listenIPv6Address!!, + publicationPort = config.publicationPort, + subscriptionPort = config.subscriptionPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID) + + driver.buildServer(aeron, logger) + val publication = driver.publication + val subscription = driver.subscription + + object : AeronPoller { + /** + * Note: + * Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is + * desired, then limiting message sizes to MTU size is a good practice. + * + * There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB. + * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery + * properties from failure and streams with mechanical sympathy. + */ + val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> + // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! + + // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. + // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE + val sessionId = header.sessionId() + + // note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!) + val remoteIpAndPort = (header.context() as Image).sourceIdentity() + + // split + val splitPoint = remoteIpAndPort.lastIndexOf(':') + val clientAddressString = remoteIpAndPort.substring(0, splitPoint) + // val port = remoteIpAndPort.substring(splitPoint+1) + + // this should never be null, because we are feeding it a valid IP address from aeron + val clientAddress = IPv6.getByName(clientAddressString)!! + + + val message = readHandshakeMessage(buffer, offset, length, header) + handshake.processUdpHandshakeMessageServer(this@Server, + publication, + sessionId, + clientAddressString, + clientAddress, + message, + aeron, + false) + } + + override fun poll(): Int { return subscription.poll(handler, 1) } + override fun close() { driver.close() } + override fun serverInfo(): String { return driver.serverInfo() } + } + } else { + object : AeronPoller { + override fun poll(): Int { return 0 } + override fun close() {} + override fun serverInfo(): String { return "IPv6 Disabled" } + } + } + + logger.info(poller.serverInfo()) + return poller + } + + private fun getIpv6WildcardPoller(aeron: Aeron, config: ServerConfiguration): AeronPoller { + val poller = if (config.enableIPv6) { + val driver = UdpMediaDriverConnection(address = listenIPv6Address!!, + publicationPort = config.publicationPort, + subscriptionPort = config.subscriptionPort, + streamId = UDP_HANDSHAKE_STREAM_ID, + sessionId = RESERVED_SESSION_ID_INVALID) + + driver.buildServer(aeron, logger) + val publication = driver.publication + val subscription = driver.subscription + + object : AeronPoller { + /** + * Note: + * Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is + * desired, then limiting message sizes to MTU size is a good practice. + * + * There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB. + * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery + * properties from failure and streams with mechanical sympathy. + */ + val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> + // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! + + // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. + // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE + val sessionId = header.sessionId() + + // note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!) + val remoteIpAndPort = (header.context() as Image).sourceIdentity() + + // split + val splitPoint = remoteIpAndPort.lastIndexOf(':') + val clientAddressString = remoteIpAndPort.substring(0, splitPoint) + // val port = remoteIpAndPort.substring(splitPoint+1) + + // this should never be null, because we are feeding it a valid IP address from aeron + // maybe IPv4, maybe IPv6!!! + val clientAddress = IP.getByName(clientAddressString)!! + + + val message = readHandshakeMessage(buffer, offset, length, header) + handshake.processUdpHandshakeMessageServer(this@Server, + publication, + sessionId, + clientAddressString, + clientAddress, + message, + aeron, + true) + } + + override fun poll(): Int { return subscription.poll(handler, 1) } + override fun close() { driver.close() } + override fun serverInfo(): String { return driver.serverInfo() } + } + } else { + object : AeronPoller { + override fun poll(): Int { return 0 } + override fun close() {} + override fun serverInfo(): String { return "IPv6 Disabled" } + } + } + + logger.info(poller.serverInfo()) + return poller + } + /** * Binds the server to AERON configuration */ @@ -150,79 +416,32 @@ open class Server(config: ServerConfiguration = ServerC config as ServerConfiguration - val ipcHandshakeDriver = IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB, - streamId = IPC_HANDSHAKE_STREAM_ID_PUB, - sessionId = RESERVED_SESSION_ID_INVALID) - ipcHandshakeDriver.buildServer(aeron) - val ipcHandshakePublication = ipcHandshakeDriver.publication - val ipcHandshakeSubscription = ipcHandshakeDriver.subscription + val ipcPoller: AeronPoller = getIpcPoller(aeron, config) + // if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled! + val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD + val ipv4Poller: AeronPoller + val ipv6Poller: AeronPoller - - val udpHandshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress, - publicationPort = config.publicationPort, - subscriptionPort = config.subscriptionPort, - streamId = UDP_HANDSHAKE_STREAM_ID, - sessionId = RESERVED_SESSION_ID_INVALID) - - udpHandshakeDriver.buildServer(aeron) - val handshakePublication = udpHandshakeDriver.publication - val handshakeSubscription = udpHandshakeDriver.subscription - - - logger.info(ipcHandshakeDriver.serverInfo()) - logger.info(udpHandshakeDriver.serverInfo()) - - - /** - * Note: - * Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is - * desired, then limiting message sizes to MTU size is a good practice. - * - * There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB. - * Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery - * properties from failure and streams with mechanical sympathy. - */ - val udpHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> - // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! - - // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. - // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE - val sessionId = header.sessionId() - - // note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!) - val remoteIpAndPort = (header.context() as Image).sourceIdentity() - - // split - val splitPoint = remoteIpAndPort.lastIndexOf(':') - val clientAddressString = remoteIpAndPort.substring(0, splitPoint) - // val port = remoteIpAndPort.substring(splitPoint+1) - val clientAddress = IPv4.toInt(clientAddressString) - - val message = readHandshakeMessage(buffer, offset, length, header) - handshake.processUdpHandshakeMessageServer(this@Server, - handshakePublication, - sessionId, - clientAddressString, - clientAddress, - message, - aeron) + if (isWildcard) { + // IPv6 will bind to IPv4 wildcard as well!! + if (config.enableIPv4 && config.enableIPv6) { + ipv4Poller = object : AeronPoller { + override fun poll(): Int { return 0 } + override fun close() {} + override fun serverInfo(): String { return "IPv4 Disabled" } + } + ipv6Poller = getIpv6WildcardPoller(aeron, config) + } else { + // only 1 will be a real poller + ipv4Poller = getIpv4Poller(aeron, config) + ipv6Poller = getIpv6Poller(aeron, config) + } + } else { + ipv4Poller = getIpv4Poller(aeron, config) + ipv6Poller = getIpv6Poller(aeron, config) } - val ipcHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> - // this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe! - - // The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity. - // for the handshake, the sessionId IS NOT GLOBALLY UNIQUE - val sessionId = header.sessionId() - - val message = readHandshakeMessage(buffer, offset, length, header) - handshake.processIpcHandshakeMessageServer(this@Server, - ipcHandshakePublication, - sessionId, - message, - aeron) - } actionDispatch.launch { val pollIdleStrategy = config.pollIdleStrategy @@ -237,10 +456,11 @@ open class Server(config: ServerConfiguration = ServerC // `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)` // this checks to see if there are NEW clients on the handshake ports - pollCount += handshakeSubscription.poll(udpHandshakeHandler, 1) + pollCount += ipv4Poller.poll() + pollCount += ipv6Poller.poll() // this checks to see if there are NEW clients via IPC - pollCount += ipcHandshakeSubscription.poll(ipcHandshakeHandler, 1) + pollCount += ipcPoller.poll() // this manages existing clients (for cleanup + connection polling) @@ -291,12 +511,53 @@ open class Server(config: ServerConfiguration = ServerC // 0 means we idle. >0 means reset and don't idle (because there are likely more poll events) pollIdleStrategy.idle(pollCount) } - } finally { - handshakePublication.close() - handshakeSubscription.close() - ipcHandshakePublication.close() - ipcHandshakeSubscription.close() + // we want to process **actual** close cleanup events on this thread as well, otherwise we will have threading problems + shutdownPollWaiter.doWait() + + // we have to manually cleanup the connections and call server-notifyDisconnect because otherwise this will never get called + val jobs = mutableListOf() + + // we want to clear all the connections FIRST (since we are shutting down) + val cons = mutableListOf() + connections.forEach { cons.add(it) } + connections.clear() + + cons.forEach { connection -> + logger.error("${connection.id} cleanup") + // have to free up resources! + // NOTE: This can only occur on the polling dispatch thread!! + handshake.cleanup(connection) + + // make sure the connection is closed (close can only happen once, so a duplicate call does nothing!) + connection.close() + + // have to manually notify the server-listenerManager that this connection was closed + // if the connection was MANUALLY closed (via calling connection.close()), then the connection-listenermanager is + // instantly notified and on cleanup, the server-listenermanager is called + // NOTE: this must be the LAST thing happening! + + // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback + val job = actionDispatch.launch { + listenerManager.notifyDisconnect(connection) + } + jobs.add(job) + } + + // reset all of the handshake info + handshake.clear() + + // when we close a client or a server, we want to make sure that ALL notifications are finished. + // when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown + jobs.forEach { it.join() } + + } finally { + ipv4Poller.close() + ipv6Poller.close() + ipcPoller.close() + + // finish closing -- this lets us make sure that we don't run into race conditions on the thread that calls close() + shutdownEventWaiter.doNotify() } } } @@ -362,41 +623,13 @@ open class Server(config: ServerConfiguration = ServerC override fun close0() { bindAlreadyCalled = false - // when we call close, it will shutdown the polling mechanism, so we have to manually cleanup the connections and call server-notifyDisconnect - // on them - + // when we call close, it will shutdown the polling mechanism then wait for us to tell it to cleanup connections. + // + // Aeron + the Media Driver will have already been shutdown at this point. runBlocking { - val jobs = mutableListOf() - - // we want to clear all the connections FIRST (since we are shutting down) - val cons = mutableListOf() - connections.forEach { cons.add(it) } - connections.clear() - - cons.forEach { connection -> - logger.error("${connection.id} cleanup") - // have to free up resources! - handshake.cleanup(connection) - - // make sure the connection is closed (close can only happen once, so a duplicate call does nothing!) - connection.close() - - // have to manually notify the server-listenerManager that this connection was closed - // if the connection was MANUALLY closed (via calling connection.close()), then the connection-listenermanager is - // instantly notified and on cleanup, the server-listenermanager is called - // NOTE: this must be the LAST thing happening! - - // this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback - val job = actionDispatch.launch { - listenerManager.notifyDisconnect(connection) - } - jobs.add(job) - } - - - // when we close a client or a server, we want to make sure that ALL notifications are finished. - // when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown - jobs.forEach { it.join() } + // These are run in lock-step + shutdownPollWaiter.doNotify() + shutdownEventWaiter.doWait() } } diff --git a/src/dorkbox/network/aeron/AeronPoller.kt b/src/dorkbox/network/aeron/AeronPoller.kt new file mode 100644 index 00000000..61ef6dbe --- /dev/null +++ b/src/dorkbox/network/aeron/AeronPoller.kt @@ -0,0 +1,8 @@ +package dorkbox.network.aeron + +internal interface AeronPoller { + fun poll(): Int + fun close() + fun serverInfo(): String +} + diff --git a/src/dorkbox/network/connection/MediaDriverConnection.kt b/src/dorkbox/network/aeron/MediaDriverConnection.kt similarity index 78% rename from src/dorkbox/network/connection/MediaDriverConnection.kt rename to src/dorkbox/network/aeron/MediaDriverConnection.kt index 8e588094..9753c2e0 100644 --- a/src/dorkbox/network/connection/MediaDriverConnection.kt +++ b/src/dorkbox/network/aeron/MediaDriverConnection.kt @@ -13,19 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection +@file:Suppress("DuplicatedCode") -import dorkbox.network.aeron.client.ClientException -import dorkbox.network.aeron.client.ClientTimedOutException -import dorkbox.network.aeron.server.ServerException +package dorkbox.network.aeron + +import dorkbox.network.connection.EndPoint +import dorkbox.network.exceptions.ClientTimedOutException import io.aeron.Aeron import io.aeron.ChannelUriStringBuilder import io.aeron.Publication import io.aeron.Subscription import kotlinx.coroutines.delay +import mu.KLogger +import java.net.Inet4Address +import java.net.InetAddress interface MediaDriverConnection : AutoCloseable { - val address: String + val address: InetAddress? val streamId: Int val sessionId: Int @@ -38,9 +42,8 @@ interface MediaDriverConnection : AutoCloseable { val isReliable: Boolean @Throws(ClientTimedOutException::class) - suspend fun buildClient(aeron: Aeron) - - fun buildServer(aeron: Aeron) + suspend fun buildClient(aeron: Aeron, logger: KLogger) + fun buildServer(aeron: Aeron, logger: KLogger) fun clientInfo() : String fun serverInfo() : String @@ -49,7 +52,7 @@ interface MediaDriverConnection : AutoCloseable { /** * For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER */ -class UdpMediaDriverConnection(override val address: String, +class UdpMediaDriverConnection(override val address: InetAddress, override val publicationPort: Int, override val subscriptionPort: Int, override val streamId: Int, @@ -62,6 +65,19 @@ class UdpMediaDriverConnection(override val address: String, var success: Boolean = false + val addressString: String by lazy { + if (address is Inet4Address) { + address.hostAddress + } else { + // IPv6 requires the address to be bracketed by [...] + val host = address.hostAddress + if (host[0] == '[') { + host + } else { + "[${address.hostAddress}]" + } + } + } private fun uri(): ChannelUriStringBuilder { val builder = ChannelUriStringBuilder().reliable(isReliable).media("udp") @@ -73,27 +89,33 @@ class UdpMediaDriverConnection(override val address: String, } @Suppress("DuplicatedCode") - override suspend fun buildClient(aeron: Aeron) { - if (address.isEmpty()) { - throw ClientException("Invalid address : '$address'") - } - - // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. - val subscriptionUri = uri() - .controlEndpoint("$address:$subscriptionPort") - .controlMode("dynamic") - - + override suspend fun buildClient(aeron: Aeron, logger: KLogger) { // Create a publication at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() - .endpoint("$address:$publicationPort") + .endpoint("$addressString:$publicationPort") + // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. + val subscriptionUri = uri() + .controlEndpoint("$addressString:$subscriptionPort") + .controlMode("dynamic") + + + + if (logger.isTraceEnabled) { + if (address is Inet4Address) { + logger.trace("IPV4 client pub URI: ${publicationUri.build()}") + logger.trace("IPV4 client sub URI: ${subscriptionUri.build()}") + } else { + logger.trace("IPV6 client pub URI: ${publicationUri.build()}") + logger.trace("IPV6 client sub URI: ${subscriptionUri.build()}") + } + } // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. - val subscription = aeron.addSubscription(subscriptionUri.build(), streamId) val publication = aeron.addPublication(publicationUri.build(), streamId) + val subscription = aeron.addSubscription(subscriptionUri.build(), streamId) var success = false @@ -139,27 +161,33 @@ class UdpMediaDriverConnection(override val address: String, this.publication = publication } - override fun buildServer(aeron: Aeron) { - if (address.isEmpty()) { - throw ServerException("Invalid address. It is empty!") - } - - // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. - val subscriptionUri = uri() - .endpoint("$address:$subscriptionPort") - - + override fun buildServer(aeron: Aeron, logger: KLogger) { // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() - .controlEndpoint("$address:$publicationPort") + .controlEndpoint("$addressString:$publicationPort") .controlMode("dynamic") + // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. + val subscriptionUri = uri() + .endpoint("$addressString:$subscriptionPort") + + + + if (logger.isTraceEnabled) { + if (address is Inet4Address) { + logger.trace("IPV4 server pub URI: ${publicationUri.build()}") + logger.trace("IPV4 server sub URI: ${subscriptionUri.build()}") + } else { + logger.trace("IPV6 server pub URI: ${publicationUri.build()}") + logger.trace("IPV6 server sub URI: ${subscriptionUri.build()}") + } + } // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. - subscription = aeron.addSubscription(subscriptionUri.build(), streamId) publication = aeron.addPublication(publicationUri.build(), streamId) + subscription = aeron.addSubscription(subscriptionUri.build(), streamId) } @@ -187,7 +215,7 @@ class UdpMediaDriverConnection(override val address: String, } override fun toString(): String { - return "$address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)" + return "$addressString [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)" } } @@ -200,8 +228,8 @@ class IpcMediaDriverConnection(override val streamId: Int, private val connectionTimeoutMS: Long = 30_000, ) : MediaDriverConnection { + override val address: InetAddress? = null override val isReliable = true - override val address = "ipc" override val subscriptionPort = 0 override val publicationPort = 0 @@ -220,19 +248,24 @@ class IpcMediaDriverConnection(override val streamId: Int, } @Throws(ClientTimedOutException::class) - override suspend fun buildClient(aeron: Aeron) { - // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. - val subscriptionUri = uri() - + override suspend fun buildClient(aeron: Aeron, logger: KLogger) { // Create a publication at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() + // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. + val subscriptionUri = uri() + + + if (logger.isTraceEnabled) { + logger.trace("IPC client pub URI: ${publicationUri.build()}") + logger.trace("IPC server sub URI: ${subscriptionUri.build()}") + } // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. - val subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription) val publication = aeron.addPublication(publicationUri.build(), streamId) + val subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription) var success = false @@ -278,25 +311,31 @@ class IpcMediaDriverConnection(override val streamId: Int, this.publication = publication } - override fun buildServer(aeron: Aeron) { - // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. - val subscriptionUri = uri() - + override fun buildServer(aeron: Aeron, logger: KLogger) { // Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID. // Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs. val publicationUri = uri() + // Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID. + val subscriptionUri = uri() + + + if (logger.isTraceEnabled) { + logger.trace("IPC server pub URI: ${publicationUri.build()}") + logger.trace("IPC server sub URI: ${subscriptionUri.build()}") + } + // NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe // publication of any state to other threads and not be long running or re-entrant with the client. - subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription) publication = aeron.addPublication(publicationUri.build(), streamId) + subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription) } override fun clientInfo() : String { return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { "[$sessionId] aeron connection established to [$streamIdSubscription|$streamId]" } else { - "Connecting IPC with handshake to [$streamIdSubscription|$streamId]" + "Connecting handshake to IPC [$streamIdSubscription|$streamId]" } } @@ -304,7 +343,7 @@ class IpcMediaDriverConnection(override val streamId: Int, return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { "[$sessionId] IPC listening on [$streamIdSubscription|$streamId] " } else { - "IPC listening with handshake on [$streamIdSubscription|$streamId]" + "Listening handshake on IPC [$streamIdSubscription|$streamId]" } } diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index eeb6e302..67297ad4 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -15,10 +15,13 @@ */ package dorkbox.network.connection -import dorkbox.netUtil.IPv4 -import dorkbox.network.aeron.server.RandomIdAllocator -import dorkbox.network.connection.ping.PingFuture -import dorkbox.network.connection.ping.PingMessage +import dorkbox.network.aeron.IpcMediaDriverConnection +import dorkbox.network.aeron.UdpMediaDriverConnection +import dorkbox.network.handshake.ConnectionCounts +import dorkbox.network.handshake.RandomIdAllocator +import dorkbox.network.ping.Ping +import dorkbox.network.ping.PingFuture +import dorkbox.network.ping.PingMessage import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.TimeoutException @@ -33,8 +36,8 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.agrona.DirectBuffer -import org.agrona.collections.Int2IntCounterMap import java.io.IOException +import java.net.InetAddress import java.util.concurrent.TimeUnit /** @@ -62,14 +65,14 @@ open class Connection(connectionParameters: ConnectionParams<*>) { val id: Int /** - * the remote address, as a string. Will be "ipc" for IPC connections + * the remote address, as a string. Will be null for IPC connections */ - val remoteAddress: String + val remoteAddress: InetAddress? /** - * the remote address, as an integer. Can be 0 for IPC connections + * the remote address, as a string. Will be "ipc" for IPC connections */ - private val remoteAddressInt: Int + val remoteAddressString: String /** * @return true if this connection is an IPC connection @@ -125,7 +128,8 @@ open class Connection(connectionParameters: ConnectionParams<*>) { // a record of how many messages are in progress of being sent. When closing the connection, this number must be 0 private val messagesInProgress = atomic(0) - val toString0: () -> String + // we customize the toString() value for this connection, and it's just better to cache it's value (since it's a modestly complex string) + private val toString0: String init { val mediaDriverConnection = connectionParameters.mediaDriverConnection @@ -141,16 +145,19 @@ open class Connection(connectionParameters: ConnectionParams<*>) { streamId = 0 // this is because with IPC, we have stream sub/pub (which are replaced as port sub/pub) subscriptionPort = mediaDriverConnection.streamIdSubscription publicationPort = mediaDriverConnection.streamId - remoteAddressInt = 0 + remoteAddressString = "ipc" - toString0 = { "[$id] IPC [$subscriptionPort|$publicationPort]" } + toString0 = "[$id] IPC [$subscriptionPort|$publicationPort]" } else { + mediaDriverConnection as UdpMediaDriverConnection + streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server! subscriptionPort = mediaDriverConnection.subscriptionPort publicationPort = mediaDriverConnection.publicationPort - remoteAddressInt = IPv4.toInt(mediaDriverConnection.address) - toString0 = { "[$id] $remoteAddress [$publicationPort|$subscriptionPort]" } + remoteAddressString = mediaDriverConnection.addressString + + toString0 = "[$id] $remoteAddressString [$publicationPort|$subscriptionPort]" } @@ -412,7 +419,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { // // override fun toString(): String { - return toString0() + return toString0 } override fun hashCode(): Int { @@ -435,13 +442,14 @@ open class Connection(connectionParameters: ConnectionParams<*>) { } // cleans up the connection information - fun cleanup(connectionsPerIpCounts: Int2IntCounterMap, sessionIdAllocator: RandomIdAllocator, streamIdAllocator: RandomIdAllocator) { + internal fun cleanup(connectionsPerIpCounts: ConnectionCounts, sessionIdAllocator: RandomIdAllocator, streamIdAllocator: RandomIdAllocator) { + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD if (isIpc) { sessionIdAllocator.free(subscriptionPort) sessionIdAllocator.free(publicationPort) streamIdAllocator.free(streamId) } else { - connectionsPerIpCounts.getAndDecrement(remoteAddressInt) + connectionsPerIpCounts.decrementSlow(remoteAddress!!) sessionIdAllocator.free(id) streamIdAllocator.free(streamId) } diff --git a/src/dorkbox/network/connection/ConnectionParams.kt b/src/dorkbox/network/connection/ConnectionParams.kt index 0212ce5e..c92ed338 100644 --- a/src/dorkbox/network/connection/ConnectionParams.kt +++ b/src/dorkbox/network/connection/ConnectionParams.kt @@ -15,6 +15,8 @@ */ package dorkbox.network.connection +import dorkbox.network.aeron.MediaDriverConnection + data class ConnectionParams(val endPoint: EndPoint, val mediaDriverConnection: MediaDriverConnection, val publicKeyValidation: PublicKeyValidationState) diff --git a/src/dorkbox/network/connection/CryptoManagement.kt b/src/dorkbox/network/connection/CryptoManagement.kt index 9e58ec5d..5d4e8367 100644 --- a/src/dorkbox/network/connection/CryptoManagement.kt +++ b/src/dorkbox/network/connection/CryptoManagement.kt @@ -15,7 +15,7 @@ */ package dorkbox.network.connection -import dorkbox.netUtil.IPv4 +import dorkbox.netUtil.IP import dorkbox.network.Configuration import dorkbox.network.handshake.ClientConnectionInfo import dorkbox.network.other.CryptoEccNative @@ -26,6 +26,7 @@ import dorkbox.util.Sys import dorkbox.util.entropy.Entropy import dorkbox.util.exceptions.SecurityException import mu.KLogger +import java.net.InetAddress import java.security.KeyFactory import java.security.KeyPair import java.security.KeyPairGenerator @@ -122,21 +123,21 @@ internal class CryptoManagement(val logger: KLogger, /** - * If the key does not match AND we have disabled remote key validation, then metachannel.changedRemoteKey = true. OTHERWISE, key validation is REQUIRED! + * If the key does not match AND we have disabled remote key validation -- key validation is REQUIRED! * * @return true if all is OK (the remote address public key matches the one saved or we disabled remote key validation.) * false if we should abort */ - internal fun validateRemoteAddress(remoteAddress: Int, publicKey: ByteArray?): PublicKeyValidationState { + internal fun validateRemoteAddress(remoteAddress: InetAddress, publicKey: ByteArray?): PublicKeyValidationState { if (publicKey == null) { - logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}! It was null (and should not have been)") + logger.error("Error validating public key for ${IP.toString(remoteAddress)}! It was null (and should not have been)") return PublicKeyValidationState.INVALID } try { val savedPublicKey = settingsStore.getRegisteredServerKey(remoteAddress) if (savedPublicKey == null) { - logger.info("Adding new remote IP address key for ${IPv4.toString(remoteAddress)} : ${Sys.bytesToHex(publicKey)}") + logger.info("Adding new signature for ${IP.toString(remoteAddress)} : ${Sys.bytesToHex(publicKey)}") settingsStore.addRegisteredServerKey(remoteAddress, publicKey) } else { @@ -144,18 +145,18 @@ internal class CryptoManagement(val logger: KLogger, if (!publicKey.contentEquals(savedPublicKey)) { return if (enableRemoteSignatureValidation) { // keys do not match, abort! - logger.error("The public key for remote connection ${IPv4.toString(remoteAddress)} does not match. Denying connection attempt") + logger.error("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Denying connection attempt") PublicKeyValidationState.INVALID } else { - logger.warn("The public key for remote connection ${IPv4.toString(remoteAddress)} does not match. Permitting connection attempt.") + logger.warn("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Permitting connection attempt.") PublicKeyValidationState.TAMPERED } } } } catch (e: SecurityException) { // keys do not match, abort! - logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}!", e) + logger.error("Error validating public key for ${IP.toString(remoteAddress)}!", e) return PublicKeyValidationState.INVALID } diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 69f95455..ec90e8b9 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -20,10 +20,11 @@ import dorkbox.network.Configuration import dorkbox.network.Server import dorkbox.network.ServerConfiguration import dorkbox.network.aeron.CoroutineIdleStrategy -import dorkbox.network.connection.ping.PingMessage +import dorkbox.network.exceptions.MessageNotRegisteredException import dorkbox.network.handshake.HandshakeMessage import dorkbox.network.ipFilter.IpFilterRule import dorkbox.network.other.coroutines.SuspendWaiter +import dorkbox.network.ping.PingMessage import dorkbox.network.rmi.RmiManagerConnections import dorkbox.network.rmi.RmiManagerGlobal import dorkbox.network.rmi.messages.RmiMessage @@ -145,8 +146,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A internal val rmiGlobalSupport = RmiManagerGlobal(logger, actionDispatch, config.serialization) init { - logger.error("NETWORK STACK IS ONLY IPV4 AT THE MOMENT. IPV6 is in progress!") - runBlocking { // our default onError handler. All error messages go though this listenerManager.onError { throwable -> diff --git a/src/dorkbox/network/connection/MediaDriverType.kt b/src/dorkbox/network/connection/MediaDriverType.kt deleted file mode 100644 index 89d00ed0..00000000 --- a/src/dorkbox/network/connection/MediaDriverType.kt +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright 2020 dorkbox, llc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package dorkbox.network.connection - -enum class MediaDriverType(private val type: String) { - IPC("ipc"), UDP("udp"); - - override fun toString(): String { - return type - } -} diff --git a/src/dorkbox/network/connection/connectionType/ConnectionProperties.kt b/src/dorkbox/network/connectionType/ConnectionProperties.kt similarity index 95% rename from src/dorkbox/network/connection/connectionType/ConnectionProperties.kt rename to src/dorkbox/network/connectionType/ConnectionProperties.kt index 8c61e0f3..260ea9be 100644 --- a/src/dorkbox/network/connection/connectionType/ConnectionProperties.kt +++ b/src/dorkbox/network/connectionType/ConnectionProperties.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.connectionType +package dorkbox.network.connectionType import dorkbox.network.handshake.UpgradeType diff --git a/src/dorkbox/network/connection/connectionType/ConnectionRule.kt b/src/dorkbox/network/connectionType/ConnectionRule.kt similarity index 99% rename from src/dorkbox/network/connection/connectionType/ConnectionRule.kt rename to src/dorkbox/network/connectionType/ConnectionRule.kt index f066914e..137208a1 100644 --- a/src/dorkbox/network/connection/connectionType/ConnectionRule.kt +++ b/src/dorkbox/network/connectionType/ConnectionRule.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.connectionType +package dorkbox.network.connectionType import java.math.BigInteger import java.net.Inet4Address diff --git a/src/dorkbox/network/connection/connectionType/IpConnectionTypeRule.kt b/src/dorkbox/network/connectionType/IpConnectionTypeRule.kt similarity index 96% rename from src/dorkbox/network/connection/connectionType/IpConnectionTypeRule.kt rename to src/dorkbox/network/connectionType/IpConnectionTypeRule.kt index 16f2f543..bf23d5cb 100644 --- a/src/dorkbox/network/connection/connectionType/IpConnectionTypeRule.kt +++ b/src/dorkbox/network/connectionType/IpConnectionTypeRule.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.connectionType +package dorkbox.network.connectionType import java.net.InetSocketAddress diff --git a/src/dorkbox/network/aeron/server/AllocationException.kt b/src/dorkbox/network/exceptions/AllocationException.kt similarity index 95% rename from src/dorkbox/network/aeron/server/AllocationException.kt rename to src/dorkbox/network/exceptions/AllocationException.kt index 7557b228..585e9674 100644 --- a/src/dorkbox/network/aeron/server/AllocationException.kt +++ b/src/dorkbox/network/exceptions/AllocationException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.server +package dorkbox.network.exceptions /** * A session/stream could not be allocated. diff --git a/src/dorkbox/network/aeron/client/ClientException.kt b/src/dorkbox/network/exceptions/ClientException.kt similarity index 96% rename from src/dorkbox/network/aeron/client/ClientException.kt rename to src/dorkbox/network/exceptions/ClientException.kt index 6aab8d62..ace7320a 100644 --- a/src/dorkbox/network/aeron/client/ClientException.kt +++ b/src/dorkbox/network/exceptions/ClientException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.client +package dorkbox.network.exceptions /** * The type of exceptions raised by the client. diff --git a/src/dorkbox/network/aeron/client/ClientRejectedException.kt b/src/dorkbox/network/exceptions/ClientRejectedException.kt similarity index 95% rename from src/dorkbox/network/aeron/client/ClientRejectedException.kt rename to src/dorkbox/network/exceptions/ClientRejectedException.kt index dc3a5810..954956be 100644 --- a/src/dorkbox/network/aeron/client/ClientRejectedException.kt +++ b/src/dorkbox/network/exceptions/ClientRejectedException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.client +package dorkbox.network.exceptions /** * The server rejected this client when it tried to connect. diff --git a/src/dorkbox/network/aeron/client/ClientTimedOutException.kt b/src/dorkbox/network/exceptions/ClientTimedOutException.kt similarity index 96% rename from src/dorkbox/network/aeron/client/ClientTimedOutException.kt rename to src/dorkbox/network/exceptions/ClientTimedOutException.kt index 6110e0c4..2f77ec84 100644 --- a/src/dorkbox/network/aeron/client/ClientTimedOutException.kt +++ b/src/dorkbox/network/exceptions/ClientTimedOutException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.client +package dorkbox.network.exceptions /** * The client timed out when it attempted to connect to the server. diff --git a/src/dorkbox/network/connection/MessageNotRegisteredException.kt b/src/dorkbox/network/exceptions/MessageNotRegisteredException.kt similarity index 95% rename from src/dorkbox/network/connection/MessageNotRegisteredException.kt rename to src/dorkbox/network/exceptions/MessageNotRegisteredException.kt index 1e44e5ea..45c92f1b 100644 --- a/src/dorkbox/network/connection/MessageNotRegisteredException.kt +++ b/src/dorkbox/network/exceptions/MessageNotRegisteredException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection +package dorkbox.network.exceptions /** * thrown when a message is received, and does not have any registered 'onMessage' handlers. diff --git a/src/dorkbox/network/aeron/server/PortAllocationException.kt b/src/dorkbox/network/exceptions/PortAllocationException.kt similarity index 95% rename from src/dorkbox/network/aeron/server/PortAllocationException.kt rename to src/dorkbox/network/exceptions/PortAllocationException.kt index 519d7788..30faa898 100644 --- a/src/dorkbox/network/aeron/server/PortAllocationException.kt +++ b/src/dorkbox/network/exceptions/PortAllocationException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.server +package dorkbox.network.exceptions /** * A port could not be allocated. diff --git a/src/dorkbox/network/aeron/server/ServerException.kt b/src/dorkbox/network/exceptions/ServerException.kt similarity index 96% rename from src/dorkbox/network/aeron/server/ServerException.kt rename to src/dorkbox/network/exceptions/ServerException.kt index 76706800..7006ba26 100644 --- a/src/dorkbox/network/aeron/server/ServerException.kt +++ b/src/dorkbox/network/exceptions/ServerException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.server +package dorkbox.network.exceptions /** * The type of exceptions raised by the server. diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index 71372b18..9916a0b0 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -16,12 +16,12 @@ package dorkbox.network.handshake import dorkbox.network.Configuration -import dorkbox.network.aeron.client.ClientException -import dorkbox.network.aeron.client.ClientTimedOutException +import dorkbox.network.aeron.MediaDriverConnection import dorkbox.network.connection.Connection import dorkbox.network.connection.CryptoManagement import dorkbox.network.connection.EndPoint -import dorkbox.network.connection.MediaDriverConnection +import dorkbox.network.exceptions.ClientException +import dorkbox.network.exceptions.ClientTimedOutException import io.aeron.FragmentAssembler import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.Header diff --git a/src/dorkbox/network/handshake/ConnectionCounts.kt b/src/dorkbox/network/handshake/ConnectionCounts.kt new file mode 100644 index 00000000..a1bd2803 --- /dev/null +++ b/src/dorkbox/network/handshake/ConnectionCounts.kt @@ -0,0 +1,30 @@ +package dorkbox.network.handshake + +import org.agrona.collections.Object2IntHashMap +import java.net.InetAddress + +/** + * + */ +internal class ConnectionCounts { + private val connectionsPerIpCounts = Object2IntHashMap(-1) + + fun get(inetAddress: InetAddress): Int { + return connectionsPerIpCounts.getOrPut(inetAddress) { 0 } + } + + fun increment(inetAddress: InetAddress, currentCount: Int) { + connectionsPerIpCounts[inetAddress] = currentCount + 1 + } + + fun decrement(inetAddress: InetAddress, currentCount: Int) { + connectionsPerIpCounts[inetAddress] = currentCount - 1 + } + + fun decrementSlow(inetAddress: InetAddress) { + if (connectionsPerIpCounts.containsKey(inetAddress)) { + val defaultVal = connectionsPerIpCounts.getValue(inetAddress) + connectionsPerIpCounts[inetAddress] = defaultVal - 1 + } + } +} diff --git a/src/dorkbox/network/aeron/server/PortAllocator.kt b/src/dorkbox/network/handshake/PortAllocator.kt similarity index 99% rename from src/dorkbox/network/aeron/server/PortAllocator.kt rename to src/dorkbox/network/handshake/PortAllocator.kt index d3bf7e42..8c7c3354 100644 --- a/src/dorkbox/network/aeron/server/PortAllocator.kt +++ b/src/dorkbox/network/handshake/PortAllocator.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.server +package dorkbox.network.handshake import org.agrona.collections.IntArrayList diff --git a/src/dorkbox/network/aeron/server/RandomIdAllocator.kt b/src/dorkbox/network/handshake/RandomIdAllocator.kt similarity index 85% rename from src/dorkbox/network/aeron/server/RandomIdAllocator.kt rename to src/dorkbox/network/handshake/RandomIdAllocator.kt index 5fee7812..6040b0f8 100644 --- a/src/dorkbox/network/aeron/server/RandomIdAllocator.kt +++ b/src/dorkbox/network/handshake/RandomIdAllocator.kt @@ -13,23 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.aeron.server +package dorkbox.network.handshake +import dorkbox.network.exceptions.AllocationException import org.agrona.collections.IntHashSet import java.security.SecureRandom /** - * An allocator for session IDs. The allocator randomly selects values from - * the given range `[min, max]` and will not return a previously-returned value `x` - * until `x` has been freed with `{ SessionAllocator#free(int)}. - *

+ * An allocator for session IDs. + * + * The allocator randomly selects values from the given range `[min, max]` and will not return a previously-returned value `x` + * until `x` has been freed with `{ SessionAllocator#free(int)}. * - *

* This implementation uses storage proportional to the number of currently-allocated * values. Allocation time is bounded by { max - min}, will be { O(1)} * with no allocated values, and will increase to { O(n)} as the number * of allocated values approached { max - min}. - *

` + * + * NOTE: THIS IS NOT THREAD SAFE! * * @param min The minimum session ID (inclusive) * @param max The maximum session ID (exclusive) @@ -55,7 +56,6 @@ class RandomIdAllocator(private val min: Int, max: Int) { * * @throws AllocationException If there are no non-allocated sessions left */ - @Throws(AllocationException::class) fun allocate(): Int { if (used.size == maxAssignments) { throw AllocationException("No session IDs left to allocate") @@ -81,4 +81,12 @@ class RandomIdAllocator(private val min: Int, max: Int) { fun free(session: Int) { used.remove(session) } + + + /** + * Removes all used sessions from the internal data structures + */ + fun clear() { + used.clear() + } } diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index d97ac301..cc4f273e 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -21,24 +21,24 @@ import com.github.benmanes.caffeine.cache.RemovalCause import com.github.benmanes.caffeine.cache.RemovalListener import dorkbox.network.Server import dorkbox.network.ServerConfiguration -import dorkbox.network.aeron.client.ClientRejectedException -import dorkbox.network.aeron.client.ClientTimedOutException -import dorkbox.network.aeron.server.AllocationException -import dorkbox.network.aeron.server.RandomIdAllocator -import dorkbox.network.aeron.server.ServerException +import dorkbox.network.aeron.IpcMediaDriverConnection +import dorkbox.network.aeron.UdpMediaDriverConnection import dorkbox.network.connection.Connection import dorkbox.network.connection.ConnectionParams import dorkbox.network.connection.EndPoint -import dorkbox.network.connection.IpcMediaDriverConnection import dorkbox.network.connection.ListenerManager import dorkbox.network.connection.PublicKeyValidationState -import dorkbox.network.connection.UdpMediaDriverConnection +import dorkbox.network.exceptions.AllocationException +import dorkbox.network.exceptions.ClientRejectedException +import dorkbox.network.exceptions.ClientTimedOutException +import dorkbox.network.exceptions.ServerException import io.aeron.Aeron import io.aeron.Publication import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import mu.KLogger -import org.agrona.collections.Int2IntCounterMap +import java.net.Inet4Address +import java.net.InetAddress import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantReadWriteLock import kotlin.concurrent.write @@ -53,7 +53,7 @@ internal class ServerHandshake(private val logger: KLog private val listenerManager: ListenerManager) { private val pendingConnectionsLock = ReentrantReadWriteLock() - private val pendingConnections: Cache = Caffeine.newBuilder() + private val pendingConnections: Cache = Caffeine.newBuilder() .expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS) .removalListener(RemovalListener { _, value, cause -> if (cause == RemovalCause.EXPIRED) { @@ -67,17 +67,17 @@ internal class ServerHandshake(private val logger: KLog } }).build() - private val connectionsPerIpCounts = Int2IntCounterMap(0) + private val connectionsPerIpCounts = ConnectionCounts() // guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!) - private val sessionIdAllocator = RandomIdAllocator(EndPoint.RESERVED_SESSION_ID_LOW, - EndPoint.RESERVED_SESSION_ID_HIGH) + private val sessionIdAllocator = RandomIdAllocator(EndPoint.RESERVED_SESSION_ID_LOW, EndPoint.RESERVED_SESSION_ID_HIGH) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) /** * @return true if we should continue parsing the incoming message, false if we should abort */ + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD private fun validateMessageTypeAndDoPending(server: Server, handshakePublication: Publication, message: Any?, @@ -126,11 +126,17 @@ internal class ServerHandshake(private val logger: KLog /** * @return true if we should continue parsing the incoming message, false if we should abort */ - private fun validateConnectionInfo(server: Server, - handshakePublication: Publication, - config: ServerConfiguration, - clientAddressString: String, - clientAddress: Int): Boolean { + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD + private fun validateUdpConnectionInfo(server: Server, + handshakePublication: Publication, + config: ServerConfiguration, + clientAddressString: String, + clientAddress: InetAddress): Boolean { + + if (clientAddress.isLoopbackAddress) { + // we do not want to limit loopback addresses + return true + } try { // VALIDATE:: Check to see if there are already too many clients connected. @@ -143,11 +149,12 @@ internal class ServerHandshake(private val logger: KLog return false } + // VALIDATE:: we are now connected to the client and are going to create a new connection. - val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) + val currentCountForIp = connectionsPerIpCounts.get(clientAddress) if (currentCountForIp >= config.maxConnectionsPerIpAddress) { // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always) - connectionsPerIpCounts.getAndDecrement(clientAddress) + connectionsPerIpCounts.decrement(clientAddress, currentCountForIp) listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) server.actionDispatch.launch { @@ -156,6 +163,7 @@ internal class ServerHandshake(private val logger: KLog return false } + connectionsPerIpCounts.increment(clientAddress, currentCountForIp) } catch (e: Exception) { listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) server.actionDispatch.launch { @@ -168,7 +176,7 @@ internal class ServerHandshake(private val logger: KLog } - // note: CANNOT be called in action dispatch + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD fun processIpcHandshakeMessageServer(server: Server, handshakePublication: Publication, sessionId: Int, @@ -243,7 +251,7 @@ internal class ServerHandshake(private val logger: KLog connectionTimeoutMS = 0) // we have to construct how the connection will communicate! - clientConnection.buildServer(aeron) + clientConnection.buildServer(aeron, logger) logger.info { "[${clientConnection.sessionId}] aeron IPC connection established to $clientConnection" @@ -264,8 +272,7 @@ internal class ServerHandshake(private val logger: KLog listenerManager.notifyError(connection, exception) server.actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, - HandshakeMessage.error("Connection was not permitted!")) + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) } return @@ -320,14 +327,15 @@ internal class ServerHandshake(private val logger: KLog } - // note: CANNOT be called in action dispatch + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD fun processUdpHandshakeMessageServer(server: Server, handshakePublication: Publication, sessionId: Int, clientAddressString: String, - clientAddress: Int, + clientAddress: InetAddress, message: Any?, - aeron: Aeron) { + aeron: Aeron, + isIpv6Wildcard: Boolean) { if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) { return @@ -345,7 +353,7 @@ internal class ServerHandshake(private val logger: KLog return } - if (!validateConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress)) { + if (!validateUdpConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress)) { return } @@ -363,7 +371,7 @@ internal class ServerHandshake(private val logger: KLog connectionSessionId = sessionIdAllocator.allocate() } catch (e: AllocationException) { // have to unwind actions! - connectionsPerIpCounts.getAndDecrement(clientAddress) + connectionsPerIpCounts.decrementSlow(clientAddress) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")) server.actionDispatch.launch { @@ -378,7 +386,7 @@ internal class ServerHandshake(private val logger: KLog connectionStreamId = streamIdAllocator.allocate() } catch (e: AllocationException) { // have to unwind actions! - connectionsPerIpCounts.getAndDecrement(clientAddress) + connectionsPerIpCounts.decrementSlow(clientAddress) sessionIdAllocator.free(connectionSessionId) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")) @@ -388,8 +396,6 @@ internal class ServerHandshake(private val logger: KLog return } - val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box? - // the pub/sub do not necessarily have to be the same. The can be ANY port val publicationPort = config.publicationPort val subscriptionPort = config.subscriptionPort @@ -398,16 +404,29 @@ internal class ServerHandshake(private val logger: KLog // create a new connection. The session ID is encrypted. try { // connection timeout of 0 doesn't matter. it is not used by the server - val clientConnection = UdpMediaDriverConnection(serverAddress, - publicationPort, - subscriptionPort, - connectionStreamId, - connectionSessionId, - 0, - message.isReliable) + // the client address WILL BE either IPv4 or IPv6 + + val clientConnection = if (clientAddress is Inet4Address && !isIpv6Wildcard) { + UdpMediaDriverConnection(server.listenIPv4Address!!, + publicationPort, + subscriptionPort, + connectionStreamId, + connectionSessionId, + 0, + message.isReliable) + } else { + // wildcard is SPECIAL, in that if we bind wildcard, it will ALSO bind to IPv4, so we can't bind both! + UdpMediaDriverConnection(server.listenIPv6Address!!, + publicationPort, + subscriptionPort, + connectionStreamId, + connectionSessionId, + 0, + message.isReliable) + } // we have to construct how the connection will communicate! - clientConnection.buildServer(aeron) + clientConnection.buildServer(aeron, logger) logger.info { "Creating new connection from $clientConnection" @@ -420,7 +439,7 @@ internal class ServerHandshake(private val logger: KLog val permitConnection = listenerManager.notifyFilter(connection) if (!permitConnection) { // have to unwind actions! - connectionsPerIpCounts.getAndDecrement(clientAddress) + connectionsPerIpCounts.decrementSlow(clientAddress) sessionIdAllocator.free(connectionSessionId) streamIdAllocator.free(connectionStreamId) @@ -429,8 +448,7 @@ internal class ServerHandshake(private val logger: KLog listenerManager.notifyError(connection, exception) server.actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, - HandshakeMessage.error("Connection was not permitted!")) + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) } return @@ -470,7 +488,7 @@ internal class ServerHandshake(private val logger: KLog } } catch (e: Exception) { // have to unwind actions! - connectionsPerIpCounts.getAndDecrement(clientAddress) + connectionsPerIpCounts.decrementSlow(clientAddress) sessionIdAllocator.free(connectionSessionId) streamIdAllocator.free(connectionStreamId) @@ -482,7 +500,17 @@ internal class ServerHandshake(private val logger: KLog * Free up resources from the closed connection */ fun cleanup(connection: CONNECTION) { + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD connection.cleanup(connectionsPerIpCounts, sessionIdAllocator, streamIdAllocator) + } + + /** + * Reset and clear all connection information + */ + fun clear() { + // note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD + sessionIdAllocator.clear() + streamIdAllocator.clear() pendingConnections.invalidateAll() } } diff --git a/src/dorkbox/network/connection/Ping.kt b/src/dorkbox/network/ping/Ping.kt similarity index 94% rename from src/dorkbox/network/connection/Ping.kt rename to src/dorkbox/network/ping/Ping.kt index a9cb99f0..620c0090 100644 --- a/src/dorkbox/network/connection/Ping.kt +++ b/src/dorkbox/network/ping/Ping.kt @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection +package dorkbox.network.ping + +import dorkbox.network.connection.Connection interface Ping { /** diff --git a/src/dorkbox/network/connection/ping/PingCanceledException.kt b/src/dorkbox/network/ping/PingCanceledException.kt similarity index 94% rename from src/dorkbox/network/connection/ping/PingCanceledException.kt rename to src/dorkbox/network/ping/PingCanceledException.kt index 3b8a3580..3ad424a9 100644 --- a/src/dorkbox/network/connection/ping/PingCanceledException.kt +++ b/src/dorkbox/network/ping/PingCanceledException.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.ping +package dorkbox.network.ping import java.io.IOException diff --git a/src/dorkbox/network/connection/ping/PingFuture.kt b/src/dorkbox/network/ping/PingFuture.kt similarity index 97% rename from src/dorkbox/network/connection/ping/PingFuture.kt rename to src/dorkbox/network/ping/PingFuture.kt index 4ed30ff3..36dbc13e 100644 --- a/src/dorkbox/network/connection/ping/PingFuture.kt +++ b/src/dorkbox/network/ping/PingFuture.kt @@ -13,11 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.ping +package dorkbox.network.ping import dorkbox.network.connection.Connection -import dorkbox.network.connection.Ping -import dorkbox.network.connection.PingListener import java.util.concurrent.atomic.AtomicInteger class PingFuture internal constructor() : Ping { diff --git a/src/dorkbox/network/connection/PingListener.java b/src/dorkbox/network/ping/PingListener.java similarity index 94% rename from src/dorkbox/network/connection/PingListener.java rename to src/dorkbox/network/ping/PingListener.java index d238d4cf..b2195b8c 100644 --- a/src/dorkbox/network/connection/PingListener.java +++ b/src/dorkbox/network/ping/PingListener.java @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection; +package dorkbox.network.ping; + +import dorkbox.network.connection.Connection; // note that we specifically DO NOT implement equals/hashCode, because we cannot create two separate // listeners that are somehow equal to each other. diff --git a/src/dorkbox/network/connection/ping/PingMessage.kt b/src/dorkbox/network/ping/PingMessage.kt similarity index 94% rename from src/dorkbox/network/connection/ping/PingMessage.kt rename to src/dorkbox/network/ping/PingMessage.kt index 5b06ed9a..93b83e07 100644 --- a/src/dorkbox/network/connection/ping/PingMessage.kt +++ b/src/dorkbox/network/ping/PingMessage.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.ping +package dorkbox.network.ping /** * Internal message to determine round trip time. diff --git a/src/dorkbox/network/connection/ping/PingTuple.kt b/src/dorkbox/network/ping/PingTuple.kt similarity index 94% rename from src/dorkbox/network/connection/ping/PingTuple.kt rename to src/dorkbox/network/ping/PingTuple.kt index 2845b22e..74fadcec 100644 --- a/src/dorkbox/network/connection/ping/PingTuple.kt +++ b/src/dorkbox/network/ping/PingTuple.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package dorkbox.network.connection.ping +package dorkbox.network.ping import dorkbox.network.connection.Connection diff --git a/src/dorkbox/network/storage/DB_Server.kt b/src/dorkbox/network/storage/DB_Server.kt index 29f71009..92f992d2 100644 --- a/src/dorkbox/network/storage/DB_Server.kt +++ b/src/dorkbox/network/storage/DB_Server.kt @@ -24,11 +24,6 @@ class DB_Server { * The storage key used to save all server connections */ val STORAGE_KEY = StorageKey("servers") - - /** - * Address 0.0.0.0/32 may be used as a source address for this host on this network. - */ - const val IP_SELF = 0 } diff --git a/src/dorkbox/network/storage/NullSettingsStore.kt b/src/dorkbox/network/storage/NullSettingsStore.kt index bd76d859..6b2b9092 100644 --- a/src/dorkbox/network/storage/NullSettingsStore.kt +++ b/src/dorkbox/network/storage/NullSettingsStore.kt @@ -18,6 +18,7 @@ package dorkbox.network.storage import dorkbox.network.serialization.Serialization import dorkbox.util.exceptions.SecurityException import dorkbox.util.storage.Storage +import java.net.InetAddress import java.security.SecureRandom class NullSettingsStore : SettingsStore() { @@ -54,17 +55,17 @@ class NullSettingsStore : SettingsStore() { } @Throws(SecurityException::class) - override fun getRegisteredServerKey(hostAddress: Int): ByteArray { + override fun getRegisteredServerKey(hostAddress: InetAddress): ByteArray { TODO("not impl") } @Throws(SecurityException::class) - override fun addRegisteredServerKey(hostAddress: Int, publicKey: ByteArray) { + override fun addRegisteredServerKey(hostAddress: InetAddress, publicKey: ByteArray) { TODO("not impl") } @Throws(SecurityException::class) - override fun removeRegisteredServerKey(hostAddress: Int): Boolean { + override fun removeRegisteredServerKey(hostAddress: InetAddress): Boolean { return true } diff --git a/src/dorkbox/network/storage/PropertyStore.kt b/src/dorkbox/network/storage/PropertyStore.kt index 031edfc5..b8df9ff0 100644 --- a/src/dorkbox/network/storage/PropertyStore.kt +++ b/src/dorkbox/network/storage/PropertyStore.kt @@ -15,10 +15,13 @@ */ package dorkbox.network.storage +import dorkbox.netUtil.IPv4 +import dorkbox.netUtil.IPv6 import dorkbox.network.connection.CryptoManagement import dorkbox.network.serialization.Serialization import dorkbox.util.storage.Storage -import org.agrona.collections.Int2ObjectHashMap +import org.agrona.collections.Object2NullableObjectHashMap +import java.net.InetAddress import java.security.SecureRandom /** @@ -26,7 +29,15 @@ import java.security.SecureRandom */ class PropertyStore : SettingsStore() { private lateinit var storage: Storage - private lateinit var servers: Int2ObjectHashMap + private lateinit var servers: Object2NullableObjectHashMap + + /** + * Address 0.0.0.0 or ::0 may be used as a source address for this host on this network. + * + * Because we assigned BOTH to the same thing, it doesn't matter which one we use + */ + private val ipv4Host = IPv4.WILDCARD + private val ipv6Host = IPv6.WILDCARD /** * Method of preference for creating/getting this connection store. @@ -35,13 +46,20 @@ class PropertyStore : SettingsStore() { */ override fun init(serializationManager: Serialization, storage: Storage) { this.storage = storage - servers = this.storage.get(DB_Server.STORAGE_KEY, Int2ObjectHashMap()) + servers = this.storage.get(DB_Server.STORAGE_KEY, Object2NullableObjectHashMap()) // this will always be null and is here to help people that copy/paste code - var localServer = servers[DB_Server.IP_SELF] + var localServer = servers[ipv4Host] if (localServer == null) { localServer = DB_Server() - servers[DB_Server.IP_SELF] = localServer + servers[ipv4Host] = localServer + + // have to always specify what we are saving + this.storage.put(DB_Server.STORAGE_KEY, servers) + } + + if (servers[ipv6Host] == null) { + servers[ipv6Host] = localServer // have to always specify what we are saving this.storage.put(DB_Server.STORAGE_KEY, servers) @@ -54,7 +72,7 @@ class PropertyStore : SettingsStore() { @Synchronized override fun getPrivateKey(): ByteArray? { checkAccess(CryptoManagement::class.java) - return servers[DB_Server.IP_SELF]!!.privateKey + return servers[ipv4Host]!!.privateKey } /** @@ -63,7 +81,7 @@ class PropertyStore : SettingsStore() { @Synchronized override fun savePrivateKey(serverPrivateKey: ByteArray) { checkAccess(CryptoManagement::class.java) - servers[DB_Server.IP_SELF]!!.privateKey = serverPrivateKey + servers[ipv4Host]!!.privateKey = serverPrivateKey // have to always specify what we are saving storage.put(DB_Server.STORAGE_KEY, servers) @@ -74,7 +92,7 @@ class PropertyStore : SettingsStore() { */ @Synchronized override fun getPublicKey(): ByteArray? { - return servers[DB_Server.IP_SELF]!!.publicKey + return servers[ipv4Host]!!.publicKey } /** @@ -83,7 +101,7 @@ class PropertyStore : SettingsStore() { @Synchronized override fun savePublicKey(serverPublicKey: ByteArray) { checkAccess(CryptoManagement::class.java) - servers[DB_Server.IP_SELF]!!.publicKey = serverPublicKey + servers[ipv4Host]!!.publicKey = serverPublicKey // have to always specify what we are saving storage.put(DB_Server.STORAGE_KEY, servers) @@ -94,7 +112,7 @@ class PropertyStore : SettingsStore() { */ @Synchronized override fun getSalt(): ByteArray { - val localServer = servers[DB_Server.IP_SELF] + val localServer = servers[ipv4Host] var salt = localServer!!.salt // we don't care who gets the server salt @@ -118,7 +136,7 @@ class PropertyStore : SettingsStore() { * Simple, property based method to getting a connected computer by host IP address */ @Synchronized - override fun getRegisteredServerKey(hostAddress: Int): ByteArray? { + override fun getRegisteredServerKey(hostAddress: InetAddress): ByteArray? { return servers[hostAddress]?.publicKey } @@ -126,7 +144,7 @@ class PropertyStore : SettingsStore() { * Saves a connected computer by host IP address and public key */ @Synchronized - override fun addRegisteredServerKey(hostAddress: Int, publicKey: ByteArray) { + override fun addRegisteredServerKey(hostAddress: InetAddress, publicKey: ByteArray) { // checkAccess(RegistrationWrapper.class); var db_server = servers[hostAddress] if (db_server == null) { @@ -144,7 +162,7 @@ class PropertyStore : SettingsStore() { * Deletes a registered computer by host IP address */ @Synchronized - override fun removeRegisteredServerKey(hostAddress: Int): Boolean { + override fun removeRegisteredServerKey(hostAddress: InetAddress): Boolean { // checkAccess(RegistrationWrapper.class); val db_server = servers.remove(hostAddress) diff --git a/src/dorkbox/network/storage/SettingsStore.kt b/src/dorkbox/network/storage/SettingsStore.kt index d7daede4..54020e7a 100644 --- a/src/dorkbox/network/storage/SettingsStore.kt +++ b/src/dorkbox/network/storage/SettingsStore.kt @@ -20,6 +20,7 @@ import dorkbox.util.bytes.ByteArrayWrapper import dorkbox.util.exceptions.SecurityException import dorkbox.util.storage.Storage import org.slf4j.LoggerFactory +import java.net.InetAddress import java.util.* /** @@ -77,13 +78,13 @@ abstract class SettingsStore : AutoCloseable { * Gets a previously registered computer by host IP address */ @Throws(SecurityException::class) - abstract fun getRegisteredServerKey(hostAddress: Int): ByteArray? + abstract fun getRegisteredServerKey(hostAddress: InetAddress): ByteArray? /** * Saves a registered computer by host IP address and public key */ @Throws(SecurityException::class) - abstract fun addRegisteredServerKey(hostAddress: Int, publicKey: ByteArray) + abstract fun addRegisteredServerKey(hostAddress: InetAddress, publicKey: ByteArray) /** * Deletes a registered computer by host IP address @@ -91,7 +92,7 @@ abstract class SettingsStore : AutoCloseable { * @return true if successful, false if there were problems (or it didn't exist) */ @Throws(SecurityException::class) - abstract fun removeRegisteredServerKey(hostAddress: Int): Boolean + abstract fun removeRegisteredServerKey(hostAddress: InetAddress): Boolean /** * Take the proper steps to close the storage system. diff --git a/src/io/aeron/driver/media/UdpChannel.java b/src/io/aeron/driver/media/UdpChannel.java new file mode 100644 index 00000000..9214bb97 --- /dev/null +++ b/src/io/aeron/driver/media/UdpChannel.java @@ -0,0 +1,918 @@ +/* + * Copyright 2014-2020 Real Logic Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// REMOVE WHEN ISSUE https://github.com/real-logic/aeron/issues/1057 is resolved! +package io.aeron.driver.media; + +import static io.aeron.driver.media.NetworkUtil.filterBySubnet; +import static io.aeron.driver.media.NetworkUtil.findAddressOnInterface; +import static io.aeron.driver.media.NetworkUtil.getProtocolFamily; +import static java.lang.System.lineSeparator; +import static java.net.InetAddress.getByAddress; + +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.ProtocolFamily; +import java.net.SocketException; +import java.net.UnknownHostException; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; + +import org.agrona.BitUtil; +import org.agrona.LangUtil; + +import io.aeron.ChannelUri; +import io.aeron.CommonContext; +import io.aeron.driver.DefaultNameResolver; +import io.aeron.driver.NameResolver; +import io.aeron.driver.exceptions.InvalidChannelException; + +/** + * The media configuration for Aeron UDP channels as an instantiation of the socket addresses for a {@link ChannelUri}. + * + * @see ChannelUri + * @see io.aeron.ChannelUriStringBuilder + */ +public final class UdpChannel +{ + private static final AtomicInteger UNIQUE_CANONICAL_FORM_VALUE = new AtomicInteger(); + + private final boolean isManualControlMode; + private final boolean isDynamicControlMode; + private final boolean hasExplicitControl; + private final boolean hasExplicitEndpoint; + private final boolean isMulticast; + private final boolean hasMulticastTtl; + private final boolean hasTag; + private final int multicastTtl; + private final long tag; + private final InetSocketAddress remoteData; + private final InetSocketAddress localData; + private final InetSocketAddress remoteControl; + private final InetSocketAddress localControl; + private final String uriStr; + private final String canonicalForm; + private final NetworkInterface localInterface; + private final ProtocolFamily protocolFamily; + private final ChannelUri channelUri; + + private UdpChannel(final Context context) + { + isManualControlMode = context.isManualControlMode; + isDynamicControlMode = context.isDynamicControlMode; + hasExplicitEndpoint = context.hasExplicitEndpoint; + hasExplicitControl = context.hasExplicitControl; + isMulticast = context.isMulticast; + hasTag = context.hasTagId; + tag = context.tagId; + hasMulticastTtl = context.hasMulticastTtl; + multicastTtl = context.multicastTtl; + remoteData = context.remoteData; + localData = context.localData; + remoteControl = context.remoteControl; + localControl = context.localControl; + uriStr = context.uriStr; + canonicalForm = context.canonicalForm; + localInterface = context.localInterface; + protocolFamily = context.protocolFamily; + channelUri = context.channelUri; + } + + /** + * Parse channel URI and create a {@link UdpChannel}. + * + * @param channelUriString to parse. + * @return a new {@link UdpChannel} as the result of parsing. + * @throws InvalidChannelException if an error occurs. + */ + public static UdpChannel parse(final String channelUriString) + { + return parse(channelUriString, DefaultNameResolver.INSTANCE); + } + + /** + * Parse channel URI and create a {@link UdpChannel}. + * + * @param channelUriString to parse. + * @param nameResolver to use for resolving names + * @return a new {@link UdpChannel} as the result of parsing. + * @throws InvalidChannelException if an error occurs. + */ + @SuppressWarnings("MethodLength") + public static UdpChannel parse(final String channelUriString, final NameResolver nameResolver) + { + try + { + final ChannelUri channelUri = ChannelUri.parse(channelUriString); + validateConfiguration(channelUri); + + InetSocketAddress endpointAddress = getEndpointAddress(channelUri, nameResolver); + final InetSocketAddress explicitControlAddress = getExplicitControlAddress(channelUri, nameResolver); + + final String tagIdStr = channelUri.channelTag(); + final String controlMode = channelUri.get(CommonContext.MDC_CONTROL_MODE_PARAM_NAME); + final boolean isManualControlMode = CommonContext.MDC_CONTROL_MODE_MANUAL.equals(controlMode); + final boolean isDynamicControlMode = CommonContext.MDC_CONTROL_MODE_DYNAMIC.equals(controlMode); + + final boolean requiresAdditionalSuffix = + null == endpointAddress && null == explicitControlAddress || + (null != endpointAddress && endpointAddress.getPort() == 0) || + (null != explicitControlAddress && explicitControlAddress.getPort() == 0); + + final boolean hasNoDistinguishingCharacteristic = + null == endpointAddress && null == explicitControlAddress && null == tagIdStr; + + if (isDynamicControlMode && null == explicitControlAddress) + { + throw new IllegalArgumentException( + "explicit control expected with dynamic control mode: " + channelUriString); + } + + if (hasNoDistinguishingCharacteristic && !isManualControlMode) + { + throw new IllegalArgumentException( + "URIs for UDP must specify an endpoint, control, tags, or control-mode=manual: " + + channelUriString); + } + + if (null != endpointAddress && endpointAddress.isUnresolved()) + { + throw new UnknownHostException("could not resolve endpoint address: " + endpointAddress); + } + + if (null != explicitControlAddress && explicitControlAddress.isUnresolved()) + { + throw new UnknownHostException("could not resolve control address: " + explicitControlAddress); + } + + boolean hasExplicitEndpoint = true; + if (null == endpointAddress) + { + hasExplicitEndpoint = false; + if (explicitControlAddress == null || explicitControlAddress.getAddress() instanceof Inet4Address) { + endpointAddress = new InetSocketAddress(InetAddress.getByAddress("", new byte[]{0,0,0,0}), 0); + } else { + endpointAddress = new InetSocketAddress(InetAddress.getByAddress("", new byte[]{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}), 0); + } + } + + final Context context = new Context() + .uriStr(channelUriString) + .channelUri(channelUri) + .isManualControlMode(isManualControlMode) + .isDynamicControlMode(isDynamicControlMode) + .hasExplicitEndpoint(hasExplicitEndpoint) + .hasNoDistinguishingCharacteristic(hasNoDistinguishingCharacteristic); + + if (null != tagIdStr) + { + context.hasTagId(true).tagId(Long.parseLong(tagIdStr)); + } + + if (endpointAddress.getAddress().isMulticastAddress()) + { + final InetSocketAddress controlAddress = getMulticastControlAddress(endpointAddress); + final InterfaceSearchAddress searchAddress = getInterfaceSearchAddress(channelUri); + final NetworkInterface localInterface = findInterface(searchAddress); + final InetSocketAddress resolvedAddress = resolveToAddressOfInterface(localInterface, searchAddress); + + context + .isMulticast(true) + .localControlAddress(resolvedAddress) + .remoteControlAddress(controlAddress) + .localDataAddress(resolvedAddress) + .remoteDataAddress(endpointAddress) + .localInterface(localInterface) + .protocolFamily(getProtocolFamily(endpointAddress.getAddress())) + .canonicalForm(canonicalise(null, resolvedAddress, null, endpointAddress)); + + final String ttlValue = channelUri.get(CommonContext.TTL_PARAM_NAME); + if (null != ttlValue) + { + context.hasMulticastTtl(true).multicastTtl(Integer.parseInt(ttlValue)); + } + } + else if (null != explicitControlAddress) + { + final String controlVal = channelUri.get(CommonContext.MDC_CONTROL_PARAM_NAME); + final String endpointVal = channelUri.get(CommonContext.ENDPOINT_PARAM_NAME); + + String suffix = ""; + if (requiresAdditionalSuffix) + { + suffix = (null != tagIdStr) ? "#" + tagIdStr : ("-" + UNIQUE_CANONICAL_FORM_VALUE.getAndAdd(1)); + } + + final String canonicalForm = canonicalise( + controlVal, explicitControlAddress, endpointVal, endpointAddress) + suffix; + + context + .hasExplicitControl(true) + .remoteControlAddress(endpointAddress) + .remoteDataAddress(endpointAddress) + .localControlAddress(explicitControlAddress) + .localDataAddress(explicitControlAddress) + .protocolFamily(getProtocolFamily(endpointAddress.getAddress())) + .canonicalForm(canonicalForm); + } + else + { + final InterfaceSearchAddress searchAddress = getInterfaceSearchAddress(channelUri); + + final InetSocketAddress localAddress = searchAddress.getInetAddress().isAnyLocalAddress() ? + searchAddress.getAddress() : + resolveToAddressOfInterface(findInterface(searchAddress), searchAddress); + + final String endpointVal = channelUri.get(CommonContext.ENDPOINT_PARAM_NAME); + String suffix = ""; + if (requiresAdditionalSuffix) + { + suffix = (null != tagIdStr) ? "#" + tagIdStr : ("-" + UNIQUE_CANONICAL_FORM_VALUE.getAndAdd(1)); + } + + context + .remoteControlAddress(endpointAddress) + .remoteDataAddress(endpointAddress) + .localControlAddress(localAddress) + .localDataAddress(localAddress) + .protocolFamily(getProtocolFamily(endpointAddress.getAddress())) + .canonicalForm(canonicalise(null, localAddress, endpointVal, endpointAddress) + suffix); + } + + return new UdpChannel(context); + } + catch (final Exception ex) + { + throw new InvalidChannelException(ex); + } + } + + /** + * Return a string which is a canonical form of the channel suitable for use as a file or directory + * name and also as a method of hashing, etc. + *

+ * The general format is: + * UDP-interface:localPort-remoteAddress:remotePort + * + * @param localParamValue interface or MDC control param value or null for not set. + * @param localData address/interface for the channel. + * @param remoteParamValue endpoint param value or null if not set. + * @param remoteData address for the channel. + * @return canonical representation as a string. + */ + public static String canonicalise( + final String localParamValue, + final InetSocketAddress localData, + final String remoteParamValue, + final InetSocketAddress remoteData) + { + final StringBuilder builder = new StringBuilder(48); + + builder.append("UDP-"); + + if (null == localParamValue) + { + builder.append(localData.getHostString()) + .append(':') + .append(localData.getPort()); + } + else + { + builder.append(localParamValue); + } + + builder.append('-'); + + if (null == remoteParamValue) + { + builder.append(remoteData.getHostString()) + .append(':') + .append(remoteData.getPort()); + } + else + { + builder.append(remoteParamValue); + } + + return builder.toString(); + } + + /** + * Remote data address and port. + * + * @return remote data address and port. + */ + public InetSocketAddress remoteData() + { + return remoteData; + } + + /** + * Local data address and port. + * + * @return local data address port. + */ + public InetSocketAddress localData() + { + return localData; + } + + /** + * Remote control address information + * + * @return remote control address information + */ + public InetSocketAddress remoteControl() + { + return remoteControl; + } + + /** + * Local control address and port. + * + * @return local control address and port. + */ + public InetSocketAddress localControl() + { + return localControl; + } + + /** + * Get the {@link ChannelUri} for this channel. + * + * @return the {@link ChannelUri} for this channel. + */ + public ChannelUri channelUri() + { + return channelUri; + } + + /** + * Has this channel got a multicast TTL value set so that {@link #multicastTtl()} is valid. + * + * @return true if this channel is a multicast TTL set otherwise false. + */ + public boolean hasMulticastTtl() + { + return hasMulticastTtl; + } + + /** + * Multicast TTL value. + * + * @return multicast TTL value. + */ + public int multicastTtl() + { + return multicastTtl; + } + + /** + * The canonical form for the channel + *

+ * {@link UdpChannel#canonicalise} + * + * @return canonical form for channel. + */ + public String canonicalForm() + { + return canonicalForm; + } + + /** + * The {@link #canonicalForm()} for the channel. + * + * @return the {@link #canonicalForm()} for the channel. + */ + @Override + public String toString() + { + return canonicalForm; + } + + /** + * Is the channel UDP multicast. + * + * @return true if the channel is UDP multicast. + */ + public boolean isMulticast() + { + return isMulticast; + } + + /** + * Local interface to be used by the channel. + * + * @return {@link NetworkInterface} for the local interface used by the channel. + */ + public NetworkInterface localInterface() + { + return localInterface; + } + + /** + * Original URI of the channel URI. + * + * @return the original uri string from the client. + */ + public String originalUriString() + { + return uriStr; + } + + /** + * Get the {@link ProtocolFamily} for this channel. + * + * @return the {@link ProtocolFamily} for this channel. + */ + public ProtocolFamily protocolFamily() + { + return protocolFamily; + } + + /** + * Get the tag value on the channel which is only valid if {@link #hasTag()} is true. + * + * @return the tag value on the channel. + */ + public long tag() + { + return tag; + } + + /** + * Does the channel have manual control mode specified. + * + * @return does channel have manual control mode specified. + */ + public boolean isManualControlMode() + { + return isManualControlMode; + } + + /** + * Does the channel have dynamic control mode specified. + * + * @return does channel have dynamic control mode specified. + */ + public boolean isDynamicControlMode() + { + return isDynamicControlMode; + } + + /** + * Does the channel have an explicit endpoint address? + * + * @return does channel have an explicit endpoint address or not? + */ + public boolean hasExplicitEndpoint() + { + return hasExplicitEndpoint; + } + + /** + * Does the channel have an explicit control address as used with multi-destination-cast or not? + * + * @return does channel have an explicit control address or not? + */ + public boolean hasExplicitControl() + { + return hasExplicitControl; + } + + /** + * Has the URI a tag to indicate entity relationships and if {@link #tag()} is valid. + * + * @return true if the channel has a tag. + */ + public boolean hasTag() + { + return hasTag; + } + + /** + * Is the channel configured as multi-destination. + * + * @return true if he channel configured as multi-destination. + */ + public boolean isMultiDestination() + { + return isDynamicControlMode || isManualControlMode || hasExplicitControl; + } + + /** + * Does this channel have a tag match to another channel including endpoints. + * + * @param udpChannel to match against. + * @return true if there is a match otherwise false. + */ + public boolean matchesTag(final UdpChannel udpChannel) + { + if (!hasTag || !udpChannel.hasTag() || tag != udpChannel.tag()) + { + return false; + } + + if (udpChannel.remoteData().getAddress().isAnyLocalAddress() && + udpChannel.remoteData().getPort() == 0 && + udpChannel.localData().getAddress().isAnyLocalAddress() && + udpChannel.localData().getPort() == 0) + { + return true; + } + + throw new IllegalArgumentException( + "matching tag has set endpoint or control address - " + uriStr + " <> " + udpChannel.uriStr); + } + + /** + * Used for debugging to get a human readable description of the channel. + * + * @return a human readable description of the channel. + */ + public String description() + { + final StringBuilder builder = new StringBuilder("UdpChannel - "); + if (null != localInterface) + { + builder + .append("interface: ") + .append(localInterface.getDisplayName()) + .append(", "); + } + + builder + .append("localData: ").append(localData) + .append(", remoteData: ").append(remoteData) + .append(", ttl: ").append(multicastTtl); + + return builder.toString(); + } + + /** + * Channels are considered equal if the {@link #canonicalForm()} is equal. + * + * @param o object to be compared with. + * @return true if the {@link #canonicalForm()} is equal, otherwise false. + */ + @Override + public boolean equals(final Object o) + { + if (this == o) + { + return true; + } + + if (o == null || getClass() != o.getClass()) + { + return false; + } + + final UdpChannel that = (UdpChannel)o; + + return Objects.equals(canonicalForm, that.canonicalForm); + } + + /** + * The hash code for the {@link #canonicalForm()}. + * + * @return the hash code for the {@link #canonicalForm()}. + */ + @Override + public int hashCode() + { + return canonicalForm != null ? canonicalForm.hashCode() : 0; + } + + /** + * Get the endpoint destination address from the URI. + * + * @param uri to check. + * @param nameResolver to use for resolution + * @return endpoint address for URI. + */ + public static InetSocketAddress destinationAddress(final ChannelUri uri, final NameResolver nameResolver) + { + try + { + validateConfiguration(uri); + return getEndpointAddress(uri, nameResolver); + } + catch (final Exception ex) + { + throw new InvalidChannelException(ex); + } + } + + /** + * Resolve and endpoint into a {@link InetSocketAddress}. + * + * @param endpoint to resolve + * @param uriParamName for the resolution + * @param isReResolution for the resolution + * @param nameResolver to be used for hostname. + * @return address for endpoint + * @throws UnknownHostException if the endpoint can not be resolved. + */ + public static InetSocketAddress resolve( + final String endpoint, final String uriParamName, final boolean isReResolution, final NameResolver nameResolver) + throws UnknownHostException + { + return SocketAddressParser.parse(endpoint, uriParamName, isReResolution, nameResolver); + } + + private static InetSocketAddress getMulticastControlAddress(final InetSocketAddress endpointAddress) + throws UnknownHostException + { + final byte[] addressAsBytes = endpointAddress.getAddress().getAddress(); + validateDataAddress(addressAsBytes); + + addressAsBytes[addressAsBytes.length - 1]++; + return new InetSocketAddress(getByAddress(addressAsBytes), endpointAddress.getPort()); + } + + private static InterfaceSearchAddress getInterfaceSearchAddress(final ChannelUri uri) throws UnknownHostException + { + final String interfaceValue = uri.get(CommonContext.INTERFACE_PARAM_NAME); + if (null != interfaceValue) + { + return InterfaceSearchAddress.parse(interfaceValue); + } + + return InterfaceSearchAddress.wildcard(); + } + + private static InetSocketAddress getEndpointAddress(final ChannelUri uri, final NameResolver nameResolver) + { + InetSocketAddress address = null; + final String endpointValue = uri.get(CommonContext.ENDPOINT_PARAM_NAME); + if (null != endpointValue) + { + try + { + address = SocketAddressParser.parse( + endpointValue, CommonContext.ENDPOINT_PARAM_NAME, false, nameResolver); + } + catch (final UnknownHostException ex) + { + LangUtil.rethrowUnchecked(ex); + } + } + + return address; + } + + private static InetSocketAddress getExplicitControlAddress(final ChannelUri uri, final NameResolver nameResolver) + { + InetSocketAddress address = null; + final String controlValue = uri.get(CommonContext.MDC_CONTROL_PARAM_NAME); + if (null != controlValue) + { + try + { + address = SocketAddressParser.parse( + controlValue, CommonContext.MDC_CONTROL_PARAM_NAME, false, nameResolver); + } + catch (final UnknownHostException ex) + { + LangUtil.rethrowUnchecked(ex); + } + } + + return address; + } + + private static void validateDataAddress(final byte[] addressAsBytes) + { + if (BitUtil.isEven(addressAsBytes[addressAsBytes.length - 1])) + { + throw new IllegalArgumentException("multicast data address must be odd"); + } + } + + private static void validateConfiguration(final ChannelUri uri) + { + validateMedia(uri); + } + + private static void validateMedia(final ChannelUri uri) + { + if (!uri.isUdp()) + { + throw new IllegalArgumentException("UdpChannel only supports UDP media: " + uri); + } + } + + private static InetSocketAddress resolveToAddressOfInterface( + final NetworkInterface localInterface, final InterfaceSearchAddress searchAddress) + { + final InetAddress interfaceAddress = findAddressOnInterface( + localInterface, searchAddress.getInetAddress(), searchAddress.getSubnetPrefix()); + + if (null == interfaceAddress) + { + throw new IllegalStateException(); + } + + return new InetSocketAddress(interfaceAddress, searchAddress.getPort()); + } + + private static NetworkInterface findInterface(final InterfaceSearchAddress searchAddress) + throws SocketException + { + final NetworkInterface[] filteredInterfaces = filterBySubnet( + searchAddress.getInetAddress(), searchAddress.getSubnetPrefix()); + + for (final NetworkInterface networkInterface : filteredInterfaces) + { + if (networkInterface.supportsMulticast() || networkInterface.isLoopback()) + { + return networkInterface; + } + } + + throw new IllegalArgumentException(errorNoMatchingInterfaces(filteredInterfaces, searchAddress)); + } + + private static String errorNoMatchingInterfaces( + final NetworkInterface[] filteredInterfaces, final InterfaceSearchAddress address) + throws SocketException + { + final StringBuilder builder = new StringBuilder() + .append("Unable to find multicast interface matching criteria: ") + .append(address.getAddress()) + .append('/') + .append(address.getSubnetPrefix()); + + if (filteredInterfaces.length > 0) + { + builder.append(lineSeparator()).append(" Candidates:"); + + for (final NetworkInterface ifc : filteredInterfaces) + { + builder + .append(lineSeparator()) + .append(" - Name: ") + .append(ifc.getDisplayName()) + .append(", addresses: ") + .append(ifc.getInterfaceAddresses()) + .append(", multicast: ") + .append(ifc.supportsMulticast()); + } + } + + return builder.toString(); + } + + static class Context + { + long tagId; + int multicastTtl; + InetSocketAddress remoteData; + InetSocketAddress localData; + InetSocketAddress remoteControl; + InetSocketAddress localControl; + String uriStr; + String canonicalForm; + NetworkInterface localInterface; + ProtocolFamily protocolFamily; + ChannelUri channelUri; + boolean isManualControlMode = false; + boolean isDynamicControlMode = false; + boolean hasExplicitEndpoint = false; + boolean hasExplicitControl = false; + boolean isMulticast = false; + boolean hasMulticastTtl = false; + boolean hasTagId = false; + boolean hasNoDistinguishingCharacteristic = false; + + Context uriStr(final String uri) + { + uriStr = uri; + return this; + } + + Context remoteDataAddress(final InetSocketAddress remoteData) + { + this.remoteData = remoteData; + return this; + } + + Context localDataAddress(final InetSocketAddress localData) + { + this.localData = localData; + return this; + } + + Context remoteControlAddress(final InetSocketAddress remoteControl) + { + this.remoteControl = remoteControl; + return this; + } + + Context localControlAddress(final InetSocketAddress localControl) + { + this.localControl = localControl; + return this; + } + + Context canonicalForm(final String canonicalForm) + { + this.canonicalForm = canonicalForm; + return this; + } + + Context localInterface(final NetworkInterface networkInterface) + { + this.localInterface = networkInterface; + return this; + } + + Context protocolFamily(final ProtocolFamily protocolFamily) + { + this.protocolFamily = protocolFamily; + return this; + } + + Context hasMulticastTtl(final boolean hasMulticastTtl) + { + this.hasMulticastTtl = hasMulticastTtl; + return this; + } + + Context multicastTtl(final int multicastTtl) + { + this.multicastTtl = multicastTtl; + return this; + } + + Context tagId(final long tagId) + { + this.tagId = tagId; + return this; + } + + Context channelUri(final ChannelUri channelUri) + { + this.channelUri = channelUri; + return this; + } + + Context isManualControlMode(final boolean isManualControlMode) + { + this.isManualControlMode = isManualControlMode; + return this; + } + + Context isDynamicControlMode(final boolean isDynamicControlMode) + { + this.isDynamicControlMode = isDynamicControlMode; + return this; + } + + Context hasExplicitEndpoint(final boolean hasExplicitEndpoint) + { + this.hasExplicitEndpoint = hasExplicitEndpoint; + return this; + } + + Context hasExplicitControl(final boolean hasExplicitControl) + { + this.hasExplicitControl = hasExplicitControl; + return this; + } + + Context isMulticast(final boolean isMulticast) + { + this.isMulticast = isMulticast; + return this; + } + + Context hasTagId(final boolean hasTagId) + { + this.hasTagId = hasTagId; + return this; + } + + Context hasNoDistinguishingCharacteristic(final boolean hasNoDistinguishingCharacteristic) + { + this.hasNoDistinguishingCharacteristic = hasNoDistinguishingCharacteristic; + return this; + } + } +} diff --git a/test/dorkboxTest/network/BaseTest.kt b/test/dorkboxTest/network/BaseTest.kt index 451276c8..faf2b48f 100644 --- a/test/dorkboxTest/network/BaseTest.kt +++ b/test/dorkboxTest/network/BaseTest.kt @@ -81,7 +81,6 @@ abstract class BaseTest { fun serverConfig(): ServerConfiguration { val configuration = ServerConfiguration() - configuration.listenIpAddress = LOOPBACK configuration.subscriptionPort = 2000 configuration.publicationPort = 2001 diff --git a/test/dorkboxTest/network/rmi/RmiDelayedInvocationTest.kt b/test/dorkboxTest/network/rmi/RmiDelayedInvocationTest.kt index aa84705b..ef91a1bd 100644 --- a/test/dorkboxTest/network/rmi/RmiDelayedInvocationTest.kt +++ b/test/dorkboxTest/network/rmi/RmiDelayedInvocationTest.kt @@ -32,7 +32,7 @@ class RmiDelayedInvocationTest : BaseTest() { @Test fun rmiNetwork() { runBlocking { - rmi() { configuration -> + rmi { configuration -> configuration.enableIpcForLoopback = false } } @@ -108,7 +108,7 @@ class RmiDelayedInvocationTest : BaseTest() { client.connect(LOOPBACK) } - waitForThreads(9999999) + waitForThreads() } private interface TestObject { diff --git a/test/dorkboxTest/network/rmi/RmiSimpleTest.kt b/test/dorkboxTest/network/rmi/RmiSimpleTest.kt index 1e7bf178..be66ff65 100644 --- a/test/dorkboxTest/network/rmi/RmiSimpleTest.kt +++ b/test/dorkboxTest/network/rmi/RmiSimpleTest.kt @@ -34,6 +34,8 @@ */ package dorkboxTest.network.rmi +import dorkbox.netUtil.IPv4 +import dorkbox.netUtil.IPv6 import dorkbox.network.Client import dorkbox.network.Configuration import dorkbox.network.Server @@ -49,15 +51,58 @@ import org.junit.Test class RmiSimpleTest : BaseTest() { @Test - fun rmiNetworkGlobal() { - rmiGlobal() { configuration -> + fun rmiIPv4NetworkGlobal() { + rmiGlobal(isIpv4 = true, isIpv6 = false) { configuration -> configuration.enableIpcForLoopback = false } } @Test - fun rmiNetworkConnection() { - rmi { configuration -> + fun rmiIPv6NetworkGlobal() { + rmiGlobal(isIpv4 = true, isIpv6 = false) { configuration -> + configuration.enableIpcForLoopback = false + } + } + + @Test + fun rmiBothIPv4ConnectNetworkGlobal() { + rmiGlobal(isIpv4 = true, isIpv6 = true) { configuration -> + configuration.enableIpcForLoopback = false + } + } + + @Test + fun rmiBothIPv6ConnectNetworkGlobal() { + rmiGlobal(isIpv4 = true, isIpv6 = true, runIpv4Connect = true) { configuration -> + configuration.enableIpcForLoopback = false + } + } + + @Test + fun rmiIPv4NetworkConnection() { + rmi(isIpv4 = true, isIpv6 = false) { configuration -> + configuration.enableIpcForLoopback = false + } + } + + @Test + fun rmiIPv6NetworkConnection() { + rmi(isIpv4 = false, isIpv6 = true) { configuration -> + configuration.enableIpcForLoopback = false + } + } + + @Test + fun rmiBothIPv4ConnectNetworkConnection() { + rmi(isIpv4 = true, isIpv6 = true) { configuration -> + configuration.enableIpcForLoopback = false + } + } + + + @Test + fun rmiBothIPv6ConnectNetworkConnection() { + rmi(isIpv4 = true, isIpv6 = true, runIpv4Connect = true) { configuration -> configuration.enableIpcForLoopback = false } } @@ -72,10 +117,13 @@ class RmiSimpleTest : BaseTest() { rmi() } - fun rmi(config: (Configuration) -> Unit = {}) { + fun rmi(isIpv4: Boolean = false, isIpv6: Boolean = false, runIpv4Connect: Boolean = true, config: (Configuration) -> Unit = {}) { run { val configuration = serverConfig() + configuration.enableIPv4 = isIpv4 + configuration.enableIPv6 = isIpv6 config(configuration) + configuration.serialization.registerRmi(TestCow::class.java, TestCowImpl::class.java) configuration.serialization.register(MessageWithTestCow::class.java) configuration.serialization.register(UnsupportedOperationException::class.java) @@ -86,19 +134,19 @@ class RmiSimpleTest : BaseTest() { server.bind() server.onMessage { connection, m -> - System.err.println("Received finish signal for test for: Client -> Server") + server.logger.error("Received finish signal for test for: Client -> Server") val `object` = m.testCow val id = `object`.id() Assert.assertEquals(23, id.toLong()) - System.err.println("Finished test for: Client -> Server") + server.logger.error("Finished test for: Client -> Server") - System.err.println("Starting test for: Server -> Client") + server.logger.error("Starting test for: Server -> Client") // NOTE: THIS IS BI-DIRECTIONAL! connection.createObject(123) { rmiId, remoteObject -> - System.err.println("Running test for: Server -> Client") + server.logger.error("Running test for: Server -> Client") RmiCommonTest.runTests(connection, remoteObject, 123) - System.err.println("Done with test for: Server -> Client") + server.logger.error("Done with test for: Server -> Client") } } } @@ -108,39 +156,47 @@ class RmiSimpleTest : BaseTest() { config(configuration) // configuration.serialization.registerRmi(TestCow::class.java, TestCowImpl::class.java) - val client = Client(configuration) addEndPoint(client) client.onConnect { connection -> connection.createObject(23) { rmiId, remoteObject -> - System.err.println("Running test for: Client -> Server") + client.logger.error("Running test for: Client -> Server") RmiCommonTest.runTests(connection, remoteObject, 23) - System.err.println("Done with test for: Client -> Server") + client.logger.error("Done with test for: Client -> Server") } } client.onMessage { _, m -> - System.err.println("Received finish signal for test for: Client -> Server") + client.logger.error("Received finish signal for test for: Client -> Server") val `object` = m.testCow val id = `object`.id() Assert.assertEquals(123, id.toLong()) - System.err.println("Finished test for: Client -> Server") + client.logger.error("Finished test for: Client -> Server") stopEndPoints(2000) } runBlocking { - client.connect(LOOPBACK) + when { + isIpv4 && isIpv6 && runIpv4Connect -> client.connect(IPv4.LOCALHOST) + isIpv4 && isIpv6 && !runIpv4Connect -> client.connect(IPv6.LOCALHOST) + isIpv4 -> client.connect(IPv4.LOCALHOST) + isIpv6 -> client.connect(IPv6.LOCALHOST) + else -> client.connect() + } } } waitForThreads() } - fun rmiGlobal(config: (Configuration) -> Unit = {}) { + fun rmiGlobal(isIpv4: Boolean = false, isIpv6: Boolean = false, runIpv4Connect: Boolean = true, config: (Configuration) -> Unit = {}) { run { val configuration = serverConfig() + configuration.enableIPv4 = isIpv4 + configuration.enableIPv6 = isIpv6 config(configuration) + configuration.serialization.registerRmi(TestCow::class.java, TestCowImpl::class.java) configuration.serialization.register(MessageWithTestCow::class.java) configuration.serialization.register(UnsupportedOperationException::class.java) @@ -153,20 +209,20 @@ class RmiSimpleTest : BaseTest() { server.bind() server.onMessage { connection, m -> - System.err.println("Received finish signal for test for: Client -> Server") + server.logger.error("Received finish signal for test for: Client -> Server") val `object` = m.testCow val id = `object`.id() Assert.assertEquals(44, id.toLong()) - System.err.println("Finished test for: Client -> Server") + server.logger.error("Finished test for: Client -> Server") // normally this is in the 'connected', but we do it here, so that it's more linear and easier to debug connection.createObject(4) { rmiId, remoteObject -> - System.err.println("Running test for: Server -> Client") + server.logger.error("Running test for: Server -> Client") RmiCommonTest.runTests(connection, remoteObject, 4) - System.err.println("Done with test for: Server -> Client") + server.logger.error("Done with test for: Server -> Client") } } } @@ -180,24 +236,30 @@ class RmiSimpleTest : BaseTest() { addEndPoint(client) client.onMessage { _, m -> - System.err.println("Received finish signal for test for: Client -> Server") + client.logger.error("Received finish signal for test for: Client -> Server") val `object` = m.testCow val id = `object`.id() Assert.assertEquals(4, id.toLong()) - System.err.println("Finished test for: Client -> Server") + client.logger.error("Finished test for: Client -> Server") stopEndPoints(2000) } runBlocking { - client.connect(LOOPBACK) + when { + isIpv4 && isIpv6 && runIpv4Connect -> client.connect(IPv4.LOCALHOST) + isIpv4 && isIpv6 && !runIpv4Connect -> client.connect(IPv6.LOCALHOST) + isIpv4 -> client.connect(IPv4.LOCALHOST) + isIpv6 -> client.connect(IPv6.LOCALHOST) + else -> client.connect() + } - System.err.println("Starting test for: Client -> Server") + client.logger.error("Starting test for: Client -> Server") // this creates a GLOBAL object on the server (instead of a connection specific object) client.createObject(44) { rmiId, remoteObject -> - System.err.println("Running test for: Client -> Server") + client.logger.error("Running test for: Client -> Server") RmiCommonTest.runTests(client.connection, remoteObject, 44) - System.err.println("Done with test for: Client -> Server") + client.logger.error("Done with test for: Client -> Server") } } }