diff --git a/src/dorkbox/network/connection/Connection.kt b/src/dorkbox/network/connection/Connection.kt index 8c2bd4af..a59137a3 100644 --- a/src/dorkbox/network/connection/Connection.kt +++ b/src/dorkbox/network/connection/Connection.kt @@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit * This connection is established once the registration information is validated, and the various connect/filter checks have passed */ open class Connection(connectionParameters: ConnectionParams<*>) { + private var messageHandler: FragmentAssembler private val subscription: Subscription private val publication: Publication @@ -126,8 +127,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) { // a record of how many messages are in progress of being sent. When closing the connection, this number must be 0 private val messagesInProgress = atomic(0) - private var messageHandler: FragmentAssembler - init { val mediaDriverConnection = connectionParameters.mediaDriverConnection @@ -264,10 +263,9 @@ open class Connection(connectionParameters: ConnectionParams<*>) { /** + * A message in progress means that we have requested to to send an object over the network, but it hasn't finished sending over the network * * @return the number of messages in progress for this connection. - * - * A message in progress means that we have requested to to send an object over the network, but it hasn't finished sending over the network */ fun messagesInProgress(): Int { return messagesInProgress.value @@ -278,7 +276,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { * @return `true` if this connection has no subscribers (which means this connection longer has a remote connection) */ internal fun isExpired(): Boolean { - return subscription.imageCount() == 0 + return !subscription.isConnected } @@ -439,7 +437,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveObject(`object`: Any): Int { + fun saveObject(`object`: Any): Int { val rmiId = rmiConnectionSupport.saveImplObject(`object`) if (rmiId == RemoteObjectStorage.INVALID_RMI) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") @@ -468,7 +466,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) { * @see RemoteObject */ @Suppress("DuplicatedCode") - suspend fun saveObject(`object`: Any, objectId: Int): Boolean { + fun saveObject(`object`: Any, objectId: Int): Boolean { val success = rmiConnectionSupport.saveImplObject(`object`, objectId) if (!success) { val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index b19cf512..907c2c55 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -247,7 +247,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A serialization.finishInit(type) } - internal suspend fun initEndpointState(): Aeron { + internal fun initEndpointState(): Aeron { val aeronDirectory = config.aeronLogDirectory!!.absolutePath val threadFactory = NamedThreadFactory("Aeron", false) @@ -398,7 +398,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A * * The error is also sent to an error log before this method is called. */ - fun onError(function: suspend (CONNECTION, Throwable) -> Unit) { + fun onError(function: (CONNECTION, Throwable) -> Unit) { actionDispatch.launch { listenerManager.onError(function) } @@ -409,7 +409,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A * * The error is also sent to an error log before this method is called. */ - fun onError(function: suspend (Throwable) -> Unit) { + fun onError(function: (Throwable) -> Unit) { actionDispatch.launch { listenerManager.onError(function) } @@ -441,9 +441,13 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A "[${publication.sessionId()}] send: $message" } + // we are not thread-safe! + val kryo = serialization.takeKryo() + try { - // NOTE: it is safe to use the global, single-threaded kryo instance! - val buffer = serialization.writeMessage(message) + kryo.write(message) + + val buffer = kryo.writerBuffer val objectSize = buffer.position() val internalBuffer = buffer.internalBuffer @@ -469,6 +473,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A listenerManager.notifyError(newException("Error serializing message $message", e)) } finally { sendIdleStrategy.reset() + serialization.returnKryo(kryo) } } @@ -494,9 +499,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A val exception = newException("[${sessionId}] Error de-serializing message", e) ListenerManager.cleanStackTrace(exception) - actionDispatch.launch { - listenerManager.notifyError(exception) - } + listenerManager.notifyError(exception) logger.error("Error de-serializing message on connection ${header.sessionId()}!", e) return null @@ -529,9 +532,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A val exception = newException("[${sessionId}] Error de-serializing message", e) ListenerManager.cleanStackTrace(exception) - actionDispatch.launch { - listenerManager.notifyError(connection, exception) - } + listenerManager.notifyError(connection, exception) return // don't do anything! } @@ -577,17 +578,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A } } else -> { - actionDispatch.launch { - // do nothing, there were problems with the message - val exception = if (message != null) { - MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}") - } else { - MessageNotRegisteredException("Unknown message received!!") - } - - ListenerManager.cleanStackTrace(exception) - listenerManager.notifyError(exception) + // do nothing, there were problems with the message + val exception = if (message != null) { + MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}") + } else { + MessageNotRegisteredException("Unknown message received!!") } + + ListenerManager.cleanStackTrace(exception) + listenerManager.notifyError(exception) } } } diff --git a/src/dorkbox/network/connection/ListenerManager.kt b/src/dorkbox/network/connection/ListenerManager.kt index 5414e329..5eb4c9c8 100644 --- a/src/dorkbox/network/connection/ListenerManager.kt +++ b/src/dorkbox/network/connection/ListenerManager.kt @@ -85,10 +85,10 @@ internal class ListenerManager { private val onDisconnectList = atomic(Array Unit>(0) { { } }) private val onDisconnectMutex = Mutex() - private val onErrorList = atomic(Array Unit>(0) { { _, _ -> } }) + private val onErrorList = atomic(Array<(CONNECTION, Throwable) -> Unit>(0) { { _, _ -> } }) private val onErrorMutex = Mutex() - private val onErrorGlobalList = atomic(Array Unit>(0) { { _ -> } }) + private val onErrorGlobalList = atomic(Array<(Throwable) -> Unit>(0) { { _ -> } }) private val onErrorGlobalMutex = Mutex() private val onMessageMap = atomic(IdentityMap, Array Unit>>(32, LOAD_FACTOR)) @@ -171,7 +171,7 @@ internal class ListenerManager { * * The error is also sent to an error log before this method is called. */ - suspend fun onError(function: suspend (CONNECTION, throwable: Throwable) -> Unit) { + suspend fun onError(function: (CONNECTION, throwable: Throwable) -> Unit) { onErrorMutex.withLock { // we have to follow the single-writer principle! onErrorList.lazySet(add(function, onErrorList.value)) @@ -183,7 +183,7 @@ internal class ListenerManager { * * The error is also sent to an error log before this method is called. */ - suspend fun onError(function: suspend (throwable: Throwable) -> Unit) { + suspend fun onError(function: (throwable: Throwable) -> Unit) { onErrorGlobalMutex.withLock { // we have to follow the single-writer principle! onErrorGlobalList.lazySet(add(function, onErrorGlobalList.value)) @@ -252,7 +252,7 @@ internal class ListenerManager { * * @return true if the connection will be allowed to connect. False if we should terminate this connection */ - suspend fun notifyFilter(connection: CONNECTION): Boolean { + fun notifyFilter(connection: CONNECTION): Boolean { // NOTE: pass a reference to a string, so if there is an error, we can get it! (and log it, and send it to the client) // first run through the IP connection filters, THEN run through the "custom" filters @@ -264,12 +264,12 @@ internal class ListenerManager { // these are the IP filters (optimized checking based on simple IP rules) onConnectIpFilterList.value.forEach { -// if (it.matches()) +// if (it.matches(connection)) // // // if (!it(connection)) { // return false -// } +// } } @@ -344,7 +344,7 @@ internal class ListenerManager { * * The error is also sent to an error log before notifying callbacks */ - suspend fun notifyError(connection: CONNECTION, exception: Throwable) { + fun notifyError(connection: CONNECTION, exception: Throwable) { onErrorList.value.forEach { it(connection, exception) } @@ -355,7 +355,7 @@ internal class ListenerManager { * * The error is also sent to an error log before notifying callbacks */ - suspend fun notifyError(exception: Throwable) { + fun notifyError(exception: Throwable) { onErrorGlobalList.value.forEach { it(exception) } diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index 8e7c16f5..d7cb88eb 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -54,8 +54,8 @@ internal class ServerHandshake(private val logger: KLog EndPoint.RESERVED_SESSION_ID_HIGH) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) - // note: this is called in action dispatch - suspend fun processHandshakeMessageServer(handshakePublication: Publication, + // note: CANNOT be called in action dispatch + fun processHandshakeMessageServer(handshakePublication: Publication, sessionId: Int, clientAddressString: String, clientAddress: Int, @@ -66,7 +66,10 @@ internal class ServerHandshake(private val logger: KLog // VALIDATE:: a Registration object is the only acceptable message during the connection phase if (message !is HandshakeMessage) { listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request")) - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) + + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) + } return } @@ -84,10 +87,9 @@ internal class ServerHandshake(private val logger: KLog // this enables the connection to start polling for messages server.connections.add(pendingConnection) - // now tell the client we are done - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) - server.actionDispatch.launch { + // now tell the client we are done + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) listenerManager.notifyConnect(pendingConnection) } @@ -104,7 +106,9 @@ internal class ServerHandshake(private val logger: KLog if (server.connections.connectionCount() >= config.maxClientCount) { listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) + } return } @@ -124,17 +128,21 @@ internal class ServerHandshake(private val logger: KLog // VALIDATE:: we are now connected to the client and are going to create a new connection. val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) if (currentCountForIp >= config.maxConnectionsPerIpAddress) { - listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) - // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always) connectionsPerIpCounts.getAndDecrement(clientAddress) - server.writeHandshakeMessage(handshakePublication, - HandshakeMessage.error("Too many connections for IP address")) + + listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address")) + } + return } } catch (e: Exception) { listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) + } return } @@ -156,9 +164,9 @@ internal class ServerHandshake(private val logger: KLog connectionsPerIpCounts.getAndDecrement(clientAddress) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")) - - server.writeHandshakeMessage(handshakePublication, - HandshakeMessage.error("Connection error!")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) + } return } @@ -172,9 +180,9 @@ internal class ServerHandshake(private val logger: KLog sessionIdAllocator.free(connectionSessionId) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")) - - server.writeHandshakeMessage(handshakePublication, - HandshakeMessage.error("Connection error!")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) + } return } @@ -218,8 +226,11 @@ internal class ServerHandshake(private val logger: KLog ListenerManager.cleanStackTrace(exception) listenerManager.notifyError(connection, exception) - server.writeHandshakeMessage(handshakePublication, - HandshakeMessage.error("Connection was not permitted!")) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, + HandshakeMessage.error("Connection was not permitted!")) + } + return } @@ -229,6 +240,7 @@ internal class ServerHandshake(private val logger: KLog /////////////// // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information + // NOTE: This modifies the readKryo! This cannot be on a different thread! serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage -> listenerManager.notifyError(connection, ClientRejectedException(errorMessage)) @@ -264,7 +276,9 @@ internal class ServerHandshake(private val logger: KLog } // this tells the client all of the info to connect. - server.writeHandshakeMessage(handshakePublication, successMessage) + server.actionDispatch.launch { + server.writeHandshakeMessage(handshakePublication, successMessage) + } } catch (e: Exception) { // have to unwind actions! connectionsPerIpCounts.getAndDecrement(clientAddress) diff --git a/src/dorkbox/network/rmi/RmiManagerGlobal.kt b/src/dorkbox/network/rmi/RmiManagerGlobal.kt index 87d4703e..6f36d8e8 100644 --- a/src/dorkbox/network/rmi/RmiManagerGlobal.kt +++ b/src/dorkbox/network/rmi/RmiManagerGlobal.kt @@ -103,7 +103,7 @@ internal class RmiManagerGlobal(logger: KLogger, /** * @return the removed object. If null, an error log will be emitted */ - suspend fun removeImplObject(endPoint: EndPoint, objectId: Int): T? { + fun removeImplObject(endPoint: EndPoint, objectId: Int): T? { val success = removeImplObject(objectId) if (success == null) { val exception = Exception("Error trying to remove RMI impl object id $objectId.") @@ -124,12 +124,12 @@ internal class RmiManagerGlobal(logger: KLogger, /** * called on "client" */ - private suspend fun onGenericObjectResponse(endPoint: EndPoint, - connection: CONNECTION, - isGlobal: Boolean, - rmiId: Int, - callback: suspend (Int, Any) -> Unit, - serialization: Serialization) { + private fun onGenericObjectResponse(endPoint: EndPoint, + connection: CONNECTION, + isGlobal: Boolean, + rmiId: Int, + callback: suspend (Int, Any) -> Unit, + serialization: Serialization) { // we only create the proxy + execute the callback if the RMI id is valid! if (rmiId == RemoteObjectStorage.INVALID_RMI) {