Added support for connection tags (so the client can set a name for its connection, and the server will get that name). This is usefull for identifying different connections (and doing different things) based on their tag name.

This commit is contained in:
Robinson 2023-10-28 20:55:49 +02:00
parent fe98763712
commit 58535a923b
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
16 changed files with 253 additions and 136 deletions

View File

@ -116,6 +116,15 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
var addressPrettyString: String = "UNKNOWN" var addressPrettyString: String = "UNKNOWN"
private set private set
/**
* The tag name assigned (by the configuration) to the client. The server will receive this tag during the handshake. The max length is
* 32 characters.
*/
@Volatile
var tag: String = ""
private set
/** /**
* The default connection reliability type (ie: can the lower-level network stack throw away data that has errors, for example real-time-voice) * The default connection reliability type (ie: can the lower-level network stack throw away data that has errors, for example real-time-voice)
*/ */
@ -507,6 +516,10 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
this.port1 = port1 this.port1 = port1
this.port2 = port2 this.port2 = port2
// DOUBLE CHECK!
require(config.tag.length <= 32) { "Client tag name length must be <= 32" }
this.tag = config.tag
this.reliable = reliable this.reliable = reliable
this.connectionTimeoutSec = connectionTimeoutSec this.connectionTimeoutSec = connectionTimeoutSec
@ -578,12 +591,13 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
remotePort2 = port2, remotePort2 = port2,
handshakeTimeoutNs = handshakeTimeoutNs, handshakeTimeoutNs = handshakeTimeoutNs,
reliable = reliable, reliable = reliable,
tagName = tag,
logger = logger logger = logger
) )
val pubSub = handshakeConnection.pubSub val pubSub = handshakeConnection.pubSub
val logInfo = pubSub.getLogInfo(logger)
val logInfo = pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) { if (logger.isDebugEnabled) {
logger.debug("Creating new handshake to $logInfo") logger.debug("Creating new handshake to $logInfo")
} else { } else {
@ -701,6 +715,7 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
// throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class) // throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class)
val connectionInfo = handshake.hello( val connectionInfo = handshake.hello(
tagName = tag,
endPoint = this, endPoint = this,
handshakeConnection = handshakeConnection, handshakeConnection = handshakeConnection,
handshakeTimeoutNs = handshakeTimeoutNs handshakeTimeoutNs = handshakeTimeoutNs
@ -771,7 +786,8 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
handshakeTimeoutNs = handshakeTimeoutNs, handshakeTimeoutNs = handshakeTimeoutNs,
handshakeConnection = handshakeConnection, handshakeConnection = handshakeConnection,
connectionInfo = connectionInfo, connectionInfo = connectionInfo,
port2Server = port2 port2Server = port2,
tagName = tag
) )
val pubSub = clientConnection.connectionInfo val pubSub = clientConnection.connectionInfo

View File

@ -167,6 +167,17 @@ class ClientConfiguration : dorkbox.network.Configuration() {
field = value field = value
} }
/**
* The tag name to be assigned to this connection and the server will receive this tag name during the handshake.
* The max length is 32 characters.
*/
var tag: String = ""
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/** /**
* Validates the current configuration. Throws an exception if there are problems. * Validates the current configuration. Throws an exception if there are problems.
*/ */
@ -180,6 +191,8 @@ class ClientConfiguration : dorkbox.network.Configuration() {
require(port > 0) { "Client listen port must be > 0" } require(port > 0) { "Client listen port must be > 0" }
require(port < 65535) { "Client listen port must be < 65535" } require(port < 65535) { "Client listen port must be < 65535" }
} }
require(tag.length <= 32) { "Client tag name length must be <= 32" }
} }
override fun initialize(logger: Logger): dorkbox.network.ClientConfiguration { override fun initialize(logger: Logger): dorkbox.network.ClientConfiguration {
@ -191,6 +204,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
super.copy(config) super.copy(config)
config.port = port config.port = port
config.tag = tag
return config return config
} }
@ -201,6 +215,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
if (!super.equals(other)) return false if (!super.equals(other)) return false
if (port != other.port) return false if (port != other.port) return false
if (tag != other.tag) return false
return true return true
} }
@ -208,6 +223,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
override fun hashCode(): Int { override fun hashCode(): Int {
var result = super.hashCode() var result = super.hashCode()
result = 31 * result + port.hashCode() result = 31 * result + port.hashCode()
result = 31 * result + tag.hashCode()
return result return result
} }
} }

View File

@ -365,7 +365,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* *
* This function will be called for **only** network clients (IPC client are excluded) * This function will be called for **only** network clients (IPC client are excluded)
*/ */
fun filter(function: CONNECTION.() -> Boolean) { fun filter(function: InetAddress.(String) -> Boolean) {
listenerManager.filter(function) listenerManager.filter(function)
} }

View File

@ -86,6 +86,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
info.sessionIdSub info.sessionIdSub
} }
/**
* The tag name for a connection permits an INCOMING client to define a custom string. The max length is 32
*/
val tag = info.tagName
/** /**
* The remote address, as a string. Will be null for IPC connections * The remote address, as a string. Will be null for IPC connections
*/ */

View File

@ -182,7 +182,6 @@ internal class CryptoManagement(val logger: Logger,
val streamIdPub = cryptInput.readInt() val streamIdPub = cryptInput.readInt()
val streamIdSub = cryptInput.readInt() val streamIdSub = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt() val regDetailsSize = cryptInput.readInt()
val enableSession = cryptInput.readBoolean()
val sessionTimeout = cryptInput.readLong() val sessionTimeout = cryptInput.readLong()
val regDetails = cryptInput.readBytes(regDetailsSize) val regDetails = cryptInput.readBytes(regDetailsSize)
@ -193,7 +192,6 @@ internal class CryptoManagement(val logger: Logger,
streamIdPub = streamIdPub, streamIdPub = streamIdPub,
streamIdSub = streamIdSub, streamIdSub = streamIdSub,
publicKey = serverPublicKeyBytes, publicKey = serverPublicKeyBytes,
enableSession = enableSession,
sessionTimeout = sessionTimeout, sessionTimeout = sessionTimeout,
kryoRegistrationDetails = regDetails, kryoRegistrationDetails = regDetails,
secretKey = secretKey) secretKey = secretKey)
@ -205,7 +203,6 @@ internal class CryptoManagement(val logger: Logger,
sessionIdSub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdPub: Int,
streamIdSub: Int, streamIdSub: Int,
enableSession: Boolean,
sessionTimeout: Long, sessionTimeout: Long,
kryoRegDetails: ByteArray kryoRegDetails: ByteArray
): ByteArray { ): ByteArray {
@ -218,7 +215,6 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdPub) cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub) cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size) cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBoolean(enableSession)
cryptOutput.writeLong(sessionTimeout) cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBytes(kryoRegDetails) cryptOutput.writeBytes(kryoRegDetails)
@ -269,7 +265,6 @@ internal class CryptoManagement(val logger: Logger,
sessionIdSub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdPub: Int,
streamIdSub: Int, streamIdSub: Int,
enableSession: Boolean,
sessionTimeout: Long, sessionTimeout: Long,
kryoRegDetails: ByteArray kryoRegDetails: ByteArray
): ByteArray { ): ByteArray {
@ -287,7 +282,6 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdPub) cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub) cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size) cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBoolean(enableSession)
cryptOutput.writeLong(sessionTimeout) cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBytes(kryoRegDetails) cryptOutput.writeBytes(kryoRegDetails)

View File

@ -189,7 +189,7 @@ internal class IpInfo(config: ServerConfiguration) {
} }
else -> { else -> {
ipType = IPC ipType = IPC
listenAddressString = "IPC" listenAddressString = EndPoint.IPC_NAME
formattedListenAddressString = listenAddressString formattedListenAddressString = listenAddressString
} }
} }

View File

@ -22,6 +22,7 @@ import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.os.OS import dorkbox.os.OS
import net.jodah.typetools.TypeResolver import net.jodah.typetools.TypeResolver
import org.slf4j.Logger import org.slf4j.Logger
import java.net.InetAddress
import java.util.concurrent.locks.* import java.util.concurrent.locks.*
import kotlin.concurrent.write import kotlin.concurrent.write
@ -144,7 +145,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return this return this
} }
internal inline fun <reified T> add(thing: T, array: Array<T>): Array<T> { internal inline fun <reified T: Any> add(thing: T, array: Array<T>): Array<T> {
val currentLength: Int = array.size val currentLength: Int = array.size
// add the new subscription to the END of the array // add the new subscription to the END of the array
@ -155,7 +156,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return newMessageArray return newMessageArray
} }
internal inline fun <reified T> remove(thing: T, array: Array<T>): Array<T> { internal inline fun <reified T: Any> remove(thing: T, array: Array<T>): Array<T> {
// remove the subscription form the array // remove the subscription form the array
// THIS IS IDENTITY CHECKS, NOT EQUALITY // THIS IS IDENTITY CHECKS, NOT EQUALITY
return array.filter { it !== thing }.toTypedArray() return array.filter { it !== thing }.toTypedArray()
@ -164,7 +165,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
// initialize emtpy arrays // initialize emtpy arrays
@Volatile @Volatile
private var onConnectFilterList = Array<(CONNECTION.() -> Boolean)>(0) { { true } } private var onConnectFilterList = Array<(InetAddress.(String) -> Boolean)>(0) { { true } }
private val onConnectFilterLock = ReentrantReadWriteLock() private val onConnectFilterLock = ReentrantReadWriteLock()
@Volatile @Volatile
@ -202,8 +203,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
*/ */
fun filter(ipFilterRule: IpFilterRule) { fun filter(ipFilterRule: IpFilterRule) {
filter { filter {
// IPC will not filter, so this is OK to coerce to not-null // IPC will not filter
ipFilterRule.matches(remoteAddress!!) ipFilterRule.matches(this)
} }
} }
@ -212,15 +213,21 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if a connection * Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if a connection
* should be allowed * should be allowed
* *
* By default, if there are no filter rules, then all connections are allowed to connect
* If there are filter rules - then ONLY connections for the filter that returns true are allowed to connect (all else are denied)
*
* It is the responsibility of the custom filter to write the error, if there is one * It is the responsibility of the custom filter to write the error, if there is one
* *
* If the function returns TRUE, then the connection will continue to connect. * If the function returns TRUE, then the connection will continue to connect.
* If the function returns FALSE, then the other end of the connection will * If the function returns FALSE, then the other end of the connection will
* receive a connection error * receive a connection error
* *
* For a server, this function will be called for ALL clients. *
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*/ */
fun filter(function: CONNECTION.() -> Boolean) { fun filter(function: InetAddress.(String) -> Boolean) {
onConnectFilterLock.write { onConnectFilterLock.write {
// we have to follow the single-writer principle! // we have to follow the single-writer principle!
onConnectFilterList = add(function, onConnectFilterList) onConnectFilterList = add(function, onConnectFilterList)
@ -348,23 +355,16 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
* *
* NOTE: This is run directly on the thread that calls it! * NOTE: This is run directly on the thread that calls it!
* *
* @return true if the connection will be allowed to connect. False if we should terminate this connection * @return true if the client address is allowed to connect. False if we should terminate this connection
*/ */
fun notifyFilter(connection: CONNECTION): Boolean { fun notifyFilter(clientAddress: InetAddress, clientTagName: String): Boolean {
// remote address will NOT be null at this stage, but best to verify.
val remoteAddress = connection.remoteAddress
if (remoteAddress == null) {
logger.error("Connection ${connection.id}: Unable to attempt connection stages when no remote address is present")
return false
}
// by default, there is a SINGLE rule that will always exist, and will always ACCEPT ALL connections. // by default, there is a SINGLE rule that will always exist, and will always ACCEPT ALL connections.
// This is so the array types can be setup (the compiler needs SOMETHING there) // This is so the array types can be setup (the compiler needs SOMETHING there)
val list = onConnectFilterList val list = onConnectFilterList
// if there is a rule, a connection must match for it to connect // if there is a rule, a connection must match for it to connect
list.forEach { list.forEach {
if (it.invoke(connection)) { if (it.invoke(clientAddress, clientTagName)) {
return true return true
} }
} }

View File

@ -21,6 +21,7 @@ import dorkbox.network.aeron.AeronDriver.Companion.getLocalAddressString
import dorkbox.network.aeron.AeronDriver.Companion.uri import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.aeron.controlEndpoint import dorkbox.network.aeron.controlEndpoint
import dorkbox.network.aeron.endpoint import dorkbox.network.aeron.endpoint
import dorkbox.network.connection.EndPoint
import dorkbox.network.exceptions.ClientRetryException import dorkbox.network.exceptions.ClientRetryException
import dorkbox.network.exceptions.ClientTimedOutException import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.CommonContext import io.aeron.CommonContext
@ -45,6 +46,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
handshakeConnection: ClientHandshakeDriver, handshakeConnection: ClientHandshakeDriver,
connectionInfo: ClientConnectionInfo, connectionInfo: ClientConnectionInfo,
port2Server: Int, // this is the port2 value from the server port2Server: Int, // this is the port2 value from the server
tagName: String
): ClientConnectionDriver { ): ClientConnectionDriver {
val handshakePubSub = handshakeConnection.pubSub val handshakePubSub = handshakeConnection.pubSub
val reliable = handshakePubSub.reliable val reliable = handshakePubSub.reliable
@ -73,6 +75,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub = streamIdPub, streamIdPub = streamIdPub,
streamIdSub = streamIdSub, streamIdSub = streamIdSub,
reliable = reliable, reliable = reliable,
tagName = tagName,
logInfo = logInfo logInfo = logInfo
) )
} }
@ -101,6 +104,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
portSub = portSub, portSub = portSub,
port2Server = port2Server, port2Server = port2Server,
reliable = reliable, reliable = reliable,
tagName = tagName,
logInfo = logInfo logInfo = logInfo
) )
} }
@ -117,6 +121,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub: Int, streamIdPub: Int,
streamIdSub: Int, streamIdSub: Int,
reliable: Boolean, reliable: Boolean,
tagName: String,
logInfo: String logInfo: String
): PubSub { ): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back) // on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -150,10 +155,20 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
} }
return PubSub(publication, subscription, return PubSub(
sessionIdPub, sessionIdSub, pub = publication,
streamIdPub, streamIdSub, sub = subscription,
reliable) sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
} }
@Throws(ClientTimedOutException::class) @Throws(ClientTimedOutException::class)
@ -170,6 +185,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
portSub: Int, portSub: Int,
port2Server: Int, // this is the port2 value from the server port2Server: Int, // this is the port2 value from the server
reliable: Boolean, reliable: Boolean,
tagName: String,
logInfo: String, logInfo: String,
): PubSub { ): PubSub {
val isRemoteIpv4 = remoteAddress is Inet4Address val isRemoteIpv4 = remoteAddress is Inet4Address
@ -213,12 +229,20 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause) ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
} }
return PubSub(publication, subscription, return PubSub(
sessionIdPub, sessionIdSub, pub = publication,
streamIdPub, streamIdSub, sub = subscription,
reliable, sessionIdPub = sessionIdPub,
remoteAddress, remoteAddressString, sessionIdSub = sessionIdSub,
portPub, portSub) streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
} }
} }
} }

View File

@ -23,7 +23,6 @@ internal class ClientConnectionInfo(
val streamIdPub: Int, val streamIdPub: Int,
val streamIdSub: Int = 0, val streamIdSub: Int = 0,
val publicKey: ByteArray = ByteArray(0), val publicKey: ByteArray = ByteArray(0),
val enableSession: Boolean,
val sessionTimeout: Long, val sessionTimeout: Long,
val kryoRegistrationDetails: ByteArray, val kryoRegistrationDetails: ByteArray,
val secretKey: SecretKeySpec val secretKey: SecretKeySpec

View File

@ -178,6 +178,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread // called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed // when exceptions are thrown, the handshake pub/sub will be closed
fun hello( fun hello(
tagName: String,
endPoint: EndPoint<CONNECTION>, endPoint: EndPoint<CONNECTION>,
handshakeConnection: ClientHandshakeDriver, handshakeConnection: ClientHandshakeDriver,
handshakeTimeoutNs: Long handshakeTimeoutNs: Long
@ -200,7 +201,8 @@ internal class ClientHandshake<CONNECTION: Connection>(
connectKey = connectKey, connectKey = connectKey,
publicKey = client.storage.publicKey, publicKey = client.storage.publicKey,
streamIdSub = pubSub.streamIdSub, streamIdSub = pubSub.streamIdSub,
portSub = pubSub.portSub portSub = pubSub.portSub,
tagName = tagName
)) ))
} catch (e: Exception) { } catch (e: Exception) {
handshakeConnection.close(endPoint) handshakeConnection.close(endPoint)

View File

@ -65,6 +65,7 @@ internal class ClientHandshakeDriver(
clientListenPort: Int, clientListenPort: Int,
handshakeTimeoutNs: Long, handshakeTimeoutNs: Long,
reliable: Boolean, reliable: Boolean,
tagName: String,
logger: Logger logger: Logger
): ClientHandshakeDriver { ): ClientHandshakeDriver {
logger.trace("Starting client handshake") logger.trace("Starting client handshake")
@ -111,6 +112,7 @@ internal class ClientHandshakeDriver(
streamIdPub = streamIdPub, streamIdPub = streamIdPub,
streamIdSub = streamIdSub, streamIdSub = streamIdSub,
reliable = reliable, reliable = reliable,
tagName = tagName,
logInfo = logInfo logInfo = logInfo
) )
} catch (exception: Exception) { } catch (exception: Exception) {
@ -160,6 +162,7 @@ internal class ClientHandshakeDriver(
streamIdPub = streamIdPub, streamIdPub = streamIdPub,
reliable = reliable, reliable = reliable,
streamIdSub = streamIdSub, streamIdSub = streamIdSub,
tagName = tagName,
logInfo = logInfo logInfo = logInfo
) )
@ -188,6 +191,7 @@ internal class ClientHandshakeDriver(
sessionIdPub: Int, sessionIdPub: Int,
streamIdPub: Int, streamIdSub: Int, streamIdPub: Int, streamIdSub: Int,
reliable: Boolean, reliable: Boolean,
tagName: String,
logInfo: String logInfo: String
): PubSub { ): PubSub {
// Create a publication at the given address and port, using the given stream ID. // Create a publication at the given address and port, using the given stream ID.
@ -214,10 +218,20 @@ internal class ClientHandshakeDriver(
val subscriptionUri = uriHandshake(CommonContext.IPC_MEDIA, reliable) val subscriptionUri = uriHandshake(CommonContext.IPC_MEDIA, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true) val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription, return PubSub(
sessionIdPub, 0, pub = publication,
streamIdPub, streamIdSub, sub = subscription,
reliable) sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
} }
@Throws(ClientTimedOutException::class) @Throws(ClientTimedOutException::class)
@ -233,6 +247,7 @@ internal class ClientHandshakeDriver(
streamIdPub: Int, streamIdPub: Int,
reliable: Boolean, reliable: Boolean,
streamIdSub: Int, streamIdSub: Int,
tagName: String,
logInfo: String, logInfo: String,
): PubSub { ): PubSub {
@Suppress("NAME_SHADOWING") @Suppress("NAME_SHADOWING")
@ -320,12 +335,20 @@ internal class ClientHandshakeDriver(
throw ex throw ex
} }
return PubSub(publication, subscription, return PubSub(
sessionIdPub, 0, pub = publication,
streamIdPub, streamIdSub, sub = subscription,
reliable, sessionIdPub = sessionIdPub,
remoteAddress, remoteAddressString, sessionIdSub = 0,
portPub, portSub) streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
} }
} }

View File

@ -32,6 +32,9 @@ internal class HandshakeMessage private constructor() {
// -1 means there is an error // -1 means there is an error
var state = INVALID var state = INVALID
// used to name a connection (via the client)
var tag: String = ""
var errorMessage: String? = null var errorMessage: String? = null
var port = 0 var port = 0
@ -51,7 +54,7 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3 const val DONE = 3
const val DONE_ACK = 4 const val DONE_ACK = 4
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int): HandshakeMessage { fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int, tagName: String): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = HELLO hello.state = HELLO
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
@ -59,6 +62,7 @@ internal class HandshakeMessage private constructor() {
hello.sessionId = 0 // not used by the server, since it connects in a different way! hello.sessionId = 0 // not used by the server, since it connects in a different way!
hello.streamId = streamIdSub hello.streamId = streamIdSub
hello.port = portSub hello.port = portSub
hello.tag = tagName
return hello return hello
} }
@ -135,6 +139,6 @@ internal class HandshakeMessage private constructor() {
"" ""
} }
return "HandshakeMessage($stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})" return "HandshakeMessage($tag :: $stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
} }
} }

View File

@ -16,9 +16,9 @@
package dorkbox.network.handshake package dorkbox.network.handshake
import dorkbox.network.connection.EndPoint
import io.aeron.Publication import io.aeron.Publication
import io.aeron.Subscription import io.aeron.Subscription
import org.slf4j.Logger
import java.net.Inet4Address import java.net.Inet4Address
import java.net.InetAddress import java.net.InetAddress
@ -30,29 +30,43 @@ data class PubSub(
val streamIdPub: Int, val streamIdPub: Int,
val streamIdSub: Int, val streamIdSub: Int,
val reliable: Boolean, val reliable: Boolean,
val remoteAddress: InetAddress? = null, val remoteAddress: InetAddress?,
val remoteAddressString: String = "IPC", val remoteAddressString: String,
val portPub: Int = 0, val portPub: Int,
val portSub: Int = 0 val portSub: Int,
val tagName: String // will either be "", or will be "[tag_name]"
) { ) {
val isIpc get() = remoteAddress == null val isIpc get() = remoteAddress == null
fun getLogInfo(logger: Logger): String { fun getLogInfo(extraDetails: Boolean): String {
val detailed = logger.isTraceEnabled
return if (isIpc) { return if (isIpc) {
if (detailed) { val prefix = if (tagName.isNotEmpty()) {
"IPC sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}" EndPoint.IPC_NAME + " $tagName"
} else { } else {
"IPC" EndPoint.IPC_NAME
}
if (extraDetails) {
if (tagName.isNotEmpty()) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
}
} else {
prefix
} }
} else { } else {
val prefix = if (remoteAddress is Inet4Address) { var prefix = if (remoteAddress is Inet4Address) {
"IPv4 $remoteAddressString" "IPv4 $remoteAddressString"
} else { } else {
"IPv6 $remoteAddressString" "IPv6 $remoteAddressString"
} }
if (detailed) { if (tagName.isNotEmpty()) {
prefix += " $tagName"
}
if (extraDetails) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, port: p=${portPub} s=${portSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}" "$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, port: p=${portPub} s=${portSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else { } else {
prefix prefix

View File

@ -18,6 +18,7 @@ package dorkbox.network.handshake
import dorkbox.network.aeron.AeronDriver import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.uri import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpInfo import dorkbox.network.connection.IpInfo
import io.aeron.CommonContext import io.aeron.CommonContext
import java.net.Inet4Address import java.net.Inet4Address
@ -33,16 +34,17 @@ import java.net.InetAddress
internal class ServerConnectionDriver(val pubSub: PubSub) { internal class ServerConnectionDriver(val pubSub: PubSub) {
companion object { companion object {
fun build(isIpc: Boolean, fun build(isIpc: Boolean,
aeronDriver: AeronDriver, aeronDriver: AeronDriver,
sessionIdPub: Int, sessionIdSub: Int, sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int, streamIdPub: Int, streamIdSub: Int,
ipInfo: IpInfo, ipInfo: IpInfo,
remoteAddress: InetAddress?, remoteAddress: InetAddress?,
remoteAddressString: String, remoteAddressString: String,
portPubMdc: Int, portPub: Int, portSub: Int, portPubMdc: Int, portPub: Int, portSub: Int,
reliable: Boolean, reliable: Boolean,
logInfo: String): ServerConnectionDriver { tagName: String,
logInfo: String): ServerConnectionDriver {
val pubSub: PubSub val pubSub: PubSub
@ -54,6 +56,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
streamIdPub = streamIdPub, streamIdPub = streamIdPub,
streamIdSub = streamIdSub, streamIdSub = streamIdSub,
reliable = reliable, reliable = reliable,
tagName = tagName,
logInfo = logInfo logInfo = logInfo
) )
} else { } else {
@ -70,6 +73,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
portPub = portPub, portPub = portPub,
portSub = portSub, portSub = portSub,
reliable = reliable, reliable = reliable,
tagName = tagName,
logInfo = logInfo logInfo = logInfo
) )
} }
@ -82,6 +86,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
sessionIdPub: Int, sessionIdSub: Int, sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int, streamIdPub: Int, streamIdSub: Int,
reliable: Boolean, reliable: Boolean,
tagName: String,
logInfo: String logInfo: String
): PubSub { ): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back) // on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -98,10 +103,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscriptionUri = uri(CommonContext.IPC_MEDIA, sessionIdSub, reliable) val subscriptionUri = uri(CommonContext.IPC_MEDIA, sessionIdSub, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true) val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription, return PubSub(
sessionIdPub, sessionIdSub, pub = publication,
streamIdPub, streamIdSub, sub = subscription,
reliable) sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
} }
private fun buildUdp( private fun buildUdp(
@ -114,6 +129,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
portPub: Int, portPub: Int,
portSub: Int, portSub: Int,
reliable: Boolean, reliable: Boolean,
tagName: String,
logInfo: String logInfo: String
): PubSub { ): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back) // on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -146,12 +162,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false) val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
return PubSub(publication, subscription, return PubSub(
sessionIdPub, sessionIdSub, pub = publication,
streamIdPub, streamIdSub, sub = subscription,
reliable, sessionIdPub = sessionIdPub,
remoteAddress, remoteAddressString, sessionIdSub = sessionIdSub,
portPub, portSub) streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
} }
} }
} }

View File

@ -21,8 +21,6 @@ import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.sessionIdAllocator import dorkbox.network.aeron.AeronDriver.Companion.sessionIdAllocator
import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator
import dorkbox.network.connection.* import dorkbox.network.connection.*
import dorkbox.network.connection.session.SessionConnection
import dorkbox.network.connection.session.SessionServer
import dorkbox.network.exceptions.AllocationException import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ServerHandshakeException import dorkbox.network.exceptions.ServerHandshakeException
import dorkbox.network.exceptions.ServerTimedoutException import dorkbox.network.exceptions.ServerTimedoutException
@ -131,23 +129,16 @@ internal class ServerHandshake<CONNECTION : Connection>(
} }
// Server is the "source", client mirrors the server // Server is the "source", client mirrors the server
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) { if (logger.isTraceEnabled) {
logger.trace("[${newConnection}] (${message.connectKey}) $connType (${newConnection.id}) done with handshake.") logger.trace("[${newConnection}] (${message.connectKey}) Buffered connection (${newConnection.id}) done with handshake.")
} else if (logger.isDebugEnabled) { } else if (logger.isDebugEnabled) {
logger.debug("[${newConnection}] $connType (${newConnection.id}) done with handshake.") logger.debug("[${newConnection}] Buffered connection (${newConnection.id}) done with handshake.")
} }
// in the specific case of using sessions, we don't want to call 'init' or `connect` for a connection that is resuming a session
// when applicable - we ALSO want to restore RMI objects BEFORE the connection is fully setup!
val newSession = server.sessionManager.onInit(newConnection)
newConnection.setImage() newConnection.setImage()
// before we finish creating the connection, we initialize it (in case there needs to be logic that happens-before `onConnect` calls // before we finish creating the connection, we initialize it (in case there needs to be logic that happens-before `onConnect` calls
if (newSession) { listenerManager.notifyInit(newConnection)
listenerManager.notifyInit(newConnection)
}
// this enables the connection to start polling for messages // this enables the connection to start polling for messages
server.addConnection(newConnection) server.addConnection(newConnection)
@ -160,9 +151,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
listenerManager.notifyConnect(newConnection) listenerManager.notifyConnect(newConnection)
if (!newSession) { newConnection.sendBufferedMessages()
(newConnection as SessionConnection).sendPendingMessages()
}
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(newConnection, TransmitException("[$newConnection] Handshake error", e)) listenerManager.notifyError(newConnection, TransmitException("[$newConnection] Handshake error", e))
} }
@ -248,6 +237,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
): Boolean { ): Boolean {
val serialization = config.serialization val serialization = config.serialization
val clientTagName = message.tag.let { if (it.isEmpty()) "" else "[$it]" }
if (clientTagName.length > 34) {
// 34 to account for []
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
///// /////
///// /////
///// DONE WITH VALIDATION ///// DONE WITH VALIDATION
@ -339,7 +335,8 @@ internal class ServerHandshake<CONNECTION : Connection>(
aeronDriver = aeronDriver, aeronDriver = aeronDriver,
ipInfo = server.ipInfo, ipInfo = server.ipInfo,
isIpc = true, isIpc = true,
logInfo = "IPC", tagName = clientTagName,
logInfo = EndPoint.IPC_NAME,
remoteAddress = null, remoteAddress = null,
remoteAddressString = "", remoteAddressString = "",
@ -353,12 +350,11 @@ internal class ServerHandshake<CONNECTION : Connection>(
reliable = true reliable = true
) )
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger) val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
val connectionType = if (server is SessionServer) "session connection" else "connection"
if (logger.isDebugEnabled) { if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo") logger.debug("Creating new buffered connection to $logInfo")
} else { } else {
logger.info("Creating new $connectionType to $logInfo") logger.info("Creating new buffered connection to $logInfo")
} }
newConnection = server.newConnection(ConnectionParams( newConnection = server.newConnection(ConnectionParams(
@ -369,7 +365,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections
)) ))
server.sessionManager.onNewConnection(newConnection) server.bufferedManager.onConnect(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@ -393,8 +389,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub, sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub, streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub, streamIdSub = connectionStreamIdSub,
enableSession = newConnection is SessionConnection, sessionTimeout = config.bufferedConnectionTimeoutSeconds,
sessionTimeout = config.sessionTimeoutSeconds,
kryoRegDetails = serialization.getKryoRegistrationDetails() kryoRegDetails = serialization.getKryoRegistrationDetails()
) )
@ -403,11 +398,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
// before we notify connect, we have to wait for the client to tell us that they can receive data // before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = newConnection pendingConnections[message.connectKey] = newConnection
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) { if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${message.connectKey}) $connType (${newConnection.id}) responding to handshake hello.") logger.trace("[$logInfo] (${message.connectKey}) buffered connection (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) { } else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.") logger.debug("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
} }
// this tells the client all the info to connect. // this tells the client all the info to connect.
@ -470,6 +464,27 @@ internal class ServerHandshake<CONNECTION : Connection>(
return false return false
} }
val clientTagName = message.tag.let { if (it.isEmpty()) "" else "[$it]" }
if (clientTagName.length > 34) {
// 34 to account for []
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(clientAddress, clientTagName)
if (!permitConnection) {
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
///// /////
///// /////
@ -585,18 +600,18 @@ internal class ServerHandshake<CONNECTION : Connection>(
portPubMdc = mdcPortPub, portPubMdc = mdcPortPub,
portPub = portPub, portPub = portPub,
portSub = portSub, portSub = portSub,
tagName = clientTagName,
reliable = isReliable reliable = isReliable
) )
val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes) val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes)
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger) val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
val connectionType = if (server is SessionServer) "session connection" else "connection"
if (logger.isDebugEnabled) { if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo") logger.debug("Creating new buffered connection to $logInfo")
} else { } else {
logger.info("Creating new $connectionType to $logInfo") logger.info("Creating new buffered connection to $logInfo")
} }
newConnection = server.newConnection(ConnectionParams( newConnection = server.newConnection(ConnectionParams(
@ -607,31 +622,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
cryptoKey = cryptoSecretKey cryptoKey = cryptoSecretKey
)) ))
server.sessionManager.onNewConnection(newConnection) server.bufferedManager.onConnect(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) {
// this will also unwind/free allocations
newConnection.close()
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
/////////////// ///////////////
/// HANDSHAKE /// HANDSHAKE
/////////////// ///////////////
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(message.connectKey) val successMessage = HandshakeMessage.helloAckToClient(message.connectKey)
@ -645,8 +642,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub, sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub, streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub, streamIdSub = connectionStreamIdSub,
enableSession = newConnection is SessionConnection, sessionTimeout = config.bufferedConnectionTimeoutSeconds,
sessionTimeout = config.sessionTimeoutSeconds,
kryoRegDetails = serialization.getKryoRegistrationDetails() kryoRegDetails = serialization.getKryoRegistrationDetails()
) )
@ -655,11 +651,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
// before we notify connect, we have to wait for the client to tell us that they can receive data // before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = newConnection pendingConnections[message.connectKey] = newConnection
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) { if (logger.isTraceEnabled) {
logger.trace("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.") logger.trace("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) { } else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.") logger.debug("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
} }
// this tells the client all the info to connect. // this tells the client all the info to connect.

View File

@ -98,6 +98,7 @@ abstract class BaseTest {
val configuration = ClientConfiguration() val configuration = ClientConfiguration()
configuration.appId = "network_test" configuration.appId = "network_test"
configuration.tag = "**Client**"
configuration.settingsStore = Storage.Memory() // don't want to persist anything on disk! configuration.settingsStore = Storage.Memory() // don't want to persist anything on disk!
configuration.enableIpc = false configuration.enableIpc = false