diff --git a/src/dorkbox/network/handshake/ClientHandshake.kt b/src/dorkbox/network/handshake/ClientHandshake.kt index 92450de5..52624f56 100644 --- a/src/dorkbox/network/handshake/ClientHandshake.kt +++ b/src/dorkbox/network/handshake/ClientHandshake.kt @@ -25,17 +25,18 @@ import io.aeron.FragmentAssembler import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.Header import org.agrona.DirectBuffer -import java.security.SecureRandom internal class ClientHandshake(private val crypto: CryptoManagement, private val endPoint: EndPoint) { // @Volatile is used BECAUSE suspension of coroutines can continue on a DIFFERENT thread. We want to make sure that thread visibility is // correct when this happens. There are no race-conditions to be wary of. - // a one-time key for connecting - private val oneTimePad = SecureRandom().nextInt() private val handler: FragmentHandler + // a one-time key for connecting + @Volatile + var oneTimeKey = 0 + @Volatile private var connectionHelloInfo: ClientConnectionInfo? = null @@ -48,9 +49,6 @@ internal class ClientHandshake(private val crypto: Crypt @Volatile private var failed: Exception? = null - @Volatile - private var sessionId: Int = 0 - init { // now we have a bi-directional connection with the server on the handshake "socket". handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> @@ -79,9 +77,8 @@ internal class ClientHandshake(private val crypto: Crypt return@FragmentAssembler } - if (this@ClientHandshake.sessionId != message.sessionId) { - failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " + - "${this@ClientHandshake.sessionId})") + if (oneTimeKey != message.oneTimeKey) { + failed = ClientException("[$message.sessionId] ignored message (one-time key: ${message.oneTimeKey}) intended for another client (mine is: ${oneTimeKey})") return@FragmentAssembler } @@ -130,19 +127,21 @@ internal class ClientHandshake(private val crypto: Crypt // called from the connect thread suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { - val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, endPoint.settingsStore.getPublicKey()!!) + oneTimeKey = endPoint.crypto.secureRandom.nextInt() + val publicKey = endPoint.settingsStore.getPublicKey()!! // Send the one-time pad to the server. - endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) - sessionId = handshakeConnection.publication.sessionId() + val publication = handshakeConnection.publication + val subscription = handshakeConnection.subscription + val pollIdleStrategy = endPoint.config.pollIdleStrategy + endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey)) + // block until we receive the connection information from the server failed = null var pollCount: Int - val subscription = handshakeConnection.subscription - val pollIdleStrategy = endPoint.config.pollIdleStrategy val startTime = System.currentTimeMillis() while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) { @@ -176,7 +175,7 @@ internal class ClientHandshake(private val crypto: Crypt // called from the connect thread suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { - val registrationMessage = HandshakeMessage.doneFromClient() + val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey) // Send the done message to the server. endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) diff --git a/src/dorkbox/network/handshake/HandshakeMessage.kt b/src/dorkbox/network/handshake/HandshakeMessage.kt index d94df6d5..306b9127 100644 --- a/src/dorkbox/network/handshake/HandshakeMessage.kt +++ b/src/dorkbox/network/handshake/HandshakeMessage.kt @@ -26,7 +26,7 @@ internal class HandshakeMessage private constructor() { // used to keep track and associate UDP/etc sessions. This is always defined by the server // a sessionId if '0', means we are still figuring it out. - var oneTimePad = 0 + var oneTimeKey = 0 // -1 means there is an error var state = INVALID @@ -58,37 +58,41 @@ internal class HandshakeMessage private constructor() { const val DONE = 3 const val DONE_ACK = 4 - fun helloFromClient(oneTimePad: Int, publicKey: ByteArray): HandshakeMessage { + fun helloFromClient(oneTimeKey: Int, publicKey: ByteArray): HandshakeMessage { val hello = HandshakeMessage() hello.state = HELLO - hello.oneTimePad = oneTimePad + hello.oneTimeKey = oneTimeKey hello.publicKey = publicKey return hello } - fun helloAckToClient(sessionId: Int): HandshakeMessage { + fun helloAckToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage { val hello = HandshakeMessage() hello.state = HELLO_ACK - hello.sessionId = sessionId // has to be the same as before (the client expects this) + hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this) + hello.sessionId = sessionId return hello } - fun helloAckIpcToClient(sessionId: Int): HandshakeMessage { + fun helloAckIpcToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage { val hello = HandshakeMessage() hello.state = HELLO_ACK_IPC - hello.sessionId = sessionId // has to be the same as before (the client expects this) + hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this) + hello.sessionId = sessionId return hello } - fun doneFromClient(): HandshakeMessage { + fun doneFromClient(oneTimeKey: Int): HandshakeMessage { val hello = HandshakeMessage() hello.state = DONE + hello.oneTimeKey = oneTimeKey return hello } - fun doneToClient(sessionId: Int): HandshakeMessage { + fun doneToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage { val hello = HandshakeMessage() hello.state = DONE_ACK + hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this) hello.sessionId = sessionId return hello } @@ -131,6 +135,6 @@ internal class HandshakeMessage private constructor() { } - return "HandshakeMessage($sessionId : oneTimePad=$oneTimePad $stateStr$errorMsg)" + return "HandshakeMessage($sessionId : oneTimePad=$oneTimeKey $stateStr$errorMsg)" } } diff --git a/src/dorkbox/network/handshake/ServerHandshake.kt b/src/dorkbox/network/handshake/ServerHandshake.kt index 4a25579c..137bdcf3 100644 --- a/src/dorkbox/network/handshake/ServerHandshake.kt +++ b/src/dorkbox/network/handshake/ServerHandshake.kt @@ -127,7 +127,7 @@ internal class ServerHandshake(private val logger: KLog // now tell the client we are done actionDispatch.launch { - server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) + server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.oneTimeKey, sessionId)) listenerManager.notifyConnect(pendingConnection) } @@ -296,7 +296,7 @@ internal class ServerHandshake(private val logger: KLog // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! - val successMessage = HandshakeMessage.helloAckIpcToClient(sessionId) + val successMessage = HandshakeMessage.helloAckIpcToClient(message.oneTimeKey, sessionId) // if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint @@ -468,7 +468,7 @@ internal class ServerHandshake(private val logger: KLog // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! - val successMessage = HandshakeMessage.helloAckToClient(sessionId) + val successMessage = HandshakeMessage.helloAckToClient(message.oneTimeKey, sessionId) // Also send the RMI registration data to the client (so the client doesn't register anything)