Handshake validate is only done based on the one-time key (instead off session id) as the session ID was shared in some situations (the aeron session ID is not guarateed to be unique, unless it is manually set)
This commit is contained in:
parent
4dc58f2485
commit
cc9742fe14
|
@ -25,17 +25,18 @@ import io.aeron.FragmentAssembler
|
|||
import io.aeron.logbuffer.FragmentHandler
|
||||
import io.aeron.logbuffer.Header
|
||||
import org.agrona.DirectBuffer
|
||||
import java.security.SecureRandom
|
||||
|
||||
internal class ClientHandshake<CONNECTION: Connection>(private val crypto: CryptoManagement, private val endPoint: EndPoint<CONNECTION>) {
|
||||
|
||||
// @Volatile is used BECAUSE suspension of coroutines can continue on a DIFFERENT thread. We want to make sure that thread visibility is
|
||||
// correct when this happens. There are no race-conditions to be wary of.
|
||||
|
||||
// a one-time key for connecting
|
||||
private val oneTimePad = SecureRandom().nextInt()
|
||||
private val handler: FragmentHandler
|
||||
|
||||
// a one-time key for connecting
|
||||
@Volatile
|
||||
var oneTimeKey = 0
|
||||
|
||||
@Volatile
|
||||
private var connectionHelloInfo: ClientConnectionInfo? = null
|
||||
|
||||
|
@ -48,9 +49,6 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
|
|||
@Volatile
|
||||
private var failed: Exception? = null
|
||||
|
||||
@Volatile
|
||||
private var sessionId: Int = 0
|
||||
|
||||
init {
|
||||
// now we have a bi-directional connection with the server on the handshake "socket".
|
||||
handler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
|
||||
|
@ -79,9 +77,8 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
|
|||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
if (this@ClientHandshake.sessionId != message.sessionId) {
|
||||
failed = ClientException("[$message.sessionId] ignored message intended for another client (mine is: " +
|
||||
"${this@ClientHandshake.sessionId})")
|
||||
if (oneTimeKey != message.oneTimeKey) {
|
||||
failed = ClientException("[$message.sessionId] ignored message (one-time key: ${message.oneTimeKey}) intended for another client (mine is: ${oneTimeKey})")
|
||||
return@FragmentAssembler
|
||||
}
|
||||
|
||||
|
@ -130,19 +127,21 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
|
|||
|
||||
// called from the connect thread
|
||||
suspend fun handshakeHello(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long) : ClientConnectionInfo {
|
||||
val registrationMessage = HandshakeMessage.helloFromClient(oneTimePad, endPoint.settingsStore.getPublicKey()!!)
|
||||
oneTimeKey = endPoint.crypto.secureRandom.nextInt()
|
||||
val publicKey = endPoint.settingsStore.getPublicKey()!!
|
||||
|
||||
// Send the one-time pad to the server.
|
||||
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
|
||||
sessionId = handshakeConnection.publication.sessionId()
|
||||
val publication = handshakeConnection.publication
|
||||
val subscription = handshakeConnection.subscription
|
||||
val pollIdleStrategy = endPoint.config.pollIdleStrategy
|
||||
|
||||
|
||||
endPoint.writeHandshakeMessage(publication, HandshakeMessage.helloFromClient(oneTimeKey, publicKey))
|
||||
|
||||
// block until we receive the connection information from the server
|
||||
|
||||
failed = null
|
||||
var pollCount: Int
|
||||
val subscription = handshakeConnection.subscription
|
||||
val pollIdleStrategy = endPoint.config.pollIdleStrategy
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
while (connectionTimeoutMS == 0L || System.currentTimeMillis() - startTime < connectionTimeoutMS) {
|
||||
|
@ -176,7 +175,7 @@ internal class ClientHandshake<CONNECTION: Connection>(private val crypto: Crypt
|
|||
|
||||
// called from the connect thread
|
||||
suspend fun handshakeDone(handshakeConnection: MediaDriverConnection, connectionTimeoutMS: Long): Boolean {
|
||||
val registrationMessage = HandshakeMessage.doneFromClient()
|
||||
val registrationMessage = HandshakeMessage.doneFromClient(oneTimeKey)
|
||||
|
||||
// Send the done message to the server.
|
||||
endPoint.writeHandshakeMessage(handshakeConnection.publication, registrationMessage)
|
||||
|
|
|
@ -26,7 +26,7 @@ internal class HandshakeMessage private constructor() {
|
|||
|
||||
// used to keep track and associate UDP/etc sessions. This is always defined by the server
|
||||
// a sessionId if '0', means we are still figuring it out.
|
||||
var oneTimePad = 0
|
||||
var oneTimeKey = 0
|
||||
|
||||
// -1 means there is an error
|
||||
var state = INVALID
|
||||
|
@ -58,37 +58,41 @@ internal class HandshakeMessage private constructor() {
|
|||
const val DONE = 3
|
||||
const val DONE_ACK = 4
|
||||
|
||||
fun helloFromClient(oneTimePad: Int, publicKey: ByteArray): HandshakeMessage {
|
||||
fun helloFromClient(oneTimeKey: Int, publicKey: ByteArray): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = HELLO
|
||||
hello.oneTimePad = oneTimePad
|
||||
hello.oneTimeKey = oneTimeKey
|
||||
hello.publicKey = publicKey
|
||||
return hello
|
||||
}
|
||||
|
||||
fun helloAckToClient(sessionId: Int): HandshakeMessage {
|
||||
fun helloAckToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = HELLO_ACK
|
||||
hello.sessionId = sessionId // has to be the same as before (the client expects this)
|
||||
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
|
||||
hello.sessionId = sessionId
|
||||
return hello
|
||||
}
|
||||
|
||||
fun helloAckIpcToClient(sessionId: Int): HandshakeMessage {
|
||||
fun helloAckIpcToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = HELLO_ACK_IPC
|
||||
hello.sessionId = sessionId // has to be the same as before (the client expects this)
|
||||
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
|
||||
hello.sessionId = sessionId
|
||||
return hello
|
||||
}
|
||||
|
||||
fun doneFromClient(): HandshakeMessage {
|
||||
fun doneFromClient(oneTimeKey: Int): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = DONE
|
||||
hello.oneTimeKey = oneTimeKey
|
||||
return hello
|
||||
}
|
||||
|
||||
fun doneToClient(sessionId: Int): HandshakeMessage {
|
||||
fun doneToClient(oneTimeKey: Int, sessionId: Int): HandshakeMessage {
|
||||
val hello = HandshakeMessage()
|
||||
hello.state = DONE_ACK
|
||||
hello.oneTimeKey = oneTimeKey // has to be the same as before (the client expects this)
|
||||
hello.sessionId = sessionId
|
||||
return hello
|
||||
}
|
||||
|
@ -131,6 +135,6 @@ internal class HandshakeMessage private constructor() {
|
|||
}
|
||||
|
||||
|
||||
return "HandshakeMessage($sessionId : oneTimePad=$oneTimePad $stateStr$errorMsg)"
|
||||
return "HandshakeMessage($sessionId : oneTimePad=$oneTimeKey $stateStr$errorMsg)"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,7 +127,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
// now tell the client we are done
|
||||
actionDispatch.launch {
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(sessionId))
|
||||
server.writeHandshakeMessage(handshakePublication, HandshakeMessage.doneToClient(message.oneTimeKey, sessionId))
|
||||
|
||||
listenerManager.notifyConnect(pendingConnection)
|
||||
}
|
||||
|
@ -296,7 +296,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
|
||||
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
|
||||
val successMessage = HandshakeMessage.helloAckIpcToClient(sessionId)
|
||||
val successMessage = HandshakeMessage.helloAckIpcToClient(message.oneTimeKey, sessionId)
|
||||
|
||||
|
||||
// if necessary, we also send the kryo RMI id's that are registered as RMI on this endpoint, but maybe not on the other endpoint
|
||||
|
@ -468,7 +468,7 @@ internal class ServerHandshake<CONNECTION : Connection>(private val logger: KLog
|
|||
|
||||
|
||||
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
|
||||
val successMessage = HandshakeMessage.helloAckToClient(sessionId)
|
||||
val successMessage = HandshakeMessage.helloAckToClient(message.oneTimeKey, sessionId)
|
||||
|
||||
|
||||
// Also send the RMI registration data to the client (so the client doesn't register anything)
|
||||
|
|
Loading…
Reference in New Issue
Block a user