Fixed issues with connection handshake. ALL kryo usage for a handshake is now ONLY on the primary coroutine

This commit is contained in:
nathan 2020-09-09 12:24:04 +02:00
parent 06a35ed027
commit ce5eb8cb77
4 changed files with 51 additions and 62 deletions

View File

@ -128,6 +128,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
*/ */
val serialization: Serialization val serialization: Serialization
private val handshakeKryo: KryoExtra
private val sendIdleStrategy: CoroutineIdleStrategy private val sendIdleStrategy: CoroutineIdleStrategy
/** /**
@ -279,6 +282,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// serialization stuff // serialization stuff
serialization = config.serialization serialization = config.serialization
sendIdleStrategy = config.sendIdleStrategy sendIdleStrategy = config.sendIdleStrategy
handshakeKryo = serialization.initHandshakeKryo()
// we have to be able to specify WHAT property store we want to use, since it can change! // we have to be able to specify WHAT property store we want to use, since it can change!
settingsStore = config.settingsStore settingsStore = config.settingsStore
@ -461,27 +465,28 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
} }
@Suppress("DuplicatedCode")
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
internal suspend fun writeHandshakeMessage(publication: Publication, message: HandshakeMessage) { internal suspend fun writeHandshakeMessage(publication: Publication, message: HandshakeMessage) {
// The sessionId is globally unique, and is assigned by the server. // The sessionId is globally unique, and is assigned by the server.
logger.trace { logger.trace {
"[${publication.sessionId()}] send HS: $message" "[${publication.sessionId()}] send HS: $message"
} }
// we are not thread-safe!
val kryo = serialization.takeHandshakeKryo()
try { try {
kryo.write(message) // we are not thread-safe!
handshakeKryo.write(message)
val buffer = kryo.writerBuffer val buffer = handshakeKryo.writerBuffer
val objectSize = buffer.position() val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer val internalBuffer = buffer.internalBuffer
var result: Long var result: Long
while (true) { while (true) {
result = publication.offer(internalBuffer, 0, objectSize) result = publication.offer(internalBuffer, 0, objectSize)
// success! if (result >= 0) {
if (result > 0) { // success!
return return
} }
@ -514,7 +519,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
listenerManager.notifyError(exception) listenerManager.notifyError(exception)
} finally { } finally {
sendIdleStrategy.reset() sendIdleStrategy.reset()
serialization.returnHandshakeKryo(kryo)
} }
} }
@ -526,9 +530,11 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* *
* @return the message * @return the message
*/ */
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
try { try {
val message = serialization.readHandshakeMessage(buffer, offset, length) val message = handshakeKryo.read(buffer, offset, length)
logger.trace { logger.trace {
"[${header.sessionId()}] received HS: $message" "[${header.sessionId()}] received HS: $message"
} }
@ -633,6 +639,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
// NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine! // NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine!
@Suppress("DuplicatedCode")
internal suspend fun send(message: Any, publication: Publication, connection: Connection) { internal suspend fun send(message: Any, publication: Publication, connection: Connection) {
// The sessionId is globally unique, and is assigned by the server. // The sessionId is globally unique, and is assigned by the server.
logger.trace { logger.trace {
@ -651,8 +658,8 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
var result: Long var result: Long
while (true) { while (true) {
result = publication.offer(internalBuffer, 0, objectSize) result = publication.offer(internalBuffer, 0, objectSize)
// success! if (result >= 0) {
if (result > 0) { // success!
return return
} }
@ -746,15 +753,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
connections.forEach { connections.forEach {
it.close() it.close()
} }
// the storage is closed via this as well.
settingsStore.close()
// Connections are closed first, because we want to make sure that no RMI messages can be received
// when we close the RMI support objects (in which case, weird - but harmless - errors show up)
rmiGlobalSupport.close()
} }
// the storage is closed via this as well.
settingsStore.close()
// Connections are closed first, because we want to make sure that no RMI messages can be received
// when we close the RMI support objects (in which case, weird - but harmless - errors show up)
rmiGlobalSupport.close()
close0() close0()
// if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now) // if we are waiting for shutdown, cancel the waiting thread (since we have shutdown now)

View File

@ -110,10 +110,10 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
} }
} }
// called from the connect thread
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, config.settingsStore.getPublicKey()!!) val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, config.settingsStore.getPublicKey()!!)
// Send the one-time pad to the server. // Send the one-time pad to the server.
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
sessionId = handshakeConnection.publication.sessionId() sessionId = handshakeConnection.publication.sessionId()
@ -156,6 +156,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val logger: KLogg
return connectionHelloInfo!! return connectionHelloInfo!!
} }
// called from the connect thread
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient() val registrationMessage = HandshakeMessage.doneFromClient()

View File

@ -45,7 +45,7 @@ import kotlin.concurrent.write
/** /**
* @throws IllegalArgumentException If the port range is not valid * 'notifyConnect' must be THE ONLY THING in this class to use the action dispatch!
*/ */
@Suppress("DuplicatedCode") @Suppress("DuplicatedCode")
internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger, internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLogger,
@ -88,7 +88,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
if (message !is HandshakeMessage) { if (message !is HandshakeMessage) {
listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request")) listenerManager.notifyError(ClientRejectedException("[$sessionId] Connection from $connectionString not allowed! Invalid connection request"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection request"))
} }
return false return false
@ -110,9 +110,12 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// this enables the connection to start polling for messages // this enables the connection to start polling for messages
server.connections.add(pendingConnection) server.connections.add(pendingConnection)
server.actionDispatch.launch { // now tell the client we are done
// now tell the client we are done runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
}
server.actionDispatch.launch {
// this must be THE ONLY THING in this class to use the action dispatch!
listenerManager.notifyConnect(pendingConnection) listenerManager.notifyConnect(pendingConnection)
} }
} }
@ -143,7 +146,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
if (server.connections.connectionCount() >= config.maxClientCount) { if (server.connections.connectionCount() >= config.maxClientCount) {
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Server is full"))
} }
return false return false
@ -157,16 +160,15 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionsPerIpCounts.decrement(clientAddress, currentCountForIp) connectionsPerIpCounts.decrement(clientAddress, currentCountForIp)
listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}")) listenerManager.notifyError(ClientRejectedException("Too many connections for IP address $clientAddressString. Max allowed is ${config.maxConnectionsPerIpAddress}"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Too many connections for IP address"))
} }
return false return false
} }
connectionsPerIpCounts.increment(clientAddress, currentCountForIp) connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(ClientRejectedException("could not validate client message", e)) listenerManager.notifyError(ClientRejectedException("could not validate client message", e))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Invalid connection"))
} }
return false return false
@ -205,7 +207,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionSessionId = sessionIdAllocator.allocate() connectionSessionId = sessionIdAllocator.allocate()
} catch (e: AllocationException) { } catch (e: AllocationException) {
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a session ID for the client connection!"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} }
return return
@ -220,7 +222,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} }
return return
@ -235,7 +237,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionStreamPubId) sessionIdAllocator.free(connectionStreamPubId)
listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $connectionString not allowed! Unable to allocate a stream ID for the client connection!"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} }
return return
@ -271,10 +273,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception) listenerManager.notifyError(connection, exception)
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
} }
return return
} }
@ -314,7 +315,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
} }
// this tells the client all of the info to connect. // this tells the client all of the info to connect.
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, successMessage) server.writeHandshakeMessage(handshakePublication, successMessage)
} }
} catch (e: Exception) { } catch (e: Exception) {
@ -374,7 +375,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
connectionsPerIpCounts.decrementSlow(clientAddress) connectionsPerIpCounts.decrementSlow(clientAddress)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a session ID for the client connection!"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} }
return return
@ -390,7 +391,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
sessionIdAllocator.free(connectionSessionId) sessionIdAllocator.free(connectionSessionId)
listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!")) listenerManager.notifyError(ClientRejectedException("Connection from $clientAddressString not allowed! Unable to allocate a stream ID for the client connection!"))
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection error!"))
} }
return return
@ -447,10 +448,9 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
ListenerManager.cleanStackTrace(exception) ListenerManager.cleanStackTrace(exception)
listenerManager.notifyError(connection, exception) listenerManager.notifyError(connection, exception)
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!")) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.error("Connection was not permitted!"))
} }
return return
} }
@ -483,7 +483,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
} }
// this tells the client all of the info to connect. // this tells the client all of the info to connect.
server.actionDispatch.launch { runBlocking {
server.writeHandshakeMessage(handshakePublication, successMessage) server.writeHandshakeMessage(handshakePublication, successMessage)
} }
} catch (e: Exception) { } catch (e: Exception) {

View File

@ -44,6 +44,7 @@ import dorkbox.os.OS
import dorkbox.util.serialization.SerializationDefaults import dorkbox.util.serialization.SerializationDefaults
import dorkbox.util.serialization.SerializationManager import dorkbox.util.serialization.SerializationManager
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
import mu.KotlinLogging import mu.KotlinLogging
@ -93,8 +94,6 @@ open class Serialization(private val references: Boolean = true, private val fac
private var initialized = atomic(false) private var initialized = atomic(false)
private val initializedKryoCount = atomic(0) private val initializedKryoCount = atomic(0)
private val kryoPool = MultithreadConcurrentQueue<KryoExtra>(1024) // reasonable size of available kryo's
private val kryoHandshakePool = MultithreadConcurrentQueue<KryoExtra>(1024) // reasonable size of available kryo's
// used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class) // used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class)
// All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems.
@ -119,14 +118,15 @@ open class Serialization(private val references: Boolean = true, private val fac
val rmiHolder = RmiHolder() val rmiHolder = RmiHolder()
// reflectASM doesn't work on android // reflectASM doesn't work on android
private val useAsm = !OS.isAndroid() private val useAsm = !OS.isAndroid()
// These are GLOBAL, single threaded only kryo instances. // These are GLOBAL, single threaded only kryo instances.
// The readKryo WILL RE-CONFIGURED during the client handshake! (it is all the same thread, so object visibility is not a problem) // The readKryo WILL RE-CONFIGURED during the client handshake! (it is all the same thread, so object visibility is not a problem)
// NOTE: These following can ONLY be called on a single thread!
private var readKryo = initGlobalKryo() private var readKryo = initGlobalKryo()
private var readHandshakeKryo = initHandshakeKryo() private val kryoPool = MultithreadConcurrentQueue<KryoExtra>(1024) // reasonable size of available kryo's?
/** /**
* Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer]. * Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer].
@ -268,7 +268,7 @@ open class Serialization(private val references: Boolean = true, private val fac
/** /**
* Kryo specifically for handshakes * Kryo specifically for handshakes
*/ */
private fun initHandshakeKryo(): KryoExtra { internal fun initHandshakeKryo(): KryoExtra {
val kryo = KryoExtra() val kryo = KryoExtra()
kryo.instantiatorStrategy = instantiatorStrategy kryo.instantiatorStrategy = instantiatorStrategy
@ -613,22 +613,6 @@ open class Serialization(private val references: Boolean = true, private val fac
return initializeClassRegistrations(kryo) return initializeClassRegistrations(kryo)
} }
/**
* @return takes a kryo instance from the pool, or creates one if the pool was empty
*/
fun takeHandshakeKryo(): KryoExtra {
// ALWAYS get as many as needed. Recycle them to prevent too many getting created
return kryoHandshakePool.poll() ?: initHandshakeKryo()
}
/**
* Returns a kryo instance to the pool for re-use later on
*/
fun returnHandshakeKryo(kryo: KryoExtra) {
// return as much as we can. don't suspend if the pool is full, we just throw it away.
kryoHandshakePool.offer(kryo)
}
/** /**
* @return The number of kryo instances created. This does not reflect the size of the pool, just the number of * @return The number of kryo instances created. This does not reflect the size of the pool, just the number of
* existing kryo instances. * existing kryo instances.
@ -819,9 +803,6 @@ open class Serialization(private val references: Boolean = true, private val fac
} }
// NOTE: These following functions are ONLY called on a single thread! // NOTE: These following functions are ONLY called on a single thread!
fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? {
return readHandshakeKryo.read(buffer, offset, length)
}
fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, connection: Connection): Any? { fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, connection: Connection): Any? {
return readKryo.read(buffer, offset, length, connection) return readKryo.read(buffer, offset, length, connection)
} }