Added support for connection tags (so the client can set a name for its connection, and the server will get that name). This is usefull for identifying different connections (and doing different things) based on their tag name.

master
Robinson 2023-10-28 20:55:49 +02:00
parent fe98763712
commit 58535a923b
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
16 changed files with 253 additions and 136 deletions

View File

@ -116,6 +116,15 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
var addressPrettyString: String = "UNKNOWN"
private set
/**
* The tag name assigned (by the configuration) to the client. The server will receive this tag during the handshake. The max length is
* 32 characters.
*/
@Volatile
var tag: String = ""
private set
/**
* The default connection reliability type (ie: can the lower-level network stack throw away data that has errors, for example real-time-voice)
*/
@ -507,6 +516,10 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
this.port1 = port1
this.port2 = port2
// DOUBLE CHECK!
require(config.tag.length <= 32) { "Client tag name length must be <= 32" }
this.tag = config.tag
this.reliable = reliable
this.connectionTimeoutSec = connectionTimeoutSec
@ -578,12 +591,13 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
remotePort2 = port2,
handshakeTimeoutNs = handshakeTimeoutNs,
reliable = reliable,
tagName = tag,
logger = logger
)
val pubSub = handshakeConnection.pubSub
val logInfo = pubSub.getLogInfo(logger)
val logInfo = pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new handshake to $logInfo")
} else {
@ -701,6 +715,7 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
// throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class)
val connectionInfo = handshake.hello(
tagName = tag,
endPoint = this,
handshakeConnection = handshakeConnection,
handshakeTimeoutNs = handshakeTimeoutNs
@ -771,7 +786,8 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
handshakeTimeoutNs = handshakeTimeoutNs,
handshakeConnection = handshakeConnection,
connectionInfo = connectionInfo,
port2Server = port2
port2Server = port2,
tagName = tag
)
val pubSub = clientConnection.connectionInfo

View File

@ -167,6 +167,17 @@ class ClientConfiguration : dorkbox.network.Configuration() {
field = value
}
/**
* The tag name to be assigned to this connection and the server will receive this tag name during the handshake.
* The max length is 32 characters.
*/
var tag: String = ""
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Validates the current configuration. Throws an exception if there are problems.
*/
@ -180,6 +191,8 @@ class ClientConfiguration : dorkbox.network.Configuration() {
require(port > 0) { "Client listen port must be > 0" }
require(port < 65535) { "Client listen port must be < 65535" }
}
require(tag.length <= 32) { "Client tag name length must be <= 32" }
}
override fun initialize(logger: Logger): dorkbox.network.ClientConfiguration {
@ -191,6 +204,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
super.copy(config)
config.port = port
config.tag = tag
return config
}
@ -201,6 +215,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
if (!super.equals(other)) return false
if (port != other.port) return false
if (tag != other.tag) return false
return true
}
@ -208,6 +223,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
override fun hashCode(): Int {
var result = super.hashCode()
result = 31 * result + port.hashCode()
result = 31 * result + tag.hashCode()
return result
}
}

View File

@ -365,7 +365,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
*
* This function will be called for **only** network clients (IPC client are excluded)
*/
fun filter(function: CONNECTION.() -> Boolean) {
fun filter(function: InetAddress.(String) -> Boolean) {
listenerManager.filter(function)
}

View File

@ -86,6 +86,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
info.sessionIdSub
}
/**
* The tag name for a connection permits an INCOMING client to define a custom string. The max length is 32
*/
val tag = info.tagName
/**
* The remote address, as a string. Will be null for IPC connections
*/

View File

@ -182,7 +182,6 @@ internal class CryptoManagement(val logger: Logger,
val streamIdPub = cryptInput.readInt()
val streamIdSub = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val enableSession = cryptInput.readBoolean()
val sessionTimeout = cryptInput.readLong()
val regDetails = cryptInput.readBytes(regDetailsSize)
@ -193,7 +192,6 @@ internal class CryptoManagement(val logger: Logger,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
publicKey = serverPublicKeyBytes,
enableSession = enableSession,
sessionTimeout = sessionTimeout,
kryoRegistrationDetails = regDetails,
secretKey = secretKey)
@ -205,7 +203,6 @@ internal class CryptoManagement(val logger: Logger,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
enableSession: Boolean,
sessionTimeout: Long,
kryoRegDetails: ByteArray
): ByteArray {
@ -218,7 +215,6 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBoolean(enableSession)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBytes(kryoRegDetails)
@ -269,7 +265,6 @@ internal class CryptoManagement(val logger: Logger,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
enableSession: Boolean,
sessionTimeout: Long,
kryoRegDetails: ByteArray
): ByteArray {
@ -287,7 +282,6 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBoolean(enableSession)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBytes(kryoRegDetails)

View File

@ -189,7 +189,7 @@ internal class IpInfo(config: ServerConfiguration) {
}
else -> {
ipType = IPC
listenAddressString = "IPC"
listenAddressString = EndPoint.IPC_NAME
formattedListenAddressString = listenAddressString
}
}

View File

@ -22,6 +22,7 @@ import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.os.OS
import net.jodah.typetools.TypeResolver
import org.slf4j.Logger
import java.net.InetAddress
import java.util.concurrent.locks.*
import kotlin.concurrent.write
@ -144,7 +145,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return this
}
internal inline fun <reified T> add(thing: T, array: Array<T>): Array<T> {
internal inline fun <reified T: Any> add(thing: T, array: Array<T>): Array<T> {
val currentLength: Int = array.size
// add the new subscription to the END of the array
@ -155,7 +156,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return newMessageArray
}
internal inline fun <reified T> remove(thing: T, array: Array<T>): Array<T> {
internal inline fun <reified T: Any> remove(thing: T, array: Array<T>): Array<T> {
// remove the subscription form the array
// THIS IS IDENTITY CHECKS, NOT EQUALITY
return array.filter { it !== thing }.toTypedArray()
@ -164,7 +165,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
// initialize emtpy arrays
@Volatile
private var onConnectFilterList = Array<(CONNECTION.() -> Boolean)>(0) { { true } }
private var onConnectFilterList = Array<(InetAddress.(String) -> Boolean)>(0) { { true } }
private val onConnectFilterLock = ReentrantReadWriteLock()
@Volatile
@ -202,8 +203,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
*/
fun filter(ipFilterRule: IpFilterRule) {
filter {
// IPC will not filter, so this is OK to coerce to not-null
ipFilterRule.matches(remoteAddress!!)
// IPC will not filter
ipFilterRule.matches(this)
}
}
@ -212,15 +213,21 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if a connection
* should be allowed
*
* By default, if there are no filter rules, then all connections are allowed to connect
* If there are filter rules - then ONLY connections for the filter that returns true are allowed to connect (all else are denied)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the connection will continue to connect.
* If the function returns FALSE, then the other end of the connection will
* receive a connection error
*
* For a server, this function will be called for ALL clients.
*
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*/
fun filter(function: CONNECTION.() -> Boolean) {
fun filter(function: InetAddress.(String) -> Boolean) {
onConnectFilterLock.write {
// we have to follow the single-writer principle!
onConnectFilterList = add(function, onConnectFilterList)
@ -348,23 +355,16 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
*
* NOTE: This is run directly on the thread that calls it!
*
* @return true if the connection will be allowed to connect. False if we should terminate this connection
* @return true if the client address is allowed to connect. False if we should terminate this connection
*/
fun notifyFilter(connection: CONNECTION): Boolean {
// remote address will NOT be null at this stage, but best to verify.
val remoteAddress = connection.remoteAddress
if (remoteAddress == null) {
logger.error("Connection ${connection.id}: Unable to attempt connection stages when no remote address is present")
return false
}
fun notifyFilter(clientAddress: InetAddress, clientTagName: String): Boolean {
// by default, there is a SINGLE rule that will always exist, and will always ACCEPT ALL connections.
// This is so the array types can be setup (the compiler needs SOMETHING there)
val list = onConnectFilterList
// if there is a rule, a connection must match for it to connect
list.forEach {
if (it.invoke(connection)) {
if (it.invoke(clientAddress, clientTagName)) {
return true
}
}

View File

@ -21,6 +21,7 @@ import dorkbox.network.aeron.AeronDriver.Companion.getLocalAddressString
import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.aeron.controlEndpoint
import dorkbox.network.aeron.endpoint
import dorkbox.network.connection.EndPoint
import dorkbox.network.exceptions.ClientRetryException
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.CommonContext
@ -45,6 +46,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
handshakeConnection: ClientHandshakeDriver,
connectionInfo: ClientConnectionInfo,
port2Server: Int, // this is the port2 value from the server
tagName: String
): ClientConnectionDriver {
val handshakePubSub = handshakeConnection.pubSub
val reliable = handshakePubSub.reliable
@ -73,6 +75,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -101,6 +104,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
portSub = portSub,
port2Server = port2Server,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -117,6 +121,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub: Int,
streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -150,10 +155,20 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
}
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
@Throws(ClientTimedOutException::class)
@ -170,6 +185,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
portSub: Int,
port2Server: Int, // this is the port2 value from the server
reliable: Boolean,
tagName: String,
logInfo: String,
): PubSub {
val isRemoteIpv4 = remoteAddress is Inet4Address
@ -213,12 +229,20 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
}
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
}

View File

@ -23,7 +23,6 @@ internal class ClientConnectionInfo(
val streamIdPub: Int,
val streamIdSub: Int = 0,
val publicKey: ByteArray = ByteArray(0),
val enableSession: Boolean,
val sessionTimeout: Long,
val kryoRegistrationDetails: ByteArray,
val secretKey: SecretKeySpec

View File

@ -178,6 +178,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed
fun hello(
tagName: String,
endPoint: EndPoint<CONNECTION>,
handshakeConnection: ClientHandshakeDriver,
handshakeTimeoutNs: Long
@ -200,7 +201,8 @@ internal class ClientHandshake<CONNECTION: Connection>(
connectKey = connectKey,
publicKey = client.storage.publicKey,
streamIdSub = pubSub.streamIdSub,
portSub = pubSub.portSub
portSub = pubSub.portSub,
tagName = tagName
))
} catch (e: Exception) {
handshakeConnection.close(endPoint)

View File

@ -65,6 +65,7 @@ internal class ClientHandshakeDriver(
clientListenPort: Int,
handshakeTimeoutNs: Long,
reliable: Boolean,
tagName: String,
logger: Logger
): ClientHandshakeDriver {
logger.trace("Starting client handshake")
@ -111,6 +112,7 @@ internal class ClientHandshakeDriver(
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
} catch (exception: Exception) {
@ -160,6 +162,7 @@ internal class ClientHandshakeDriver(
streamIdPub = streamIdPub,
reliable = reliable,
streamIdSub = streamIdSub,
tagName = tagName,
logInfo = logInfo
)
@ -188,6 +191,7 @@ internal class ClientHandshakeDriver(
sessionIdPub: Int,
streamIdPub: Int, streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// Create a publication at the given address and port, using the given stream ID.
@ -214,10 +218,20 @@ internal class ClientHandshakeDriver(
val subscriptionUri = uriHandshake(CommonContext.IPC_MEDIA, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, 0,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
@Throws(ClientTimedOutException::class)
@ -233,6 +247,7 @@ internal class ClientHandshakeDriver(
streamIdPub: Int,
reliable: Boolean,
streamIdSub: Int,
tagName: String,
logInfo: String,
): PubSub {
@Suppress("NAME_SHADOWING")
@ -320,12 +335,20 @@ internal class ClientHandshakeDriver(
throw ex
}
return PubSub(publication, subscription,
sessionIdPub, 0,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}

View File

@ -32,6 +32,9 @@ internal class HandshakeMessage private constructor() {
// -1 means there is an error
var state = INVALID
// used to name a connection (via the client)
var tag: String = ""
var errorMessage: String? = null
var port = 0
@ -51,7 +54,7 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3
const val DONE_ACK = 4
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int): HandshakeMessage {
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int, tagName: String): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
@ -59,6 +62,7 @@ internal class HandshakeMessage private constructor() {
hello.sessionId = 0 // not used by the server, since it connects in a different way!
hello.streamId = streamIdSub
hello.port = portSub
hello.tag = tagName
return hello
}
@ -135,6 +139,6 @@ internal class HandshakeMessage private constructor() {
""
}
return "HandshakeMessage($stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
return "HandshakeMessage($tag :: $stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
}
}

View File

@ -16,9 +16,9 @@
package dorkbox.network.handshake
import dorkbox.network.connection.EndPoint
import io.aeron.Publication
import io.aeron.Subscription
import org.slf4j.Logger
import java.net.Inet4Address
import java.net.InetAddress
@ -30,29 +30,43 @@ data class PubSub(
val streamIdPub: Int,
val streamIdSub: Int,
val reliable: Boolean,
val remoteAddress: InetAddress? = null,
val remoteAddressString: String = "IPC",
val portPub: Int = 0,
val portSub: Int = 0
val remoteAddress: InetAddress?,
val remoteAddressString: String,
val portPub: Int,
val portSub: Int,
val tagName: String // will either be "", or will be "[tag_name]"
) {
val isIpc get() = remoteAddress == null
fun getLogInfo(logger: Logger): String {
val detailed = logger.isTraceEnabled
fun getLogInfo(extraDetails: Boolean): String {
return if (isIpc) {
if (detailed) {
"IPC sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
val prefix = if (tagName.isNotEmpty()) {
EndPoint.IPC_NAME + " $tagName"
} else {
"IPC"
EndPoint.IPC_NAME
}
if (extraDetails) {
if (tagName.isNotEmpty()) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
}
} else {
prefix
}
} else {
val prefix = if (remoteAddress is Inet4Address) {
var prefix = if (remoteAddress is Inet4Address) {
"IPv4 $remoteAddressString"
} else {
"IPv6 $remoteAddressString"
}
if (detailed) {
if (tagName.isNotEmpty()) {
prefix += " $tagName"
}
if (extraDetails) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, port: p=${portPub} s=${portSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
prefix

View File

@ -18,6 +18,7 @@ package dorkbox.network.handshake
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpInfo
import io.aeron.CommonContext
import java.net.Inet4Address
@ -33,16 +34,17 @@ import java.net.InetAddress
internal class ServerConnectionDriver(val pubSub: PubSub) {
companion object {
fun build(isIpc: Boolean,
aeronDriver: AeronDriver,
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
aeronDriver: AeronDriver,
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
ipInfo: IpInfo,
remoteAddress: InetAddress?,
remoteAddressString: String,
portPubMdc: Int, portPub: Int, portSub: Int,
reliable: Boolean,
logInfo: String): ServerConnectionDriver {
ipInfo: IpInfo,
remoteAddress: InetAddress?,
remoteAddressString: String,
portPubMdc: Int, portPub: Int, portSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String): ServerConnectionDriver {
val pubSub: PubSub
@ -54,6 +56,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
} else {
@ -70,6 +73,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
portPub = portPub,
portSub = portSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -82,6 +86,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -98,10 +103,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscriptionUri = uri(CommonContext.IPC_MEDIA, sessionIdSub, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
private fun buildUdp(
@ -114,6 +129,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
portPub: Int,
portSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -146,12 +162,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
}

View File

@ -21,8 +21,6 @@ import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.sessionIdAllocator
import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator
import dorkbox.network.connection.*
import dorkbox.network.connection.session.SessionConnection
import dorkbox.network.connection.session.SessionServer
import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ServerHandshakeException
import dorkbox.network.exceptions.ServerTimedoutException
@ -131,23 +129,16 @@ internal class ServerHandshake<CONNECTION : Connection>(
}
// Server is the "source", client mirrors the server
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) {
logger.trace("[${newConnection}] (${message.connectKey}) $connType (${newConnection.id}) done with handshake.")
logger.trace("[${newConnection}] (${message.connectKey}) Buffered connection (${newConnection.id}) done with handshake.")
} else if (logger.isDebugEnabled) {
logger.debug("[${newConnection}] $connType (${newConnection.id}) done with handshake.")
logger.debug("[${newConnection}] Buffered connection (${newConnection.id}) done with handshake.")
}
// in the specific case of using sessions, we don't want to call 'init' or `connect` for a connection that is resuming a session
// when applicable - we ALSO want to restore RMI objects BEFORE the connection is fully setup!
val newSession = server.sessionManager.onInit(newConnection)
newConnection.setImage()
// before we finish creating the connection, we initialize it (in case there needs to be logic that happens-before `onConnect` calls
if (newSession) {
listenerManager.notifyInit(newConnection)
}
listenerManager.notifyInit(newConnection)
// this enables the connection to start polling for messages
server.addConnection(newConnection)
@ -160,9 +151,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
listenerManager.notifyConnect(newConnection)
if (!newSession) {
(newConnection as SessionConnection).sendPendingMessages()
}
newConnection.sendBufferedMessages()
} catch (e: Exception) {
listenerManager.notifyError(newConnection, TransmitException("[$newConnection] Handshake error", e))
}
@ -248,6 +237,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
): Boolean {
val serialization = config.serialization
val clientTagName = message.tag.let { if (it.isEmpty()) "" else "[$it]" }
if (clientTagName.length > 34) {
// 34 to account for []
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
/////
/////
///// DONE WITH VALIDATION
@ -339,7 +335,8 @@ internal class ServerHandshake<CONNECTION : Connection>(
aeronDriver = aeronDriver,
ipInfo = server.ipInfo,
isIpc = true,
logInfo = "IPC",
tagName = clientTagName,
logInfo = EndPoint.IPC_NAME,
remoteAddress = null,
remoteAddressString = "",
@ -353,12 +350,11 @@ internal class ServerHandshake<CONNECTION : Connection>(
reliable = true
)
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger)
val connectionType = if (server is SessionServer) "session connection" else "connection"
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo")
logger.debug("Creating new buffered connection to $logInfo")
} else {
logger.info("Creating new $connectionType to $logInfo")
logger.info("Creating new buffered connection to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
@ -369,7 +365,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections
))
server.sessionManager.onNewConnection(newConnection)
server.bufferedManager.onConnect(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@ -393,8 +389,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
enableSession = newConnection is SessionConnection,
sessionTimeout = config.sessionTimeoutSeconds,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
@ -403,11 +398,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = newConnection
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${message.connectKey}) $connType (${newConnection.id}) responding to handshake hello.")
logger.trace("[$logInfo] (${message.connectKey}) buffered connection (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.")
logger.debug("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
}
// this tells the client all the info to connect.
@ -470,6 +464,27 @@ internal class ServerHandshake<CONNECTION : Connection>(
return false
}
val clientTagName = message.tag.let { if (it.isEmpty()) "" else "[$it]" }
if (clientTagName.length > 34) {
// 34 to account for []
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(clientAddress, clientTagName)
if (!permitConnection) {
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
/////
/////
@ -585,18 +600,18 @@ internal class ServerHandshake<CONNECTION : Connection>(
portPubMdc = mdcPortPub,
portPub = portPub,
portSub = portSub,
tagName = clientTagName,
reliable = isReliable
)
val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes)
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger)
val connectionType = if (server is SessionServer) "session connection" else "connection"
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo")
logger.debug("Creating new buffered connection to $logInfo")
} else {
logger.info("Creating new $connectionType to $logInfo")
logger.info("Creating new buffered connection to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
@ -607,31 +622,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
cryptoKey = cryptoSecretKey
))
server.sessionManager.onNewConnection(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) {
// this will also unwind/free allocations
newConnection.close()
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
server.bufferedManager.onConnect(newConnection)
///////////////
/// HANDSHAKE
///////////////
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(message.connectKey)
@ -645,8 +642,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
enableSession = newConnection is SessionConnection,
sessionTimeout = config.sessionTimeoutSeconds,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
@ -655,11 +651,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = newConnection
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.")
logger.trace("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.")
logger.debug("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
}
// this tells the client all the info to connect.

View File

@ -98,6 +98,7 @@ abstract class BaseTest {
val configuration = ClientConfiguration()
configuration.appId = "network_test"
configuration.tag = "**Client**"
configuration.settingsStore = Storage.Memory() // don't want to persist anything on disk!
configuration.enableIpc = false