Now use the client public-key as the client ID for the connection (to determine/know exactly which connection belongs to what client)

This commit is contained in:
Robinson 2023-07-02 21:56:46 +02:00
parent 87e790dcaf
commit e403e1d6e1
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
10 changed files with 45 additions and 73 deletions

View File

@ -665,7 +665,7 @@ open class Client<CONNECTION : Connection>(
// throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class) // throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class)
val connectionInfo = handshake.hello(handshakeConnection, connectionTimeoutSec, uuid) val connectionInfo = handshake.hello(handshakeConnection, connectionTimeoutSec)
// VALIDATE:: check to see if the remote connection's public key has changed! // VALIDATE:: check to see if the remote connection's public key has changed!
val validateRemoteAddress = if (handshakeConnection.pubSub.isIpc) { val validateRemoteAddress = if (handshakeConnection.pubSub.isIpc) {
@ -732,9 +732,9 @@ open class Client<CONNECTION : Connection>(
val newConnection: CONNECTION val newConnection: CONNECTION
if (handshakeConnection.pubSub.isIpc) { if (handshakeConnection.pubSub.isIpc) {
newConnection = connectionFunc(ConnectionParams(uuid, this, clientConnection.connectionInfo, PublicKeyValidationState.VALID)) newConnection = connectionFunc(ConnectionParams(connectionInfo.publicKey, this, clientConnection.connectionInfo, PublicKeyValidationState.VALID))
} else { } else {
newConnection = connectionFunc(ConnectionParams(uuid, this, clientConnection.connectionInfo, validateRemoteAddress)) newConnection = connectionFunc(ConnectionParams(connectionInfo.publicKey, this, clientConnection.connectionInfo, validateRemoteAddress))
address!! address!!
// NOTE: Client can ALWAYS connect to the server. The server makes the decision if the client can connect or not. // NOTE: Client can ALWAYS connect to the server. The server makes the decision if the client can connect or not.
@ -907,7 +907,7 @@ open class Client<CONNECTION : Connection>(
} }
override fun toString(): String { override fun toString(): String {
return "EndPoint [Client: $uuid]" return "EndPoint [Client: $${storage.publicKey!!.toHexString()}]"
} }
fun <R> use(block: (Client<CONNECTION>) -> R): R { fun <R> use(block: (Client<CONNECTION>) -> R): R {

View File

@ -16,25 +16,19 @@
package dorkbox.network.aeron package dorkbox.network.aeron
import dorkbox.bytes.ByteArrayWrapper
import dorkbox.collections.ConcurrentIterator import dorkbox.collections.ConcurrentIterator
import dorkbox.network.Configuration import dorkbox.network.Configuration
import dorkbox.network.connection.EndPoint import dorkbox.network.connection.EndPoint
import dorkbox.util.NamedThreadFactory import dorkbox.util.NamedThreadFactory
import dorkbox.util.sync.CountDownLatch import dorkbox.util.sync.CountDownLatch
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.*
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.cancel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import mu.KLogger import mu.KLogger
import mu.KotlinLogging import mu.KotlinLogging
import org.agrona.concurrent.IdleStrategy import org.agrona.concurrent.IdleStrategy
import java.util.*
import java.util.concurrent.* import java.util.concurrent.*
/** /**
@ -65,7 +59,7 @@ internal class EventPoller {
// this is thread safe // this is thread safe
private val pollEvents = ConcurrentIterator<Pair<suspend EventPoller.()->Int, suspend ()->Unit>>() private val pollEvents = ConcurrentIterator<Pair<suspend EventPoller.()->Int, suspend ()->Unit>>()
private val submitEvents = atomic(0) private val submitEvents = atomic(0)
private val configureEventsEndpoints = mutableSetOf<UUID>() private val configureEventsEndpoints = mutableSetOf<ByteArrayWrapper>()
@Volatile @Volatile
private var delayClose = false private var delayClose = false
@ -86,7 +80,7 @@ internal class EventPoller {
fun configure(logger: KLogger, config: Configuration, endPoint: EndPoint<*>) = runBlocking { fun configure(logger: KLogger, config: Configuration, endPoint: EndPoint<*>) = runBlocking {
mutex.withLock { mutex.withLock {
logger.debug { "Initializing the Network Event Poller..." } logger.debug { "Initializing the Network Event Poller..." }
configureEventsEndpoints.add(endPoint.uuid) configureEventsEndpoints.add(ByteArrayWrapper.wrap(endPoint.storage.publicKey)!!)
if (!configured) { if (!configured) {
logger.trace { "Configuring the Network Event Poller..." } logger.trace { "Configuring the Network Event Poller..." }
@ -198,7 +192,8 @@ internal class EventPoller {
// ONLY if there are no more poll-events do we ACTUALLY shut down. // ONLY if there are no more poll-events do we ACTUALLY shut down.
// when an endpoint closes its polling, it will automatically be removed from this datastructure. // when an endpoint closes its polling, it will automatically be removed from this datastructure.
configureEventsEndpoints.removeIf { it == endPoint.uuid } val publicKeyWrapped = ByteArrayWrapper.wrap(endPoint.storage.publicKey)
configureEventsEndpoints.removeIf { it == publicKeyWrapped }
val cEvents = configureEventsEndpoints.size val cEvents = configureEventsEndpoints.size
// these prevent us from closing too early // these prevent us from closing too early

View File

@ -75,7 +75,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/** /**
* This is the client UUID. This is useful determine if the same client is connecting multiple times to a server (instead of only using IP address) * This is the client UUID. This is useful determine if the same client is connecting multiple times to a server (instead of only using IP address)
*/ */
val uuid = connectionParameters.clientUuid val uuid = connectionParameters.publicKey
/** /**
* The unique session id of this connection, assigned by the server. * The unique session id of this connection, assigned by the server.

View File

@ -16,10 +16,9 @@
package dorkbox.network.connection package dorkbox.network.connection
import dorkbox.network.handshake.PubSub import dorkbox.network.handshake.PubSub
import java.util.*
data class ConnectionParams<CONNECTION : Connection>( data class ConnectionParams<CONNECTION : Connection>(
val clientUuid: UUID, val publicKey: ByteArray,
val endPoint: EndPoint<CONNECTION>, val endPoint: EndPoint<CONNECTION>,
val connectionInfo: PubSub, val connectionInfo: PubSub,
val publicKeyValidation: PublicKeyValidationState val publicKeyValidation: PublicKeyValidationState

View File

@ -15,7 +15,6 @@
*/ */
package dorkbox.network.connection package dorkbox.network.connection
import com.fasterxml.uuid.impl.RandomBasedGenerator
import dorkbox.collections.ConcurrentIterator import dorkbox.collections.ConcurrentIterator
import dorkbox.netUtil.IP import dorkbox.netUtil.IP
import dorkbox.network.Client import dorkbox.network.Client
@ -118,12 +117,6 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
} }
} }
/**
* The UUID is a unique, in-memory instance that is created on object construction
*/
val uuid = RandomBasedGenerator(CryptoManagement.secureRandom).generate()
// the ID would be different?? but the UUID would be the same??
val logger: KLogger = KotlinLogging.logger(loggerName) val logger: KLogger = KotlinLogging.logger(loggerName)
// this is rather silly, BUT if there are more complex errors WITH the coroutine that occur, a regular try/catch WILL NOT catch it. // this is rather silly, BUT if there are more complex errors WITH the coroutine that occur, a regular try/catch WILL NOT catch it.

View File

@ -28,7 +28,6 @@ import io.aeron.logbuffer.Header
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import mu.KLogger import mu.KLogger
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
import java.util.*
import java.util.concurrent.* import java.util.concurrent.*
internal class ClientHandshake<CONNECTION: Connection>( internal class ClientHandshake<CONNECTION: Connection>(
@ -88,7 +87,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// ugh, this is verbose -- but necessary // ugh, this is verbose -- but necessary
val message = try { val message = try {
val msg = handshaker.readMessage(buffer, offset, length, logInfo) val msg = handshaker.readMessage(buffer, offset, length)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) { if (msg !is HandshakeMessage) {
@ -178,7 +177,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread // called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed // when exceptions are thrown, the handshake pub/sub will be closed
suspend fun hello(handshakeConnection: ClientHandshakeDriver, handshakeTimeoutSec: Int, uuid: UUID) : ClientConnectionInfo { suspend fun hello(handshakeConnection: ClientHandshakeDriver, handshakeTimeoutSec: Int) : ClientConnectionInfo {
val pubSub = handshakeConnection.pubSub val pubSub = handshakeConnection.pubSub
// is our pub still connected?? // is our pub still connected??
@ -190,17 +189,14 @@ internal class ClientHandshake<CONNECTION: Connection>(
reset() reset()
connectKey = getSafeConnectKey() connectKey = getSafeConnectKey()
val publicKey = client.storage.getPublicKey()!!
try { try {
// Send the one-time pad to the server. // Send the one-time pad to the server.
handshaker.writeMessage(pubSub.pub, handshakeConnection.details, handshaker.writeMessage(pubSub.pub, handshakeConnection.details,
HandshakeMessage.helloFromClient( HandshakeMessage.helloFromClient(
connectKey = connectKey, connectKey = connectKey,
publicKey = publicKey, publicKey = client.storage.publicKey!!,
streamIdSub = pubSub.streamIdSub, streamIdSub = pubSub.streamIdSub,
portSub = pubSub.portSub, portSub = pubSub.portSub
uuid = uuid
)) ))
} catch (e: Exception) { } catch (e: Exception) {
handshakeConnection.close() handshakeConnection.close()

View File

@ -15,9 +15,6 @@
*/ */
package dorkbox.network.handshake package dorkbox.network.handshake
import dorkbox.bytes.LittleEndian
import java.util.*
/** /**
* Internal message to handle the connection registration process * Internal message to handle the connection registration process
*/ */
@ -54,19 +51,7 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3 const val DONE = 3
const val DONE_ACK = 4 const val DONE_ACK = 4
private val uuidWriter: (UUID) -> ByteArray = { uuid -> fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int): HandshakeMessage {
val bytes = byteArrayOf(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0) // 16 elements
LittleEndian.Long_.toBytes(uuid.mostSignificantBits, bytes, 0)
LittleEndian.Long_.toBytes(uuid.leastSignificantBits, bytes, 8)
bytes
}
internal val uuidReader: (ByteArray) -> UUID = { bytes ->
UUID(LittleEndian.Long_.from(bytes, 0, 8),
LittleEndian.Long_.from(bytes, 8, 8))
}
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int, uuid: UUID): HandshakeMessage {
val hello = HandshakeMessage() val hello = HandshakeMessage()
hello.state = HELLO hello.state = HELLO
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
@ -74,7 +59,6 @@ internal class HandshakeMessage private constructor() {
hello.sessionId = 0 // not used by the server, since it connects in a different way! hello.sessionId = 0 // not used by the server, since it connects in a different way!
hello.streamId = streamIdSub hello.streamId = streamIdSub
hello.port = portSub hello.port = portSub
hello.registrationData = uuidWriter(uuid)
return hello return hello
} }

View File

@ -165,7 +165,7 @@ internal class Handshaker<CONNECTION : Connection>(
* *
* @return the message * @return the message
*/ */
internal fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, logInfo: String): Any? { internal fun readMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? {
// NOTE: This ABSOLUTELY MUST be done on the same thread! This cannot be done on a new one, because the buffer could change! // NOTE: This ABSOLUTELY MUST be done on the same thread! This cannot be done on a new one, because the buffer could change!
return handshakeReadKryo.read(buffer, offset, length) return handshakeReadKryo.read(buffer, offset, length)
} }

View File

@ -33,7 +33,6 @@ import net.jodah.expiringmap.ExpirationPolicy
import net.jodah.expiringmap.ExpiringMap import net.jodah.expiringmap.ExpiringMap
import java.net.Inet4Address import java.net.Inet4Address
import java.net.InetAddress import java.net.InetAddress
import java.util.*
import java.util.concurrent.* import java.util.concurrent.*
@ -210,7 +209,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
handshaker: Handshaker<CONNECTION>, handshaker: Handshaker<CONNECTION>,
aeronDriver: AeronDriver, aeronDriver: AeronDriver,
handshakePublication: Publication, handshakePublication: Publication,
clientUuid: UUID, publicKey: ByteArray,
message: HandshakeMessage, message: HandshakeMessage,
aeronLogInfo: String, aeronLogInfo: String,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION, connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
@ -331,7 +330,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
connection = connectionFunc(ConnectionParams(clientUuid, server, newConnectionDriver.pubSub, PublicKeyValidationState.VALID)) connection = connectionFunc(ConnectionParams(publicKey, server, newConnectionDriver.pubSub, PublicKeyValidationState.VALID))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
// NOTE: all IPC client connections are, by default, always allowed to connect, because they are running on the same machine // NOTE: all IPC client connections are, by default, always allowed to connect, because they are running on the same machine
@ -389,7 +388,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
server: Server<CONNECTION>, server: Server<CONNECTION>,
handshaker: Handshaker<CONNECTION>, handshaker: Handshaker<CONNECTION>,
handshakePublication: Publication, handshakePublication: Publication,
clientUuid: UUID, publicKey: ByteArray,
clientAddress: InetAddress, clientAddress: InetAddress,
clientAddressString: String, clientAddressString: String,
isReliable: Boolean, isReliable: Boolean,
@ -548,7 +547,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
logger.info { "Creating new connection to $logInfo" } logger.info { "Creating new connection to $logInfo" }
} }
connection = connectionFunc(ConnectionParams(clientUuid, server, newConnectionDriver.pubSub, validateRemoteAddress)) connection = connectionFunc(ConnectionParams(publicKey, server, newConnectionDriver.pubSub, validateRemoteAddress))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information) // VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(connection) val permitConnection = listenerManager.notifyFilter(connection)

View File

@ -93,7 +93,7 @@ internal object ServerHandshakePollers {
// ugh, this is verbose -- but necessary // ugh, this is verbose -- but necessary
val message = try { val message = try {
val msg = handshaker.readMessage(buffer, offset, length, logInfo) val msg = handshaker.readMessage(buffer, offset, length)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) { if (msg !is HandshakeMessage) {
@ -138,12 +138,14 @@ internal object ServerHandshakePollers {
return@launch return@launch
} }
try {
val success = handshake.processIpcHandshakeMessageServer( val success = handshake.processIpcHandshakeMessageServer(
server = server, server = server,
handshaker = handshaker, handshaker = handshaker,
aeronDriver = driver, aeronDriver = driver,
handshakePublication = publication, handshakePublication = publication,
clientUuid = HandshakeMessage.uuidReader(message.registrationData!!), publicKey = message.publicKey!!,
message = message, message = message,
aeronLogInfo = logInfo, aeronLogInfo = logInfo,
connectionFunc = connectionFunc, connectionFunc = connectionFunc,
@ -155,6 +157,10 @@ internal object ServerHandshakePollers {
} else { } else {
driver.closeAndDeletePublication(publication, "HANDSHAKE-IPC") driver.closeAndDeletePublication(publication, "HANDSHAKE-IPC")
} }
} catch (e: Exception) {
driver.closeAndDeletePublication(publication, "HANDSHAKE-IPC")
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] Error processing IPC handshake", e))
}
} else { } else {
val publication = publications.remove(connectKey) val publication = publications.remove(connectKey)
@ -259,7 +265,7 @@ internal object ServerHandshakePollers {
// ugh, this is verbose -- but necessary // ugh, this is verbose -- but necessary
val message = try { val message = try {
val msg = handshaker.readMessage(buffer, offset, length, logInfo) val msg = handshaker.readMessage(buffer, offset, length)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase // VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) { if (msg !is HandshakeMessage) {
@ -313,7 +319,7 @@ internal object ServerHandshakePollers {
server = server, server = server,
handshaker = handshaker, handshaker = handshaker,
handshakePublication = publication, handshakePublication = publication,
clientUuid = HandshakeMessage.uuidReader(message.registrationData!!), publicKey = message.publicKey!!,
clientAddress = clientAddress, clientAddress = clientAddress,
clientAddressString = clientAddressString, clientAddressString = clientAddressString,
isReliable = isReliable, isReliable = isReliable,