Handshake validate is only done based on the one-time key (instead off session id) as the session ID was shared in some situations (the aeron session ID is not guarateed to be unique, unless it is manually set)

This commit is contained in:
nathan 2020-09-28 16:30:38 +02:00
parent 4dc58f2485
commit cc9742fe14
3 changed files with 31 additions and 28 deletions

View File

@ -25,17 +25,18 @@ import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.security.SecureRandom
internal class ClientHandshake<CONNECTION: Connection>(private val crypto: CryptoManagement, private val endPoint: EndPoint<CONNECTION>) { internal class ClientHandshake<CONNECTION: Connection>(private val crypto: CryptoManagement, private val endPoint: EndPoint<CONNECTION>) {
// @Volatile is used BECAUSE suspension of coroutines can continue on a DIFFERENT thread. We want to make sure that thread visibility is // @Volatile is used BECAUSE suspension of coroutines can continue on a DIFFERENT thread. We want to make sure that thread visibility is
// correct when this happens. There are no race-conditions to be wary of. // correct when this happens. There are no race-conditions to be wary of.
// a one-time key for connecting
private val oneTimePad = SecureRandom().nextInt()
private val handler: FragmentHandler private val handler: FragmentHandler
// a one-time key for connecting
@Volatile
var oneTimeKey = 0
@Volatile @Volatile
private var connectionHelloInfo: ClientConnectionInfo? = null private var connectionHelloInfo: ClientConnectionInfo? = null
@ -48,9 +49,6 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
@Volatile @Volatile
private var failed: Exception? = null private var failed: Exception? = null
@Volatile
private var sessionId: Int = 0
init { init {
// now we have a bi-directional connection with the server on the handshake "socket". // now we have a bi-directional connection with the server on the handshake "socket".
handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
@ -79,9 +77,8 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
return@FragmentAssembler return@FragmentAssembler
} }
if (this@ClientHandshake.sessionId != message.sessionId) { if (oneTimeKey != message.oneTimeKey) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " + failed = ClientException("[$message.sessionId] ignored message (one-time key: ${message.oneTimeKey}) intended for another client (mine is: ${oneTimeKey})")
"${this@ClientHandshake.sessionId})")
return@FragmentAssembler return@FragmentAssembler
} }
@ -130,19 +127,21 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
// called from the connect thread // called from the connect thread
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo { suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, endPoint.settingsStore.getPublicKey()!!) oneTimeKey = endPoint.crypto.secureRandom.nextInt()
val publicKey = endPoint.settingsStore.getPublicKey()!!
// Send the one-time pad to the server. // Send the one-time pad to the server.
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) val publication = handshakeConnection.publication
sessionId = handshakeConnection.publication.sessionId() val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
// block until we receive the connection information from the server // block until we receive the connection information from the server
failed = null failed = null
var pollCount: Int var pollCount: Int
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) { while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) {
@ -176,7 +175,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
// called from the connect thread // called from the connect thread
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean { suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient() val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
// Send the done message to the server. // Send the done message to the server.
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage) endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)

View File

@ -26,7 +26,7 @@ internal class HandshakeMessage private constructor() {
// used to keep track and associate UDP/etc sessions. This is always defined by the server // used to keep track and associate UDP/etc sessions. This is always defined by the server
// a sessionId if '0', means we are still figuring it out. // a sessionId if '0', means we are still figuring it out.
var oneTimePad = 0 var oneTimeKey = 0
// -1 means there is an error // -1 means there is an error
var state = INVALID var state = INVALID
@ -58,37 +58,41 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3 const val DONE = 3
const val DONE_ACK = 4 const val DONE_ACK = 4
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray): HandshakeMessage { fun helloFromClient(oneTimeKey: Int, publicKey: ByteArray): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = HELLO hello.state = HELLO
hello.oneTimePad = oneTimePad hello.oneTimeKey = oneTimeKey
hello.publicKey = publicKey hello.publicKey = publicKey
return hello return hello
} }
fun helloAckToClient(sessionId: Int): HandshakeMessage { fun helloAckToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = HELLO_ACK hello.state = HELLO_ACK
hello.sessionId = sessionId // has to be the same as before (the client expects this) hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
return hello return hello
} }
fun helloAckIpcToClient(sessionId: Int): HandshakeMessage { fun helloAckIpcToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = HELLO_ACK_IPC hello.state = HELLO_ACK_IPC
hello.sessionId = sessionId // has to be the same as before (the client expects this) hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
return hello return hello
} }
fun doneFromClient(): HandshakeMessage { fun doneFromClient(oneTimeKey: Int): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = DONE hello.state = DONE
hello.oneTimeKey = oneTimeKey
return hello return hello
} }
fun doneToClient(sessionId: Int): HandshakeMessage { fun doneToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = DONE_ACK hello.state = DONE_ACK
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId hello.sessionId = sessionId
return hello return hello
} }
@ -131,6 +135,6 @@ internal class HandshakeMessage private constructor() {
} }
return "HandshakeMessage($sessionId : oneTimePad=$oneTimePad $stateStr$errorMsg)" return "HandshakeMessage($sessionId : oneTimePad=$oneTimeKey $stateStr$errorMsg)"
} }
} }

View File

@ -127,7 +127,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// now tell the client we are done // now tell the client we are done
actionDispatch.launch { actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId)) server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.oneTimeKey, sessionId))
listenerManager.notifyConnect(pendingConnection) listenerManager.notifyConnect(pendingConnection)
} }
@ -296,7 +296,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckIpcToClient(sessionId) val successMessage = HandshakeMessage.helloAckIpcToClient(message.oneTimeKey, sessionId)
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint // if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint
@ -468,7 +468,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is! // The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(sessionId) val successMessage = HandshakeMessage.helloAckToClient(message.oneTimeKey, sessionId)
// Also send the RMI registration data to the client (so the client doesn't register anything) // Also send the RMI registration data to the client (so the client doesn't register anything)