Fixed threading issue with read/write global kryo

This commit is contained in:
nathan 2020-09-01 14:39:18 +02:00
parent 9a4c79e445
commit 3ade7f229e
5 changed files with 76 additions and 65 deletions

View File

@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit
* This connection is established once the registration information is validated, and the various connect/filter checks have passed * This connection is established once the registration information is validated, and the various connect/filter checks have passed
*/ */
open class Connection(connectionParameters: ConnectionParams<*>) { open class Connection(connectionParameters: ConnectionParams<*>) {
private var messageHandler: FragmentAssembler
private val subscription: Subscription private val subscription: Subscription
private val publication: Publication private val publication: Publication
@ -126,8 +127,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// a record of how many messages are in progress of being sent. When closing the connection, this number must be 0 // a record of how many messages are in progress of being sent. When closing the connection, this number must be 0
private val messagesInProgress = atomic(0) private val messagesInProgress = atomic(0)
private var messageHandler: FragmentAssembler
init { init {
val mediaDriverConnection = connectionParameters.mediaDriverConnection val mediaDriverConnection = connectionParameters.mediaDriverConnection
@ -264,10 +263,9 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/** /**
* A message in progress means that we have requested to to send an object over the network, but it hasn't finished sending over the network
* *
* @return the number of messages in progress for this connection. * @return the number of messages in progress for this connection.
*
* A message in progress means that we have requested to to send an object over the network, but it hasn't finished sending over the network
*/ */
fun messagesInProgress(): Int { fun messagesInProgress(): Int {
return messagesInProgress.value return messagesInProgress.value
@ -278,7 +276,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* @return `true` if this connection has no subscribers (which means this connection longer has a remote connection) * @return `true` if this connection has no subscribers (which means this connection longer has a remote connection)
*/ */
internal fun isExpired(): Boolean { internal fun isExpired(): Boolean {
return subscription.imageCount() == 0 return !subscription.isConnected
} }
@ -439,7 +437,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveObject(`object`: Any): Int { fun saveObject(`object`: Any): Int {
val rmiId = rmiConnectionSupport.saveImplObject(`object`) val rmiId = rmiConnectionSupport.saveImplObject(`object`)
if (rmiId == RemoteObjectStorage.INVALID_RMI) { if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")
@ -468,7 +466,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* @see RemoteObject * @see RemoteObject
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
suspend fun saveObject(`object`: Any, objectId: Int): Boolean { fun saveObject(`object`: Any, objectId: Int): Boolean {
val success = rmiConnectionSupport.saveImplObject(`object`, objectId) val success = rmiConnectionSupport.saveImplObject(`object`, objectId)
if (!success) { if (!success) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated") val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")

View File

@ -247,7 +247,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
serialization.finishInit(type) serialization.finishInit(type)
} }
internal suspend fun initEndpointState(): Aeron { internal fun initEndpointState(): Aeron {
val aeronDirectory = config.aeronLogDirectory!!.absolutePath val aeronDirectory = config.aeronLogDirectory!!.absolutePath
val threadFactory = NamedThreadFactory("Aeron", false) val threadFactory = NamedThreadFactory("Aeron", false)
@ -398,7 +398,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* *
* The error is also sent to an error log before this method is called. * The error is also sent to an error log before this method is called.
*/ */
fun onError(function: suspend (CONNECTION, Throwable) -> Unit) { fun onError(function: (CONNECTION, Throwable) -> Unit) {
actionDispatch.launch { actionDispatch.launch {
listenerManager.onError(function) listenerManager.onError(function)
} }
@ -409,7 +409,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* *
* The error is also sent to an error log before this method is called. * The error is also sent to an error log before this method is called.
*/ */
fun onError(function: suspend (Throwable) -> Unit) { fun onError(function: (Throwable) -> Unit) {
actionDispatch.launch { actionDispatch.launch {
listenerManager.onError(function) listenerManager.onError(function)
} }
@ -441,9 +441,13 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
"[${publication.sessionId()}] send: $message" "[${publication.sessionId()}] send: $message"
} }
// we are not thread-safe!
val kryo = serialization.takeKryo()
try { try {
// NOTE: it is safe to use the global, single-threaded kryo instance! kryo.write(message)
val buffer = serialization.writeMessage(message)
val buffer = kryo.writerBuffer
val objectSize = buffer.position() val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer val internalBuffer = buffer.internalBuffer
@ -469,6 +473,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
listenerManager.notifyError(newException("Error serializing message $message", e)) listenerManager.notifyError(newException("Error serializing message $message", e))
} finally { } finally {
sendIdleStrategy.reset() sendIdleStrategy.reset()
serialization.returnKryo(kryo)
} }
} }
@ -494,9 +499,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val exception = newException("[${sessionId}] Error de-serializing message", e) val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
actionDispatch.launch { listenerManager.notifyError(exception)
listenerManager.notifyError(exception)
}
logger.error("Error de-serializing message on connection ${header.sessionId()}!", e) logger.error("Error de-serializing message on connection ${header.sessionId()}!", e)
return null return null
@ -529,9 +532,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val exception = newException("[${sessionId}] Error de-serializing message", e) val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
actionDispatch.launch { listenerManager.notifyError(connection, exception)
listenerManager.notifyError(connection, exception)
}
return // don't do anything! return // don't do anything!
} }
@ -577,17 +578,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
} }
else -> { else -> {
actionDispatch.launch { // do nothing, there were problems with the message
// do nothing, there were problems with the message val exception = if (message != null) {
val exception = if (message != null) { MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}") } else {
} else { MessageNotRegisteredException("Unknown message received!!")
MessageNotRegisteredException("Unknown message received!!")
}
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
} }
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
} }
} }
} }

View File

@ -85,10 +85,10 @@ internal class ListenerManager<CONNECTION: Connection> {
private val onDisconnectList = atomic(Array<suspend (CONNECTION) -> Unit>(0) { { } }) private val onDisconnectList = atomic(Array<suspend (CONNECTION) -> Unit>(0) { { } })
private val onDisconnectMutex = Mutex() private val onDisconnectMutex = Mutex()
private val onErrorList = atomic(Array<suspend (CONNECTION, Throwable) -> Unit>(0) { { _, _ -> } }) private val onErrorList = atomic(Array<(CONNECTION, Throwable) -> Unit>(0) { { _, _ -> } })
private val onErrorMutex = Mutex() private val onErrorMutex = Mutex()
private val onErrorGlobalList = atomic(Array<suspend (Throwable) -> Unit>(0) { { _ -> } }) private val onErrorGlobalList = atomic(Array<(Throwable) -> Unit>(0) { { _ -> } })
private val onErrorGlobalMutex = Mutex() private val onErrorGlobalMutex = Mutex()
private val onMessageMap = atomic(IdentityMap<Class<*>, Array<suspend (CONNECTION, Any) -> Unit>>(32, LOAD_FACTOR)) private val onMessageMap = atomic(IdentityMap<Class<*>, Array<suspend (CONNECTION, Any) -> Unit>>(32, LOAD_FACTOR))
@ -171,7 +171,7 @@ internal class ListenerManager<CONNECTION: Connection> {
* *
* The error is also sent to an error log before this method is called. * The error is also sent to an error log before this method is called.
*/ */
suspend fun onError(function: suspend (CONNECTION, throwable: Throwable) -> Unit) { suspend fun onError(function: (CONNECTION, throwable: Throwable) -> Unit) {
onErrorMutex.withLock { onErrorMutex.withLock {
// we have to follow the single-writer principle! // we have to follow the single-writer principle!
onErrorList.lazySet(add(function, onErrorList.value)) onErrorList.lazySet(add(function, onErrorList.value))
@ -183,7 +183,7 @@ internal class ListenerManager<CONNECTION: Connection> {
* *
* The error is also sent to an error log before this method is called. * The error is also sent to an error log before this method is called.
*/ */
suspend fun onError(function: suspend (throwable: Throwable) -> Unit) { suspend fun onError(function: (throwable: Throwable) -> Unit) {
onErrorGlobalMutex.withLock { onErrorGlobalMutex.withLock {
// we have to follow the single-writer principle! // we have to follow the single-writer principle!
onErrorGlobalList.lazySet(add(function, onErrorGlobalList.value)) onErrorGlobalList.lazySet(add(function, onErrorGlobalList.value))
@ -252,7 +252,7 @@ internal class ListenerManager<CONNECTION: Connection> {
* *
* @return true if the connection will be allowed to connect. False if we should terminate this connection * @return true if the connection will be allowed to connect. False if we should terminate this connection
*/ */
suspend fun notifyFilter(connection: CONNECTION): Boolean { fun notifyFilter(connection: CONNECTION): Boolean {
// NOTE: pass a reference to a string, so if there is an error, we can get it! (and log it, and send it to the client) // NOTE: pass a reference to a string, so if there is an error, we can get it! (and log it, and send it to the client)
// first run through the IP connection filters, THEN run through the "custom" filters // first run through the IP connection filters, THEN run through the "custom" filters
@ -264,12 +264,12 @@ internal class ListenerManager<CONNECTION: Connection> {
// these are the IP filters (optimized checking based on simple IP rules) // these are the IP filters (optimized checking based on simple IP rules)
onConnectIpFilterList.value.forEach { onConnectIpFilterList.value.forEach {
// if (it.matches()) // if (it.matches(connection))
// //
// //
// if (!it(connection)) { // if (!it(connection)) {
// return false // return false
// } // }
} }
@ -344,7 +344,7 @@ internal class ListenerManager<CONNECTION: Connection> {
* *
* The error is also sent to an error log before notifying callbacks * The error is also sent to an error log before notifying callbacks
*/ */
suspend fun notifyError(connection: CONNECTION, exception: Throwable) { fun notifyError(connection: CONNECTION, exception: Throwable) {
onErrorList.value.forEach { onErrorList.value.forEach {
it(connection, exception) it(connection, exception)
} }
@ -355,7 +355,7 @@ internal class ListenerManager<CONNECTION: Connection> {
* *
* The error is also sent to an error log before notifying callbacks * The error is also sent to an error log before notifying callbacks
*/ */
suspend fun notifyError(exception: Throwable) { fun notifyError(exception: Throwable) {
onErrorGlobalList.value.forEach { onErrorGlobalList.value.forEach {
it(exception) it(exception)
} }

View File

@ -54,8 +54,8 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
EndPoint.RESERVED_SESSION_ID_HIGH) EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE) private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
// note: this is called in action dispatch // note: CANNOT be called in action dispatch
suspend fun processHandshakeMessageServer(handshakePublication: Publication, fun processHandshakeMessageServer(handshakePublication: Publication,
sessionId: Int, sessionId: Int,
clientAddressString: String, clientAddressString: String,
clientAddress: Int, clientAddress: Int,
@ -66,7 +66,10 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
}
return return
} }
@ -84,10 +87,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// this enables the connection to start polling for messages // this enables the connection to start polling for messages
server.connections.add(pendingConnection) server.connections.add(pendingConnection)
// now tell the client we are done
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
server.actionDispatch.launch { server.actionDispatch.launch {
// now tell the client we are done
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
listenerManager.notifyConnect(pendingConnection) listenerManager.notifyConnect(pendingConnection)
} }
@ -104,7 +106,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
if (server.connections.connectionCount() >= config.maxClientCount) { if (server.connections.connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
}
return return
} }
@ -124,17 +128,21 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// VALIDATE:: we are now connected to the client and are going to create a new connection. // VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress) val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) { if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always) // decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Too many connections for IP address")) listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
}
return return
} }
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
}
return return
} }
@ -156,9 +164,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
HandshakeMessage.error("Connection error!")) }
return return
} }
@ -172,9 +180,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
HandshakeMessage.error("Connection error!")) }
return return
} }
@ -218,8 +226,11 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception) listenerManager.notifyError(connection, exception)
server.writeHandshakeMessage(handshakePublication, server.actionDispatch.launch {
HandshakeMessage.error("Connection was not permitted!")) server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
}
return return
} }
@ -229,6 +240,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
/////////////// ///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information // if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
// NOTE: This modifies the readKryo! This cannot be on a different thread!
serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage -> serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage ->
listenerManager.notifyError(connection, listenerManager.notifyError(connection,
ClientRejectedException(errorMessage)) ClientRejectedException(errorMessage))
@ -264,7 +276,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
} }
// this tells the client all of the info to connect. // this tells the client all of the info to connect.
server.writeHandshakeMessage(handshakePublication, successMessage) server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, successMessage)
}
} catch (e: Exception) { } catch (e: Exception) {
// have to unwind actions! // have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress) connectionsPerIpCounts.getAndDecrement(clientAddress)

View File

@ -103,7 +103,7 @@ internal class RmiManagerGlobal<CONNECTION : Connection>(logger: KLogger,
/** /**
* @return the removed object. If null, an error log will be emitted * @return the removed object. If null, an error log will be emitted
*/ */
suspend fun <T> removeImplObject(endPoint: EndPoint<CONNECTION>, objectId: Int): T? { fun <T> removeImplObject(endPoint: EndPoint<CONNECTION>, objectId: Int): T? {
val success = removeImplObject<Any>(objectId) val success = removeImplObject<Any>(objectId)
if (success == null) { if (success == null) {
val exception = Exception("Error trying to remove RMI impl object id $objectId.") val exception = Exception("Error trying to remove RMI impl object id $objectId.")
@ -124,12 +124,12 @@ internal class RmiManagerGlobal<CONNECTION : Connection>(logger: KLogger,
/** /**
* called on "client" * called on "client"
*/ */
private suspend fun onGenericObjectResponse(endPoint: EndPoint<CONNECTION>, private fun onGenericObjectResponse(endPoint: EndPoint<CONNECTION>,
connection: CONNECTION, connection: CONNECTION,
isGlobal: Boolean, isGlobal: Boolean,
rmiId: Int, rmiId: Int,
callback: suspend (Int, Any) -> Unit, callback: suspend (Int, Any) -> Unit,
serialization: Serialization) { serialization: Serialization) {
// we only create the proxy + execute the callback if the RMI id is valid! // we only create the proxy + execute the callback if the RMI id is valid!
if (rmiId == RemoteObjectStorage.INVALID_RMI) { if (rmiId == RemoteObjectStorage.INVALID_RMI) {