Added support for PER-CONNECTION buffering of messages (default is enabled)

master
Robinson 2023-11-27 11:14:52 +01:00
parent f1a06fd8fd
commit bf0cd3f0e6
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
14 changed files with 220 additions and 35 deletions

View File

@ -797,19 +797,26 @@ open class Client<CONNECTION : Connection>(config: ClientConfiguration = ClientC
val pubSub = clientConnection.connectionInfo
val connectionType = if (connectionInfo.bufferedMessages) {
"buffered connection"
} else {
"connection"
}
val logInfo = pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new buffered connection to $logInfo")
logger.debug("Creating new $connectionType to $logInfo")
} else {
logger.info("Creating new buffered connection to $logInfo")
logger.info("Creating new $connectionType to $logInfo")
}
val newConnection = newConnection(ConnectionParams(
connectionInfo.publicKey,
this,
clientConnection.connectionInfo,
validateRemoteAddress,
connectionInfo.secretKey
publicKey = connectionInfo.publicKey,
endPoint = this,
connectionInfo = clientConnection.connectionInfo,
publicKeyValidation = validateRemoteAddress,
enableBufferedMessages = connectionInfo.bufferedMessages,
cryptoKey = connectionInfo.secretKey
))
bufferedManager?.onConnect(newConnection)

View File

@ -21,7 +21,6 @@ import dorkbox.network.connection.*
import dorkbox.network.connection.IpInfo.Companion.IpListenType
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import dorkbox.network.connection.buffer.BufferManager
import dorkbox.network.connectionType.ConnectionRule
import dorkbox.network.exceptions.ServerException
import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.handshake.ServerHandshakePollers
@ -370,10 +369,31 @@ open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerC
* @param function clientAddress: UDP connection address
* tagName: the connection tag name
*/
fun filter(function: InetAddress.(String) -> Boolean) {
fun filter(function: (clientAddress: InetAddress, tagName: String) -> Boolean) {
listenerManager.filter(function)
}
/**
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if buffered messages
* for a connection should be enabled
*
* By default, if there are no rules, then all connections will have buffered messages enabled
* If there are rules - then ONLY connections for the rule that returns true will have buffered messages enabled (all else are disabled)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the buffered messages for a connection are enabled.
* If the function returns FALSE, then the buffered messages for a connection is disabled.
*
* If ANY rule that is applied returns true, then the buffered messages for a connection are enabled
*
* @param function clientAddress: not-null when UDP connection, null when IPC connection
* tagName: the connection tag name
*/
fun enablePendingMessages(function: (clientAddress: InetAddress?, tagName: String) -> Boolean) {
listenerManager.enableBufferedMessages(function)
}
/**
* Runs an action for each connection
*/

View File

@ -121,6 +121,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
*/
private val bufferedSession: BufferedSession
/**
* used to determine if this connection will have buffered messages enabled or not.
*/
internal val enableBufferedMessages = connectionParameters.enableBufferedMessages
/**
* The largest size a SINGLE message via AERON can be. Because the maximum size we can send in a "single fragment" is the
* publication.maxPayloadLength() function (which is the MTU length less header). We could depend on Aeron for fragment reassembly,
@ -338,13 +343,15 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
}
internal fun sendBufferedMessages() {
// now send all buffered/pending messages
if (logger.isDebugEnabled) {
logger.debug("Sending pending messages: ${bufferedSession.pendingMessagesQueue.size}")
}
if (enableBufferedMessages) {
// now send all buffered/pending messages
if (logger.isDebugEnabled) {
logger.debug("Sending buffered messages: ${bufferedSession.pendingMessagesQueue.size}")
}
bufferedSession.pendingMessagesQueue.forEach {
sendNoBuffer(it)
bufferedSession.pendingMessagesQueue.forEach {
sendNoBuffer(it)
}
}
}

View File

@ -23,5 +23,6 @@ data class ConnectionParams<CONNECTION : Connection>(
val endPoint: EndPoint<CONNECTION>,
val connectionInfo: PubSub,
val publicKeyValidation: PublicKeyValidationState,
val enableBufferedMessages: Boolean,
val cryptoKey: SecretKeySpec
)

View File

@ -183,6 +183,7 @@ internal class CryptoManagement(val logger: Logger,
val streamIdSub = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val sessionTimeout = cryptInput.readLong()
val bufferedMessages = cryptInput.readBoolean()
val regDetails = cryptInput.readBytes(regDetailsSize)
// now save data off
@ -193,6 +194,7 @@ internal class CryptoManagement(val logger: Logger,
streamIdSub = streamIdSub,
publicKey = serverPublicKeyBytes,
sessionTimeout = sessionTimeout,
bufferedMessages = bufferedMessages,
kryoRegistrationDetails = regDetails,
secretKey = secretKey)
}
@ -204,6 +206,7 @@ internal class CryptoManagement(val logger: Logger,
streamIdPub: Int,
streamIdSub: Int,
sessionTimeout: Long,
bufferedMessages: Boolean,
kryoRegDetails: ByteArray
): ByteArray {
@ -216,6 +219,7 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBoolean(bufferedMessages)
cryptOutput.writeBytes(kryoRegDetails)
cryptOutput.toBytes()
@ -266,6 +270,7 @@ internal class CryptoManagement(val logger: Logger,
streamIdPub: Int,
streamIdSub: Int,
sessionTimeout: Long,
bufferedMessages: Boolean,
kryoRegDetails: ByteArray
): ByteArray {
@ -283,6 +288,7 @@ internal class CryptoManagement(val logger: Logger,
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBoolean(bufferedMessages)
cryptOutput.writeBytes(kryoRegDetails)
return iv + aesCipher.doFinal(cryptOutput.toBytes())

View File

@ -165,9 +165,13 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
// initialize emtpy arrays
@Volatile
private var onConnectFilterList = Array<(InetAddress.(String) -> Boolean)>(0) { { true } }
private var onConnectFilterList = Array<((InetAddress, String) -> Boolean)>(0) { { _, _ -> true } }
private val onConnectFilterLock = ReentrantReadWriteLock()
@Volatile
private var onConnectBufferedMessageFilterList = Array<((InetAddress?, String) -> Boolean)>(0) { { _, _ -> true } }
private val onConnectBufferedMessageFilterLock = ReentrantReadWriteLock()
@Volatile
private var onInitList = Array<(CONNECTION.() -> Unit)>(0) { { } }
private val onInitLock = ReentrantReadWriteLock()
@ -202,9 +206,9 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
* If there are rules added, then a rule MUST be matched to be allowed
*/
fun filter(ipFilterRule: IpFilterRule) {
filter {
filter { clientAddress, _ ->
// IPC will not filter
ipFilterRule.matches(this)
ipFilterRule.matches(clientAddress)
}
}
@ -226,14 +230,41 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*
* @param function clientAddress: UDP connection address
* tagName: the connection tag name
*/
fun filter(function: InetAddress.(String) -> Boolean) {
fun filter(function: (clientAddress: InetAddress, tagName: String) -> Boolean) {
onConnectFilterLock.write {
// we have to follow the single-writer principle!
onConnectFilterList = add(function, onConnectFilterList)
}
}
/**
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if buffered messages
* for a connection should be enabled
*
* By default, if there are no rules, then all connections will have buffered messages enabled
* If there are rules - then ONLY connections for the rule that returns true will have buffered messages enabled (all else are disabled)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the buffered messages for a connection are enabled.
* If the function returns FALSE, then the buffered messages for a connection is disabled.
*
* If ANY rule that is applied returns true, then the buffered messages for a connection are enabled
*
* @param function clientAddress: not-null when UDP connection, null when IPC connection
* tagName: the connection tag name
*/
fun enableBufferedMessages(function: (clientAddress: InetAddress?, tagName: String) -> Boolean) {
onConnectBufferedMessageFilterLock.write {
// we have to follow the single-writer principle!
onConnectBufferedMessageFilterList = add(function, onConnectBufferedMessageFilterList)
}
}
/**
* Adds a function that will be called when a client/server connection is FIRST initialized, but before it's
* connected to the remote endpoint
@ -375,6 +406,33 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
return list.isEmpty()
}
/**
* Invoked just after a connection is created, but before it is connected.
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* This is run directly on the thread that calls it!
*
* @return true if the connection will have pending messages enabled. False if pending messages for this connection should be disabled.
*/
fun notifyEnableBufferedMessages(clientAddress: InetAddress?, clientTagName: String): Boolean {
// by default, there is a SINGLE rule that will always exist, and will always PERMIT pending messages.
// This is so the array types can be setup (the compiler needs SOMETHING there)
val list = onConnectBufferedMessageFilterList
// if there is a rule, a connection must match for it to enable pending messages
list.forEach {
if (it.invoke(clientAddress, clientTagName)) {
return true
}
}
// default if nothing matches
// NO RULES ADDED -> ALLOW Pending Messages
// RULES ADDED -> DISABLE Pending Messages
return list.isEmpty()
}
/**
* Invoked when a connection is first initialized, but BEFORE it's connected to the remote address.
*
@ -554,7 +612,10 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: Logge
logger.debug("Closing the listener manager")
onConnectFilterLock.write {
onConnectFilterList = Array(0) { { true } }
onConnectFilterList = Array(0) { { _, _ -> true } }
}
onConnectBufferedMessageFilterLock.write {
onConnectBufferedMessageFilterList = Array(0) { { _, _ -> true } }
}
onInitLock.write {
onInitList = Array(0) { { } }

View File

@ -38,6 +38,11 @@ open class BufferedSession(@Volatile var connection: Connection) {
}
}
if (!connection.enableBufferedMessages) {
// nothing, since we emit logs during connection initialization that pending messages are DISABLED
return false
}
if (!abortEarly) {
// this was a "normal" send (instead of the disconnect message).
pendingMessagesQueue.put(message)

View File

@ -24,6 +24,7 @@ internal class ClientConnectionInfo(
val streamIdSub: Int = 0,
val publicKey: ByteArray = ByteArray(0),
val sessionTimeout: Long,
val bufferedMessages: Boolean,
val kryoRegistrationDetails: ByteArray,
val secretKey: SecretKeySpec
)

View File

@ -353,11 +353,19 @@ internal class ServerHandshake<CONNECTION : Connection>(
reliable = true
)
val enableBufferedMessagesForConnection = listenerManager.notifyEnableBufferedMessages(null, clientTagName)
val connectionType = if (enableBufferedMessagesForConnection) {
"buffered connection"
} else {
"connection"
}
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new buffered connection to $logInfo")
logger.debug("Creating new $connectionType to $logInfo")
} else {
logger.info("Creating new buffered connection to $logInfo")
logger.info("Creating new $connectionType to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
@ -365,6 +373,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
endPoint = server,
connectionInfo = newConnectionDriver.pubSub,
publicKeyValidation = PublicKeyValidationState.VALID,
enableBufferedMessages = enableBufferedMessagesForConnection,
cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections
))
@ -393,6 +402,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
bufferedMessages = enableBufferedMessagesForConnection,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
@ -611,11 +621,19 @@ internal class ServerHandshake<CONNECTION : Connection>(
val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes)
val enableBufferedMessagesForConnection = listenerManager.notifyEnableBufferedMessages(clientAddress, clientTagName)
val connectionType = if (enableBufferedMessagesForConnection) {
"buffered connection"
} else {
"connection"
}
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new buffered connection to $logInfo")
logger.debug("Creating new $connectionType to $logInfo")
} else {
logger.info("Creating new buffered connection to $logInfo")
logger.info("Creating new $connectionType to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
@ -623,6 +641,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
endPoint = server,
connectionInfo = newConnectionDriver.pubSub,
publicKeyValidation = validateRemoteAddress,
enableBufferedMessages = enableBufferedMessagesForConnection,
cryptoKey = cryptoSecretKey
))
@ -647,6 +666,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
bufferedMessages = enableBufferedMessagesForConnection,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)

View File

@ -27,6 +27,7 @@ import kotlinx.atomicfu.atomic
import org.junit.Assert
import org.junit.Test
@Suppress("UNUSED_ANONYMOUS_PARAMETER")
class ConnectionFilterTest : BaseTest() {
@Test
fun autoAcceptAll() {
@ -395,7 +396,7 @@ class ConnectionFilterTest : BaseTest() {
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.filter {
server.filter { clientAddress, tagName ->
true
}
@ -498,7 +499,7 @@ class ConnectionFilterTest : BaseTest() {
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.filter {
server.filter { clientAddress, tagName ->
false
}
@ -544,7 +545,7 @@ class ConnectionFilterTest : BaseTest() {
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.filter {
server.filter { clientAddress, tagName ->
false
}
@ -583,4 +584,60 @@ class ConnectionFilterTest : BaseTest() {
waitForThreads()
}
@Test
fun acceptAllCustomClientNoPendingMessages() {
val serverConnectSuccess = atomic(false)
val clientConnectSuccess = atomic(false)
val server = run {
val configuration = serverConfig()
val server: Server<Connection> = Server(configuration)
addEndPoint(server)
server.enablePendingMessages { clientAddress, tagName ->
false
}
server.onConnect {
serverConnectSuccess.value = true
close()
}
server
}
val client = run {
val config = clientConfig()
val client: Client<Connection> = Client(config)
addEndPoint(client)
client.onConnect {
clientConnectSuccess.value = true
}
client.onDisconnect {
stopEndPoints()
}
client
}
server.bind(2000)
try {
client.connect(LOCALHOST, 2000)
} catch (e: Exception) {
stopEndPoints()
waitForThreads()
throw e
}
waitForThreads()
Assert.assertTrue(serverConnectSuccess.value)
Assert.assertTrue(clientConnectSuccess.value)
}
}

View File

@ -205,8 +205,8 @@ class AeronClientServer {
server.filter(IpSubnetFilterRule(IPv4.LOCALHOST, 32))
server.filter {
println("should the connection $this be allowed?")
server.filter { clientAddress, tagName ->
println("should the connection $clientAddress be allowed?")
true
}

View File

@ -216,8 +216,8 @@ class AeronClientServerForever {
server.filter(IpSubnetFilterRule(IPv4.LOCALHOST, 32))
server.filter {
println("should the connection $this be allowed?")
server.filter { clientAddress, tagName ->
println("should the connection $clientAddress be allowed?")
true
}

View File

@ -221,8 +221,8 @@ class AeronClientServerRMIForever {
server.filter(IpSubnetFilterRule(IPv4.LOCALHOST, 32))
server.filter {
println("should the connection $this be allowed?")
server.filter { clientAddress, tagName ->
println("should the connection $clientAddress be allowed?")
true
}

View File

@ -105,8 +105,8 @@ object AeronServer {
throw IllegalStateException("Aeron was unable to shut down in a timely manner.")
}
server.filter {
println("should the connection $this be allowed?")
server.filter { clientAddress, tagName ->
println("should the connection $clientAddress be allowed?")
true
}