Fixed threading issue with read/write global kryo

This commit is contained in:
nathan 2020-09-01 14:39:18 +02:00
parent 9a4c79e445
commit 3ade7f229e
5 changed files with 76 additions and 65 deletions

View File

@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit
* This connection is established once the registration information is validated, and the various connect/filter checks have passed
*/
open class Connection(connectionParameters: ConnectionParams<*>) {
private var messageHandler: FragmentAssembler
private val subscription: Subscription
private val publication: Publication
@ -126,8 +127,6 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// a record of how many messages are in progress of being sent. When closing the connection, this number must be 0
private val messagesInProgress = atomic(0)
private var messageHandler: FragmentAssembler
init {
val mediaDriverConnection = connectionParameters.mediaDriverConnection
@ -264,10 +263,9 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/**
* A message in progress means that we have requested to to send an object over the network, but it hasn't finished sending over the network
*
* @return the number of messages in progress for this connection.
*
* A message in progress means that we have requested to to send an object over the network, but it hasn't finished sending over the network
*/
fun messagesInProgress(): Int {
return messagesInProgress.value
@ -278,7 +276,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* @return `true` if this connection has no subscribers (which means this connection longer has a remote connection)
*/
internal fun isExpired(): Boolean {
return subscription.imageCount() == 0
return !subscription.isConnected
}
@ -439,7 +437,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* @see RemoteObject
*/
@Suppress("DuplicatedCode")
suspend fun saveObject(`object`: Any): Int {
fun saveObject(`object`: Any): Int {
val rmiId = rmiConnectionSupport.saveImplObject(`object`)
if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")
@ -468,7 +466,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* @see RemoteObject
*/
@Suppress("DuplicatedCode")
suspend fun saveObject(`object`: Any, objectId: Int): Boolean {
fun saveObject(`object`: Any, objectId: Int): Boolean {
val success = rmiConnectionSupport.saveImplObject(`object`, objectId)
if (!success) {
val exception = Exception("RMI implementation '${`object`::class.java}' could not be saved! No more RMI id's could be generated")

View File

@ -247,7 +247,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
serialization.finishInit(type)
}
internal suspend fun initEndpointState(): Aeron {
internal fun initEndpointState(): Aeron {
val aeronDirectory = config.aeronLogDirectory!!.absolutePath
val threadFactory = NamedThreadFactory("Aeron", false)
@ -398,7 +398,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*
* The error is also sent to an error log before this method is called.
*/
fun onError(function: suspend (CONNECTION, Throwable) -> Unit) {
fun onError(function: (CONNECTION, Throwable) -> Unit) {
actionDispatch.launch {
listenerManager.onError(function)
}
@ -409,7 +409,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*
* The error is also sent to an error log before this method is called.
*/
fun onError(function: suspend (Throwable) -> Unit) {
fun onError(function: (Throwable) -> Unit) {
actionDispatch.launch {
listenerManager.onError(function)
}
@ -441,9 +441,13 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
"[${publication.sessionId()}] send: $message"
}
// we are not thread-safe!
val kryo = serialization.takeKryo()
try {
// NOTE: it is safe to use the global, single-threaded kryo instance!
val buffer = serialization.writeMessage(message)
kryo.write(message)
val buffer = kryo.writerBuffer
val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer
@ -469,6 +473,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
listenerManager.notifyError(newException("Error serializing message $message", e))
} finally {
sendIdleStrategy.reset()
serialization.returnKryo(kryo)
}
}
@ -494,9 +499,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception)
actionDispatch.launch {
listenerManager.notifyError(exception)
}
listenerManager.notifyError(exception)
logger.error("Error de-serializing message on connection ${header.sessionId()}!", e)
return null
@ -529,9 +532,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
val exception = newException("[${sessionId}] Error de-serializing message", e)
ListenerManager.cleanStackTrace(exception)
actionDispatch.launch {
listenerManager.notifyError(connection, exception)
}
listenerManager.notifyError(connection, exception)
return // don't do anything!
}
@ -577,17 +578,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
}
}
else -> {
actionDispatch.launch {
// do nothing, there were problems with the message
val exception = if (message != null) {
MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
} else {
MessageNotRegisteredException("Unknown message received!!")
}
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
// do nothing, there were problems with the message
val exception = if (message != null) {
MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}")
} else {
MessageNotRegisteredException("Unknown message received!!")
}
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(exception)
}
}
}

View File

@ -85,10 +85,10 @@ internal class ListenerManager<CONNECTION: Connection> {
private val onDisconnectList = atomic(Array<suspend (CONNECTION) -> Unit>(0) { { } })
private val onDisconnectMutex = Mutex()
private val onErrorList = atomic(Array<suspend (CONNECTION, Throwable) -> Unit>(0) { { _, _ -> } })
private val onErrorList = atomic(Array<(CONNECTION, Throwable) -> Unit>(0) { { _, _ -> } })
private val onErrorMutex = Mutex()
private val onErrorGlobalList = atomic(Array<suspend (Throwable) -> Unit>(0) { { _ -> } })
private val onErrorGlobalList = atomic(Array<(Throwable) -> Unit>(0) { { _ -> } })
private val onErrorGlobalMutex = Mutex()
private val onMessageMap = atomic(IdentityMap<Class<*>, Array<suspend (CONNECTION, Any) -> Unit>>(32, LOAD_FACTOR))
@ -171,7 +171,7 @@ internal class ListenerManager<CONNECTION: Connection> {
*
* The error is also sent to an error log before this method is called.
*/
suspend fun onError(function: suspend (CONNECTION, throwable: Throwable) -> Unit) {
suspend fun onError(function: (CONNECTION, throwable: Throwable) -> Unit) {
onErrorMutex.withLock {
// we have to follow the single-writer principle!
onErrorList.lazySet(add(function, onErrorList.value))
@ -183,7 +183,7 @@ internal class ListenerManager<CONNECTION: Connection> {
*
* The error is also sent to an error log before this method is called.
*/
suspend fun onError(function: suspend (throwable: Throwable) -> Unit) {
suspend fun onError(function: (throwable: Throwable) -> Unit) {
onErrorGlobalMutex.withLock {
// we have to follow the single-writer principle!
onErrorGlobalList.lazySet(add(function, onErrorGlobalList.value))
@ -252,7 +252,7 @@ internal class ListenerManager<CONNECTION: Connection> {
*
* @return true if the connection will be allowed to connect. False if we should terminate this connection
*/
suspend fun notifyFilter(connection: CONNECTION): Boolean {
fun notifyFilter(connection: CONNECTION): Boolean {
// NOTE: pass a reference to a string, so if there is an error, we can get it! (and log it, and send it to the client)
// first run through the IP connection filters, THEN run through the "custom" filters
@ -264,12 +264,12 @@ internal class ListenerManager<CONNECTION: Connection> {
// these are the IP filters (optimized checking based on simple IP rules)
onConnectIpFilterList.value.forEach {
// if (it.matches())
// if (it.matches(connection))
//
//
// if (!it(connection)) {
// return false
// }
// }
}
@ -344,7 +344,7 @@ internal class ListenerManager<CONNECTION: Connection> {
*
* The error is also sent to an error log before notifying callbacks
*/
suspend fun notifyError(connection: CONNECTION, exception: Throwable) {
fun notifyError(connection: CONNECTION, exception: Throwable) {
onErrorList.value.forEach {
it(connection, exception)
}
@ -355,7 +355,7 @@ internal class ListenerManager<CONNECTION: Connection> {
*
* The error is also sent to an error log before notifying callbacks
*/
suspend fun notifyError(exception: Throwable) {
fun notifyError(exception: Throwable) {
onErrorGlobalList.value.forEach {
it(exception)
}

View File

@ -54,8 +54,8 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
// note: this is called in action dispatch
suspend fun processHandshakeMessageServer(handshakePublication: Publication,
// note: CANNOT be called in action dispatch
fun processHandshakeMessageServer(handshakePublication: Publication,
sessionId: Int,
clientAddressString: String,
clientAddress: Int,
@ -66,7 +66,10 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Invalid connection request"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
}
return
}
@ -84,10 +87,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// this enables the connection to start polling for messages
server.connections.add(pendingConnection)
// now tell the client we are done
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
server.actionDispatch.launch {
// now tell the client we are done
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
listenerManager.notifyConnect(pendingConnection)
}
@ -104,7 +106,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
if (server.connections.connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
}
return
}
@ -124,17 +128,21 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.getAndDecrement(clientAddress)
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Too many connections for IP address"))
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
}
return
}
} catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
}
return
}
@ -156,9 +164,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionsPerIpCounts.getAndDecrement(clientAddress)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!"))
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection error!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return
}
@ -172,9 +180,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!"))
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection error!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
}
return
}
@ -218,8 +226,11 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception)
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication,
HandshakeMessage.error("Connection was not permitted!"))
}
return
}
@ -229,6 +240,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
///////////////
// if necessary (and only for RMI id's that have never been seen before) we want to re-write our kryo information
// NOTE: This modifies the readKryo! This cannot be on a different thread!
serialization.updateKryoIdsForRmi(connection, message.registrationRmiIdData!!) { errorMessage ->
listenerManager.notifyError(connection,
ClientRejectedException(errorMessage))
@ -264,7 +276,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
}
// this tells the client all of the info to connect.
server.writeHandshakeMessage(handshakePublication, successMessage)
server.actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, successMessage)
}
} catch (e: Exception) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)

View File

@ -103,7 +103,7 @@ internal class RmiManagerGlobal<CONNECTION : Connection>(logger: KLogger,
/**
* @return the removed object. If null, an error log will be emitted
*/
suspend fun <T> removeImplObject(endPoint: EndPoint<CONNECTION>, objectId: Int): T? {
fun <T> removeImplObject(endPoint: EndPoint<CONNECTION>, objectId: Int): T? {
val success = removeImplObject<Any>(objectId)
if (success == null) {
val exception = Exception("Error trying to remove RMI impl object id $objectId.")
@ -124,12 +124,12 @@ internal class RmiManagerGlobal<CONNECTION : Connection>(logger: KLogger,
/**
* called on "client"
*/
private suspend fun onGenericObjectResponse(endPoint: EndPoint<CONNECTION>,
connection: CONNECTION,
isGlobal: Boolean,
rmiId: Int,
callback: suspend (Int, Any) -> Unit,
serialization: Serialization) {
private fun onGenericObjectResponse(endPoint: EndPoint<CONNECTION>,
connection: CONNECTION,
isGlobal: Boolean,
rmiId: Int,
callback: suspend (Int, Any) -> Unit,
serialization: Serialization) {
// we only create the proxy + execute the callback if the RMI id is valid!
if (rmiId == RemoteObjectStorage.INVALID_RMI) {