Handshake validate is only done based on the one-time key (instead off session id) as the session ID was shared in some situations (the aeron session ID is not guarateed to be unique, unless it is manually set)

This commit is contained in:
nathan 2020-09-28 16:30:38 +02:00
parent 4dc58f2485
commit cc9742fe14
3 changed files with 31 additions and 28 deletions

View File

@ -25,17 +25,18 @@ import io.aeron.FragmentAssembler
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header
import org.agrona.DirectBuffer
import java.security.SecureRandom
internal class ClientHandshake<CONNECTION: Connection>(private val crypto: CryptoManagement, private val endPoint: EndPoint<CONNECTION>) {
// @Volatile is used BECAUSE suspension of coroutines can continue on a DIFFERENT thread. We want to make sure that thread visibility is
// correct when this happens. There are no race-conditions to be wary of.
// a one-time key for connecting
private val oneTimePad = SecureRandom().nextInt()
private val handler: FragmentHandler
// a one-time key for connecting
@Volatile
var oneTimeKey = 0
@Volatile
private var connectionHelloInfo: ClientConnectionInfo? = null
@ -48,9 +49,6 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
@Volatile
private var failed: Exception? = null
@Volatile
private var sessionId: Int = 0
init {
// now we have a bi-directional connection with the server on the handshake "socket".
handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
@ -79,9 +77,8 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
return@FragmentAssembler
}
if (this@ClientHandshake.sessionId != message.sessionId) {
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " +
"${this@ClientHandshake.sessionId})")
if (oneTimeKey != message.oneTimeKey) {
failed = ClientException("[$message.sessionId] ignored message (one-time key: ${message.oneTimeKey}) intended for another client (mine is: ${oneTimeKey})")
return@FragmentAssembler
}
@ -130,19 +127,21 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
// called from the connect thread
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, endPoint.settingsStore.getPublicKey()!!)
oneTimeKey = endPoint.crypto.secureRandom.nextInt()
val publicKey = endPoint.settingsStore.getPublicKey()!!
// Send the one-time pad to the server.
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
sessionId = handshakeConnection.publication.sessionId()
val publication = handshakeConnection.publication
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
// block until we receive the connection information from the server
failed = null
var pollCount: Int
val subscription = handshakeConnection.subscription
val pollIdleStrategy = endPoint.config.pollIdleStrategy
val startTime = System.currentTimeMillis()
while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) {
@ -176,7 +175,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
// called from the connect thread
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
val registrationMessage = HandshakeMessage.doneFromClient()
val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
// Send the done message to the server.
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)

View File

@ -26,7 +26,7 @@ internal class HandshakeMessage private constructor() {
// used to keep track and associate UDP/etc sessions. This is always defined by the server
// a sessionId if '0', means we are still figuring it out.
var oneTimePad = 0
var oneTimeKey = 0
// -1 means there is an error
var state = INVALID
@ -58,37 +58,41 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3
const val DONE_ACK = 4
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray): HandshakeMessage {
fun helloFromClient(oneTimeKey: Int, publicKey: ByteArray): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.oneTimePad = oneTimePad
hello.oneTimeKey = oneTimeKey
hello.publicKey = publicKey
return hello
}
fun helloAckToClient(sessionId: Int): HandshakeMessage {
fun helloAckToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO_ACK
hello.sessionId = sessionId // has to be the same as before (the client expects this)
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
return hello
}
fun helloAckIpcToClient(sessionId: Int): HandshakeMessage {
fun helloAckIpcToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO_ACK_IPC
hello.sessionId = sessionId // has to be the same as before (the client expects this)
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
return hello
}
fun doneFromClient(): HandshakeMessage {
fun doneFromClient(oneTimeKey: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE
hello.oneTimeKey = oneTimeKey
return hello
}
fun doneToClient(sessionId: Int): HandshakeMessage {
fun doneToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE_ACK
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
hello.sessionId = sessionId
return hello
}
@ -131,6 +135,6 @@ internal class HandshakeMessage private constructor() {
}
return "HandshakeMessage($sessionId : oneTimePad=$oneTimePad $stateStr$errorMsg)"
return "HandshakeMessage($sessionId : oneTimePad=$oneTimeKey $stateStr$errorMsg)"
}
}

View File

@ -127,7 +127,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// now tell the client we are done
actionDispatch.launch {
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.oneTimeKey, sessionId))
listenerManager.notifyConnect(pendingConnection)
}
@ -296,7 +296,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckIpcToClient(sessionId)
val successMessage = HandshakeMessage.helloAckIpcToClient(message.oneTimeKey, sessionId)
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint
@ -468,7 +468,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(sessionId)
val successMessage = HandshakeMessage.helloAckToClient(message.oneTimeKey, sessionId)
// Also send the RMI registration data to the client (so the client doesn't register anything)