diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 2588a5ea..1e7212cc 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -168,8 +168,7 @@ abstract class EndPoint private constructor(val type: C * This is only notified when endpoint.close() is called where EVERYTHING is to be closed. */ @Volatile - private var closeLatch = CountDownLatch(1) - + internal var closeLatch = CountDownLatch(0) /** * Returns the storage used by this endpoint. This is the backing data structure for key/value pairs, and can be a database, file, etc @@ -317,10 +316,20 @@ abstract class EndPoint private constructor(val type: C * The client calls this every time it attempts a connection. */ internal fun initializeState() { + // on repeated runs, we have to make sure that we release the original latches so we don't appear to deadlock. + val origCloseLatch = closeLatch + val origShutdownLatch = shutdownLatch + val origPollerLatch = pollerClosedLatch + // on the first run, we depend on these to be 0 shutdownLatch = CountDownLatch(1) closeLatch = CountDownLatch(1) + // make sure we don't deadlock if we are waiting for the server to close + origCloseLatch.countDown() + origShutdownLatch.countDown() + origPollerLatch.countDown() + endpointIsRunning.lazySet(true) shutdown = false shutdownEventPoller = false @@ -633,10 +642,10 @@ abstract class EndPoint private constructor(val type: C logger.debug("Received session disconnect message from $otherTypeName") } } - connection.close(sendDisconnectMessage = false, - notifyDisconnect = true, - closeEverything = closeEverything - ) + + // make sure we flag the connection as NOT to timeout!! + connection.isClosedWithTimeout() // we only need this to update fields + connection.close(sendDisconnectMessage = false, closeEverything = closeEverything) } // streaming message. This is used when the published data is too large for a single Aeron message. @@ -844,14 +853,33 @@ abstract class EndPoint private constructor(val type: C * @return true if the wait completed before the timeout */ internal fun waitForEndpointShutdown(timeoutMS: Long = 0L): Boolean { - return if (timeoutMS > 0) { - pollerClosedLatch.await(timeoutMS, TimeUnit.MILLISECONDS) && - shutdownLatch.await(timeoutMS, TimeUnit.MILLISECONDS) - } else { - pollerClosedLatch.await() - shutdownLatch.await() - true + // default is true, because if we haven't started up yet, we don't even check the latches + var success = true + + + var origPollerLatch: CountDownLatch? + var origShutdownLatch: CountDownLatch? = null + + + // don't need to check for both, as they are set together (we just have to check the later of the two) + while (origShutdownLatch !== shutdownLatch) { + // if we redefine the latches WHILE we are waiting for them, then we will NEVER release (since we lose the reference to the + // original latch). This makes sure to check again to make sure we don't appear to deadlock + origPollerLatch = pollerClosedLatch + origShutdownLatch = shutdownLatch + + + if (timeoutMS > 0) { + success = success && origPollerLatch.await(timeoutMS, TimeUnit.MILLISECONDS) + success = success && origShutdownLatch.await(timeoutMS, TimeUnit.MILLISECONDS) + } else { + origPollerLatch.await() + origShutdownLatch.await() + success = true + } } + + return success } @@ -876,11 +904,22 @@ abstract class EndPoint private constructor(val type: C throw IllegalStateException("Unable to 'waitForClose()' while inside the network event dispatch, this will deadlock!") } - val success = if (timeoutMS > 0) { - closeLatch.await(timeoutMS, TimeUnit.MILLISECONDS) - } else { - closeLatch.await() - true + + var origCloseLatch: CountDownLatch? = null + + var success = false + while (origCloseLatch !== closeLatch) { + // if we redefine the latches WHILE we are waiting for them, then we will NEVER release (since we lose the reference to the + // original latch). This makes sure to check again to make sure we don't appear to deadlock + origCloseLatch = closeLatch + + + success = if (timeoutMS > 0) { + origCloseLatch.await(timeoutMS, TimeUnit.MILLISECONDS) + } else { + origCloseLatch.await() + true + } } return success @@ -955,13 +994,12 @@ abstract class EndPoint private constructor(val type: C logger.debug("Shutting down endpoint...") } + // always do this. It is OK to run this multiple times // the server has to be able to call server.notifyDisconnect() on a list of connections. If we remove the connections // inside of connection.close(), then the server does not have a list of connections to call the global notifyDisconnect() connections.forEach { - it.closeImmediately(sendDisconnectMessage = sendDisconnectMessage, - notifyDisconnect = notifyDisconnect, - closeEverything = closeEverything) + it.closeImmediately(sendDisconnectMessage = sendDisconnectMessage, closeEverything = closeEverything) }