Added support for connection tags (so the client can set a name for its connection, and the server will get that name). This is usefull for identifying different connections (and doing different things) based on their tag name.

This commit is contained in:
Robinson 2023-10-28 20:55:49 +02:00
parent fe98763712
commit 58535a923b
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
16 changed files with 253 additions and 136 deletions

View File

@ -116,6 +116,15 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
var addressPrettyString: String = "UNKNOWN"
private set
/**
* The tag name assigned (by the configuration) to the client. The server will receive this tag during the handshake. The max length is
* 32 characters.
*/
@Volatile
var tag: String = ""
private set
/**
* The default connection reliability type (ie: can the lower-level network stack throw away data that has errors, for example real-time-voice)
*/
@ -507,6 +516,10 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
this.port1 = port1
this.port2 = port2
// DOUBLE CHECK!
require(config.tag.length <= 32) { "Client tag name length must be <= 32" }
this.tag = config.tag
this.reliable = reliable
this.connectionTimeoutSec = connectionTimeoutSec
@ -578,12 +591,13 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
remotePort2 = port2,
handshakeTimeoutNs = handshakeTimeoutNs,
reliable = reliable,
tagName = tag,
logger = logger
)
val pubSub = handshakeConnection.pubSub
val logInfo = pubSub.getLogInfo(logger)
val logInfo = pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new handshake to $logInfo")
} else {
@ -701,6 +715,7 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
// throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class)
val connectionInfo = handshake.hello(
tagName = tag,
endPoint = this,
handshakeConnection = handshakeConnection,
handshakeTimeoutNs = handshakeTimeoutNs
@ -771,7 +786,8 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
handshakeTimeoutNs = handshakeTimeoutNs,
handshakeConnection = handshakeConnection,
connectionInfo = connectionInfo,
port2Server = port2
port2Server = port2,
tagName = tag
)
val pubSub = clientConnection.connectionInfo

View File

@ -167,6 +167,17 @@ class ClientConfiguration : dorkbox.network.Configuration() {
field = value
}
/**
* The tag name to be assigned to this connection and the server will receive this tag name during the handshake.
* The max length is 32 characters.
*/
var tag: String = ""
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Validates the current configuration. Throws an exception if there are problems.
*/
@ -180,6 +191,8 @@ class ClientConfiguration : dorkbox.network.Configuration() {
require(port > 0) { "Client listen port must be > 0" }
require(port < 65535) { "Client listen port must be < 65535" }
}
require(tag.length <= 32) { "Client tag name length must be <= 32" }
}
override fun initialize(logger: Logger): dorkbox.network.ClientConfiguration {
@ -191,6 +204,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
super.copy(config)
config.port = port
config.tag = tag
return config
}
@ -201,6 +215,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
if (!super.equals(other)) return false
if (port != other.port) return false
if (tag != other.tag) return false
return true
}
@ -208,6 +223,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
override fun hashCode(): Int {
var result = super.hashCode()
result = 31 * result + port.hashCode()
result = 31 * result + tag.hashCode()
return result
}
}

View File

@ -365,7 +365,7 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
*
* This function will be called for **only** network clients (IPC client are excluded)
*/
fun filter(function: CONNECTION.() -> Boolean) {
fun filter(function: InetAddress.(String) -> Boolean) {
listenerManager.filter(function)
}

View File

@ -86,6 +86,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
info.sessionIdSub
}
/**
* The tag name for a connection permits an INCOMING client to define a custom string. The max length is 32
*/
val tag = info.tagName
/**
* The remote address, as a string. Will be null for IPC connections
*/

View File

@ -182,7 +182,6 @@ internal class CryptoManagement(val logger: Logger,
val streamIdPub = cryptInput.readInt()
val streamIdSub = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val enableSession = cryptInput.readBoolean()
val sessionTimeout = cryptInput.readLong()
val regDetails = cryptInput.readBytes(regDetailsSize)
@ -193,7 +192,6 @@ internal class CryptoManagement(val logger: Logger,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
publicKey = serverPublicKeyBytes,
enableSession = enableSession,
sessionTimeout = sessionTimeout,
kryoRegistrationDetails = regDetails,
secretKey = secretKey)
@ -205,7 +203,6 @@ internal class CryptoManagement(val logger: Logger,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
enableSession: Boolean,
sessionTimeout: Long,
kryoRegDetails: ByteArray
): ByteArray {
@ -218,7 +215,6 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBoolean(enableSession)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBytes(kryoRegDetails)
@ -269,7 +265,6 @@ internal class CryptoManagement(val logger: Logger,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
enableSession: Boolean,
sessionTimeout: Long,
kryoRegDetails: ByteArray
): ByteArray {
@ -287,7 +282,6 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBoolean(enableSession)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBytes(kryoRegDetails)

View File

@ -189,7 +189,7 @@ internal class IpInfo(config: ServerConfiguration) {
}
else -> {
ipType = IPC
listenAddressString = "IPC"
listenAddressString = EndPoint.IPC_NAME
formattedListenAddressString = listenAddressString
}
}

View File

@ -22,6 +22,7 @@ import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.os.OS
import net.jodah.typetools.TypeResolver
import org.slf4j.Logger
import java.net.InetAddress
import java.util.concurrent.locks.*
import kotlin.concurrent.write
@ -144,7 +145,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return this
}
internal inline fun <reified T> add(thing: T, array: Array<T>): Array<T> {
internal inline fun <reified T: Any> add(thing: T, array: Array<T>): Array<T> {
val currentLength: Int = array.size
// add the new subscription to the END of the array
@ -155,7 +156,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return newMessageArray
}
internal inline fun <reified T> remove(thing: T, array: Array<T>): Array<T> {
internal inline fun <reified T: Any> remove(thing: T, array: Array<T>): Array<T> {
// remove the subscription form the array
// THIS IS IDENTITY CHECKS, NOT EQUALITY
return array.filter { it !== thing }.toTypedArray()
@ -164,7 +165,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
// initialize emtpy arrays
@Volatile
private var onConnectFilterList = Array<(CONNECTION.() -> Boolean)>(0) { { true } }
private var onConnectFilterList = Array<(InetAddress.(String) -> Boolean)>(0) { { true } }
private val onConnectFilterLock = ReentrantReadWriteLock()
@Volatile
@ -202,8 +203,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
*/
fun filter(ipFilterRule: IpFilterRule) {
filter {
// IPC will not filter, so this is OK to coerce to not-null
ipFilterRule.matches(remoteAddress!!)
// IPC will not filter
ipFilterRule.matches(this)
}
}
@ -212,15 +213,21 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if a connection
* should be allowed
*
* By default, if there are no filter rules, then all connections are allowed to connect
* If there are filter rules - then ONLY connections for the filter that returns true are allowed to connect (all else are denied)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the connection will continue to connect.
* If the function returns FALSE, then the other end of the connection will
* receive a connection error
*
* For a server, this function will be called for ALL clients.
*
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*/
fun filter(function: CONNECTION.() -> Boolean) {
fun filter(function: InetAddress.(String) -> Boolean) {
onConnectFilterLock.write {
// we have to follow the single-writer principle!
onConnectFilterList = add(function, onConnectFilterList)
@ -348,23 +355,16 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
*
* NOTE: This is run directly on the thread that calls it!
*
* @return true if the connection will be allowed to connect. False if we should terminate this connection
* @return true if the client address is allowed to connect. False if we should terminate this connection
*/
fun notifyFilter(connection: CONNECTION): Boolean {
// remote address will NOT be null at this stage, but best to verify.
val remoteAddress = connection.remoteAddress
if (remoteAddress == null) {
logger.error("Connection ${connection.id}: Unable to attempt connection stages when no remote address is present")
return false
}
fun notifyFilter(clientAddress: InetAddress, clientTagName: String): Boolean {
// by default, there is a SINGLE rule that will always exist, and will always ACCEPT ALL connections.
// This is so the array types can be setup (the compiler needs SOMETHING there)
val list = onConnectFilterList
// if there is a rule, a connection must match for it to connect
list.forEach {
if (it.invoke(connection)) {
if (it.invoke(clientAddress, clientTagName)) {
return true
}
}

View File

@ -21,6 +21,7 @@ import dorkbox.network.aeron.AeronDriver.Companion.getLocalAddressString
import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.aeron.controlEndpoint
import dorkbox.network.aeron.endpoint
import dorkbox.network.connection.EndPoint
import dorkbox.network.exceptions.ClientRetryException
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.CommonContext
@ -45,6 +46,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
handshakeConnection: ClientHandshakeDriver,
connectionInfo: ClientConnectionInfo,
port2Server: Int, // this is the port2 value from the server
tagName: String
): ClientConnectionDriver {
val handshakePubSub = handshakeConnection.pubSub
val reliable = handshakePubSub.reliable
@ -73,6 +75,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -101,6 +104,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
portSub = portSub,
port2Server = port2Server,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -117,6 +121,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub: Int,
streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -150,10 +155,20 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
}
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
@Throws(ClientTimedOutException::class)
@ -170,6 +185,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
portSub: Int,
port2Server: Int, // this is the port2 value from the server
reliable: Boolean,
tagName: String,
logInfo: String,
): PubSub {
val isRemoteIpv4 = remoteAddress is Inet4Address
@ -213,12 +229,20 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
}
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
}

View File

@ -23,7 +23,6 @@ internal class ClientConnectionInfo(
val streamIdPub: Int,
val streamIdSub: Int = 0,
val publicKey: ByteArray = ByteArray(0),
val enableSession: Boolean,
val sessionTimeout: Long,
val kryoRegistrationDetails: ByteArray,
val secretKey: SecretKeySpec

View File

@ -178,6 +178,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed
fun hello(
tagName: String,
endPoint: EndPoint<CONNECTION>,
handshakeConnection: ClientHandshakeDriver,
handshakeTimeoutNs: Long
@ -200,7 +201,8 @@ internal class ClientHandshake<CONNECTION: Connection>(
connectKey = connectKey,
publicKey = client.storage.publicKey,
streamIdSub = pubSub.streamIdSub,
portSub = pubSub.portSub
portSub = pubSub.portSub,
tagName = tagName
))
} catch (e: Exception) {
handshakeConnection.close(endPoint)

View File

@ -65,6 +65,7 @@ internal class ClientHandshakeDriver(
clientListenPort: Int,
handshakeTimeoutNs: Long,
reliable: Boolean,
tagName: String,
logger: Logger
): ClientHandshakeDriver {
logger.trace("Starting client handshake")
@ -111,6 +112,7 @@ internal class ClientHandshakeDriver(
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
} catch (exception: Exception) {
@ -160,6 +162,7 @@ internal class ClientHandshakeDriver(
streamIdPub = streamIdPub,
reliable = reliable,
streamIdSub = streamIdSub,
tagName = tagName,
logInfo = logInfo
)
@ -188,6 +191,7 @@ internal class ClientHandshakeDriver(
sessionIdPub: Int,
streamIdPub: Int, streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// Create a publication at the given address and port, using the given stream ID.
@ -214,10 +218,20 @@ internal class ClientHandshakeDriver(
val subscriptionUri = uriHandshake(CommonContext.IPC_MEDIA, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, 0,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
@Throws(ClientTimedOutException::class)
@ -233,6 +247,7 @@ internal class ClientHandshakeDriver(
streamIdPub: Int,
reliable: Boolean,
streamIdSub: Int,
tagName: String,
logInfo: String,
): PubSub {
@Suppress("NAME_SHADOWING")
@ -320,12 +335,20 @@ internal class ClientHandshakeDriver(
throw ex
}
return PubSub(publication, subscription,
sessionIdPub, 0,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}

View File

@ -32,6 +32,9 @@ internal class HandshakeMessage private constructor() {
// -1 means there is an error
var state = INVALID
// used to name a connection (via the client)
var tag: String = ""
var errorMessage: String? = null
var port = 0
@ -51,7 +54,7 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3
const val DONE_ACK = 4
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int): HandshakeMessage {
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int, tagName: String): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
@ -59,6 +62,7 @@ internal class HandshakeMessage private constructor() {
hello.sessionId = 0 // not used by the server, since it connects in a different way!
hello.streamId = streamIdSub
hello.port = portSub
hello.tag = tagName
return hello
}
@ -135,6 +139,6 @@ internal class HandshakeMessage private constructor() {
""
}
return "HandshakeMessage($stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
return "HandshakeMessage($tag :: $stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
}
}

View File

@ -16,9 +16,9 @@
package dorkbox.network.handshake
import dorkbox.network.connection.EndPoint
import io.aeron.Publication
import io.aeron.Subscription
import org.slf4j.Logger
import java.net.Inet4Address
import java.net.InetAddress
@ -30,29 +30,43 @@ data class PubSub(
val streamIdPub: Int,
val streamIdSub: Int,
val reliable: Boolean,
val remoteAddress: InetAddress? = null,
val remoteAddressString: String = "IPC",
val portPub: Int = 0,
val portSub: Int = 0
val remoteAddress: InetAddress?,
val remoteAddressString: String,
val portPub: Int,
val portSub: Int,
val tagName: String // will either be "", or will be "[tag_name]"
) {
val isIpc get() = remoteAddress == null
fun getLogInfo(logger: Logger): String {
val detailed = logger.isTraceEnabled
fun getLogInfo(extraDetails: Boolean): String {
return if (isIpc) {
if (detailed) {
"IPC sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
val prefix = if (tagName.isNotEmpty()) {
EndPoint.IPC_NAME + " $tagName"
} else {
"IPC"
EndPoint.IPC_NAME
}
if (extraDetails) {
if (tagName.isNotEmpty()) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
}
} else {
val prefix = if (remoteAddress is Inet4Address) {
prefix
}
} else {
var prefix = if (remoteAddress is Inet4Address) {
"IPv4 $remoteAddressString"
} else {
"IPv6 $remoteAddressString"
}
if (detailed) {
if (tagName.isNotEmpty()) {
prefix += " $tagName"
}
if (extraDetails) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, port: p=${portPub} s=${portSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
prefix

View File

@ -18,6 +18,7 @@ package dorkbox.network.handshake
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpInfo
import io.aeron.CommonContext
import java.net.Inet4Address
@ -42,6 +43,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
remoteAddressString: String,
portPubMdc: Int, portPub: Int, portSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String): ServerConnectionDriver {
val pubSub: PubSub
@ -54,6 +56,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
} else {
@ -70,6 +73,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
portPub = portPub,
portSub = portSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -82,6 +86,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -98,10 +103,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscriptionUri = uri(CommonContext.IPC_MEDIA, sessionIdSub, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
private fun buildUdp(
@ -114,6 +129,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
portPub: Int,
portSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -146,12 +162,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
}

View File

@ -21,8 +21,6 @@ import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.sessionIdAllocator
import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator
import dorkbox.network.connection.*
import dorkbox.network.connection.session.SessionConnection
import dorkbox.network.connection.session.SessionServer
import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ServerHandshakeException
import dorkbox.network.exceptions.ServerTimedoutException
@ -131,23 +129,16 @@ internal class ServerHandshake<CONNECTION : Connection>(
}
// Server is the "source", client mirrors the server
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) {
logger.trace("[${newConnection}] (${message.connectKey}) $connType (${newConnection.id}) done with handshake.")
logger.trace("[${newConnection}] (${message.connectKey}) Buffered connection (${newConnection.id}) done with handshake.")
} else if (logger.isDebugEnabled) {
logger.debug("[${newConnection}] $connType (${newConnection.id}) done with handshake.")
logger.debug("[${newConnection}] Buffered connection (${newConnection.id}) done with handshake.")
}
// in the specific case of using sessions, we don't want to call 'init' or `connect` for a connection that is resuming a session
// when applicable - we ALSO want to restore RMI objects BEFORE the connection is fully setup!
val newSession = server.sessionManager.onInit(newConnection)
newConnection.setImage()
// before we finish creating the connection, we initialize it (in case there needs to be logic that happens-before `onConnect` calls
if (newSession) {
listenerManager.notifyInit(newConnection)
}
// this enables the connection to start polling for messages
server.addConnection(newConnection)
@ -160,9 +151,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
listenerManager.notifyConnect(newConnection)
if (!newSession) {
(newConnection as SessionConnection).sendPendingMessages()
}
newConnection.sendBufferedMessages()
} catch (e: Exception) {
listenerManager.notifyError(newConnection, TransmitException("[$newConnection] Handshake error", e))
}
@ -248,6 +237,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
): Boolean {
val serialization = config.serialization
val clientTagName = message.tag.let { if (it.isEmpty()) "" else "[$it]" }
if (clientTagName.length > 34) {
// 34 to account for []
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
/////
/////
///// DONE WITH VALIDATION
@ -339,7 +335,8 @@ internal class ServerHandshake<CONNECTION : Connection>(
aeronDriver = aeronDriver,
ipInfo = server.ipInfo,
isIpc = true,
logInfo = "IPC",
tagName = clientTagName,
logInfo = EndPoint.IPC_NAME,
remoteAddress = null,
remoteAddressString = "",
@ -353,12 +350,11 @@ internal class ServerHandshake<CONNECTION : Connection>(
reliable = true
)
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger)
val connectionType = if (server is SessionServer) "session connection" else "connection"
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo")
logger.debug("Creating new buffered connection to $logInfo")
} else {
logger.info("Creating new $connectionType to $logInfo")
logger.info("Creating new buffered connection to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
@ -369,7 +365,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections
))
server.sessionManager.onNewConnection(newConnection)
server.bufferedManager.onConnect(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@ -393,8 +389,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
enableSession = newConnection is SessionConnection,
sessionTimeout = config.sessionTimeoutSeconds,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
@ -403,11 +398,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = newConnection
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${message.connectKey}) $connType (${newConnection.id}) responding to handshake hello.")
logger.trace("[$logInfo] (${message.connectKey}) buffered connection (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.")
logger.debug("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
}
// this tells the client all the info to connect.
@ -470,6 +464,27 @@ internal class ServerHandshake<CONNECTION : Connection>(
return false
}
val clientTagName = message.tag.let { if (it.isEmpty()) "" else "[$it]" }
if (clientTagName.length > 34) {
// 34 to account for []
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(clientAddress, clientTagName)
if (!permitConnection) {
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
/////
/////
@ -585,18 +600,18 @@ internal class ServerHandshake<CONNECTION : Connection>(
portPubMdc = mdcPortPub,
portPub = portPub,
portSub = portSub,
tagName = clientTagName,
reliable = isReliable
)
val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes)
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger)
val connectionType = if (server is SessionServer) "session connection" else "connection"
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo")
logger.debug("Creating new buffered connection to $logInfo")
} else {
logger.info("Creating new $connectionType to $logInfo")
logger.info("Creating new buffered connection to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
@ -607,31 +622,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
cryptoKey = cryptoSecretKey
))
server.sessionManager.onNewConnection(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(newConnection)
if (!permitConnection) {
// this will also unwind/free allocations
newConnection.close()
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
server.bufferedManager.onConnect(newConnection)
///////////////
/// HANDSHAKE
///////////////
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(message.connectKey)
@ -645,8 +642,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
enableSession = newConnection is SessionConnection,
sessionTimeout = config.sessionTimeoutSeconds,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
@ -655,11 +651,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = newConnection
val connType = if (newConnection is SessionConnection) "Session connection" else "Connection"
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.")
logger.trace("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connType (${newConnection.id}) responding to handshake hello.")
logger.debug("[$logInfo] Buffered connection (${newConnection.id}) responding to handshake hello.")
}
// this tells the client all the info to connect.

View File

@ -98,6 +98,7 @@ abstract class BaseTest {
val configuration = ClientConfiguration()
configuration.appId = "network_test"
configuration.tag = "**Client**"
configuration.settingsStore = Storage.Memory() // don't want to persist anything on disk!
configuration.enableIpc = false