Network/src/dorkbox/network/connection/ConnectionManagerServer.kt

266 lines
11 KiB
Kotlin

package dorkbox.network.connection
import dorkbox.network.ServerConfiguration
import dorkbox.network.aeron.client.ClientRejectedException
import dorkbox.network.aeron.server.AllocationException
import dorkbox.network.aeron.server.PortAllocator
import dorkbox.network.aeron.server.RandomIdAllocator
import dorkbox.network.aeron.server.ServerException
import dorkbox.network.connection.registration.Registration
import dorkbox.network.other.NetworkUtil
import io.aeron.Image
import io.aeron.Publication
import io.aeron.logbuffer.Header
import org.agrona.DirectBuffer
import org.agrona.collections.Int2IntCounterMap
import org.slf4j.Logger
/**
* TODO: when adding a "custom" connection, it's super important to not have to worry about the sessionID (which is what we key off of)
*
* @throws IllegalArgumentException If the port range is not valid
*/
class ConnectionManagerServer<C : Connection>(logger: Logger,
config: ServerConfiguration) : ConnectionManager<C>(logger, config) {
companion object {
// this is the number of ports used per client. Depending on how a client is configured, this number can change
const val portsPerClient = 2
}
private val portAllocator: PortAllocator
private val connectionsPerIpCounts = Int2IntCounterMap(0)
// guarantee that session/stream ID's will ALWAYS be unique! (there can NEVER be a collision!)
private val sessionIdAllocator = RandomIdAllocator(EndPoint.RESERVED_SESSION_ID_LOW, EndPoint.RESERVED_SESSION_ID_HIGH)
private val streamIdAllocator = RandomIdAllocator(1, Integer.MAX_VALUE)
init {
val minPort = config.clientStartPort
val maxPortCount = portsPerClient * config.maxClientCount
portAllocator = PortAllocator(minPort, maxPortCount)
}
@Throws(ServerException::class)
suspend fun receiveHandshakeMessageServer(handshakePublication: Publication,
buffer: DirectBuffer, offset: Int, length: Int, header: Header,
endPoint: EndPoint<C>) {
// TODO: notify error callbacks if there is an exception!
// The sessionId is unique within a Subscription and unique across all Publication's from a sourceIdentity.
// For the handshake, the sessionId IS NOT GLOBALLY UNIQUE
val sessionId = header.sessionId()
// note: this address will ALWAYS be an IP:PORT combo
val remoteIpAndPort = (header.context() as Image).sourceIdentity()
// split
val splitPoint = remoteIpAndPort.lastIndexOf(':')
val clientAddressString = remoteIpAndPort.substring(0, splitPoint)
// val port = remoteIpAndPort.substring(splitPoint+1)
val clientAddress = NetworkUtil.IP.toInt(clientAddressString)
config as ServerConfiguration
val message = endPoint.readHandshakeMessage(buffer, offset, length, header)
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (message !is Registration) {
endPoint.writeHandshakeMessage(handshakePublication, Registration.error("Invalid connection request"))
return
}
try {
// VALIDATE:: Check to see if there are already too many clients connected.
if (connectionCount() >= config.maxClientCount) {
logger.debug("server is full")
endPoint.writeHandshakeMessage(handshakePublication, Registration.error("server full. Max allowed is ${config.maxClientCount}"))
return
}
// VALIDATE:: check to see if the remote connection's public key has changed!
if (!endPoint.validateRemoteAddress(clientAddress, message.publicKey)) {
// TODO: this should provide info to a callback
println("connection not allowed! public key mismatch")
return
}
// VALIDATE:: make sure the serialization matches between the client/server!
if (!config.serialization.verifyKryoRegistration(message.registrationData!!)) {
// TODO: this should provide info to a callback
println("connection not allowed! registration data mismatch")
return
}
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.getAndIncrement(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
// decrement it now, since we aren't going to permit this connection (take the hit on failure, instead
connectionsPerIpCounts.getAndDecrement(clientAddress)
logger.debug("too many connections for IP address")
endPoint.writeHandshakeMessage(handshakePublication, Registration.error("too many connections for IP address. Max allowed is ${config.maxConnectionsPerIpAddress}"))
return
}
} catch (e: Exception) {
logger.error("could not validate client message: ", e)
}
// VALIDATE:: TODO: ?? check to see if this session is ALREADY connected??. It should not be!
/////
/////
///// DONE WITH VALIDATION
/////
/////
// allocate ports for the client
val connectionPorts: IntArray
try {
// throws exception if this is not possible
connectionPorts = portAllocator.allocate(portsPerClient)
} catch (e: IllegalArgumentException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
logger.error("Unable to allocate $portsPerClient ports for client connection!")
return
}
// allocate session/stream id's
val connectionSessionId: Int
try {
connectionSessionId = sessionIdAllocator.allocate()
} catch (e: AllocationException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
logger.error("Unable to allocate a session ID for the client connection!")
return
}
val connectionStreamId: Int
try {
connectionStreamId = streamIdAllocator.allocate()
} catch (e: AllocationException) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId)
logger.error("Unable to allocate a stream ID for the client connection!")
return
}
val serverAddress = config.listenIpAddress // TODO :: my IP address?? this should be the IP of the box?
val subscriptionPort = connectionPorts[0]
val publicationPort = connectionPorts[1]
// create a new connection. The session ID is encrypted.
try {
// connection timeout of 0 doesn't matter. it is not used by the server
val clientConnection = UdpMediaDriverConnection(
serverAddress, subscriptionPort, publicationPort,
connectionStreamId, connectionSessionId, 0, message.isReliable)
val connection: Connection = endPoint.newConnection(endPoint, clientConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
@Suppress("UNCHECKED_CAST")
val permitConnection = notifyFilter(connection as C)
if (!permitConnection) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
logger.error("Error creating new duologue")
notifyError(connection, ClientRejectedException("Connection was not permitted!"))
return
}
logger.info("Client connected [$clientAddressString:$subscriptionPort|$publicationPort] (session: $sessionId")
logger.debug("[{}] created new client connection", connectionSessionId)
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = Registration.helloAck(message.oneTimePad xor connectionSessionId)
successMessage.sessionId = sessionId // has to be the same as before (the client expects this)
successMessage.streamId = message.oneTimePad xor connectionStreamId
successMessage.subscriptionPort = subscriptionPort
successMessage.publicationPort = publicationPort
successMessage.publicKey = config.settingsStore.getPublicKey()
endPoint.writeHandshakeMessage(handshakePublication, successMessage)
addConnection(connection)
notifyConnect(connection)
} catch (e: Exception) {
// have to unwind actions!
connectionsPerIpCounts.getAndDecrement(clientAddress)
portAllocator.free(connectionPorts)
sessionIdAllocator.free(connectionSessionId)
streamIdAllocator.free(connectionStreamId)
logger.error("Error creating new duologue")
logger.error("could not process client message: $message")
notifyError(e)
}
}
suspend fun poll(): Int {
// Get the current time, used to cleanup connections
val now = System.currentTimeMillis()
var pollCount = 0
forEachConnectionCleanup({ connection ->
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
var shouldCleanupConnection = false
if (connection.isExpired(now)) {
logger.debug("[{}] connection expired", connection.sessionId)
shouldCleanupConnection = true
}
if (connection.isClosed()) {
logger.debug("[{}] connection closed", connection.sessionId)
shouldCleanupConnection = true
}
if (shouldCleanupConnection) {
true
}
else {
// Otherwise, poll the duologue for activity.
pollCount += connection.pollSubscriptions()
false
}
}, { connectionToClean ->
logger.debug("[{}] deleted connection", connectionToClean.sessionId)
removeConnection(connectionToClean)
// have to free up resources!
connectionsPerIpCounts.getAndDecrement(connectionToClean.remoteAddressInt)
portAllocator.free(connectionToClean.subscriptionPort)
portAllocator.free(connectionToClean.publicationPort)
sessionIdAllocator.free(connectionToClean.sessionId)
streamIdAllocator.free(connectionToClean.streamId)
notifyDisconnect(connectionToClean)
})
return pollCount
}
}