Added support for IPv6.

This commit is contained in:
nathan 2020-09-09 01:33:09 +02:00
parent 1e36cba8f2
commit b7da14834e
39 changed files with 1865 additions and 451 deletions

View File

@ -16,26 +16,26 @@
package dorkbox.network
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverConnection
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.network.handshake.ClientHandshake
import dorkbox.network.other.coroutines.SuspendWaiter
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.TimeoutException
import dorkbox.util.exceptions.SecurityException
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.launch
import java.net.Inet4Address
import java.net.InetAddress
/**
* The client is both SYNC and ASYNC. It starts off SYNC (blocks thread until it's done), then once it's connected to the server, it's
@ -59,7 +59,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* For the IPC (Inter-Process-Communication) address. it must be:
* - the IPC integer ID, "0x1337c0de", "0x12312312", etc.
*/
private var remoteAddress0 = ""
private var remoteAddress0: InetAddress? = IPv4.LOCALHOST
@Volatile
private var isConnected = false
@ -67,9 +67,6 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// is valid when there is a connection to the server, otherwise it is null
private var connection0: CONNECTION? = null
@Volatile
private var connectionTimeoutMS: Long = 5_000 // default is 5 seconds
private val previousClosedConnectionActivity: Long = 0
private val rmiConnectionSupport = RmiManagerConnections(logger, rmiGlobalSupport, serialization)
@ -107,9 +104,12 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* - an InetAddress address
*
* ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY. ie: just call `connect()`
* - EMPTY.
* - `connect()`
* - `connect("")`
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
*
@ -121,8 +121,113 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
suspend fun connect(remoteAddress: String,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
when {
// this is default IPC settings
remoteAddress.isEmpty() -> connect(connectionTimeoutMS = connectionTimeoutMS)
IPv4.isPreferred -> connect(remoteAddress = Inet4Address.getAllByName(remoteAddress)[0],
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable)
else -> connect(remoteAddress = Inet4Address.getAllByName(remoteAddress)[0],
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable)
}
}
/**
* Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed.
*
* Default connection is to localhost
*
* ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
* - an InetAddress address
*
* ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY.
* - `connect()`
* - `connect("")`
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
*
* @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
suspend fun connect(remoteAddress: InetAddress,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
// Default IPC ports are flipped because they are in the perspective of the SERVER
connect(remoteAddress = remoteAddress,
ipcPublicationId = IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId = IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS = connectionTimeoutMS,
reliable = reliable)
}
/**
* Will attempt to connect to the server via IPC, with a default 30 second connection timeout and will block until completed.
*
* @param ipcPublicationId The IPC publication address for the client to connect to
* @param ipcSubscriptionId The IPC subscription address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely.
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("DuplicatedCode")
suspend fun connect(remoteAddress: String = IPv4.LOCALHOST.hostAddress, connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
suspend fun connect(ipcPublicationId: Int = IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L) {
// Default IPC ports are flipped because they are in the perspective of the SERVER
require(ipcPublicationId != ipcSubscriptionId) { "IPC publication and subscription ports cannot be the same! The must match the server's configuration." }
connect(remoteAddress = null, // required!
ipcPublicationId = ipcPublicationId,
ipcSubscriptionId = ipcSubscriptionId,
connectionTimeoutMS = connectionTimeoutMS)
}
/**
* Will attempt to connect to the server, with a default 30 second connection timeout and will block until completed.
*
* Default connection is to localhost
*
* ### For a network address, it can be:
* - a network name ("localhost", "loopback", "lo", "bob.example.org")
* - an IP address ("127.0.0.1", "123.123.123.123", "::1")
*
* ### For the IPC (Inter-Process-Communication) it must be:
* - EMPTY. ie: just call `connect()`
* - Specified EMPTY. ie: just call `connect()`
*
* ### Case does not matter, and "localhost" is the default. IPC address must be in HEX notation (starting with '0x')
*
* @param remoteAddress The network or if localhost, IPC address for the client to connect to
* @param ipcPublicationId The IPC publication address for the client to connect to
* @param ipcSubscriptionId The IPC subscription address for the client to connect to
* @param connectionTimeoutMS wait for x milliseconds. 0 will wait indefinitely.
* @param reliable true if we want to create a reliable connection. IPC connections are always reliable
*
* @throws IllegalArgumentException if the remote address is invalid
* @throws ClientTimedOutException if the client is unable to connect in x amount of time
* @throws ClientRejectedException if the client connection is rejected
*/
@Suppress("DuplicatedCode")
private suspend fun connect(remoteAddress: InetAddress? = null,
// Default IPC ports are flipped because they are in the perspective of the SERVER
ipcPublicationId: Int = IPC_HANDSHAKE_STREAM_ID_SUB,
ipcSubscriptionId: Int = IPC_HANDSHAKE_STREAM_ID_PUB,
connectionTimeoutMS: Long = 30_000L, reliable: Boolean = true) {
// this will exist ONLY if we are reconnecting via a "disconnect" callback
lockStepForReconnect.value?.doWait()
@ -143,80 +248,31 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
logger.info("Media driver is running. Support for enable auto-switch from LOCALHOST -> IPC enabled")
}
this.connectionTimeoutMS = connectionTimeoutMS
val isIpcConnection: Boolean
// NETWORK OR IPC ADDRESS
// if we connect to "loopback", then we substitute if for IPC (with log message)
// if we connect to "loopback", then MAYBE we substitute if for IPC (with log message)
// localhost/loopback IP might not always be 127.0.0.1 or ::1
when (remoteAddress) {
"0.0.0.0" -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
"loopback", "localhost", "lo", "" -> {
if (canAutoChangeToIpc) {
isIpcConnection = true
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = "ipc"
} else {
isIpcConnection = false
this.remoteAddress0 = IPv4.LOCALHOST.hostAddress
}
}
"0x" -> {
isIpcConnection = true
this.remoteAddress0 = "ipc"
}
else -> when {
IPv4.isLoopback(remoteAddress) -> {
if (canAutoChangeToIpc) {
isIpcConnection = true
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = "ipc"
} else {
isIpcConnection = false
this.remoteAddress0 = IPv4.LOCALHOST.hostAddress
}
}
IPv6.isLoopback(remoteAddress) -> {
if (canAutoChangeToIpc) {
isIpcConnection = true
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = "ipc"
} else {
isIpcConnection = false
this.remoteAddress0 = IPv6.LOCALHOST.hostAddress
}
}
else -> {
isIpcConnection = false
this.remoteAddress0 = remoteAddress
}
when {
remoteAddress == null -> this.remoteAddress0 = null
remoteAddress.isAnyLocalAddress -> throw IllegalArgumentException("0.0.0.0 is an invalid address to connect to!")
canAutoChangeToIpc && remoteAddress.isLoopbackAddress -> {
logger.info { "Auto-changing network connection from $remoteAddress -> IPC" }
this.remoteAddress0 = null
}
else -> this.remoteAddress0 = remoteAddress
}
if (IPv6.isValid(this.remoteAddress0)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (this.remoteAddress0.count { it == '[' } < 1 &&
this.remoteAddress0.count { it == ']' } < 1) {
this.remoteAddress0 = """[${this.remoteAddress0}]"""
}
}
val handshake = ClientHandshake(logger, config, crypto, this)
// initially we only connect to the handshake connect ports. Ports are flipped because they are in the perspective of the SERVER
val handshakeConnection = if (isIpcConnection) {
IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_PUB,
streamId = IPC_HANDSHAKE_STREAM_ID_SUB,
val handshakeConnection = if (this.remoteAddress0 == null) {
IpcMediaDriverConnection(streamIdSubscription = ipcSubscriptionId,
streamId = ipcPublicationId,
sessionId = RESERVED_SESSION_ID_INVALID)
}
else {
UdpMediaDriverConnection(address = this.remoteAddress0,
UdpMediaDriverConnection(address = this.remoteAddress0!!,
publicationPort = config.subscriptionPort,
subscriptionPort = config.publicationPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
@ -227,7 +283,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// throws a ConnectTimedOutException if the client cannot connect for any reason to the server handshake ports
handshakeConnection.buildClient(aeron)
handshakeConnection.buildClient(aeron, logger)
logger.info(handshakeConnection.clientInfo())
@ -238,10 +294,10 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// VALIDATE:: check to see if the remote connection's public key has changed!
val validateRemoteAddress = if (isIpcConnection) {
val validateRemoteAddress = if (this.remoteAddress0 == null) {
PublicKeyValidationState.VALID
} else {
crypto.validateRemoteAddress(IPv4.toInt(this.remoteAddress0), connectionInfo.publicKey)
crypto.validateRemoteAddress(this.remoteAddress0!!, connectionInfo.publicKey)
}
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
@ -258,7 +314,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
// we are now connected, so we can connect to the NEW client-specific ports
val reliableClientConnection = if (isIpcConnection) {
val reliableClientConnection = if (this.remoteAddress0 == null) {
IpcMediaDriverConnection(sessionId = connectionInfo.sessionId,
// NOTE: pub/sub must be switched!
streamIdSubscription = connectionInfo.publicationPort,
@ -266,7 +322,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
connectionTimeoutMS = connectionTimeoutMS)
}
else {
UdpMediaDriverConnection(address = handshakeConnection.address,
UdpMediaDriverConnection(address = handshakeConnection.address!!,
// NOTE: pub/sub must be switched!
subscriptionPort = connectionInfo.publicationPort,
publicationPort = connectionInfo.subscriptionPort,
@ -277,7 +333,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
// we have to construct how the connection will communicate!
reliableClientConnection.buildClient(aeron)
reliableClientConnection.buildClient(aeron, logger)
// only the client connects to the server, so here we have to connect. The server (when creating the new "connection" object)
// does not need to do anything
@ -302,7 +358,7 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
}
val newConnection = if (isIpcConnection) {
val newConnection = if (this.remoteAddress0 == null) {
newConnection(ConnectionParams(this, reliableClientConnection, PublicKeyValidationState.VALID))
} else {
newConnection(ConnectionParams(this, reliableClientConnection, validateRemoteAddress))
@ -406,11 +462,17 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
val remoteKeyHasChanged: Boolean
get() = connection.hasRemoteKeyChanged()
/**
* the remote address
*/
val remoteAddress: InetAddress?
get() = remoteAddress0
/**
* the remote address, as a string.
*/
val remoteAddress: String
get() = remoteAddress0
val remoteAddressString: String
get() = remoteAddress0?.hostAddress ?: "ipc"
/**
* true if this connection is an IPC connection
@ -464,12 +526,10 @@ open class Client<CONNECTION : Connection>(config: Configuration = Configuration
/**
* Removes the specified host address from the list of registered server keys.
*/
@Throws(SecurityException::class)
fun removeRegisteredServerKey(hostAddress: String) {
val address = IPv4.toInt(hostAddress)
fun removeRegisteredServerKey(address: InetAddress) {
val savedPublicKey = settingsStore.getRegisteredServerKey(address)
if (savedPublicKey != null) {
logger.debug { "Deleting remote IP address key $hostAddress" }
logger.debug { "Deleting remote IP address key $address" }
settingsStore.removeRegisteredServerKey(address)
}
}

View File

@ -18,6 +18,7 @@ package dorkbox.network
import dorkbox.network.aeron.CoroutineBackoffIdleStrategy
import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy
import dorkbox.network.connection.EndPoint
import dorkbox.network.serialization.Serialization
import dorkbox.network.storage.PropertyStore
import dorkbox.network.storage.SettingsStore
@ -30,6 +31,21 @@ import mu.KLogger
import java.io.File
class ServerConfiguration : dorkbox.network.Configuration() {
/**
* Enables the ability to use the IPv4 network stack.
*/
var enableIPv4 = true
/**
* Enables the ability to use the IPv6 network stack.
*/
var enableIPv6 = true
/**
* Enables the ability use IPC (Inter Process Communication)
*/
var enableIPC = true
/**
* The address for the server to listen on. "*" will accept connections from all interfaces, otherwise specify
* the hostname (or IP) to bind to.
@ -37,14 +53,24 @@ class ServerConfiguration : dorkbox.network.Configuration() {
var listenIpAddress = "*"
/**
* The maximum number of clients allowed for a server
* The maximum number of clients allowed for a server. IPC is unlimited
*/
var maxClientCount = 0
/**
* The maximum number of client connection allowed per IP address
* The maximum number of client connection allowed per IP address. IPC is unlimited
*/
var maxConnectionsPerIpAddress = 0
/**
* The IPC Publication ID is used to define what ID the server will send data on. The client IPC subscription ID must match this value.
*/
var ipcPublicationId = EndPoint.IPC_HANDSHAKE_STREAM_ID_PUB
/**
* The IPC Subscription ID is used to define what ID the server will receive data on. The client IPC publication ID must match this value.
*/
var ipcSubscriptionId = EndPoint.IPC_HANDSHAKE_STREAM_ID_SUB
}
open class Configuration {

View File

@ -15,19 +15,23 @@
*/
package dorkbox.network
import dorkbox.netUtil.IP
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.aeron.AeronPoller
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverConnection
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.connection.connectionType.ConnectionRule
import dorkbox.network.connectionType.ConnectionRule
import dorkbox.network.exceptions.ServerException
import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.other.coroutines.SuspendWaiter
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.TimeoutException
import io.aeron.Aeron
import io.aeron.FragmentAssembler
import io.aeron.Image
import io.aeron.logbuffer.Header
@ -35,6 +39,9 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
import java.util.concurrent.CopyOnWriteArrayList
/**
@ -72,6 +79,13 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
@Volatile
private var bindAlreadyCalled = false
/**
* These are run in lock-step to shutdown/close the server. Afterwards, bind() can be called again
*/
private val shutdownPollWaiter = SuspendWaiter()
private val shutdownEventWaiter = SuspendWaiter()
/**
* Used for handshake connections
*/
@ -82,38 +96,43 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
*/
private val connectionRules = CopyOnWriteArrayList<ConnectionRule>()
internal val listenIPv4Address: InetAddress?
internal val listenIPv6Address: InetAddress?
init {
// have to do some basic validation of our configuration
config.listenIpAddress = config.listenIpAddress.toLowerCase()
require(config.listenIpAddress.isNotBlank()) { "Blank listen IP address, cannot continue"}
// localhost/loopback IP might not always be 127.0.0.1 or ::1
when (config.listenIpAddress) {
"loopback", "localhost", "lo", "" -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
else -> when {
IPv4.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv4.LOCALHOST.hostAddress
IPv6.isLoopback(config.listenIpAddress) -> config.listenIpAddress = IPv6.LOCALHOST.hostAddress
else -> config.listenIpAddress = "0.0.0.0" // we set this to "0.0.0.0" so that it is clear that we are trying to bind to that address.
// We want to listen on BOTH IPv4 and IPv6 (config option lets us configure this)
listenIPv4Address = if (!config.enableIPv4) {
null
} else {
when (config.listenIpAddress) {
"loopback", "localhost", "lo" -> IPv4.LOCALHOST
"0", "::", "0.0.0.0", "*" -> {
// this is the "wildcard" address. Windows has problems with this.
InetAddress.getByAddress("", byteArrayOf(0, 0, 0, 0))
}
else -> Inet4Address.getAllByName(config.listenIpAddress)[0]
}
}
// if we are IPv4 wildcard
if (config.listenIpAddress == "0.0.0.0") {
// this will also fixup windows!
config.listenIpAddress = IPv4.WILDCARD
}
if (IPv6.isValid(config.listenIpAddress)) {
// "[" and "]" are valid for ipv6 addresses... we want to make sure it is so
// if we are IPv6, the IP must be in '[]'
if (config.listenIpAddress.count { it == '[' } < 1 &&
config.listenIpAddress.count { it == ']' } < 1) {
config.listenIpAddress = """[${config.listenIpAddress}]"""
listenIPv6Address = if (!config.enableIPv6) {
null
} else {
when (config.listenIpAddress) {
"loopback", "localhost", "lo" -> IPv6.LOCALHOST
"0", "::", "0.0.0.0", "*" -> {
// this is the "wildcard" address. Windows has problems with this.
InetAddress.getByAddress("", byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
}
else -> Inet6Address.getAllByName(config.listenIpAddress)[0]
}
}
if (config.publicationPort <= 0) { throw ServerException("configuration port must be > 0") }
if (config.publicationPort >= 65535) { throw ServerException("configuration port must be < 65535") }
@ -133,6 +152,253 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
return ServerException(message, cause)
}
private fun getIpcPoller(aeron: Aeron, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIPC) {
val driver = IpcMediaDriverConnection(streamIdSubscription = config.ipcSubscriptionId,
streamId = config.ipcPublicationId,
sessionId = RESERVED_SESSION_ID_INVALID)
driver.buildServer(aeron, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processIpcHandshakeMessageServer(this@Server,
publication,
sessionId,
message,
aeron)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPC Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
private fun getIpv4Poller(aeron: Aeron, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIPv4) {
val driver = UdpMediaDriverConnection(address = listenIPv4Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID)
driver.buildServer(aeron, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
val clientAddress = IPv4.getByNameUnsafe(clientAddressString)
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processUdpHandshakeMessageServer(this@Server,
publication,
sessionId,
clientAddressString,
clientAddress,
message,
aeron,
false)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv4 Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
private fun getIpv6Poller(aeron: Aeron, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIPv6) {
val driver = UdpMediaDriverConnection(address = listenIPv6Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID)
driver.buildServer(aeron, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
val clientAddress = IPv6.getByName(clientAddressString)!!
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processUdpHandshakeMessageServer(this@Server,
publication,
sessionId,
clientAddressString,
clientAddress,
message,
aeron,
false)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv6 Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
private fun getIpv6WildcardPoller(aeron: Aeron, config: ServerConfiguration): AeronPoller {
val poller = if (config.enableIPv6) {
val driver = UdpMediaDriverConnection(address = listenIPv6Address!!,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID)
driver.buildServer(aeron, logger)
val publication = driver.publication
val subscription = driver.subscription
object : AeronPoller {
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
// this should never be null, because we are feeding it a valid IP address from aeron
// maybe IPv4, maybe IPv6!!!
val clientAddress = IP.getByName(clientAddressString)!!
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processUdpHandshakeMessageServer(this@Server,
publication,
sessionId,
clientAddressString,
clientAddress,
message,
aeron,
true)
}
override fun poll(): Int { return subscription.poll(handler, 1) }
override fun close() { driver.close() }
override fun serverInfo(): String { return driver.serverInfo() }
}
} else {
object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv6 Disabled" }
}
}
logger.info(poller.serverInfo())
return poller
}
/**
* Binds the server to AERON configuration
*/
@ -150,79 +416,32 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
config as ServerConfiguration
val ipcHandshakeDriver = IpcMediaDriverConnection(streamIdSubscription = IPC_HANDSHAKE_STREAM_ID_SUB,
streamId = IPC_HANDSHAKE_STREAM_ID_PUB,
sessionId = RESERVED_SESSION_ID_INVALID)
ipcHandshakeDriver.buildServer(aeron)
val ipcHandshakePublication = ipcHandshakeDriver.publication
val ipcHandshakeSubscription = ipcHandshakeDriver.subscription
val ipcPoller: AeronPoller = getIpcPoller(aeron, config)
// if we are binding to WILDCARD, then we have to do something special if BOTH IPv4 and IPv6 are enabled!
val isWildcard = listenIPv4Address == IPv4.WILDCARD || listenIPv6Address != IPv6.WILDCARD
val ipv4Poller: AeronPoller
val ipv6Poller: AeronPoller
val udpHandshakeDriver = UdpMediaDriverConnection(address = config.listenIpAddress,
publicationPort = config.publicationPort,
subscriptionPort = config.subscriptionPort,
streamId = UDP_HANDSHAKE_STREAM_ID,
sessionId = RESERVED_SESSION_ID_INVALID)
udpHandshakeDriver.buildServer(aeron)
val handshakePublication = udpHandshakeDriver.publication
val handshakeSubscription = udpHandshakeDriver.subscription
logger.info(ipcHandshakeDriver.serverInfo())
logger.info(udpHandshakeDriver.serverInfo())
/**
* Note:
* Reassembly has been shown to be minimal impact to latency. But not totally negligible. If the lowest latency is
* desired, then limiting message sizes to MTU size is a good practice.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*/
val udpHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo OR it will be aeron:ipc (if IPC, it will be a different handler!)
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = IPv4.toInt(clientAddressString)
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processUdpHandshakeMessageServer(this@Server,
handshakePublication,
sessionId,
clientAddressString,
clientAddress,
message,
aeron)
if (isWildcard) {
// IPv6 will bind to IPv4 wildcard as well!!
if (config.enableIPv4 && config.enableIPv6) {
ipv4Poller = object : AeronPoller {
override fun poll(): Int { return 0 }
override fun close() {}
override fun serverInfo(): String { return "IPv4 Disabled" }
}
ipv6Poller = getIpv6WildcardPoller(aeron, config)
} else {
// only 1 will be a real poller
ipv4Poller = getIpv4Poller(aeron, config)
ipv6Poller = getIpv6Poller(aeron, config)
}
} else {
ipv4Poller = getIpv4Poller(aeron, config)
ipv6Poller = getIpv6Poller(aeron, config)
}
val ipcHandshakeHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// for the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
val message = readHandshakeMessage(buffer, offset, length, header)
handshake.processIpcHandshakeMessageServer(this@Server,
ipcHandshakePublication,
sessionId,
message,
aeron)
}
actionDispatch.launch {
val pollIdleStrategy = config.pollIdleStrategy
@ -237,10 +456,11 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
// this checks to see if there are NEW clients on the handshake ports
pollCount += handshakeSubscription.poll(udpHandshakeHandler, 1)
pollCount += ipv4Poller.poll()
pollCount += ipv6Poller.poll()
// this checks to see if there are NEW clients via IPC
pollCount += ipcHandshakeSubscription.poll(ipcHandshakeHandler, 1)
pollCount += ipcPoller.poll()
// this manages existing clients (for cleanup + connection polling)
@ -291,12 +511,53 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
// 0 means we idle. >0 means reset and don't idle (because there are likely more poll events)
pollIdleStrategy.idle(pollCount)
}
} finally {
handshakePublication.close()
handshakeSubscription.close()
ipcHandshakePublication.close()
ipcHandshakeSubscription.close()
// we want to process **actual** close cleanup events on this thread as well, otherwise we will have threading problems
shutdownPollWaiter.doWait()
// we have to manually cleanup the connections and call server-notifyDisconnect because otherwise this will never get called
val jobs = mutableListOf<Job>()
// we want to clear all the connections FIRST (since we are shutting down)
val cons = mutableListOf<CONNECTION>()
connections.forEach { cons.add(it) }
connections.clear()
cons.forEach { connection ->
logger.error("${connection.id} cleanup")
// have to free up resources!
// NOTE: This can only occur on the polling dispatch thread!!
handshake.cleanup(connection)
// make sure the connection is closed (close can only happen once, so a duplicate call does nothing!)
connection.close()
// have to manually notify the server-listenerManager that this connection was closed
// if the connection was MANUALLY closed (via calling connection.close()), then the connection-listenermanager is
// instantly notified and on cleanup, the server-listenermanager is called
// NOTE: this must be the LAST thing happening!
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
val job = actionDispatch.launch {
listenerManager.notifyDisconnect(connection)
}
jobs.add(job)
}
// reset all of the handshake info
handshake.clear()
// when we close a client or a server, we want to make sure that ALL notifications are finished.
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
jobs.forEach { it.join() }
} finally {
ipv4Poller.close()
ipv6Poller.close()
ipcPoller.close()
// finish closing -- this lets us make sure that we don't run into race conditions on the thread that calls close()
shutdownEventWaiter.doNotify()
}
}
}
@ -362,41 +623,13 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
override fun close0() {
bindAlreadyCalled = false
// when we call close, it will shutdown the polling mechanism, so we have to manually cleanup the connections and call server-notifyDisconnect
// on them
// when we call close, it will shutdown the polling mechanism then wait for us to tell it to cleanup connections.
//
// Aeron + the Media Driver will have already been shutdown at this point.
runBlocking {
val jobs = mutableListOf<Job>()
// we want to clear all the connections FIRST (since we are shutting down)
val cons = mutableListOf<CONNECTION>()
connections.forEach { cons.add(it) }
connections.clear()
cons.forEach { connection ->
logger.error("${connection.id} cleanup")
// have to free up resources!
handshake.cleanup(connection)
// make sure the connection is closed (close can only happen once, so a duplicate call does nothing!)
connection.close()
// have to manually notify the server-listenerManager that this connection was closed
// if the connection was MANUALLY closed (via calling connection.close()), then the connection-listenermanager is
// instantly notified and on cleanup, the server-listenermanager is called
// NOTE: this must be the LAST thing happening!
// this always has to be on a new dispatch, otherwise we can have weird logic loops if we reconnect within a disconnect callback
val job = actionDispatch.launch {
listenerManager.notifyDisconnect(connection)
}
jobs.add(job)
}
// when we close a client or a server, we want to make sure that ALL notifications are finished.
// when it's just a connection getting closed, we don't care about this. We only care when it's "global" shutdown
jobs.forEach { it.join() }
// These are run in lock-step
shutdownPollWaiter.doNotify()
shutdownEventWaiter.doWait()
}
}

View File

@ -0,0 +1,8 @@
package dorkbox.network.aeron
internal interface AeronPoller {
fun poll(): Int
fun close()
fun serverInfo(): String
}

View File

@ -13,19 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
@file:Suppress("DuplicatedCode")
import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.server.ServerException
package dorkbox.network.aeron
import dorkbox.network.connection.EndPoint
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.Aeron
import io.aeron.ChannelUriStringBuilder
import io.aeron.Publication
import io.aeron.Subscription
import kotlinx.coroutines.delay
import mu.KLogger
import java.net.Inet4Address
import java.net.InetAddress
interface MediaDriverConnection : AutoCloseable {
val address: String
val address: InetAddress?
val streamId: Int
val sessionId: Int
@ -38,9 +42,8 @@ interface MediaDriverConnection : AutoCloseable {
val isReliable: Boolean
@Throws(ClientTimedOutException::class)
suspend fun buildClient(aeron: Aeron)
fun buildServer(aeron: Aeron)
suspend fun buildClient(aeron: Aeron, logger: KLogger)
fun buildServer(aeron: Aeron, logger: KLogger)
fun clientInfo() : String
fun serverInfo() : String
@ -49,7 +52,7 @@ interface MediaDriverConnection : AutoCloseable {
/**
* For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER
*/
class UdpMediaDriverConnection(override val address: String,
class UdpMediaDriverConnection(override val address: InetAddress,
override val publicationPort: Int,
override val subscriptionPort: Int,
override val streamId: Int,
@ -62,6 +65,19 @@ class UdpMediaDriverConnection(override val address: String,
var success: Boolean = false
val addressString: String by lazy {
if (address is Inet4Address) {
address.hostAddress
} else {
// IPv6 requires the address to be bracketed by [...]
val host = address.hostAddress
if (host[0] == '[') {
host
} else {
"[${address.hostAddress}]"
}
}
}
private fun uri(): ChannelUriStringBuilder {
val builder = ChannelUriStringBuilder().reliable(isReliable).media("udp")
@ -73,27 +89,33 @@ class UdpMediaDriverConnection(override val address: String,
}
@Suppress("DuplicatedCode")
override suspend fun buildClient(aeron: Aeron) {
if (address.isEmpty()) {
throw ClientException("Invalid address : '$address'")
}
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
.controlEndpoint("$address:$subscriptionPort")
.controlMode("dynamic")
override suspend fun buildClient(aeron: Aeron, logger: KLogger) {
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri()
.endpoint("$address:$publicationPort")
.endpoint("$addressString:$publicationPort")
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
.controlEndpoint("$addressString:$subscriptionPort")
.controlMode("dynamic")
if (logger.isTraceEnabled) {
if (address is Inet4Address) {
logger.trace("IPV4 client pub URI: ${publicationUri.build()}")
logger.trace("IPV4 client sub URI: ${subscriptionUri.build()}")
} else {
logger.trace("IPV6 client pub URI: ${publicationUri.build()}")
logger.trace("IPV6 client sub URI: ${subscriptionUri.build()}")
}
}
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
val subscription = aeron.addSubscription(subscriptionUri.build(), streamId)
val publication = aeron.addPublication(publicationUri.build(), streamId)
val subscription = aeron.addSubscription(subscriptionUri.build(), streamId)
var success = false
@ -139,27 +161,33 @@ class UdpMediaDriverConnection(override val address: String,
this.publication = publication
}
override fun buildServer(aeron: Aeron) {
if (address.isEmpty()) {
throw ServerException("Invalid address. It is empty!")
}
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
.endpoint("$address:$subscriptionPort")
override fun buildServer(aeron: Aeron, logger: KLogger) {
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri()
.controlEndpoint("$address:$publicationPort")
.controlEndpoint("$addressString:$publicationPort")
.controlMode("dynamic")
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
.endpoint("$addressString:$subscriptionPort")
if (logger.isTraceEnabled) {
if (address is Inet4Address) {
logger.trace("IPV4 server pub URI: ${publicationUri.build()}")
logger.trace("IPV4 server sub URI: ${subscriptionUri.build()}")
} else {
logger.trace("IPV6 server pub URI: ${publicationUri.build()}")
logger.trace("IPV6 server sub URI: ${subscriptionUri.build()}")
}
}
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
subscription = aeron.addSubscription(subscriptionUri.build(), streamId)
publication = aeron.addPublication(publicationUri.build(), streamId)
subscription = aeron.addSubscription(subscriptionUri.build(), streamId)
}
@ -187,7 +215,7 @@ class UdpMediaDriverConnection(override val address: String,
}
override fun toString(): String {
return "$address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
return "$addressString [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
}
}
@ -200,8 +228,8 @@ class IpcMediaDriverConnection(override val streamId: Int,
private val connectionTimeoutMS: Long = 30_000,
) : MediaDriverConnection {
override val address: InetAddress? = null
override val isReliable = true
override val address = "ipc"
override val subscriptionPort = 0
override val publicationPort = 0
@ -220,19 +248,24 @@ class IpcMediaDriverConnection(override val streamId: Int,
}
@Throws(ClientTimedOutException::class)
override suspend fun buildClient(aeron: Aeron) {
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
override suspend fun buildClient(aeron: Aeron, logger: KLogger) {
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri()
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
if (logger.isTraceEnabled) {
logger.trace("IPC client pub URI: ${publicationUri.build()}")
logger.trace("IPC server sub URI: ${subscriptionUri.build()}")
}
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
val subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription)
val publication = aeron.addPublication(publicationUri.build(), streamId)
val subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription)
var success = false
@ -278,25 +311,31 @@ class IpcMediaDriverConnection(override val streamId: Int,
this.publication = publication
}
override fun buildServer(aeron: Aeron) {
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
override fun buildServer(aeron: Aeron, logger: KLogger) {
// Create a publication with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
val publicationUri = uri()
// Create a subscription with a control port (for dynamic MDC) at the given address and port, using the given stream ID.
val subscriptionUri = uri()
if (logger.isTraceEnabled) {
logger.trace("IPC server pub URI: ${publicationUri.build()}")
logger.trace("IPC server sub URI: ${subscriptionUri.build()}")
}
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription)
publication = aeron.addPublication(publicationUri.build(), streamId)
subscription = aeron.addSubscription(subscriptionUri.build(), streamIdSubscription)
}
override fun clientInfo() : String {
return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
"[$sessionId] aeron connection established to [$streamIdSubscription|$streamId]"
} else {
"Connecting IPC with handshake to [$streamIdSubscription|$streamId]"
"Connecting handshake to IPC [$streamIdSubscription|$streamId]"
}
}
@ -304,7 +343,7 @@ class IpcMediaDriverConnection(override val streamId: Int,
return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
"[$sessionId] IPC listening on [$streamIdSubscription|$streamId] "
} else {
"IPC listening with handshake on [$streamIdSubscription|$streamId]"
"Listening handshake on IPC [$streamIdSubscription|$streamId]"
}
}

View File

@ -15,10 +15,13 @@
*/
package dorkbox.network.connection
import dorkbox.netUtil.IPv4
import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.connection.ping.PingFuture
import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverConnection
import dorkbox.network.handshake.ConnectionCounts
import dorkbox.network.handshake.RandomIdAllocator
import dorkbox.network.ping.Ping
import dorkbox.network.ping.PingFuture
import dorkbox.network.ping.PingMessage
import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.TimeoutException
@ -33,8 +36,8 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.agrona.DirectBuffer
import org.agrona.collections.Int2IntCounterMap
import java.io.IOException
import java.net.InetAddress
import java.util.concurrent.TimeUnit
/**
@ -62,14 +65,14 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
val id: Int
/**
* the remote address, as a string. Will be "ipc" for IPC connections
* the remote address, as a string. Will be null for IPC connections
*/
val remoteAddress: String
val remoteAddress: InetAddress?
/**
* the remote address, as an integer. Can be 0 for IPC connections
* the remote address, as a string. Will be "ipc" for IPC connections
*/
private val remoteAddressInt: Int
val remoteAddressString: String
/**
* @return true if this connection is an IPC connection
@ -125,7 +128,8 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// a record of how many messages are in progress of being sent. When closing the connection, this number must be 0
private val messagesInProgress = atomic(0)
val toString0: () -> String
// we customize the toString() value for this connection, and it's just better to cache it's value (since it's a modestly complex string)
private val toString0: String
init {
val mediaDriverConnection = connectionParameters.mediaDriverConnection
@ -141,16 +145,19 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
streamId = 0 // this is because with IPC, we have stream sub/pub (which are replaced as port sub/pub)
subscriptionPort = mediaDriverConnection.streamIdSubscription
publicationPort = mediaDriverConnection.streamId
remoteAddressInt = 0
remoteAddressString = "ipc"
toString0 = { "[$id] IPC [$subscriptionPort|$publicationPort]" }
toString0 = "[$id] IPC [$subscriptionPort|$publicationPort]"
} else {
mediaDriverConnection as UdpMediaDriverConnection
streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server!
subscriptionPort = mediaDriverConnection.subscriptionPort
publicationPort = mediaDriverConnection.publicationPort
remoteAddressInt = IPv4.toInt(mediaDriverConnection.address)
toString0 = { "[$id] $remoteAddress [$publicationPort|$subscriptionPort]" }
remoteAddressString = mediaDriverConnection.addressString
toString0 = "[$id] $remoteAddressString [$publicationPort|$subscriptionPort]"
}
@ -412,7 +419,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
//
//
override fun toString(): String {
return toString0()
return toString0
}
override fun hashCode(): Int {
@ -435,13 +442,14 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
}
// cleans up the connection information
fun cleanup(connectionsPerIpCounts: Int2IntCounterMap, sessionIdAllocator: RandomIdAllocator, streamIdAllocator: RandomIdAllocator) {
internal fun cleanup(connectionsPerIpCounts: ConnectionCounts, sessionIdAllocator: RandomIdAllocator, streamIdAllocator: RandomIdAllocator) {
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
if (isIpc) {
sessionIdAllocator.free(subscriptionPort)
sessionIdAllocator.free(publicationPort)
streamIdAllocator.free(streamId)
} else {
connectionsPerIpCounts.getAndDecrement(remoteAddressInt)
connectionsPerIpCounts.decrementSlow(remoteAddress!!)
sessionIdAllocator.free(id)
streamIdAllocator.free(streamId)
}

View File

@ -15,6 +15,8 @@
*/
package dorkbox.network.connection
import dorkbox.network.aeron.MediaDriverConnection
data class ConnectionParams<C : Connection>(val endPoint: EndPoint<C>,
val mediaDriverConnection: MediaDriverConnection,
val publicKeyValidation: PublicKeyValidationState)

View File

@ -15,7 +15,7 @@
*/
package dorkbox.network.connection
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IP
import dorkbox.network.Configuration
import dorkbox.network.handshake.ClientConnectionInfo
import dorkbox.network.other.CryptoEccNative
@ -26,6 +26,7 @@ import dorkbox.util.Sys
import dorkbox.util.entropy.Entropy
import dorkbox.util.exceptions.SecurityException
import mu.KLogger
import java.net.InetAddress
import java.security.KeyFactory
import java.security.KeyPair
import java.security.KeyPairGenerator
@ -122,21 +123,21 @@ internal class CryptoManagement(val logger: KLogger,
/**
* If the key does not match AND we have disabled remote key validation, then metachannel.changedRemoteKey = true. OTHERWISE, key validation is REQUIRED!
* If the key does not match AND we have disabled remote key validation -- key validation is REQUIRED!
*
* @return true if all is OK (the remote address public key matches the one saved or we disabled remote key validation.)
* false if we should abort
*/
internal fun validateRemoteAddress(remoteAddress: Int, publicKey: ByteArray?): PublicKeyValidationState {
internal fun validateRemoteAddress(remoteAddress: InetAddress, publicKey: ByteArray?): PublicKeyValidationState {
if (publicKey == null) {
logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}! It was null (and should not have been)")
logger.error("Error validating public key for ${IP.toString(remoteAddress)}! It was null (and should not have been)")
return PublicKeyValidationState.INVALID
}
try {
val savedPublicKey = settingsStore.getRegisteredServerKey(remoteAddress)
if (savedPublicKey == null) {
logger.info("Adding new remote IP address key for ${IPv4.toString(remoteAddress)} : ${Sys.bytesToHex(publicKey)}")
logger.info("Adding new signature for ${IP.toString(remoteAddress)} : ${Sys.bytesToHex(publicKey)}")
settingsStore.addRegisteredServerKey(remoteAddress, publicKey)
} else {
@ -144,18 +145,18 @@ internal class CryptoManagement(val logger: KLogger,
if (!publicKey.contentEquals(savedPublicKey)) {
return if (enableRemoteSignatureValidation) {
// keys do not match, abort!
logger.error("The public key for remote connection ${IPv4.toString(remoteAddress)} does not match. Denying connection attempt")
logger.error("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Denying connection attempt")
PublicKeyValidationState.INVALID
}
else {
logger.warn("The public key for remote connection ${IPv4.toString(remoteAddress)} does not match. Permitting connection attempt.")
logger.warn("The public key for remote connection ${IP.toString(remoteAddress)} does not match. Permitting connection attempt.")
PublicKeyValidationState.TAMPERED
}
}
}
} catch (e: SecurityException) {
// keys do not match, abort!
logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}!", e)
logger.error("Error validating public key for ${IP.toString(remoteAddress)}!", e)
return PublicKeyValidationState.INVALID
}

View File

@ -20,10 +20,11 @@ import dorkbox.network.Configuration
import dorkbox.network.Server
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.exceptions.MessageNotRegisteredException
import dorkbox.network.handshake.HandshakeMessage
import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.other.coroutines.SuspendWaiter
import dorkbox.network.ping.PingMessage
import dorkbox.network.rmi.RmiManagerConnections
import dorkbox.network.rmi.RmiManagerGlobal
import dorkbox.network.rmi.messages.RmiMessage
@ -145,8 +146,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal val rmiGlobalSupport = RmiManagerGlobal<CONNECTION>(logger, actionDispatch, config.serialization)
init {
logger.error("NETWORK STACK IS ONLY IPV4 AT THE MOMENT. IPV6 is in progress!")
runBlocking {
// our default onError handler. All error messages go though this
listenerManager.onError { throwable ->

View File

@ -1,24 +0,0 @@
/*
* Copyright 2020 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
enum class MediaDriverType(private val type: String) {
IPC("ipc"), UDP("udp");
override fun toString(): String {
return type
}
}

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.connectionType
package dorkbox.network.connectionType
import dorkbox.network.handshake.UpgradeType

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.connectionType
package dorkbox.network.connectionType
import java.math.BigInteger
import java.net.Inet4Address

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.connectionType
package dorkbox.network.connectionType
import java.net.InetSocketAddress

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.server
package dorkbox.network.exceptions
/**
* A session/stream could not be allocated.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.client
package dorkbox.network.exceptions
/**
* The type of exceptions raised by the client.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.client
package dorkbox.network.exceptions
/**
* The server rejected this client when it tried to connect.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.client
package dorkbox.network.exceptions
/**
* The client timed out when it attempted to connect to the server.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
package dorkbox.network.exceptions
/**
* thrown when a message is received, and does not have any registered 'onMessage' handlers.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.server
package dorkbox.network.exceptions
/**
* A port could not be allocated.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.server
package dorkbox.network.exceptions
/**
* The type of exceptions raised by the server.

View File

@ -16,12 +16,12 @@
package dorkbox.network.handshake
import dorkbox.network.Configuration
import dorkbox.network.aeron.client.ClientException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.MediaDriverConnection
import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.MediaDriverConnection
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header

View File

@ -0,0 +1,30 @@
package dorkbox.network.handshake
import org.agrona.collections.Object2IntHashMap
import java.net.InetAddress
/**
*
*/
internal class ConnectionCounts {
private val connectionsPerIpCounts = Object2IntHashMap<InetAddress>(-1)
fun get(inetAddress: InetAddress): Int {
return connectionsPerIpCounts.getOrPut(inetAddress) { 0 }
}
fun increment(inetAddress: InetAddress, currentCount: Int) {
connectionsPerIpCounts[inetAddress] = currentCount + 1
}
fun decrement(inetAddress: InetAddress, currentCount: Int) {
connectionsPerIpCounts[inetAddress] = currentCount - 1
}
fun decrementSlow(inetAddress: InetAddress) {
if (connectionsPerIpCounts.containsKey(inetAddress)) {
val defaultVal = connectionsPerIpCounts.getValue(inetAddress)
connectionsPerIpCounts[inetAddress] = defaultVal - 1
}
}
}

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.server
package dorkbox.network.handshake
import org.agrona.collections.IntArrayList

View File

@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron.server
package dorkbox.network.handshake
import dorkbox.network.exceptions.AllocationException
import org.agrona.collections.IntHashSet
import java.security.SecureRandom
/**
* An allocator for session IDs. The allocator randomly selects values from
* the given range `[min, max]` and will not return a previously-returned value `x`
* until `x` has been freed with `{ SessionAllocator#free(int)}.
* </p>
* An allocator for session IDs.
*
* The allocator randomly selects values from the given range `[min, max]` and will not return a previously-returned value `x`
* until `x` has been freed with `{ SessionAllocator#free(int)}.
*
* <p>
* This implementation uses storage proportional to the number of currently-allocated
* values. Allocation time is bounded by { max - min}, will be { O(1)}
* with no allocated values, and will increase to { O(n)} as the number
* of allocated values approached { max - min}.
* </p>`
*
* NOTE: THIS IS NOT THREAD SAFE!
*
* @param min The minimum session ID (inclusive)
* @param max The maximum session ID (exclusive)
@ -55,7 +56,6 @@ class RandomIdAllocator(private val min: Int, max: Int) {
*
* @throws AllocationException If there are no non-allocated sessions left
*/
@Throws(AllocationException::class)
fun allocate(): Int {
if (used.size == maxAssignments) {
throw AllocationException("No session IDs left to allocate")
@ -81,4 +81,12 @@ class RandomIdAllocator(private val min: Int, max: Int) {
fun free(session: Int) {
used.remove(session)
}
/**
* Removes all used sessions from the internal data structures
*/
fun clear() {
used.clear()
}
}

View File

@ -21,24 +21,24 @@ import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import dorkbox.network.Server
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.client.ClientTimedOutException
import dorkbox.network.aeron.server.AllocationException
import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.aeron.IpcMediaDriverConnection
import dorkbox.network.aeron.UdpMediaDriverConnection
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpcMediaDriverConnection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.PublicKeyValidationState
import dorkbox.network.connection.UdpMediaDriverConnection
import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ClientRejectedException
import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.network.exceptions.ServerException
import io.aeron.Aeron
import io.aeron.Publication
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger
import org.agrona.collections.Int2IntCounterMap
import java.net.Inet4Address
import java.net.InetAddress
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.write
@ -53,7 +53,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
private val listenerManager: ListenerManager<CONNECTION>) {
private val pendingConnectionsLock = ReentrantReadWriteLock()
private val pendingConnections: Cache<Int,CONNECTION> = Caffeine.newBuilder()
private val pendingConnections: Cache<Int, CONNECTION> = Caffeine.newBuilder()
.expireAfterAccess(config.connectionCloseTimeoutInSeconds.toLong(), TimeUnit.SECONDS)
.removalListener(RemovalListener<Any?, Any?> { _, value, cause ->
if (cause == RemovalCause.EXPIRED) {
@ -67,17 +67,17 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
}
}).build()
private val connectionsPerIpCounts = Int2IntCounterMap(0)
private val connectionsPerIpCounts = ConnectionCounts()
// guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!)
private val sessionIdAllocator = RandomIdAllocator(EndPoint.RESERVED_SESSION_ID_LOW,
EndPoint.RESERVED_SESSION_ID_HIGH)
private val sessionIdAllocator = RandomIdAllocator(EndPoint.RESERVED_SESSION_ID_LOW, EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
/**
* @return true if we should continue parsing the incoming message, false if we should abort
*/
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
private fun validateMessageTypeAndDoPending(server: Server<CONNECTION>,
handshakePublication: Publication,
message: Any?,
@ -126,11 +126,17 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
/**
* @return true if we should continue parsing the incoming message, false if we should abort
*/
private fun validateConnectionInfo(server: Server<CONNECTION>,
handshakePublication: Publication,
config: ServerConfiguration,
clientAddressString: String,
clientAddress: Int): Boolean {
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
private fun validateUdpConnectionInfo(server: Server<CONNECTION>,
handshakePublication: Publication,
config: ServerConfiguration,
clientAddressString: String,
clientAddress: InetAddress): Boolean {
if (clientAddress.isLoopbackAddress) {
// we do not want to limit loopback addresses
return true
}
try {
// VALIDATE:: Check to see if there are already too many clients connected.
@ -143,11 +149,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return false
}
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
val currentCountForIp = connectionsPerIpCounts.get(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress)
connectionsPerIpCounts.decrement(clientAddress, currentCountForIp)
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
server.actionDispatch.launch {
@ -156,6 +163,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return false
}
connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
} catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
server.actionDispatch.launch {
@ -168,7 +176,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
}
// note: CANNOT be called in action dispatch
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
fun processIpcHandshakeMessageServer(server: Server<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
@ -243,7 +251,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionTimeoutMS = 0)
// we have to construct how the connection will communicate!
clientConnection.buildServer(aeron)
clientConnection.buildServer(aeron, logger)
logger.info {
"[${clientConnection.sessionId}] aeron IPC connection established to $clientConnection"
@ -264,8 +272,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
listenerManager.notifyError(connection, exception)
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
}
return
@ -320,14 +327,15 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
}
// note: CANNOT be called in action dispatch
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
fun processUdpHandshakeMessageServer(server: Server<CONNECTION>,
handshakePublication: Publication,
sessionId: Int,
clientAddressString: String,
clientAddress: Int,
clientAddress: InetAddress,
message: Any?,
aeron: Aeron) {
aeron: Aeron,
isIpv6Wildcard: Boolean) {
if (!validateMessageTypeAndDoPending(server, handshakePublication, message, sessionId, clientAddressString)) {
return
@ -345,7 +353,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return
}
if (!validateConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress)) {
if (!validateUdpConnectionInfo(server, handshakePublication, config, clientAddressString, clientAddress)) {
return
}
@ -363,7 +371,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionSessionId = sessionIdAllocator.allocate()
} catch (e: AllocationException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
connectionsPerIpCounts.decrementSlow(clientAddress)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!"))
server.actionDispatch.launch {
@ -378,7 +386,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionStreamId = streamIdAllocator.allocate()
} catch (e: AllocationException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!"))
@ -388,8 +396,6 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
return
}
val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box?
// the pub/sub do not necessarily have to be the same. The can be ANY port
val publicationPort = config.publicationPort
val subscriptionPort = config.subscriptionPort
@ -398,16 +404,29 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// create a new connection. The session ID is encrypted.
try {
// connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = UdpMediaDriverConnection(serverAddress,
publicationPort,
subscriptionPort,
connectionStreamId,
connectionSessionId,
0,
message.isReliable)
// the client address WILL BE either IPv4 or IPv6
val clientConnection = if (clientAddress is Inet4Address && !isIpv6Wildcard) {
UdpMediaDriverConnection(server.listenIPv4Address!!,
publicationPort,
subscriptionPort,
connectionStreamId,
connectionSessionId,
0,
message.isReliable)
} else {
// wildcard is SPECIAL, in that if we bind wildcard, it will ALSO bind to IPv4, so we can't bind both!
UdpMediaDriverConnection(server.listenIPv6Address!!,
publicationPort,
subscriptionPort,
connectionStreamId,
connectionSessionId,
0,
message.isReliable)
}
// we have to construct how the connection will communicate!
clientConnection.buildServer(aeron)
clientConnection.buildServer(aeron, logger)
logger.info {
"Creating new connection from $clientConnection"
@ -420,7 +439,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
val permitConnection = listenerManager.notifyFilter(connection)
if (!permitConnection) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
@ -429,8 +448,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
listenerManager.notifyError(connection, exception)
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
}
return
@ -470,7 +488,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
}
} catch (e: Exception) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
@ -482,7 +500,17 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
* Free up resources from the closed connection
*/
fun cleanup(connection: CONNECTION) {
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
connection.cleanup(connectionsPerIpCounts, sessionIdAllocator, streamIdAllocator)
}
/**
* Reset and clear all connection information
*/
fun clear() {
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
sessionIdAllocator.clear()
streamIdAllocator.clear()
pendingConnections.invalidateAll()
}
}

View File

@ -13,7 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
package dorkbox.network.ping
import dorkbox.network.connection.Connection
interface Ping {
/**

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.ping
package dorkbox.network.ping
import java.io.IOException

View File

@ -13,11 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.ping
package dorkbox.network.ping
import dorkbox.network.connection.Connection
import dorkbox.network.connection.Ping
import dorkbox.network.connection.PingListener
import java.util.concurrent.atomic.AtomicInteger
class PingFuture internal constructor() : Ping {

View File

@ -13,7 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection;
package dorkbox.network.ping;
import dorkbox.network.connection.Connection;
// note that we specifically DO NOT implement equals/hashCode, because we cannot create two separate
// listeners that are somehow equal to each other.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.ping
package dorkbox.network.ping
/**
* Internal message to determine round trip time.

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.ping
package dorkbox.network.ping
import dorkbox.network.connection.Connection

View File

@ -24,11 +24,6 @@ class DB_Server {
* The storage key used to save all server connections
*/
val STORAGE_KEY = StorageKey("servers")
/**
* Address 0.0.0.0/32 may be used as a source address for this host on this network.
*/
const val IP_SELF = 0
}

View File

@ -18,6 +18,7 @@ package dorkbox.network.storage
import dorkbox.network.serialization.Serialization
import dorkbox.util.exceptions.SecurityException
import dorkbox.util.storage.Storage
import java.net.InetAddress
import java.security.SecureRandom
class NullSettingsStore : SettingsStore() {
@ -54,17 +55,17 @@ class NullSettingsStore : SettingsStore() {
}
@Throws(SecurityException::class)
override fun getRegisteredServerKey(hostAddress: Int): ByteArray {
override fun getRegisteredServerKey(hostAddress: InetAddress): ByteArray {
TODO("not impl")
}
@Throws(SecurityException::class)
override fun addRegisteredServerKey(hostAddress: Int, publicKey: ByteArray) {
override fun addRegisteredServerKey(hostAddress: InetAddress, publicKey: ByteArray) {
TODO("not impl")
}
@Throws(SecurityException::class)
override fun removeRegisteredServerKey(hostAddress: Int): Boolean {
override fun removeRegisteredServerKey(hostAddress: InetAddress): Boolean {
return true
}

View File

@ -15,10 +15,13 @@
*/
package dorkbox.network.storage
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.serialization.Serialization
import dorkbox.util.storage.Storage
import org.agrona.collections.Int2ObjectHashMap
import org.agrona.collections.Object2NullableObjectHashMap
import java.net.InetAddress
import java.security.SecureRandom
/**
@ -26,7 +29,15 @@ import java.security.SecureRandom
*/
class PropertyStore : SettingsStore() {
private lateinit var storage: Storage
private lateinit var servers: Int2ObjectHashMap<DB_Server>
private lateinit var servers: Object2NullableObjectHashMap<InetAddress, DB_Server>
/**
* Address 0.0.0.0 or ::0 may be used as a source address for this host on this network.
*
* Because we assigned BOTH to the same thing, it doesn't matter which one we use
*/
private val ipv4Host = IPv4.WILDCARD
private val ipv6Host = IPv6.WILDCARD
/**
* Method of preference for creating/getting this connection store.
@ -35,13 +46,20 @@ class PropertyStore : SettingsStore() {
*/
override fun init(serializationManager: Serialization, storage: Storage) {
this.storage = storage
servers = this.storage.get(DB_Server.STORAGE_KEY, Int2ObjectHashMap())
servers = this.storage.get(DB_Server.STORAGE_KEY, Object2NullableObjectHashMap())
// this will always be null and is here to help people that copy/paste code
var localServer = servers[DB_Server.IP_SELF]
var localServer = servers[ipv4Host]
if (localServer == null) {
localServer = DB_Server()
servers[DB_Server.IP_SELF] = localServer
servers[ipv4Host] = localServer
// have to always specify what we are saving
this.storage.put(DB_Server.STORAGE_KEY, servers)
}
if (servers[ipv6Host] == null) {
servers[ipv6Host] = localServer
// have to always specify what we are saving
this.storage.put(DB_Server.STORAGE_KEY, servers)
@ -54,7 +72,7 @@ class PropertyStore : SettingsStore() {
@Synchronized
override fun getPrivateKey(): ByteArray? {
checkAccess(CryptoManagement::class.java)
return servers[DB_Server.IP_SELF]!!.privateKey
return servers[ipv4Host]!!.privateKey
}
/**
@ -63,7 +81,7 @@ class PropertyStore : SettingsStore() {
@Synchronized
override fun savePrivateKey(serverPrivateKey: ByteArray) {
checkAccess(CryptoManagement::class.java)
servers[DB_Server.IP_SELF]!!.privateKey = serverPrivateKey
servers[ipv4Host]!!.privateKey = serverPrivateKey
// have to always specify what we are saving
storage.put(DB_Server.STORAGE_KEY, servers)
@ -74,7 +92,7 @@ class PropertyStore : SettingsStore() {
*/
@Synchronized
override fun getPublicKey(): ByteArray? {
return servers[DB_Server.IP_SELF]!!.publicKey
return servers[ipv4Host]!!.publicKey
}
/**
@ -83,7 +101,7 @@ class PropertyStore : SettingsStore() {
@Synchronized
override fun savePublicKey(serverPublicKey: ByteArray) {
checkAccess(CryptoManagement::class.java)
servers[DB_Server.IP_SELF]!!.publicKey = serverPublicKey
servers[ipv4Host]!!.publicKey = serverPublicKey
// have to always specify what we are saving
storage.put(DB_Server.STORAGE_KEY, servers)
@ -94,7 +112,7 @@ class PropertyStore : SettingsStore() {
*/
@Synchronized
override fun getSalt(): ByteArray {
val localServer = servers[DB_Server.IP_SELF]
val localServer = servers[ipv4Host]
var salt = localServer!!.salt
// we don't care who gets the server salt
@ -118,7 +136,7 @@ class PropertyStore : SettingsStore() {
* Simple, property based method to getting a connected computer by host IP address
*/
@Synchronized
override fun getRegisteredServerKey(hostAddress: Int): ByteArray? {
override fun getRegisteredServerKey(hostAddress: InetAddress): ByteArray? {
return servers[hostAddress]?.publicKey
}
@ -126,7 +144,7 @@ class PropertyStore : SettingsStore() {
* Saves a connected computer by host IP address and public key
*/
@Synchronized
override fun addRegisteredServerKey(hostAddress: Int, publicKey: ByteArray) {
override fun addRegisteredServerKey(hostAddress: InetAddress, publicKey: ByteArray) {
// checkAccess(RegistrationWrapper.class);
var db_server = servers[hostAddress]
if (db_server == null) {
@ -144,7 +162,7 @@ class PropertyStore : SettingsStore() {
* Deletes a registered computer by host IP address
*/
@Synchronized
override fun removeRegisteredServerKey(hostAddress: Int): Boolean {
override fun removeRegisteredServerKey(hostAddress: InetAddress): Boolean {
// checkAccess(RegistrationWrapper.class);
val db_server = servers.remove(hostAddress)

View File

@ -20,6 +20,7 @@ import dorkbox.util.bytes.ByteArrayWrapper
import dorkbox.util.exceptions.SecurityException
import dorkbox.util.storage.Storage
import org.slf4j.LoggerFactory
import java.net.InetAddress
import java.util.*
/**
@ -77,13 +78,13 @@ abstract class SettingsStore : AutoCloseable {
* Gets a previously registered computer by host IP address
*/
@Throws(SecurityException::class)
abstract fun getRegisteredServerKey(hostAddress: Int): ByteArray?
abstract fun getRegisteredServerKey(hostAddress: InetAddress): ByteArray?
/**
* Saves a registered computer by host IP address and public key
*/
@Throws(SecurityException::class)
abstract fun addRegisteredServerKey(hostAddress: Int, publicKey: ByteArray)
abstract fun addRegisteredServerKey(hostAddress: InetAddress, publicKey: ByteArray)
/**
* Deletes a registered computer by host IP address
@ -91,7 +92,7 @@ abstract class SettingsStore : AutoCloseable {
* @return true if successful, false if there were problems (or it didn't exist)
*/
@Throws(SecurityException::class)
abstract fun removeRegisteredServerKey(hostAddress: Int): Boolean
abstract fun removeRegisteredServerKey(hostAddress: InetAddress): Boolean
/**
* Take the proper steps to close the storage system.

View File

@ -0,0 +1,918 @@
/*
* Copyright 2014-2020 Real Logic Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// REMOVE WHEN ISSUE https://github.com/real-logic/aeron/issues/1057 is resolved!
package io.aeron.driver.media;
import static io.aeron.driver.media.NetworkUtil.filterBySubnet;
import static io.aeron.driver.media.NetworkUtil.findAddressOnInterface;
import static io.aeron.driver.media.NetworkUtil.getProtocolFamily;
import static java.lang.System.lineSeparator;
import static java.net.InetAddress.getByAddress;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.ProtocolFamily;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import org.agrona.BitUtil;
import org.agrona.LangUtil;
import io.aeron.ChannelUri;
import io.aeron.CommonContext;
import io.aeron.driver.DefaultNameResolver;
import io.aeron.driver.NameResolver;
import io.aeron.driver.exceptions.InvalidChannelException;
/**
* The media configuration for Aeron UDP channels as an instantiation of the socket addresses for a {@link ChannelUri}.
*
* @see ChannelUri
* @see io.aeron.ChannelUriStringBuilder
*/
public final class UdpChannel
{
private static final AtomicInteger UNIQUE_CANONICAL_FORM_VALUE = new AtomicInteger();
private final boolean isManualControlMode;
private final boolean isDynamicControlMode;
private final boolean hasExplicitControl;
private final boolean hasExplicitEndpoint;
private final boolean isMulticast;
private final boolean hasMulticastTtl;
private final boolean hasTag;
private final int multicastTtl;
private final long tag;
private final InetSocketAddress remoteData;
private final InetSocketAddress localData;
private final InetSocketAddress remoteControl;
private final InetSocketAddress localControl;
private final String uriStr;
private final String canonicalForm;
private final NetworkInterface localInterface;
private final ProtocolFamily protocolFamily;
private final ChannelUri channelUri;
private UdpChannel(final Context context)
{
isManualControlMode = context.isManualControlMode;
isDynamicControlMode = context.isDynamicControlMode;
hasExplicitEndpoint = context.hasExplicitEndpoint;
hasExplicitControl = context.hasExplicitControl;
isMulticast = context.isMulticast;
hasTag = context.hasTagId;
tag = context.tagId;
hasMulticastTtl = context.hasMulticastTtl;
multicastTtl = context.multicastTtl;
remoteData = context.remoteData;
localData = context.localData;
remoteControl = context.remoteControl;
localControl = context.localControl;
uriStr = context.uriStr;
canonicalForm = context.canonicalForm;
localInterface = context.localInterface;
protocolFamily = context.protocolFamily;
channelUri = context.channelUri;
}
/**
* Parse channel URI and create a {@link UdpChannel}.
*
* @param channelUriString to parse.
* @return a new {@link UdpChannel} as the result of parsing.
* @throws InvalidChannelException if an error occurs.
*/
public static UdpChannel parse(final String channelUriString)
{
return parse(channelUriString, DefaultNameResolver.INSTANCE);
}
/**
* Parse channel URI and create a {@link UdpChannel}.
*
* @param channelUriString to parse.
* @param nameResolver to use for resolving names
* @return a new {@link UdpChannel} as the result of parsing.
* @throws InvalidChannelException if an error occurs.
*/
@SuppressWarnings("MethodLength")
public static UdpChannel parse(final String channelUriString, final NameResolver nameResolver)
{
try
{
final ChannelUri channelUri = ChannelUri.parse(channelUriString);
validateConfiguration(channelUri);
InetSocketAddress endpointAddress = getEndpointAddress(channelUri, nameResolver);
final InetSocketAddress explicitControlAddress = getExplicitControlAddress(channelUri, nameResolver);
final String tagIdStr = channelUri.channelTag();
final String controlMode = channelUri.get(CommonContext.MDC_CONTROL_MODE_PARAM_NAME);
final boolean isManualControlMode = CommonContext.MDC_CONTROL_MODE_MANUAL.equals(controlMode);
final boolean isDynamicControlMode = CommonContext.MDC_CONTROL_MODE_DYNAMIC.equals(controlMode);
final boolean requiresAdditionalSuffix =
null == endpointAddress && null == explicitControlAddress ||
(null != endpointAddress && endpointAddress.getPort() == 0) ||
(null != explicitControlAddress && explicitControlAddress.getPort() == 0);
final boolean hasNoDistinguishingCharacteristic =
null == endpointAddress && null == explicitControlAddress && null == tagIdStr;
if (isDynamicControlMode && null == explicitControlAddress)
{
throw new IllegalArgumentException(
"explicit control expected with dynamic control mode: " + channelUriString);
}
if (hasNoDistinguishingCharacteristic && !isManualControlMode)
{
throw new IllegalArgumentException(
"URIs for UDP must specify an endpoint, control, tags, or control-mode=manual: " +
channelUriString);
}
if (null != endpointAddress && endpointAddress.isUnresolved())
{
throw new UnknownHostException("could not resolve endpoint address: " + endpointAddress);
}
if (null != explicitControlAddress && explicitControlAddress.isUnresolved())
{
throw new UnknownHostException("could not resolve control address: " + explicitControlAddress);
}
boolean hasExplicitEndpoint = true;
if (null == endpointAddress)
{
hasExplicitEndpoint = false;
if (explicitControlAddress == null || explicitControlAddress.getAddress() instanceof Inet4Address) {
endpointAddress = new InetSocketAddress(InetAddress.getByAddress("", new byte[]{0,0,0,0}), 0);
} else {
endpointAddress = new InetSocketAddress(InetAddress.getByAddress("", new byte[]{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}), 0);
}
}
final Context context = new Context()
.uriStr(channelUriString)
.channelUri(channelUri)
.isManualControlMode(isManualControlMode)
.isDynamicControlMode(isDynamicControlMode)
.hasExplicitEndpoint(hasExplicitEndpoint)
.hasNoDistinguishingCharacteristic(hasNoDistinguishingCharacteristic);
if (null != tagIdStr)
{
context.hasTagId(true).tagId(Long.parseLong(tagIdStr));
}
if (endpointAddress.getAddress().isMulticastAddress())
{
final InetSocketAddress controlAddress = getMulticastControlAddress(endpointAddress);
final InterfaceSearchAddress searchAddress = getInterfaceSearchAddress(channelUri);
final NetworkInterface localInterface = findInterface(searchAddress);
final InetSocketAddress resolvedAddress = resolveToAddressOfInterface(localInterface, searchAddress);
context
.isMulticast(true)
.localControlAddress(resolvedAddress)
.remoteControlAddress(controlAddress)
.localDataAddress(resolvedAddress)
.remoteDataAddress(endpointAddress)
.localInterface(localInterface)
.protocolFamily(getProtocolFamily(endpointAddress.getAddress()))
.canonicalForm(canonicalise(null, resolvedAddress, null, endpointAddress));
final String ttlValue = channelUri.get(CommonContext.TTL_PARAM_NAME);
if (null != ttlValue)
{
context.hasMulticastTtl(true).multicastTtl(Integer.parseInt(ttlValue));
}
}
else if (null != explicitControlAddress)
{
final String controlVal = channelUri.get(CommonContext.MDC_CONTROL_PARAM_NAME);
final String endpointVal = channelUri.get(CommonContext.ENDPOINT_PARAM_NAME);
String suffix = "";
if (requiresAdditionalSuffix)
{
suffix = (null != tagIdStr) ? "#" + tagIdStr : ("-" + UNIQUE_CANONICAL_FORM_VALUE.getAndAdd(1));
}
final String canonicalForm = canonicalise(
controlVal, explicitControlAddress, endpointVal, endpointAddress) + suffix;
context
.hasExplicitControl(true)
.remoteControlAddress(endpointAddress)
.remoteDataAddress(endpointAddress)
.localControlAddress(explicitControlAddress)
.localDataAddress(explicitControlAddress)
.protocolFamily(getProtocolFamily(endpointAddress.getAddress()))
.canonicalForm(canonicalForm);
}
else
{
final InterfaceSearchAddress searchAddress = getInterfaceSearchAddress(channelUri);
final InetSocketAddress localAddress = searchAddress.getInetAddress().isAnyLocalAddress() ?
searchAddress.getAddress() :
resolveToAddressOfInterface(findInterface(searchAddress), searchAddress);
final String endpointVal = channelUri.get(CommonContext.ENDPOINT_PARAM_NAME);
String suffix = "";
if (requiresAdditionalSuffix)
{
suffix = (null != tagIdStr) ? "#" + tagIdStr : ("-" + UNIQUE_CANONICAL_FORM_VALUE.getAndAdd(1));
}
context
.remoteControlAddress(endpointAddress)
.remoteDataAddress(endpointAddress)
.localControlAddress(localAddress)
.localDataAddress(localAddress)
.protocolFamily(getProtocolFamily(endpointAddress.getAddress()))
.canonicalForm(canonicalise(null, localAddress, endpointVal, endpointAddress) + suffix);
}
return new UdpChannel(context);
}
catch (final Exception ex)
{
throw new InvalidChannelException(ex);
}
}
/**
* Return a string which is a canonical form of the channel suitable for use as a file or directory
* name and also as a method of hashing, etc.
* <p>
* The general format is:
* UDP-interface:localPort-remoteAddress:remotePort
*
* @param localParamValue interface or MDC control param value or null for not set.
* @param localData address/interface for the channel.
* @param remoteParamValue endpoint param value or null if not set.
* @param remoteData address for the channel.
* @return canonical representation as a string.
*/
public static String canonicalise(
final String localParamValue,
final InetSocketAddress localData,
final String remoteParamValue,
final InetSocketAddress remoteData)
{
final StringBuilder builder = new StringBuilder(48);
builder.append("UDP-");
if (null == localParamValue)
{
builder.append(localData.getHostString())
.append(':')
.append(localData.getPort());
}
else
{
builder.append(localParamValue);
}
builder.append('-');
if (null == remoteParamValue)
{
builder.append(remoteData.getHostString())
.append(':')
.append(remoteData.getPort());
}
else
{
builder.append(remoteParamValue);
}
return builder.toString();
}
/**
* Remote data address and port.
*
* @return remote data address and port.
*/
public InetSocketAddress remoteData()
{
return remoteData;
}
/**
* Local data address and port.
*
* @return local data address port.
*/
public InetSocketAddress localData()
{
return localData;
}
/**
* Remote control address information
*
* @return remote control address information
*/
public InetSocketAddress remoteControl()
{
return remoteControl;
}
/**
* Local control address and port.
*
* @return local control address and port.
*/
public InetSocketAddress localControl()
{
return localControl;
}
/**
* Get the {@link ChannelUri} for this channel.
*
* @return the {@link ChannelUri} for this channel.
*/
public ChannelUri channelUri()
{
return channelUri;
}
/**
* Has this channel got a multicast TTL value set so that {@link #multicastTtl()} is valid.
*
* @return true if this channel is a multicast TTL set otherwise false.
*/
public boolean hasMulticastTtl()
{
return hasMulticastTtl;
}
/**
* Multicast TTL value.
*
* @return multicast TTL value.
*/
public int multicastTtl()
{
return multicastTtl;
}
/**
* The canonical form for the channel
* <p>
* {@link UdpChannel#canonicalise}
*
* @return canonical form for channel.
*/
public String canonicalForm()
{
return canonicalForm;
}
/**
* The {@link #canonicalForm()} for the channel.
*
* @return the {@link #canonicalForm()} for the channel.
*/
@Override
public String toString()
{
return canonicalForm;
}
/**
* Is the channel UDP multicast.
*
* @return true if the channel is UDP multicast.
*/
public boolean isMulticast()
{
return isMulticast;
}
/**
* Local interface to be used by the channel.
*
* @return {@link NetworkInterface} for the local interface used by the channel.
*/
public NetworkInterface localInterface()
{
return localInterface;
}
/**
* Original URI of the channel URI.
*
* @return the original uri string from the client.
*/
public String originalUriString()
{
return uriStr;
}
/**
* Get the {@link ProtocolFamily} for this channel.
*
* @return the {@link ProtocolFamily} for this channel.
*/
public ProtocolFamily protocolFamily()
{
return protocolFamily;
}
/**
* Get the tag value on the channel which is only valid if {@link #hasTag()} is true.
*
* @return the tag value on the channel.
*/
public long tag()
{
return tag;
}
/**
* Does the channel have manual control mode specified.
*
* @return does channel have manual control mode specified.
*/
public boolean isManualControlMode()
{
return isManualControlMode;
}
/**
* Does the channel have dynamic control mode specified.
*
* @return does channel have dynamic control mode specified.
*/
public boolean isDynamicControlMode()
{
return isDynamicControlMode;
}
/**
* Does the channel have an explicit endpoint address?
*
* @return does channel have an explicit endpoint address or not?
*/
public boolean hasExplicitEndpoint()
{
return hasExplicitEndpoint;
}
/**
* Does the channel have an explicit control address as used with multi-destination-cast or not?
*
* @return does channel have an explicit control address or not?
*/
public boolean hasExplicitControl()
{
return hasExplicitControl;
}
/**
* Has the URI a tag to indicate entity relationships and if {@link #tag()} is valid.
*
* @return true if the channel has a tag.
*/
public boolean hasTag()
{
return hasTag;
}
/**
* Is the channel configured as multi-destination.
*
* @return true if he channel configured as multi-destination.
*/
public boolean isMultiDestination()
{
return isDynamicControlMode || isManualControlMode || hasExplicitControl;
}
/**
* Does this channel have a tag match to another channel including endpoints.
*
* @param udpChannel to match against.
* @return true if there is a match otherwise false.
*/
public boolean matchesTag(final UdpChannel udpChannel)
{
if (!hasTag || !udpChannel.hasTag() || tag != udpChannel.tag())
{
return false;
}
if (udpChannel.remoteData().getAddress().isAnyLocalAddress() &&
udpChannel.remoteData().getPort() == 0 &&
udpChannel.localData().getAddress().isAnyLocalAddress() &&
udpChannel.localData().getPort() == 0)
{
return true;
}
throw new IllegalArgumentException(
"matching tag has set endpoint or control address - " + uriStr + " <> " + udpChannel.uriStr);
}
/**
* Used for debugging to get a human readable description of the channel.
*
* @return a human readable description of the channel.
*/
public String description()
{
final StringBuilder builder = new StringBuilder("UdpChannel - ");
if (null != localInterface)
{
builder
.append("interface: ")
.append(localInterface.getDisplayName())
.append(", ");
}
builder
.append("localData: ").append(localData)
.append(", remoteData: ").append(remoteData)
.append(", ttl: ").append(multicastTtl);
return builder.toString();
}
/**
* Channels are considered equal if the {@link #canonicalForm()} is equal.
*
* @param o object to be compared with.
* @return true if the {@link #canonicalForm()} is equal, otherwise false.
*/
@Override
public boolean equals(final Object o)
{
if (this == o)
{
return true;
}
if (o == null || getClass() != o.getClass())
{
return false;
}
final UdpChannel that = (UdpChannel)o;
return Objects.equals(canonicalForm, that.canonicalForm);
}
/**
* The hash code for the {@link #canonicalForm()}.
*
* @return the hash code for the {@link #canonicalForm()}.
*/
@Override
public int hashCode()
{
return canonicalForm != null ? canonicalForm.hashCode() : 0;
}
/**
* Get the endpoint destination address from the URI.
*
* @param uri to check.
* @param nameResolver to use for resolution
* @return endpoint address for URI.
*/
public static InetSocketAddress destinationAddress(final ChannelUri uri, final NameResolver nameResolver)
{
try
{
validateConfiguration(uri);
return getEndpointAddress(uri, nameResolver);
}
catch (final Exception ex)
{
throw new InvalidChannelException(ex);
}
}
/**
* Resolve and endpoint into a {@link InetSocketAddress}.
*
* @param endpoint to resolve
* @param uriParamName for the resolution
* @param isReResolution for the resolution
* @param nameResolver to be used for hostname.
* @return address for endpoint
* @throws UnknownHostException if the endpoint can not be resolved.
*/
public static InetSocketAddress resolve(
final String endpoint, final String uriParamName, final boolean isReResolution, final NameResolver nameResolver)
throws UnknownHostException
{
return SocketAddressParser.parse(endpoint, uriParamName, isReResolution, nameResolver);
}
private static InetSocketAddress getMulticastControlAddress(final InetSocketAddress endpointAddress)
throws UnknownHostException
{
final byte[] addressAsBytes = endpointAddress.getAddress().getAddress();
validateDataAddress(addressAsBytes);
addressAsBytes[addressAsBytes.length - 1]++;
return new InetSocketAddress(getByAddress(addressAsBytes), endpointAddress.getPort());
}
private static InterfaceSearchAddress getInterfaceSearchAddress(final ChannelUri uri) throws UnknownHostException
{
final String interfaceValue = uri.get(CommonContext.INTERFACE_PARAM_NAME);
if (null != interfaceValue)
{
return InterfaceSearchAddress.parse(interfaceValue);
}
return InterfaceSearchAddress.wildcard();
}
private static InetSocketAddress getEndpointAddress(final ChannelUri uri, final NameResolver nameResolver)
{
InetSocketAddress address = null;
final String endpointValue = uri.get(CommonContext.ENDPOINT_PARAM_NAME);
if (null != endpointValue)
{
try
{
address = SocketAddressParser.parse(
endpointValue, CommonContext.ENDPOINT_PARAM_NAME, false, nameResolver);
}
catch (final UnknownHostException ex)
{
LangUtil.rethrowUnchecked(ex);
}
}
return address;
}
private static InetSocketAddress getExplicitControlAddress(final ChannelUri uri, final NameResolver nameResolver)
{
InetSocketAddress address = null;
final String controlValue = uri.get(CommonContext.MDC_CONTROL_PARAM_NAME);
if (null != controlValue)
{
try
{
address = SocketAddressParser.parse(
controlValue, CommonContext.MDC_CONTROL_PARAM_NAME, false, nameResolver);
}
catch (final UnknownHostException ex)
{
LangUtil.rethrowUnchecked(ex);
}
}
return address;
}
private static void validateDataAddress(final byte[] addressAsBytes)
{
if (BitUtil.isEven(addressAsBytes[addressAsBytes.length - 1]))
{
throw new IllegalArgumentException("multicast data address must be odd");
}
}
private static void validateConfiguration(final ChannelUri uri)
{
validateMedia(uri);
}
private static void validateMedia(final ChannelUri uri)
{
if (!uri.isUdp())
{
throw new IllegalArgumentException("UdpChannel only supports UDP media: " + uri);
}
}
private static InetSocketAddress resolveToAddressOfInterface(
final NetworkInterface localInterface, final InterfaceSearchAddress searchAddress)
{
final InetAddress interfaceAddress = findAddressOnInterface(
localInterface, searchAddress.getInetAddress(), searchAddress.getSubnetPrefix());
if (null == interfaceAddress)
{
throw new IllegalStateException();
}
return new InetSocketAddress(interfaceAddress, searchAddress.getPort());
}
private static NetworkInterface findInterface(final InterfaceSearchAddress searchAddress)
throws SocketException
{
final NetworkInterface[] filteredInterfaces = filterBySubnet(
searchAddress.getInetAddress(), searchAddress.getSubnetPrefix());
for (final NetworkInterface networkInterface : filteredInterfaces)
{
if (networkInterface.supportsMulticast() || networkInterface.isLoopback())
{
return networkInterface;
}
}
throw new IllegalArgumentException(errorNoMatchingInterfaces(filteredInterfaces, searchAddress));
}
private static String errorNoMatchingInterfaces(
final NetworkInterface[] filteredInterfaces, final InterfaceSearchAddress address)
throws SocketException
{
final StringBuilder builder = new StringBuilder()
.append("Unable to find multicast interface matching criteria: ")
.append(address.getAddress())
.append('/')
.append(address.getSubnetPrefix());
if (filteredInterfaces.length > 0)
{
builder.append(lineSeparator()).append(" Candidates:");
for (final NetworkInterface ifc : filteredInterfaces)
{
builder
.append(lineSeparator())
.append(" - Name: ")
.append(ifc.getDisplayName())
.append(", addresses: ")
.append(ifc.getInterfaceAddresses())
.append(", multicast: ")
.append(ifc.supportsMulticast());
}
}
return builder.toString();
}
static class Context
{
long tagId;
int multicastTtl;
InetSocketAddress remoteData;
InetSocketAddress localData;
InetSocketAddress remoteControl;
InetSocketAddress localControl;
String uriStr;
String canonicalForm;
NetworkInterface localInterface;
ProtocolFamily protocolFamily;
ChannelUri channelUri;
boolean isManualControlMode = false;
boolean isDynamicControlMode = false;
boolean hasExplicitEndpoint = false;
boolean hasExplicitControl = false;
boolean isMulticast = false;
boolean hasMulticastTtl = false;
boolean hasTagId = false;
boolean hasNoDistinguishingCharacteristic = false;
Context uriStr(final String uri)
{
uriStr = uri;
return this;
}
Context remoteDataAddress(final InetSocketAddress remoteData)
{
this.remoteData = remoteData;
return this;
}
Context localDataAddress(final InetSocketAddress localData)
{
this.localData = localData;
return this;
}
Context remoteControlAddress(final InetSocketAddress remoteControl)
{
this.remoteControl = remoteControl;
return this;
}
Context localControlAddress(final InetSocketAddress localControl)
{
this.localControl = localControl;
return this;
}
Context canonicalForm(final String canonicalForm)
{
this.canonicalForm = canonicalForm;
return this;
}
Context localInterface(final NetworkInterface networkInterface)
{
this.localInterface = networkInterface;
return this;
}
Context protocolFamily(final ProtocolFamily protocolFamily)
{
this.protocolFamily = protocolFamily;
return this;
}
Context hasMulticastTtl(final boolean hasMulticastTtl)
{
this.hasMulticastTtl = hasMulticastTtl;
return this;
}
Context multicastTtl(final int multicastTtl)
{
this.multicastTtl = multicastTtl;
return this;
}
Context tagId(final long tagId)
{
this.tagId = tagId;
return this;
}
Context channelUri(final ChannelUri channelUri)
{
this.channelUri = channelUri;
return this;
}
Context isManualControlMode(final boolean isManualControlMode)
{
this.isManualControlMode = isManualControlMode;
return this;
}
Context isDynamicControlMode(final boolean isDynamicControlMode)
{
this.isDynamicControlMode = isDynamicControlMode;
return this;
}
Context hasExplicitEndpoint(final boolean hasExplicitEndpoint)
{
this.hasExplicitEndpoint = hasExplicitEndpoint;
return this;
}
Context hasExplicitControl(final boolean hasExplicitControl)
{
this.hasExplicitControl = hasExplicitControl;
return this;
}
Context isMulticast(final boolean isMulticast)
{
this.isMulticast = isMulticast;
return this;
}
Context hasTagId(final boolean hasTagId)
{
this.hasTagId = hasTagId;
return this;
}
Context hasNoDistinguishingCharacteristic(final boolean hasNoDistinguishingCharacteristic)
{
this.hasNoDistinguishingCharacteristic = hasNoDistinguishingCharacteristic;
return this;
}
}
}

View File

@ -81,7 +81,6 @@ abstract class BaseTest {
fun serverConfig(): ServerConfiguration {
val configuration = ServerConfiguration()
configuration.listenIpAddress = LOOPBACK
configuration.subscriptionPort = 2000
configuration.publicationPort = 2001

View File

@ -32,7 +32,7 @@ class RmiDelayedInvocationTest : BaseTest() {
@Test
fun rmiNetwork() {
runBlocking {
rmi() { configuration ->
rmi { configuration ->
configuration.enableIpcForLoopback = false
}
}
@ -108,7 +108,7 @@ class RmiDelayedInvocationTest : BaseTest() {
client.connect(LOOPBACK)
}
waitForThreads(9999999)
waitForThreads()
}
private interface TestObject {

View File

@ -34,6 +34,8 @@
*/
package dorkboxTest.network.rmi
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.Client
import dorkbox.network.Configuration
import dorkbox.network.Server
@ -49,15 +51,58 @@ import org.junit.Test
class RmiSimpleTest : BaseTest() {
@Test
fun rmiNetworkGlobal() {
rmiGlobal() { configuration ->
fun rmiIPv4NetworkGlobal() {
rmiGlobal(isIpv4 = true, isIpv6 = false) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiNetworkConnection() {
rmi { configuration ->
fun rmiIPv6NetworkGlobal() {
rmiGlobal(isIpv4 = true, isIpv6 = false) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiBothIPv4ConnectNetworkGlobal() {
rmiGlobal(isIpv4 = true, isIpv6 = true) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiBothIPv6ConnectNetworkGlobal() {
rmiGlobal(isIpv4 = true, isIpv6 = true, runIpv4Connect = true) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiIPv4NetworkConnection() {
rmi(isIpv4 = true, isIpv6 = false) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiIPv6NetworkConnection() {
rmi(isIpv4 = false, isIpv6 = true) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiBothIPv4ConnectNetworkConnection() {
rmi(isIpv4 = true, isIpv6 = true) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@Test
fun rmiBothIPv6ConnectNetworkConnection() {
rmi(isIpv4 = true, isIpv6 = true, runIpv4Connect = true) { configuration ->
configuration.enableIpcForLoopback = false
}
}
@ -72,10 +117,13 @@ class RmiSimpleTest : BaseTest() {
rmi()
}
fun rmi(config: (Configuration) -> Unit = {}) {
fun rmi(isIpv4: Boolean = false, isIpv6: Boolean = false, runIpv4Connect: Boolean = true, config: (Configuration) -> Unit = {}) {
run {
val configuration = serverConfig()
configuration.enableIPv4 = isIpv4
configuration.enableIPv6 = isIpv6
config(configuration)
configuration.serialization.registerRmi(TestCow::class.java, TestCowImpl::class.java)
configuration.serialization.register(MessageWithTestCow::class.java)
configuration.serialization.register(UnsupportedOperationException::class.java)
@ -86,19 +134,19 @@ class RmiSimpleTest : BaseTest() {
server.bind()
server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server")
server.logger.error("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(23, id.toLong())
System.err.println("Finished test for: Client -> Server")
server.logger.error("Finished test for: Client -> Server")
System.err.println("Starting test for: Server -> Client")
server.logger.error("Starting test for: Server -> Client")
// NOTE: THIS IS BI-DIRECTIONAL!
connection.createObject<TestCow>(123) { rmiId, remoteObject ->
System.err.println("Running test for: Server -> Client")
server.logger.error("Running test for: Server -> Client")
RmiCommonTest.runTests(connection, remoteObject, 123)
System.err.println("Done with test for: Server -> Client")
server.logger.error("Done with test for: Server -> Client")
}
}
}
@ -108,39 +156,47 @@ class RmiSimpleTest : BaseTest() {
config(configuration)
// configuration.serialization.registerRmi(TestCow::class.java, TestCowImpl::class.java)
val client = Client<Connection>(configuration)
addEndPoint(client)
client.onConnect { connection ->
connection.createObject<TestCow>(23) { rmiId, remoteObject ->
System.err.println("Running test for: Client -> Server")
client.logger.error("Running test for: Client -> Server")
RmiCommonTest.runTests(connection, remoteObject, 23)
System.err.println("Done with test for: Client -> Server")
client.logger.error("Done with test for: Client -> Server")
}
}
client.onMessage<MessageWithTestCow> { _, m ->
System.err.println("Received finish signal for test for: Client -> Server")
client.logger.error("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(123, id.toLong())
System.err.println("Finished test for: Client -> Server")
client.logger.error("Finished test for: Client -> Server")
stopEndPoints(2000)
}
runBlocking {
client.connect(LOOPBACK)
when {
isIpv4 && isIpv6 && runIpv4Connect -> client.connect(IPv4.LOCALHOST)
isIpv4 && isIpv6 && !runIpv4Connect -> client.connect(IPv6.LOCALHOST)
isIpv4 -> client.connect(IPv4.LOCALHOST)
isIpv6 -> client.connect(IPv6.LOCALHOST)
else -> client.connect()
}
}
}
waitForThreads()
}
fun rmiGlobal(config: (Configuration) -> Unit = {}) {
fun rmiGlobal(isIpv4: Boolean = false, isIpv6: Boolean = false, runIpv4Connect: Boolean = true, config: (Configuration) -> Unit = {}) {
run {
val configuration = serverConfig()
configuration.enableIPv4 = isIpv4
configuration.enableIPv6 = isIpv6
config(configuration)
configuration.serialization.registerRmi(TestCow::class.java, TestCowImpl::class.java)
configuration.serialization.register(MessageWithTestCow::class.java)
configuration.serialization.register(UnsupportedOperationException::class.java)
@ -153,20 +209,20 @@ class RmiSimpleTest : BaseTest() {
server.bind()
server.onMessage<MessageWithTestCow> { connection, m ->
System.err.println("Received finish signal for test for: Client -> Server")
server.logger.error("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(44, id.toLong())
System.err.println("Finished test for: Client -> Server")
server.logger.error("Finished test for: Client -> Server")
// normally this is in the 'connected', but we do it here, so that it's more linear and easier to debug
connection.createObject<TestCow>(4) { rmiId, remoteObject ->
System.err.println("Running test for: Server -> Client")
server.logger.error("Running test for: Server -> Client")
RmiCommonTest.runTests(connection, remoteObject, 4)
System.err.println("Done with test for: Server -> Client")
server.logger.error("Done with test for: Server -> Client")
}
}
}
@ -180,24 +236,30 @@ class RmiSimpleTest : BaseTest() {
addEndPoint(client)
client.onMessage<MessageWithTestCow> { _, m ->
System.err.println("Received finish signal for test for: Client -> Server")
client.logger.error("Received finish signal for test for: Client -> Server")
val `object` = m.testCow
val id = `object`.id()
Assert.assertEquals(4, id.toLong())
System.err.println("Finished test for: Client -> Server")
client.logger.error("Finished test for: Client -> Server")
stopEndPoints(2000)
}
runBlocking {
client.connect(LOOPBACK)
when {
isIpv4 && isIpv6 && runIpv4Connect -> client.connect(IPv4.LOCALHOST)
isIpv4 && isIpv6 && !runIpv4Connect -> client.connect(IPv6.LOCALHOST)
isIpv4 -> client.connect(IPv4.LOCALHOST)
isIpv6 -> client.connect(IPv6.LOCALHOST)
else -> client.connect()
}
System.err.println("Starting test for: Client -> Server")
client.logger.error("Starting test for: Client -> Server")
// this creates a GLOBAL object on the server (instead of a connection specific object)
client.createObject<TestCow>(44) { rmiId, remoteObject ->
System.err.println("Running test for: Client -> Server")
client.logger.error("Running test for: Client -> Server")
RmiCommonTest.runTests(client.connection, remoteObject, 44)
System.err.println("Done with test for: Client -> Server")
client.logger.error("Done with test for: Client -> Server")
}
}
}