diff --git a/src/dorkbox/network/Client.kt b/src/dorkbox/network/Client.kt index c0bd3063..347a4685 100644 --- a/src/dorkbox/network/Client.kt +++ b/src/dorkbox/network/Client.kt @@ -797,19 +797,26 @@ open class Client(config: ClientConfiguration = ClientC val pubSub = clientConnection.connectionInfo + val connectionType = if (connectionInfo.bufferedMessages) { + "buffered connection" + } else { + "connection" + } + val logInfo = pubSub.getLogInfo(logger.isDebugEnabled) if (logger.isDebugEnabled) { - logger.debug("Creating new buffered connection to $logInfo") + logger.debug("Creating new $connectionType to $logInfo") } else { - logger.info("Creating new buffered connection to $logInfo") + logger.info("Creating new $connectionType to $logInfo") } val newConnection = newConnection(ConnectionParams( - connectionInfo.publicKey, - this, - clientConnection.connectionInfo, - validateRemoteAddress, - connectionInfo.secretKey + publicKey = connectionInfo.publicKey, + endPoint = this, + connectionInfo = clientConnection.connectionInfo, + publicKeyValidation = validateRemoteAddress, + enableBufferedMessages = connectionInfo.bufferedMessages, + cryptoKey = connectionInfo.secretKey )) bufferedManager?.onConnect(newConnection) diff --git a/src/dorkbox/network/Server.kt b/src/dorkbox/network/Server.kt index 48dcecbf..3fdd952a 100644 --- a/src/dorkbox/network/Server.kt +++ b/src/dorkbox/network/Server.kt @@ -21,7 +21,6 @@ import dorkbox.network.connection.* import dorkbox.network.connection.IpInfo.Companion.IpListenType import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace import dorkbox.network.connection.buffer.BufferManager -import dorkbox.network.connectionType.ConnectionRule import dorkbox.network.exceptions.ServerException import dorkbox.network.handshake.ServerHandshake import dorkbox.network.handshake.ServerHandshakePollers @@ -370,10 +369,31 @@ open class Server(config: ServerConfiguration = ServerC * @param function clientAddress: UDP connection address * tagName: the connection tag name */ - fun filter(function: InetAddress.(String) -> Boolean) { + fun filter(function: (clientAddress: InetAddress, tagName: String) -> Boolean) { listenerManager.filter(function) } + /** + * Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if buffered messages + * for a connection should be enabled + * + * By default, if there are no rules, then all connections will have buffered messages enabled + * If there are rules - then ONLY connections for the rule that returns true will have buffered messages enabled (all else are disabled) + * + * It is the responsibility of the custom filter to write the error, if there is one + * + * If the function returns TRUE, then the buffered messages for a connection are enabled. + * If the function returns FALSE, then the buffered messages for a connection is disabled. + * + * If ANY rule that is applied returns true, then the buffered messages for a connection are enabled + * + * @param function clientAddress: not-null when UDP connection, null when IPC connection + * tagName: the connection tag name + */ + fun enablePendingMessages(function: (clientAddress: InetAddress?, tagName: String) -> Boolean) { + listenerManager.enableBufferedMessages(function) + } + /** * Runs an action for each connection */ diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index 41c362c9..eac7314e 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -121,6 +121,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) { */ private val bufferedSession: BufferedSession + /** + * used to determine if this connection will have buffered messages enabled or not. + */ + internal val enableBufferedMessages = connectionParameters.enableBufferedMessages + /** * The largest size a SINGLE message via AERON can be. Because the maximum size we can send in a "single fragment" is the * publication.maxPayloadLength() function (which is the MTU length less header). We could depend on Aeron for fragment reassembly, @@ -338,13 +343,15 @@ open class Connection(connectionParameters: ConnectionParams<*>) { } internal fun sendBufferedMessages() { - // now send all buffered/pending messages - if (logger.isDebugEnabled) { - logger.debug("Sending pending messages: ${bufferedSession.pendingMessagesQueue.size}") - } + if (enableBufferedMessages) { + // now send all buffered/pending messages + if (logger.isDebugEnabled) { + logger.debug("Sending buffered messages: ${bufferedSession.pendingMessagesQueue.size}") + } - bufferedSession.pendingMessagesQueue.forEach { - sendNoBuffer(it) + bufferedSession.pendingMessagesQueue.forEach { + sendNoBuffer(it) + } } } diff --git a/src/dorkbox/network/connection/ConnectionParams.kt b/src/dorkbox/network/connection/ConnectionParams.kt index eb056364..4c40092e 100644 --- a/src/dorkbox/network/connection/ConnectionParams.kt +++ b/src/dorkbox/network/connection/ConnectionParams.kt @@ -23,5 +23,6 @@ data class ConnectionParams( val endPoint: EndPoint, val connectionInfo: PubSub, val publicKeyValidation: PublicKeyValidationState, + val enableBufferedMessages: Boolean, val cryptoKey: SecretKeySpec ) diff --git a/src/dorkbox/network/connection/CryptoManagement.kt b/src/dorkbox/network/connection/CryptoManagement.kt index 63fa9696..0db003f5 100644 --- a/src/dorkbox/network/connection/CryptoManagement.kt +++ b/src/dorkbox/network/connection/CryptoManagement.kt @@ -183,6 +183,7 @@ internal class CryptoManagement(val logger: Logger, val streamIdSub = cryptInput.readInt() val regDetailsSize = cryptInput.readInt() val sessionTimeout = cryptInput.readLong() + val bufferedMessages = cryptInput.readBoolean() val regDetails = cryptInput.readBytes(regDetailsSize) // now save data off @@ -193,6 +194,7 @@ internal class CryptoManagement(val logger: Logger, streamIdSub = streamIdSub, publicKey = serverPublicKeyBytes, sessionTimeout = sessionTimeout, + bufferedMessages = bufferedMessages, kryoRegistrationDetails = regDetails, secretKey = secretKey) } @@ -204,6 +206,7 @@ internal class CryptoManagement(val logger: Logger, streamIdPub: Int, streamIdSub: Int, sessionTimeout: Long, + bufferedMessages: Boolean, kryoRegDetails: ByteArray ): ByteArray { @@ -216,6 +219,7 @@ internal class CryptoManagement(val logger: Logger, cryptOutput.writeInt(streamIdSub) cryptOutput.writeInt(kryoRegDetails.size) cryptOutput.writeLong(sessionTimeout) + cryptOutput.writeBoolean(bufferedMessages) cryptOutput.writeBytes(kryoRegDetails) cryptOutput.toBytes() @@ -266,6 +270,7 @@ internal class CryptoManagement(val logger: Logger, streamIdPub: Int, streamIdSub: Int, sessionTimeout: Long, + bufferedMessages: Boolean, kryoRegDetails: ByteArray ): ByteArray { @@ -283,6 +288,7 @@ internal class CryptoManagement(val logger: Logger, cryptOutput.writeInt(streamIdSub) cryptOutput.writeInt(kryoRegDetails.size) cryptOutput.writeLong(sessionTimeout) + cryptOutput.writeBoolean(bufferedMessages) cryptOutput.writeBytes(kryoRegDetails) return iv + aesCipher.doFinal(cryptOutput.toBytes()) diff --git a/src/dorkbox/network/connection/ListenerManager.kt b/src/dorkbox/network/connection/ListenerManager.kt index 4a64fb4e..17b75103 100644 --- a/src/dorkbox/network/connection/ListenerManager.kt +++ b/src/dorkbox/network/connection/ListenerManager.kt @@ -165,9 +165,13 @@ internal class ListenerManager(private val logger: Logge // initialize emtpy arrays @Volatile - private var onConnectFilterList = Array<(InetAddress.(String) -> Boolean)>(0) { { true } } + private var onConnectFilterList = Array<((InetAddress, String) -> Boolean)>(0) { { _, _ -> true } } private val onConnectFilterLock = ReentrantReadWriteLock() + @Volatile + private var onConnectBufferedMessageFilterList = Array<((InetAddress?, String) -> Boolean)>(0) { { _, _ -> true } } + private val onConnectBufferedMessageFilterLock = ReentrantReadWriteLock() + @Volatile private var onInitList = Array<(CONNECTION.() -> Unit)>(0) { { } } private val onInitLock = ReentrantReadWriteLock() @@ -202,9 +206,9 @@ internal class ListenerManager(private val logger: Logge * If there are rules added, then a rule MUST be matched to be allowed */ fun filter(ipFilterRule: IpFilterRule) { - filter { + filter { clientAddress, _ -> // IPC will not filter - ipFilterRule.matches(this) + ipFilterRule.matches(clientAddress) } } @@ -226,14 +230,41 @@ internal class ListenerManager(private val logger: Logge * If ANY filter rule that is applied returns true, then the connection is permitted * * This function will be called for **only** network clients (IPC client are excluded) + * + * @param function clientAddress: UDP connection address + * tagName: the connection tag name */ - fun filter(function: InetAddress.(String) -> Boolean) { + fun filter(function: (clientAddress: InetAddress, tagName: String) -> Boolean) { onConnectFilterLock.write { // we have to follow the single-writer principle! onConnectFilterList = add(function, onConnectFilterList) } } + /** + * Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if buffered messages + * for a connection should be enabled + * + * By default, if there are no rules, then all connections will have buffered messages enabled + * If there are rules - then ONLY connections for the rule that returns true will have buffered messages enabled (all else are disabled) + * + * It is the responsibility of the custom filter to write the error, if there is one + * + * If the function returns TRUE, then the buffered messages for a connection are enabled. + * If the function returns FALSE, then the buffered messages for a connection is disabled. + * + * If ANY rule that is applied returns true, then the buffered messages for a connection are enabled + * + * @param function clientAddress: not-null when UDP connection, null when IPC connection + * tagName: the connection tag name + */ + fun enableBufferedMessages(function: (clientAddress: InetAddress?, tagName: String) -> Boolean) { + onConnectBufferedMessageFilterLock.write { + // we have to follow the single-writer principle! + onConnectBufferedMessageFilterList = add(function, onConnectBufferedMessageFilterList) + } + } + /** * Adds a function that will be called when a client/server connection is FIRST initialized, but before it's * connected to the remote endpoint @@ -375,6 +406,33 @@ internal class ListenerManager(private val logger: Logge return list.isEmpty() } + /** + * Invoked just after a connection is created, but before it is connected. + * + * It is the responsibility of the custom filter to write the error, if there is one + * + * This is run directly on the thread that calls it! + * + * @return true if the connection will have pending messages enabled. False if pending messages for this connection should be disabled. + */ + fun notifyEnableBufferedMessages(clientAddress: InetAddress?, clientTagName: String): Boolean { + // by default, there is a SINGLE rule that will always exist, and will always PERMIT pending messages. + // This is so the array types can be setup (the compiler needs SOMETHING there) + val list = onConnectBufferedMessageFilterList + + // if there is a rule, a connection must match for it to enable pending messages + list.forEach { + if (it.invoke(clientAddress, clientTagName)) { + return true + } + } + + // default if nothing matches + // NO RULES ADDED -> ALLOW Pending Messages + // RULES ADDED -> DISABLE Pending Messages + return list.isEmpty() + } + /** * Invoked when a connection is first initialized, but BEFORE it's connected to the remote address. * @@ -554,7 +612,10 @@ internal class ListenerManager(private val logger: Logge logger.debug("Closing the listener manager") onConnectFilterLock.write { - onConnectFilterList = Array(0) { { true } } + onConnectFilterList = Array(0) { { _, _ -> true } } + } + onConnectBufferedMessageFilterLock.write { + onConnectBufferedMessageFilterList = Array(0) { { _, _ -> true } } } onInitLock.write { onInitList = Array(0) { { } } diff --git a/src/dorkbox/network/connection/buffer/BufferedSession.kt b/src/dorkbox/network/connection/buffer/BufferedSession.kt index d80f4cb1..34cfddf5 100644 --- a/src/dorkbox/network/connection/buffer/BufferedSession.kt +++ b/src/dorkbox/network/connection/buffer/BufferedSession.kt @@ -38,6 +38,11 @@ open class BufferedSession(@Volatile var connection: Connection) { } } + if (!connection.enableBufferedMessages) { + // nothing, since we emit logs during connection initialization that pending messages are DISABLED + return false + } + if (!abortEarly) { // this was a "normal" send (instead of the disconnect message). pendingMessagesQueue.put(message) diff --git a/src/dorkbox/network/handshake/ClientConnectionInfo.kt b/src/dorkbox/network/handshake/ClientConnectionInfo.kt index fb191e94..fe157748 100644 --- a/src/dorkbox/network/handshake/ClientConnectionInfo.kt +++ b/src/dorkbox/network/handshake/ClientConnectionInfo.kt @@ -24,6 +24,7 @@ internal class ClientConnectionInfo( val streamIdSub: Int = 0, val publicKey: ByteArray = ByteArray(0), val sessionTimeout: Long, + val bufferedMessages: Boolean, val kryoRegistrationDetails: ByteArray, val secretKey: SecretKeySpec ) diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index d4e14d39..280482bb 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -353,11 +353,19 @@ internal class ServerHandshake( reliable = true ) + + val enableBufferedMessagesForConnection = listenerManager.notifyEnableBufferedMessages(null, clientTagName) + val connectionType = if (enableBufferedMessagesForConnection) { + "buffered connection" + } else { + "connection" + } + val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled) if (logger.isDebugEnabled) { - logger.debug("Creating new buffered connection to $logInfo") + logger.debug("Creating new $connectionType to $logInfo") } else { - logger.info("Creating new buffered connection to $logInfo") + logger.info("Creating new $connectionType to $logInfo") } newConnection = server.newConnection(ConnectionParams( @@ -365,6 +373,7 @@ internal class ServerHandshake( endPoint = server, connectionInfo = newConnectionDriver.pubSub, publicKeyValidation = PublicKeyValidationState.VALID, + enableBufferedMessages = enableBufferedMessagesForConnection, cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections )) @@ -393,6 +402,7 @@ internal class ServerHandshake( streamIdPub = connectionStreamIdPub, streamIdSub = connectionStreamIdSub, sessionTimeout = config.bufferedConnectionTimeoutSeconds, + bufferedMessages = enableBufferedMessagesForConnection, kryoRegDetails = serialization.getKryoRegistrationDetails() ) @@ -611,11 +621,19 @@ internal class ServerHandshake( val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes) + val enableBufferedMessagesForConnection = listenerManager.notifyEnableBufferedMessages(clientAddress, clientTagName) + val connectionType = if (enableBufferedMessagesForConnection) { + "buffered connection" + } else { + "connection" + } + + val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled) if (logger.isDebugEnabled) { - logger.debug("Creating new buffered connection to $logInfo") + logger.debug("Creating new $connectionType to $logInfo") } else { - logger.info("Creating new buffered connection to $logInfo") + logger.info("Creating new $connectionType to $logInfo") } newConnection = server.newConnection(ConnectionParams( @@ -623,6 +641,7 @@ internal class ServerHandshake( endPoint = server, connectionInfo = newConnectionDriver.pubSub, publicKeyValidation = validateRemoteAddress, + enableBufferedMessages = enableBufferedMessagesForConnection, cryptoKey = cryptoSecretKey )) @@ -647,6 +666,7 @@ internal class ServerHandshake( streamIdPub = connectionStreamIdPub, streamIdSub = connectionStreamIdSub, sessionTimeout = config.bufferedConnectionTimeoutSeconds, + bufferedMessages = enableBufferedMessagesForConnection, kryoRegDetails = serialization.getKryoRegistrationDetails() ) diff --git a/test/dorkboxTest/network/ConnectionFilterTest.kt b/test/dorkboxTest/network/ConnectionFilterTest.kt index 86d88ea1..7e749d6e 100644 --- a/test/dorkboxTest/network/ConnectionFilterTest.kt +++ b/test/dorkboxTest/network/ConnectionFilterTest.kt @@ -27,6 +27,7 @@ import kotlinx.atomicfu.atomic import org.junit.Assert import org.junit.Test +@Suppress("UNUSED_ANONYMOUS_PARAMETER") class ConnectionFilterTest : BaseTest() { @Test fun autoAcceptAll() { @@ -395,7 +396,7 @@ class ConnectionFilterTest : BaseTest() { val server: Server = Server(configuration) addEndPoint(server) - server.filter { + server.filter { clientAddress, tagName -> true } @@ -498,7 +499,7 @@ class ConnectionFilterTest : BaseTest() { val server: Server = Server(configuration) addEndPoint(server) - server.filter { + server.filter { clientAddress, tagName -> false } @@ -544,7 +545,7 @@ class ConnectionFilterTest : BaseTest() { val server: Server = Server(configuration) addEndPoint(server) - server.filter { + server.filter { clientAddress, tagName -> false } @@ -583,4 +584,60 @@ class ConnectionFilterTest : BaseTest() { waitForThreads() } + + + @Test + fun acceptAllCustomClientNoPendingMessages() { + val serverConnectSuccess = atomic(false) + val clientConnectSuccess = atomic(false) + + val server = run { + val configuration = serverConfig() + + val server: Server = Server(configuration) + addEndPoint(server) + + server.enablePendingMessages { clientAddress, tagName -> + false + } + + server.onConnect { + serverConnectSuccess.value = true + close() + } + server + } + + val client = run { + val config = clientConfig() + + val client: Client = Client(config) + addEndPoint(client) + + + client.onConnect { + clientConnectSuccess.value = true + } + + client.onDisconnect { + stopEndPoints() + } + + client + } + + server.bind(2000) + try { + client.connect(LOCALHOST, 2000) + } catch (e: Exception) { + stopEndPoints() + waitForThreads() + throw e + } + + waitForThreads() + + Assert.assertTrue(serverConnectSuccess.value) + Assert.assertTrue(clientConnectSuccess.value) + } } diff --git a/test/dorkboxTest/network/app/AeronClientServer.kt b/test/dorkboxTest/network/app/AeronClientServer.kt index 8ec21328..6a6ac141 100644 --- a/test/dorkboxTest/network/app/AeronClientServer.kt +++ b/test/dorkboxTest/network/app/AeronClientServer.kt @@ -205,8 +205,8 @@ class AeronClientServer { server.filter(IpSubnetFilterRule(IPv4.LOCALHOST, 32)) - server.filter { - println("should the connection $this be allowed?") + server.filter { clientAddress, tagName -> + println("should the connection $clientAddress be allowed?") true } diff --git a/test/dorkboxTest/network/app/AeronClientServerForever.kt b/test/dorkboxTest/network/app/AeronClientServerForever.kt index f5b510a4..265c5d7a 100644 --- a/test/dorkboxTest/network/app/AeronClientServerForever.kt +++ b/test/dorkboxTest/network/app/AeronClientServerForever.kt @@ -216,8 +216,8 @@ class AeronClientServerForever { server.filter(IpSubnetFilterRule(IPv4.LOCALHOST, 32)) - server.filter { - println("should the connection $this be allowed?") + server.filter { clientAddress, tagName -> + println("should the connection $clientAddress be allowed?") true } diff --git a/test/dorkboxTest/network/app/AeronClientServerRMIForever.kt b/test/dorkboxTest/network/app/AeronClientServerRMIForever.kt index 38a956a5..992a9cde 100644 --- a/test/dorkboxTest/network/app/AeronClientServerRMIForever.kt +++ b/test/dorkboxTest/network/app/AeronClientServerRMIForever.kt @@ -221,8 +221,8 @@ class AeronClientServerRMIForever { server.filter(IpSubnetFilterRule(IPv4.LOCALHOST, 32)) - server.filter { - println("should the connection $this be allowed?") + server.filter { clientAddress, tagName -> + println("should the connection $clientAddress be allowed?") true } diff --git a/test/dorkboxTest/network/app/AeronServer.kt b/test/dorkboxTest/network/app/AeronServer.kt index d642962e..f922d7c5 100644 --- a/test/dorkboxTest/network/app/AeronServer.kt +++ b/test/dorkboxTest/network/app/AeronServer.kt @@ -105,8 +105,8 @@ object AeronServer { throw IllegalStateException("Aeron was unable to shut down in a timely manner.") } - server.filter { - println("should the connection $this be allowed?") + server.filter { clientAddress, tagName -> + println("should the connection $clientAddress be allowed?") true }