Now use the client public-key as the client ID for the connection (to determine/know exactly which connection belongs to what client)

This commit is contained in:
Robinson 2023-07-02 21:56:46 +02:00
parent 87e790dcaf
commit e403e1d6e1
No known key found for this signature in database
GPG Key ID: 8E7DB78588BD6F5C
10 changed files with 45 additions and 73 deletions

View File

@ -665,7 +665,7 @@ open class Client<CONNECTION : Connection>(
// throws(ConnectTimedOutException::class, ClientRejectedException::class, ClientException::class)
val connectionInfo = handshake.hello(handshakeConnection, connectionTimeoutSec, uuid)
val connectionInfo = handshake.hello(handshakeConnection, connectionTimeoutSec)
// VALIDATE:: check to see if the remote connection's public key has changed!
val validateRemoteAddress = if (handshakeConnection.pubSub.isIpc) {
@ -732,9 +732,9 @@ open class Client<CONNECTION : Connection>(
val newConnection: CONNECTION
if (handshakeConnection.pubSub.isIpc) {
newConnection = connectionFunc(ConnectionParams(uuid, this, clientConnection.connectionInfo, PublicKeyValidationState.VALID))
newConnection = connectionFunc(ConnectionParams(connectionInfo.publicKey, this, clientConnection.connectionInfo, PublicKeyValidationState.VALID))
} else {
newConnection = connectionFunc(ConnectionParams(uuid, this, clientConnection.connectionInfo, validateRemoteAddress))
newConnection = connectionFunc(ConnectionParams(connectionInfo.publicKey, this, clientConnection.connectionInfo, validateRemoteAddress))
address!!
// NOTE: Client can ALWAYS connect to the server. The server makes the decision if the client can connect or not.
@ -907,7 +907,7 @@ open class Client<CONNECTION : Connection>(
}
override fun toString(): String {
return "EndPoint [Client: $uuid]"
return "EndPoint [Client: $${storage.publicKey!!.toHexString()}]"
}
fun <R> use(block: (Client<CONNECTION>) -> R): R {

View File

@ -16,25 +16,19 @@
package dorkbox.network.aeron
import dorkbox.bytes.ByteArrayWrapper
import dorkbox.collections.ConcurrentIterator
import dorkbox.network.Configuration
import dorkbox.network.connection.EndPoint
import dorkbox.util.NamedThreadFactory
import dorkbox.util.sync.CountDownLatch
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.cancel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KLogger
import mu.KotlinLogging
import org.agrona.concurrent.IdleStrategy
import java.util.*
import java.util.concurrent.*
/**
@ -65,7 +59,7 @@ internal class EventPoller {
// this is thread safe
private val pollEvents = ConcurrentIterator<Pair<suspend EventPoller.()->Int, suspend ()->Unit>>()
private val submitEvents = atomic(0)
private val configureEventsEndpoints = mutableSetOf<UUID>()
private val configureEventsEndpoints = mutableSetOf<ByteArrayWrapper>()
@Volatile
private var delayClose = false
@ -86,7 +80,7 @@ internal class EventPoller {
fun configure(logger: KLogger, config: Configuration, endPoint: EndPoint<*>) = runBlocking {
mutex.withLock {
logger.debug { "Initializing the Network Event Poller..." }
configureEventsEndpoints.add(endPoint.uuid)
configureEventsEndpoints.add(ByteArrayWrapper.wrap(endPoint.storage.publicKey)!!)
if (!configured) {
logger.trace { "Configuring the Network Event Poller..." }
@ -198,7 +192,8 @@ internal class EventPoller {
// ONLY if there are no more poll-events do we ACTUALLY shut down.
// when an endpoint closes its polling, it will automatically be removed from this datastructure.
configureEventsEndpoints.removeIf { it == endPoint.uuid }
val publicKeyWrapped = ByteArrayWrapper.wrap(endPoint.storage.publicKey)
configureEventsEndpoints.removeIf { it == publicKeyWrapped }
val cEvents = configureEventsEndpoints.size
// these prevent us from closing too early

View File

@ -75,7 +75,7 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/**
* This is the client UUID. This is useful determine if the same client is connecting multiple times to a server (instead of only using IP address)
*/
val uuid = connectionParameters.clientUuid
val uuid = connectionParameters.publicKey
/**
* The unique session id of this connection, assigned by the server.

View File

@ -16,10 +16,9 @@
package dorkbox.network.connection
import dorkbox.network.handshake.PubSub
import java.util.*
data class ConnectionParams<CONNECTION : Connection>(
val clientUuid: UUID,
val publicKey: ByteArray,
val endPoint: EndPoint<CONNECTION>,
val connectionInfo: PubSub,
val publicKeyValidation: PublicKeyValidationState

View File

@ -15,7 +15,6 @@
*/
package dorkbox.network.connection
import com.fasterxml.uuid.impl.RandomBasedGenerator
import dorkbox.collections.ConcurrentIterator
import dorkbox.netUtil.IP
import dorkbox.network.Client
@ -118,12 +117,6 @@ abstract class EndPoint<CONNECTION : Connection> private constructor(val type: C
}
}
/**
* The UUID is a unique, in-memory instance that is created on object construction
*/
val uuid = RandomBasedGenerator(CryptoManagement.secureRandom).generate()
// the ID would be different?? but the UUID would be the same??
val logger: KLogger = KotlinLogging.logger(loggerName)
// this is rather silly, BUT if there are more complex errors WITH the coroutine that occur, a regular try/catch WILL NOT catch it.

View File

@ -28,7 +28,6 @@ import io.aeron.logbuffer.Header
import kotlinx.coroutines.delay
import mu.KLogger
import org.agrona.DirectBuffer
import java.util.*
import java.util.concurrent.*
internal class ClientHandshake<CONNECTION: Connection>(
@ -88,7 +87,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length, logInfo)
val msg = handshaker.readMessage(buffer, offset, length)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) {
@ -178,7 +177,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed
suspend fun hello(handshakeConnection: ClientHandshakeDriver, handshakeTimeoutSec: Int, uuid: UUID) : ClientConnectionInfo {
suspend fun hello(handshakeConnection: ClientHandshakeDriver, handshakeTimeoutSec: Int) : ClientConnectionInfo {
val pubSub = handshakeConnection.pubSub
// is our pub still connected??
@ -190,17 +189,14 @@ internal class ClientHandshake<CONNECTION: Connection>(
reset()
connectKey = getSafeConnectKey()
val publicKey = client.storage.getPublicKey()!!
try {
// Send the one-time pad to the server.
handshaker.writeMessage(pubSub.pub, handshakeConnection.details,
HandshakeMessage.helloFromClient(
connectKey = connectKey,
publicKey = publicKey,
publicKey = client.storage.publicKey!!,
streamIdSub = pubSub.streamIdSub,
portSub = pubSub.portSub,
uuid = uuid
portSub = pubSub.portSub
))
} catch (e: Exception) {
handshakeConnection.close()

View File

@ -15,9 +15,6 @@
*/
package dorkbox.network.handshake
import dorkbox.bytes.LittleEndian
import java.util.*
/**
* Internal message to handle the connection registration process
*/
@ -54,19 +51,7 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3
const val DONE_ACK = 4
private val uuidWriter: (UUID) -> ByteArray = { uuid ->
val bytes = byteArrayOf(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0) // 16 elements
LittleEndian.Long_.toBytes(uuid.mostSignificantBits, bytes, 0)
LittleEndian.Long_.toBytes(uuid.leastSignificantBits, bytes, 8)
bytes
}
internal val uuidReader: (ByteArray) -> UUID = { bytes ->
UUID(LittleEndian.Long_.from(bytes, 0, 8),
LittleEndian.Long_.from(bytes, 8, 8))
}
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int, uuid: UUID): HandshakeMessage {
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
@ -74,7 +59,6 @@ internal class HandshakeMessage private constructor() {
hello.sessionId = 0 // not used by the server, since it connects in a different way!
hello.streamId = streamIdSub
hello.port = portSub
hello.registrationData = uuidWriter(uuid)
return hello
}

View File

@ -165,7 +165,7 @@ internal class Handshaker<CONNECTION : Connection>(
*
* @return the message
*/
internal fun readMessage(buffer: DirectBuffer, offset: Int, length: Int, logInfo: String): Any? {
internal fun readMessage(buffer: DirectBuffer, offset: Int, length: Int): Any? {
// NOTE: This ABSOLUTELY MUST be done on the same thread! This cannot be done on a new one, because the buffer could change!
return handshakeReadKryo.read(buffer, offset, length)
}

View File

@ -33,7 +33,6 @@ import net.jodah.expiringmap.ExpirationPolicy
import net.jodah.expiringmap.ExpiringMap
import java.net.Inet4Address
import java.net.InetAddress
import java.util.*
import java.util.concurrent.*
@ -210,7 +209,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
handshaker: Handshaker<CONNECTION>,
aeronDriver: AeronDriver,
handshakePublication: Publication,
clientUuid: UUID,
publicKey: ByteArray,
message: HandshakeMessage,
aeronLogInfo: String,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
@ -331,7 +330,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
connection = connectionFunc(ConnectionParams(clientUuid, server, newConnectionDriver.pubSub, PublicKeyValidationState.VALID))
connection = connectionFunc(ConnectionParams(publicKey, server, newConnectionDriver.pubSub, PublicKeyValidationState.VALID))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
// NOTE: all IPC client connections are, by default, always allowed to connect, because they are running on the same machine
@ -389,7 +388,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
server: Server<CONNECTION>,
handshaker: Handshaker<CONNECTION>,
handshakePublication: Publication,
clientUuid: UUID,
publicKey: ByteArray,
clientAddress: InetAddress,
clientAddressString: String,
isReliable: Boolean,
@ -548,7 +547,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
logger.info { "Creating new connection to $logInfo" }
}
connection = connectionFunc(ConnectionParams(clientUuid, server, newConnectionDriver.pubSub, validateRemoteAddress))
connection = connectionFunc(ConnectionParams(publicKey, server, newConnectionDriver.pubSub, validateRemoteAddress))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(connection)

View File

@ -93,7 +93,7 @@ internal object ServerHandshakePollers {
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length, logInfo)
val msg = handshaker.readMessage(buffer, offset, length)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) {
@ -138,12 +138,14 @@ internal object ServerHandshakePollers {
return@launch
}
try {
val success = handshake.processIpcHandshakeMessageServer(
server = server,
handshaker = handshaker,
aeronDriver = driver,
handshakePublication = publication,
clientUuid = HandshakeMessage.uuidReader(message.registrationData!!),
publicKey = message.publicKey!!,
message = message,
aeronLogInfo = logInfo,
connectionFunc = connectionFunc,
@ -155,6 +157,10 @@ internal object ServerHandshakePollers {
} else {
driver.closeAndDeletePublication(publication, "HANDSHAKE-IPC")
}
} catch (e: Exception) {
driver.closeAndDeletePublication(publication, "HANDSHAKE-IPC")
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] Error processing IPC handshake", e))
}
} else {
val publication = publications.remove(connectKey)
@ -259,7 +265,7 @@ internal object ServerHandshakePollers {
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length, logInfo)
val msg = handshaker.readMessage(buffer, offset, length)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) {
@ -313,7 +319,7 @@ internal object ServerHandshakePollers {
server = server,
handshaker = handshaker,
handshakePublication = publication,
clientUuid = HandshakeMessage.uuidReader(message.registrationData!!),
publicKey = message.publicKey!!,
clientAddress = clientAddress,
clientAddressString = clientAddressString,
isReliable = isReliable,