More responsive shutdown logic during handshake and fixed crashes when forcing a shutdown

This commit is contained in:
Robinson 2024-02-19 10:08:56 +01:00
parent c197a2f627
commit 3f77af5894
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
9 changed files with 135 additions and 45 deletions

View File

@ -36,6 +36,7 @@ import io.aeron.driver.reports.LossReportReader
import io.aeron.driver.reports.LossReportUtil
import io.aeron.logbuffer.BufferClaim
import io.aeron.protocol.DataHeaderFlyweight
import kotlinx.atomicfu.AtomicBoolean
import org.agrona.*
import org.agrona.concurrent.AtomicBuffer
import org.agrona.concurrent.IdleStrategy
@ -525,6 +526,7 @@ class AeronDriver(config: Configuration, val logger: Logger, val endPoint: EndPo
* The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
*/
fun waitForConnection(
shutdown: AtomicBoolean,
publication: Publication,
handshakeTimeoutNs: Long,
logInfo: String,
@ -540,6 +542,9 @@ class AeronDriver(config: Configuration, val logger: Logger, val endPoint: EndPo
if (publication.isConnected) {
return
}
if (shutdown.value) {
break
}
Thread.sleep(200L)
}
@ -562,6 +567,7 @@ class AeronDriver(config: Configuration, val logger: Logger, val endPoint: EndPo
* For subscriptions, in the client we want to guarantee that the remote server has connected BACK to us!
*/
fun waitForConnection(
shutdown: AtomicBoolean,
subscription: Subscription,
handshakeTimeoutNs: Long,
logInfo: String,
@ -577,6 +583,9 @@ class AeronDriver(config: Configuration, val logger: Logger, val endPoint: EndPo
if (subscription.isConnected && subscription.imageCount() > 0) {
return
}
if (shutdown.value) {
break
}
Thread.sleep(200L)
}

View File

@ -146,7 +146,7 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
internal val endpointIsRunning = atomic(false)
// this only prevents multiple shutdowns (in the event this close() is called multiple times)
private var shutdown = atomic(false)
internal var shutdown = atomic(false)
internal val shutdownInProgress = atomic(false)
@Volatile
@ -949,19 +949,37 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
* 1) We should reset 100% of the state+events, so that every time we connect, everything is redone
* 2) We preserve the state+event, BECAUSE adding the onConnect/Disconnect/message event states might be VERY expensive.
*
* NOTE: This method does NOT block, as the connection state is asynchronous. Use "waitForClose()" to wait for this to finish
* NOTE: This method does NOT block, as the connection state is asynchronous. Use "waitForClose()" to wait for this to finish.
*
* @param closeEverything unless explicitly called, this is only false when a connection is closed in the client.
* This will unblock the tread waiting in "waitForClose()" when it is finished.
*
* @param closeEverything true only possible via the Client.close() or Server.close() methods.
*/
internal fun close(
closeEverything: Boolean,
sendDisconnectMessage: Boolean,
releaseWaitingThreads: Boolean)
releaseWaitingThreads: Boolean,
redispatched: Boolean = false)
{
if (isShutdown()) {
// we have already closed! Don't try to close again
logger.debug("Already shutting down endpoint, skipping multiple attempts...")
return
}
if (!eventDispatch.CLOSE.isDispatch()) {
eventDispatch.CLOSE.launch {
close(closeEverything, sendDisconnectMessage, releaseWaitingThreads)
// only time the redispatch is true!
close(closeEverything, sendDisconnectMessage, releaseWaitingThreads, true)
}
if (closeEverything) {
waitForClose()
shutdownEventDispatcher() // once shutdown, it cannot be restarted!
}
logger.info("Done shutting down the endpoint.")
return
}
@ -1004,12 +1022,8 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
}
}
if (logger.isDebugEnabled) {
logger.debug("Shutting down endpoint...")
}
shutdown.lazySet(true)
// always do this. It is OK to run this multiple times
// the server has to be able to call server.notifyDisconnect() on a list of connections. If we remove the connections
@ -1018,6 +1032,11 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
it.closeImmediately(sendDisconnectMessage = sendDisconnectMessage, closeEverything = closeEverything)
}
if (this is Client<*>) {
// if there is a client connection IN PROGRESS... then we must wait for that to timeout so we can make sure everything is closed in the right order.
clientConnectionInProgress.await()
}
// this closes the endpoint specific instance running in the poller
@ -1058,8 +1077,6 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
// we might be restarting the aeron driver, so make sure it's closed.
aeronDriver.close()
shutdown.lazySet(true)
// the shutdown here must be in the launchSequentially lambda, this way we can guarantee the driver is closed before we move on
shutdownInProgress.lazySet(false)
shutdownLatch.countDown()
@ -1069,8 +1086,15 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
closeLatch.countDown()
}
if (!redispatched) {
if (closeEverything) {
waitForClose()
shutdownEventDispatcher() // once shutdown, it cannot be restarted!
}
logger.info("Done shutting down the endpoint.")
}
}
/**
* @return true if the current execution thread is in the primary network event dispatch
@ -1088,8 +1112,9 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
* @param timeoutUnit what the unit count is
*/
fun shutdownEventDispatcher(timeout: Long = 15, timeoutUnit: TimeUnit = TimeUnit.SECONDS) {
logger.info("Waiting for Event Dispatcher to shutdown...")
logger.debug("Waiting for Event Dispatcher to shutdown...")
eventDispatch.shutdownAndWait(timeout, timeoutUnit)
logger.info("Done shutting down Event Dispatcher...")
}
/**

View File

@ -25,6 +25,7 @@ import dorkbox.network.connection.EndPoint
import dorkbox.network.exceptions.ClientRetryException
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.CommonContext
import kotlinx.atomicfu.AtomicBoolean
import java.net.Inet4Address
import java.net.InetAddress
@ -41,6 +42,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
companion object {
fun build(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
handshakeConnection: ClientHandshakeDriver,
@ -68,6 +70,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
logInfo = "CONNECTION-IPC"
pubSub = buildIPC(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
sessionIdPub = sessionIdPub,
@ -92,6 +95,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
}
pubSub = buildUDP(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
sessionIdPub = sessionIdPub,
@ -114,6 +118,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
@Throws(ClientTimedOutException::class)
private fun buildIPC(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
sessionIdPub: Int,
@ -139,7 +144,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo publication cannot connect with server!", cause)
}
@ -150,7 +155,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// wait for the REMOTE end to also connect to us!
aeronDriver.waitForConnection(subscription, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, subscription, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
}
@ -173,6 +178,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
@Throws(ClientTimedOutException::class)
private fun buildUDP(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
sessionIdPub: Int,
@ -206,7 +212,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo publication cannot connect with server $remoteAddressString", cause)
}
@ -225,7 +231,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// wait for the REMOTE end to also connect to us!
aeronDriver.waitForConnection(subscription, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, subscription, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
}

View File

@ -217,7 +217,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
pubSub.sub.poll(handler, 1)
if (failedException != null || connectionHelloInfo != null) {
if (endPoint.isShutdown() || failedException != null || connectionHelloInfo != null) {
break
}
@ -283,7 +283,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
handshakePubSub.sub.poll(handler, 1)
if (failedException != null || connectionDone) {
if (endPoint.isShutdown() || failedException != null || connectionDone) {
break
}

View File

@ -16,7 +16,6 @@
package dorkbox.network.handshake
import dorkbox.network.Configuration
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.getLocalAddressString
import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator
@ -34,6 +33,7 @@ import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.util.Sys
import io.aeron.CommonContext
import io.aeron.Subscription
import kotlinx.atomicfu.AtomicBoolean
import org.slf4j.Logger
import java.net.Inet4Address
import java.net.InetAddress
@ -55,7 +55,7 @@ internal class ClientHandshakeDriver(
) {
companion object {
fun build(
config: Configuration,
endpoint: EndPoint<*>,
aeronDriver: AeronDriver,
autoChangeToIpc: Boolean,
remoteAddress: InetAddress?,
@ -105,6 +105,8 @@ internal class ClientHandshakeDriver(
"[Handshake: ${Sys.getTimePrettyFull(handshakeTimeoutNs)}, Max connection attempt: Unlimited]"
}
val config = endpoint.config
val shutdown = endpoint.shutdown
if (isUsingIPC) {
streamIdPub = config.ipcId
@ -117,6 +119,7 @@ internal class ClientHandshakeDriver(
try {
pubSub = buildIPC(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
sessionIdPub = sessionIdPub,
@ -168,6 +171,7 @@ internal class ClientHandshakeDriver(
}
pubSub = buildUDP(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
remoteAddress = remoteAddress,
@ -203,13 +207,15 @@ internal class ClientHandshakeDriver(
@Throws(ClientTimedOutException::class)
private fun buildIPC(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
sessionIdPub: Int,
streamIdPub: Int, streamIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
logInfo: String,
): PubSub {
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
@ -227,7 +233,7 @@ internal class ClientHandshakeDriver(
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo publication cannot connect with server in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
@ -253,6 +259,7 @@ internal class ClientHandshakeDriver(
@Throws(ClientTimedOutException::class)
private fun buildUDP(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
remoteAddress: InetAddress,
@ -293,7 +300,7 @@ internal class ClientHandshakeDriver(
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
streamIdAllocator.free(streamIdSub) // we don't continue, so close this as well
ClientTimedOutException("$logInfo publication cannot connect with server in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}

View File

@ -73,6 +73,11 @@ internal class ServerHandshakeDriver(
}
}
fun unsafeClose() {
// we might not be able to close this connection.
aeronDriver.close(subscription, logInfo)
}
override fun toString(): String {
return info
}

View File

@ -64,6 +64,8 @@ internal object ServerHandshakePollers {
private val isReliable = server.config.isReliable
private val handshaker = server.handshaker
private val handshakeTimeoutNs = handshake.handshakeTimeoutNs
private val shutdownInProgress = server.shutdownInProgress
private val shutdown = server.shutdown
// note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close
private val publications = ExpiringMap.builder()
@ -94,6 +96,12 @@ internal object ServerHandshakePollers {
val logInfo = "$sessionId/$streamId : IPC" // Server is the "source", client mirrors the server
if (shutdownInProgress.value) {
driver.deleteLogFile(image)
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] server is shutting down. Aborting new connection attempts."))
return
}
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length)
@ -149,7 +157,7 @@ internal object ServerHandshakePollers {
try {
// we actually have to wait for it to connect before we continue
driver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
driver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ServerTimedoutException("$logInfo publication cannot connect with client in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
}
@ -275,6 +283,8 @@ internal object ServerHandshakePollers {
private val ipInfo = server.ipInfo
private val handshaker = server.handshaker
private val handshakeTimeoutNs = handshake.handshakeTimeoutNs
private val shutdownInProgress = server.shutdownInProgress
private val shutdown = server.shutdown
private val serverPortSub = server.port1
// MDC 'dynamic control mode' means that the server will to listen for status messages and NAK (from the client) on a port.
@ -351,6 +361,12 @@ internal object ServerHandshakePollers {
val logInfo = "$sessionId/$streamId:$clientAddressString"
if (shutdownInProgress.value) {
driver.deleteLogFile(image)
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] server is shutting down. Aborting new connection attempts."))
return
}
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length)
@ -407,7 +423,7 @@ internal object ServerHandshakePollers {
try {
// we actually have to wait for it to connect before we continue.
//
driver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
driver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ServerTimedoutException("$logInfo publication cannot connect with client in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
} catch (e: Exception) {
@ -556,7 +572,12 @@ internal object ServerHandshakePollers {
override fun close() {
delegate.close()
handler.clear()
driver.close(server)
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPC poller")
}
@ -605,7 +626,12 @@ internal object ServerHandshakePollers {
override fun close() {
delegate.close()
handler.clear()
driver.close(server)
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPv4 poller")
}
@ -652,7 +678,12 @@ internal object ServerHandshakePollers {
override fun close() {
delegate.close()
handler.clear()
driver.close(server)
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPv4 poller")
}
@ -700,7 +731,12 @@ internal object ServerHandshakePollers {
override fun close() {
delegate.close()
handler.clear()
driver.close(server)
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPv4+6 poller")
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -84,7 +84,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}
@ -167,7 +167,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}
@ -250,7 +250,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}
@ -274,7 +274,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}
@ -367,7 +367,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}
@ -395,7 +395,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}
@ -487,7 +487,7 @@ class AeronPubSubTest : BaseTest() {
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
clientDriver.waitForConnection(publication, handshakeTimeoutNs, "client_$index") { cause ->
clientDriver.waitForConnection(atomic(false), publication, handshakeTimeoutNs, "client_$index") { cause ->
ClientTimedOutException("Client publication cannot connect with localhost server", cause)
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -33,9 +33,6 @@ import java.util.concurrent.*
@Suppress("UNUSED_ANONYMOUS_PARAMETER")
class MultiClientTest : BaseTest() {
// this can be upped to 100 for stress testing, but for general unit tests this should be smaller (as this is sensitive on the load of the machine)
private val totalCount = 80
private val clientConnectCount = atomic(0)
private val serverConnectCount = atomic(0)
private val disconnectCount = atomic(0)
@ -43,6 +40,11 @@ class MultiClientTest : BaseTest() {
@OptIn(DelicateCoroutinesApi::class, ExperimentalCoroutinesApi::class)
@Test
fun multiConnectClient() {
// this can be upped to 100 for stress testing, but for general unit tests this should be smaller (as this is sensitive on the load of the machine)
// THE ONLY limitation you will have with this, is the size of the temp drive space.
val totalCount = 30
val server = run {
val config = serverConfig()
config.uniqueAeronDirectory = true