Worked on RMI, moved aeron IO, tweaks to connection management

This commit is contained in:
nathan 2020-08-15 13:20:31 +02:00
parent d9ab3f7247
commit 9eb3c122d7
10 changed files with 288 additions and 890 deletions

View File

@ -22,6 +22,7 @@ import dorkbox.network.connection.ping.PingMessage
import dorkbox.network.rmi.RemoteObject import dorkbox.network.rmi.RemoteObject
import dorkbox.network.rmi.RemoteObjectStorage import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.TimeoutException import dorkbox.network.rmi.TimeoutException
import dorkbox.network.serialization.KryoExtra
import dorkbox.util.classes.ClassHelper import dorkbox.util.classes.ClassHelper
import io.aeron.FragmentAssembler import io.aeron.FragmentAssembler
import io.aeron.Publication import io.aeron.Publication
@ -29,6 +30,7 @@ import io.aeron.Subscription
import io.aeron.logbuffer.FragmentHandler import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.agrona.BitUtil import org.agrona.BitUtil
@ -43,15 +45,31 @@ import javax.crypto.SecretKey
/** /**
* This connection is established once the registration information is validated, and the various connect/filter checks have passed * This connection is established once the registration information is validated, and the various connect/filter checks have passed
*/ */
open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDriverConnection) : AutoCloseable { open class Connection(connectionParameters: ConnectionParams<*>) {
private val subscription: Subscription private val subscription: Subscription
private val publication: Publication private val publication: Publication
/** /**
* The publication port (used by aeron) for this connection. This is from the perspective of the server! * The publication port (used by aeron) for this connection. This is from the perspective of the server!
*/ */
val subscriptionPort: Int internal val subscriptionPort: Int
val publicationPort: Int internal val publicationPort: Int
/**
* the stream id of this connection.
*/
internal val streamId: Int
/**
* the session id of this connection. This value is UNIQUE
*/
internal val sessionId: Int
/**
* the id of this connection. This value is UNIQUE
*/
val id: Int
get() = sessionId
/** /**
* the remote address, as a string. * the remote address, as a string.
@ -63,15 +81,10 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
*/ */
val remoteAddressInt: Int val remoteAddressInt: Int
/**
* the stream id of this connection.
*/
val streamId: Int
/**
* the session id of this connection. This value is UNIQUE
*/
val sessionId: Int
/** /**
* @return true if this connection is established on the loopback interface * @return true if this connection is established on the loopback interface
@ -105,6 +118,9 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
return pingFuture2?.response ?: -1 return pingFuture2?.response ?: -1
} }
private val endPoint = connectionParameters.endPoint
private val listenerManager = atomic<ListenerManager<Connection>?>(null)
private val serialization = endPoint.config.serialization private val serialization = endPoint.config.serialization
private val sendIdleStrategy = endPoint.config.sendIdleStrategy.clone() private val sendIdleStrategy = endPoint.config.sendIdleStrategy.clone()
@ -132,7 +148,7 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
// private var localListenerManager: ConnectionManager<*>? = null // private var localListenerManager: ConnectionManager<*>? = null
// while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error. // while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error.
private var remoteKeyChanged = false private var remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED
// The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 8 (external counter) + 4 (GCM counter) // The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 8 (external counter) + 4 (GCM counter)
// The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this // The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this
@ -152,6 +168,8 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
init { init {
val mediaDriverConnection = connectionParameters.mediaDriverConnection
// we have to construct how the connection will communicate! // we have to construct how the connection will communicate!
if (endPoint is Server<*>) { if (endPoint is Server<*>) {
mediaDriverConnection.buildServer(endPoint.aeron) mediaDriverConnection.buildServer(endPoint.aeron)
@ -161,7 +179,9 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
} }
} }
logger.debug("creating new connection $mediaDriverConnection") logger.trace {
"Creating new connection $mediaDriverConnection"
}
// can only get this AFTER we have built the sub/pub // can only get this AFTER we have built the sub/pub
subscription = mediaDriverConnection.subscription subscription = mediaDriverConnection.subscription
@ -172,7 +192,7 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
remoteAddress = mediaDriverConnection.address remoteAddress = mediaDriverConnection.address
remoteAddressInt = IPv4.toInt(remoteAddress) remoteAddressInt = IPv4.toInt(remoteAddress)
streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server! streamId = mediaDriverConnection.streamId // NOTE: this is UNIQUE per server!
sessionId = mediaDriverConnection.sessionId sessionId = mediaDriverConnection.sessionId // NOTE: this is UNIQUE per server!
messageHandler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header -> messageHandler = FragmentAssembler(FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// small problem... If we expect IN ORDER messages (ie: setting a value, then later reading the value), multiple threads // small problem... If we expect IN ORDER messages (ie: setting a value, then later reading the value), multiple threads
@ -277,34 +297,6 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
} }
} }
/**
* Safely sends objects to a destination with the specified priority.
*
*
* A priority of 255 (highest) will always be sent immediately.
*
*
* A priority of 0-254 will be sent (0, the lowest, will be last) if there is no backpressure from the MediaDriver.
*/
suspend fun send(message: Any, priority: Byte) {
TODO("SEND PRIO NOT IMPL YET")
}
/** /**
* Updates the ping times for this connection (called when this connection gets a REPLY ping message). * Updates the ping times for this connection (called when this connection gets a REPLY ping message).
*/ */
@ -381,68 +373,16 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
/** /**
* Closes the connection, and removes all connection specific listeners * Closes the connection, and removes all connection specific listeners
*/ */
override fun close() { internal suspend fun close() {
if (isClosed.compareAndSet(expect = false, update = true)) { if (isClosed.compareAndSet(expect = false, update = true)) {
subscription.close() subscription.close()
publication.close() publication.close()
// a connection might have also registered for disconnect events
notifyDisconnect()
} }
// only close if we aren't already in the middle of closing.
// if (closeInProgress.compareAndSet(false, true)) {
// val idleTimeoutMs = 2000
//
// // if we are in the middle of a message, hold off.
//// synchronized(messageInProgressLock) {
//// // while loop is to prevent spurious wakeups!
//// while (messageInProgress.get()) {
//// try {
////// messageInProgressLock.wait(idleTimeoutMs.toLong())
//// } catch (ignored: InterruptedException) {
//// }
//// }
//// }
//
//
// // close out the ping future
// val pingFuture2 = pingFuture
// pingFuture2?.cancel()
// pingFuture = null
//
//// synchronized(channelIsClosed) {
//// if (!channelIsClosed.get()) {
//// // this will have netty call "channelInactive()"
////// channelWrapper.close(this, sessionManager, false)
////
//// // want to wait for the "channelInactive()" method to FINISH ALL TYPES before allowing our current thread to continue!
//// try {
//// closeLatch!!.await(idleTimeoutMs.toLong(), TimeUnit.MILLISECONDS)
//// } catch (ignored: InterruptedException) {
//// }
//// }
//// }
//
// // remove all listeners AFTER we close the channel.
// if (!keepListeners) {
// removeAll()
// }
//
//
// // remove all RMI listeners
// rmiSupport!!.close()
// }
// remove all listeners AFTER we close the channel.
// if (!keepListeners) {
// removeAll()
// }
// remove all RMI listeners
// rmiSupport.close() // TODO
} }
/** /**
* Adds a function that will be called when a client/server "disconnects" with * Adds a function that will be called when a client/server "disconnects" with
* each other * each other
@ -453,160 +393,39 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
* (via connection.addListener), meaning that ONLY that listener attached to * (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners) * the connection is notified on that event (ie, admin type listeners)
*/ */
fun onDisconnect(function: (C: Connection) -> Unit): Int { suspend fun onDisconnect(function: suspend (Connection) -> Unit) {
TODO("Not yet implemented") // make sure we atomically create the listener manager, if necessary
} listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager(logger)
/**
* Adds a function that will be called when a client/server encounters an error
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
fun onError(function: (C: Connection, throwable: Throwable) -> Unit): Int {
TODO("Not yet implemented")
}
/**
* Adds a function that will be called when a client/server receives a message
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
fun <M : Any> onMessage(function: (C: Connection, M) -> Unit): Int {
TODO("Not yet implemented")
}
/**
* Adds a listener to this connection/endpoint to be notified of
* connect/disconnect/idle/receive(object) events.
*
*
* If the listener already exists, it is not added again.
*
*
* When called by a server, NORMALLY listeners are added at the GLOBAL level
* (meaning, I add one listener, and ALL connections are notified of that
* listener.
*
*
* It is POSSIBLE to add a server connection ONLY (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
// override fun add(listener: OnConnected<Connection>): Listeners<Connection> {
// if (endPoint is EndPointServer) {
// // when we are a server, NORMALLY listeners are added at the GLOBAL level
// // meaning --
// // I add one listener, and ALL connections are notified of that listener.
// //
// // HOWEVER, it is also POSSIBLE to add a local listener (via connection.addListener), meaning that ONLY
// // that listener is notified on that event (ie, admin type listeners)
//
// // synchronized because this should be VERY uncommon, and we want to make sure that when the manager
// // is empty, we can remove it from this connection.
//// synchronized(this) {
//// if (localListenerManager == null) {
//// localListenerManager = endPoint.addListenerManager(this)
//// }
//// localListenerManager!!.add(listener)
//// }
// } else {
//// endPoint.listeners()
//// .add(listener)
// }
// return this
// }
/**
* Removes a listener from this connection/endpoint to NO LONGER be notified
* of connect/disconnect/idle/receive(object) events.
*
*
* When called by a server, NORMALLY listeners are added at the GLOBAL level
* (meaning, I add one listener, and ALL connections are notified of that
* listener.
*
*
* It is POSSIBLE to remove a server-connection 'non-global' listener (via
* connection.removeListener), meaning that ONLY that listener attached to
* the connection is removed
*/
fun removeListener(listenerId: Int) {
if (endPoint is Server<*>) {
// when we are a server, NORMALLY listeners are added at the GLOBAL level
// meaning --
// I add one listener, and ALL connections are notified of that listener.
//
// HOWEVER, it is also POSSIBLE to add a local listener (via connection.addListener), meaning that ONLY
// that listener is notified on that event (ie, admin type listeners)
// synchronized because this should be uncommon, and we want to make sure that when the manager
// is empty, we can remove it from this connection.
// synchronized(this) {
// val local = localListenerManager
// if (local != null) {
// local.remove(listener)
// if (!local.hasListeners()) {
// endPoint.removeListenerManager(this)
// }
// }
// }
} else {
// endPoint.listeners()
// .remove(listener)
} }
TODO("Not yet implemented") listenerManager.value!!.onDisconnect(function)
} }
/** /**
* Removes all registered listeners from this connection/endpoint to NO * Adds a function that will be called only for this connection, when a client/server receives a message
* LONGER be notified of connect/disconnect/idle/receive(object) events.
*
* This includes all proxy listeners
*/ */
fun removeAllListeners() { suspend fun <MESSAGE> onMessage(function: suspend (Connection, MESSAGE) -> Unit) {
// rmiSupport.removeAllListeners() // TODO // make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
if (endPoint is Server<*>) { origManager ?: ListenerManager(logger)
// when we are a server, NORMALLY listeners are added at the GLOBAL level
// meaning --
// I add one listener, and ALL connections are notified of that listener.
//
// HOWEVER, it is also POSSIBLE to add a local listener (via connection.addListener), meaning that ONLY
// that listener is notified on that event (ie, admin type listeners)
// synchronized because this should be uncommon, and we want to make sure that when the manager
// is empty, we can remove it from this connection.
// synchronized(this) {
// if (localListenerManager != null) {
// localListenerManager?.removeAll()
// localListenerManager = null
// endPoint.removeListenerManager(this)
// }
// }
} else {
// endPoint.listeners()
// .removeAll()
} }
TODO("Not yet implemented") listenerManager.value!!.onMessage(function)
}
/**
* Invoked when a connection is disconnected from the remote endpoint
*/
internal suspend fun notifyDisconnect() {
listenerManager.value?.notifyDisconnect(this)
}
/**
* Invoked when a message object was received from a remote peer.
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false
} }
@ -639,7 +458,7 @@ open class Connection(val endPoint: EndPoint<*>, mediaDriverConnection: MediaDri
} }
// RMI notes (in multiple places, copypasta, because this is confusing if not written down // RMI notes (in multiple places, copypasta, because this is confusing if not written down)
// //
// only server can create a global object (in itself, via save) // only server can create a global object (in itself, via save)
// server // server

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2010 dorkbox, llc * Copyright 2020 dorkbox, llc
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,32 +16,25 @@
package dorkbox.network.connection package dorkbox.network.connection
import dorkbox.network.Configuration import dorkbox.network.Configuration
import kotlinx.coroutines.runBlocking import dorkbox.util.collections.ConcurrentEntry
import dorkbox.util.collections.ConcurrentIterator
import dorkbox.util.collections.ConcurrentIterator.headREF
import mu.KLogger import mu.KLogger
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
// Because all of our callbacks are in response to network communication, and there CANNOT be CPU race conditions over a network...
// we specifically use atomic references to set/get all of the callbacks. This ensures that these objects are visible when accessed
// from different coroutines (because, ultimately, we want to use multiple threads on the box for processing data, and if we use
// coroutines, we can ensure maximum thread output)
// .equals() compares the identity on purpose,this because we cannot create two separate objects that are somehow equal to each other. // .equals() compares the identity on purpose,this because we cannot create two separate objects that are somehow equal to each other.
internal open class ConnectionManager<CONNECTION: Connection>(val logger: KLogger, val config: Configuration, val listenerManager: ListenerManager<CONNECTION>) : AutoCloseable { @Suppress("UNCHECKED_CAST")
internal open class ConnectionManager<CONNECTION: Connection>(val logger: KLogger, val config: Configuration) {
private val connections = ConcurrentIterator<CONNECTION>()
private val connectionLock = ReentrantReadWriteLock()
private val connections = mutableListOf<CONNECTION>()
/** /**
* Invoked when aeron successfully connects to a remote address. * Invoked when aeron successfully connects to a remote address.
* *
* @param connection the connection to add * @param connection the connection to add
*/ */
fun addConnection(connection: CONNECTION) { fun add(connection: CONNECTION) {
connectionLock.write { connections.add(connection)
connections.add(connection)
}
} }
/** /**
@ -53,149 +46,69 @@ internal open class ConnectionManager<CONNECTION: Connection>(val logger: KLogge
* *
* @param connection the connection to remove * @param connection the connection to remove
*/ */
fun removeConnection(connection: CONNECTION) { fun remove(connection: CONNECTION) {
connectionLock.write { connections.remove(connection)
connections.remove(connection)
}
} }
/** /**
* Performs an action on each connection in the list inside a read lock * Performs an action on each connection in the list.
*/ */
suspend fun forEachConnectionDoRead(function: suspend (connection: CONNECTION) -> Unit) { inline fun forEach(function: (connection: CONNECTION) -> Unit) {
connectionLock.read { // access a snapshot (single-writer-principle)
connections.forEach { val head = headREF.get(connections) as ConcurrentEntry<CONNECTION>?
function(it) var current: ConcurrentEntry<CONNECTION>? = head
}
var connection: CONNECTION
while (current != null) {
// Concurrent iteration...
connection = current.value
current = current.next()
function(connection)
} }
} }
/** /**
* Performs an action on each connection in the list. * Performs an action on each connection in the list.
*/ */
private val connectionsToRemove = mutableListOf<CONNECTION>() internal inline fun forEachWithCleanup(function: (connection: CONNECTION) -> Boolean,
internal suspend fun forEachConnectionCleanup(function: suspend (connection: CONNECTION) -> Boolean, cleanup: (connection: CONNECTION) -> Unit) {
cleanup: suspend (connection: CONNECTION) -> Unit) {
connectionLock.write {
connections.forEach {
if (function(it)) {
try {
it.close()
} finally {
connectionsToRemove.add(it)
}
}
}
if (connectionsToRemove.size > 0) { val head = headREF.get(connections) as ConcurrentEntry<CONNECTION>?
connectionsToRemove.forEach { var current: ConcurrentEntry<CONNECTION>? = head
cleanup(it)
} var connection: CONNECTION
connectionsToRemove.clear() while (current != null) {
connection = current.value
current = current.next()
function(connection)
if (function(connection)) {
// Concurrent iteration...
connections.remove(connection)
cleanup(connection)
} }
} }
} }
fun connectionCount(): Int { fun connectionCount(): Int {
return connections.size return connections.size()
} }
// fun addListenerManager(connection: C): ConnectionManager<C> {
// // when we are a server, NORMALLY listeners are added at the GLOBAL level (meaning, I add one listener, and ALL connections
// // are notified of that listener.
//
// // it is POSSIBLE to add a connection-specific listener (via connection.addListener), meaning that ONLY
// // that listener is notified on that event (ie, admin type listeners)
// var created = false
// var manager = localManagers[connection]
// if (manager == null) {
// created = true
// manager = ConnectionManager<C>("$loggerName-$connection Specific", actionDispatchScope)
// localManagers.put(connection, manager)
// }
// if (created) {
// val logger2 = logger
// if (logger2.isTraceEnabled) {
// logger2.trace("Connection specific Listener Manager added for connection: {}", connection)
// }
// }
// return manager
// }
// fun removeListenerManager(connection: C) {
// var wasRemoved = false
// val removed = localManagers.remove(connection)
// if (removed != null) {
// wasRemoved = true
// }
// if (wasRemoved) {
// val logger2 = logger
// if (logger2.isTraceEnabled) {
// logger2.trace("Connection specific Listener Manager removed for connection: {}", connection)
// }
// }
// }
/** /**
* Closes all associated resources/threads/connections * Closes all associated resources/threads/connections
*/ */
override fun close() { fun close() {
connectionLock.write { connections.clear()
// runBlocking because we don't want to progress until we are 100% done closing all connections
runBlocking {
// don't need anything fast or fancy here, because this method will only be called once
connections.forEach {
it.close()
}
connections.forEach {
listenerManager.notifyDisconnect(it)
}
connections.clear()
}
}
} }
/**
* Exposes methods to send the object to all server connections (except the specified one) over the network. (or via LOCAL when it's a
* local channel).
*/
// @Override
// public
// ConnectionExceptSpecifiedBridgeServer except() {
// return this;
// }
// /**
// * Sends the message to other listeners INSIDE this endpoint for EVERY connection. It does not send it to a remote address.
// */
// @Override
// public
// ConnectionPoint self(final Object message) {
// ConcurrentEntry<ConnectionImpl> current = connectionsREF.get(this);
// ConnectionImpl c;
// while (current != null) {
// c = current.getValue();
// current = current.next();
//
// onMessage(c, message);
// }
// return this;
// }
/** /**
* Safely sends objects to a destination (such as a custom object or a standard ping). This will automatically choose which protocol * Safely sends objects to a destination (such as a custom object or a standard ping). This will automatically choose which protocol
* is available to use. * is available to use.
*/ */
suspend fun send(message: Any) { suspend inline fun send(message: Any) {
TODO("NOT IMPL YET. going to use aeron for this functionality since it's a lot faster") forEach {
// TODO: USE AERON add.dataPublisher thingy, so it's areon pushing messages out (way, WAY faster than if we are to iterate over it.send(message)
// the connections }
// for (connection in connections) {
// connection.send(message)
// }
} }
} }

View File

@ -0,0 +1,4 @@
package dorkbox.network.connection
data class ConnectionParams<C: Connection>(val endPoint: EndPoint<C>, val mediaDriverConnection: MediaDriverConnection,
val publicKeyValidation: PublicKeyValidationState)

View File

@ -6,13 +6,20 @@ import dorkbox.netUtil.IPv4
import dorkbox.network.Configuration import dorkbox.network.Configuration
import dorkbox.network.handshake.ClientConnectionInfo import dorkbox.network.handshake.ClientConnectionInfo
import dorkbox.network.other.CryptoEccNative import dorkbox.network.other.CryptoEccNative
import dorkbox.network.pipeline.AeronInput import dorkbox.network.serialization.AeronInput
import dorkbox.network.pipeline.AeronOutput import dorkbox.network.serialization.AeronOutput
import dorkbox.network.store.SettingsStore import dorkbox.network.storage.SettingsStore
import dorkbox.util.Sys
import dorkbox.util.entropy.Entropy import dorkbox.util.entropy.Entropy
import dorkbox.util.exceptions.SecurityException import dorkbox.util.exceptions.SecurityException
import mu.KLogger import mu.KLogger
import java.security.* import java.security.KeyFactory
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.MessageDigest
import java.security.PrivateKey
import java.security.PublicKey
import java.security.SecureRandom
import java.security.interfaces.XECPrivateKey import java.security.interfaces.XECPrivateKey
import java.security.interfaces.XECPublicKey import java.security.interfaces.XECPublicKey
import java.security.spec.NamedParameterSpec import java.security.spec.NamedParameterSpec
@ -24,7 +31,6 @@ import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec
/** /**
* Management for all of the crypto stuff used * Management for all of the crypto stuff used
*/ */
@ -32,31 +38,30 @@ internal class CryptoManagement(val logger: KLogger,
private val settingsStore: SettingsStore, private val settingsStore: SettingsStore,
type: Class<*>, type: Class<*>,
config: Configuration) { config: Configuration) {
private val keyFactory = KeyFactory.getInstance("X25519") private val keyFactory = KeyFactory.getInstance("X25519")
private val keyAgreement = KeyAgreement.getInstance("XDH") private val keyAgreement = KeyAgreement.getInstance("XDH")
private val aesCipher = Cipher.getInstance("AES/GCM/PKCS5Padding") private val aesCipher = Cipher.getInstance("AES/GCM/PKCS5Padding")
private val hash = MessageDigest.getInstance("SHA-256"); private val hash = MessageDigest.getInstance("SHA-256");
companion object { companion object {
val AES_KEY_SIZE = 256 const val AES_KEY_SIZE = 256
val GCM_IV_LENGTH = 12 const val GCM_IV_LENGTH = 12
val GCM_TAG_LENGTH = 16 const val GCM_TAG_LENGTH = 16
} }
val privateKey: PrivateKey val privateKey: PrivateKey
val publicKey: PublicKey val publicKey: PublicKey
val publicKeyBytes: ByteArray val publicKeyBytes: ByteArray
val secureRandom: SecureRandom val secureRandom = SecureRandom(settingsStore.getSalt())
val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation private val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation
var disableRemoteKeyValidation = false
init { init {
secureRandom = SecureRandom(settingsStore.getSalt()) if (!enableRemoteSignatureValidation) {
logger.warn("WARNING: Disabling remote key validation is a security risk!!")
}
// initialize the private/public keys used for negotiating ECC handshakes // initialize the private/public keys used for negotiating ECC handshakes
// these are ONLY used for IP connections. LOCAL connections do not need a handshake! // these are ONLY used for IP connections. LOCAL connections do not need a handshake!
@ -77,8 +82,6 @@ internal class CryptoManagement(val logger: KLogger,
// save to properties file // save to properties file
settingsStore.savePrivateKey(privateKeyBytes) settingsStore.savePrivateKey(privateKeyBytes)
settingsStore.savePublicKey(publicKeyBytes) settingsStore.savePublicKey(publicKeyBytes)
logger.debug("Done with ECC keys!")
} catch (e: Exception) { } catch (e: Exception) {
val message = "Unable to initialize/generate ECC keys. FORCED SHUTDOWN." val message = "Unable to initialize/generate ECC keys. FORCED SHUTDOWN."
logger.error(message, e) logger.error(message, e)
@ -86,86 +89,62 @@ internal class CryptoManagement(val logger: KLogger,
} }
} }
logger.info("ECC public key: ${Sys.bytesToHex(publicKeyBytes)}")
this.privateKey = keyFactory.generatePrivate(PKCS8EncodedKeySpec(privateKeyBytes)) as XECPrivateKey this.privateKey = keyFactory.generatePrivate(PKCS8EncodedKeySpec(privateKeyBytes)) as XECPrivateKey
this.publicKey = keyFactory.generatePublic(X509EncodedKeySpec(publicKeyBytes)) as XECPublicKey this.publicKey = keyFactory.generatePublic(X509EncodedKeySpec(publicKeyBytes)) as XECPublicKey
this.publicKeyBytes = publicKeyBytes!! this.publicKeyBytes = publicKeyBytes!!
} }
fun createKeyPair(secureRandom: SecureRandom): KeyPair { private fun createKeyPair(secureRandom: SecureRandom): KeyPair {
val kpg: KeyPairGenerator = KeyPairGenerator.getInstance("XDH") val kpg: KeyPairGenerator = KeyPairGenerator.getInstance("XDH")
kpg.initialize(NamedParameterSpec.X25519, secureRandom) kpg.initialize(NamedParameterSpec.X25519, secureRandom)
return kpg.generateKeyPair() return kpg.generateKeyPair()
} }
/** /**
* If the key does not match AND we have disabled remote key validation, then metachannel.changedRemoteKey = true. OTHERWISE, key validation is REQUIRED! * If the key does not match AND we have disabled remote key validation, then metachannel.changedRemoteKey = true. OTHERWISE, key validation is REQUIRED!
* *
* @return true if the remote address public key matches the one saved or we disabled remote key validation. * @return true if all is OK (the remote address public key matches the one saved or we disabled remote key validation.)
* false if we should abort
*/ */
internal fun validateRemoteAddress(remoteAddress: Int, publicKey: ByteArray?): Boolean { internal fun validateRemoteAddress(remoteAddress: Int, publicKey: ByteArray?): PublicKeyValidationState {
if (publicKey == null) { if (publicKey == null) {
logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}! It was null (and should not have been)") logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}! It was null (and should not have been)")
return false return PublicKeyValidationState.INVALID
} }
try { try {
val savedPublicKey = settingsStore.getRegisteredServerKey(remoteAddress) val savedPublicKey = settingsStore.getRegisteredServerKey(remoteAddress)
if (savedPublicKey == null) { if (savedPublicKey == null) {
if (logger.isDebugEnabled) { logger.info("Adding new remote IP address key for ${IPv4.toString(remoteAddress)} : ${Sys.bytesToHex(publicKey)}")
logger.debug("Adding new remote IP address key for ${IPv4.toString(remoteAddress)}")
}
settingsStore.addRegisteredServerKey(remoteAddress, publicKey) settingsStore.addRegisteredServerKey(remoteAddress, publicKey)
} else { } else {
// COMPARE! // COMPARE!
if (!publicKey.contentEquals(savedPublicKey)) { if (!publicKey.contentEquals(savedPublicKey)) {
return if (!enableRemoteSignatureValidation) { return if (enableRemoteSignatureValidation) {
logger.warn("Invalid or non-matching public key from remote connection, their public key has changed. Toggling extra flag in channel to indicate key change. To fix, remove entry for: ${IPv4.toString(remoteAddress)}")
true
} else {
// keys do not match, abort! // keys do not match, abort!
logger.error("Invalid or non-matching public key from remote connection, their public key has changed. To fix, remove entry for: ${IPv4.toString(remoteAddress)}") logger.error("The public key for remote connection ${IPv4.toString(remoteAddress)} does not match. Denying connection attempt")
false PublicKeyValidationState.INVALID
}
else {
logger.warn("The public key for remote connection ${IPv4.toString(remoteAddress)} does not match. Permitting connection attempt.")
PublicKeyValidationState.TAMPERED
} }
} }
} }
} catch (e: SecurityException) { } catch (e: SecurityException) {
// keys do not match, abort! // keys do not match, abort!
logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}!", e) logger.error("Error validating public key for ${IPv4.toString(remoteAddress)}!", e)
return false return PublicKeyValidationState.INVALID
} }
return true return PublicKeyValidationState.VALID
} }
override fun hashCode(): Int {
val prime = 31
var result = 1
result = prime * result + publicKeyBytes.hashCode()
return result
}
override fun equals(other: Any?): Boolean {
if (this === other) {
return true
}
if (other == null) {
return false
}
if (javaClass != other.javaClass) {
return false
}
val other1 = other as CryptoManagement
if (!privateKey.encoded!!.contentEquals(other1.privateKey.encoded)) {
return false
}
if (!publicKeyBytes.contentEquals(other1.publicKeyBytes)) {
return false
}
return true
}
fun encrypt(publicationPort: Int, fun encrypt(publicationPort: Int,
subscriptionPort: Int, subscriptionPort: Int,
@ -246,9 +225,37 @@ internal class CryptoManagement(val logger: KLogger,
// now read data off // now read data off
return ClientConnectionInfo(sessionId = data.readInt(), return ClientConnectionInfo(sessionId = data.readInt(),
streamId = data.readInt(), streamId = data.readInt(),
// NOTE: pub/sub must be switched!
subscriptionPort = data.readInt(),
publicationPort = data.readInt(), publicationPort = data.readInt(),
publicKey = publicKeyBytes) subscriptionPort = data.readInt(),
publicKey = serverPublicKeyBytes)
}
override fun hashCode(): Int {
val prime = 31
var result = 1
result = prime * result + publicKeyBytes.hashCode()
return result
}
override fun equals(other: Any?): Boolean {
if (this === other) {
return true
}
if (other == null) {
return false
}
if (javaClass != other.javaClass) {
return false
}
val other1 = other as CryptoManagement
if (!privateKey.encoded!!.contentEquals(other1.privateKey.encoded)) {
return false
}
if (!publicKeyBytes.contentEquals(other1.publicKeyBytes)) {
return false
}
return true
} }
} }

View File

@ -25,9 +25,9 @@ import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.rmi.RmiSupport import dorkbox.network.rmi.RmiSupport
import dorkbox.network.rmi.RmiSupportConnection import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.rmi.messages.RmiMessage import dorkbox.network.rmi.messages.RmiMessage
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.NetworkSerializationManager import dorkbox.network.serialization.NetworkSerializationManager
import dorkbox.network.store.SettingsStore import dorkbox.network.storage.SettingsStore
import dorkbox.os.OS
import dorkbox.util.NamedThreadFactory import dorkbox.util.NamedThreadFactory
import dorkbox.util.exceptions.SecurityException import dorkbox.util.exceptions.SecurityException
import io.aeron.Aeron import io.aeron.Aeron
@ -35,7 +35,12 @@ import io.aeron.Publication
import io.aeron.driver.MediaDriver import io.aeron.driver.MediaDriver
import io.aeron.logbuffer.Header import io.aeron.logbuffer.Header
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.coroutines.* import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.actor
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import mu.KLogger import mu.KLogger
import mu.KotlinLogging import mu.KotlinLogging
import org.agrona.DirectBuffer import org.agrona.DirectBuffer
@ -43,6 +48,7 @@ import java.io.File
import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
// If TCP and UDP both fill the pipe, THERE WILL BE FRAGMENTATION and dropped UDP packets! // If TCP and UDP both fill the pipe, THERE WILL BE FRAGMENTATION and dropped UDP packets!
// it results in severe UDP packet loss and contention. // it results in severe UDP packet loss and contention.
// //
@ -62,6 +68,19 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
protected constructor(config: Configuration) : this(Client::class.java, config) protected constructor(config: Configuration) : this(Client::class.java, config)
protected constructor(config: ServerConfiguration) : this(Server::class.java, config) protected constructor(config: ServerConfiguration) : this(Server::class.java, config)
fun CoroutineScope.connectionActor() = actor<ActorMessage<CONNECTION>> {
var counter = 0
for (message in channel) {
when(message) {
is ActorMessage.AddConnection -> println("add")
is ActorMessage.RemoveConnection -> println("del")
// is ActorMessage.GetValue -> message.deferred.complete(counter)
}
}
}
companion object { companion object {
/** /**
* Identifier for invalid sessions. This must be < RESERVED_SESSION_ID_LOW * Identifier for invalid sessions. This must be < RESERVED_SESSION_ID_LOW
@ -82,10 +101,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
const val IPC_HANDSHAKE_STREAM_ID_PUB: Int = 0x1337c0de const val IPC_HANDSHAKE_STREAM_ID_PUB: Int = 0x1337c0de
const val IPC_HANDSHAKE_STREAM_ID_SUB: Int = 0x1337c0d3 const val IPC_HANDSHAKE_STREAM_ID_SUB: Int = 0x1337c0d3
init {
println("THIS IS ONLY IPV4 AT THE MOMENT. IPV6 is in progress!")
}
fun errorCodeName(result: Long): String { fun errorCodeName(result: Long): String {
return when (result) { return when (result) {
Publication.NOT_CONNECTED -> "Not connected" Publication.NOT_CONNECTED -> "Not connected"
@ -104,12 +119,10 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal val autoClosableObjects = CopyOnWriteArrayList<AutoCloseable>() internal val autoClosableObjects = CopyOnWriteArrayList<AutoCloseable>()
internal val actionDispatch = CoroutineScope(Dispatchers.Default) internal val actionDispatch = CoroutineScope(Dispatchers.Default)
// internal val connectionActor = actionDispatch.connectionActor()
internal val listenerManager = ListenerManager<CONNECTION>(logger) { message, cause -> internal val listenerManager = ListenerManager<CONNECTION>(logger)
newException(message, cause) internal val connections = ConnectionManager<CONNECTION>(logger, config)
}
internal abstract val handshake: ConnectionManager<CONNECTION>
internal val mediaDriverContext: MediaDriver.Context internal val mediaDriverContext: MediaDriver.Context
internal val mediaDriver: MediaDriver internal val mediaDriver: MediaDriver
@ -135,17 +148,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
internal val rmiGlobalSupport = RmiSupport(logger, actionDispatch, config.serialization) internal val rmiGlobalSupport = RmiSupport(logger, actionDispatch, config.serialization)
/**
* Checks to see if this client has connected yet or not.
*
* Once a server has connected to ANY client, it will always return true until server.close() is called
*
* @return true if we are connected, false otherwise.
*/
abstract fun isConnected(): Boolean
init { init {
logger.error("NETWORK STACK IS ONLY IPV4 AT THE MOMENT. IPV6 is in progress!")
runBlocking { runBlocking {
// our default onError handler. All error messages go though this // our default onError handler. All error messages go though this
listenerManager.onError { throwable -> listenerManager.onError { throwable ->
@ -215,33 +220,20 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* *
* After this command is executed the new disk will be mounted under /Volumes/DevShm. * After this command is executed the new disk will be mounted under /Volumes/DevShm.
*/ */
var aeronDirAlreadyExists = false
if (config.aeronLogDirectory == null) { if (config.aeronLogDirectory == null) {
val baseFile = when { val baseFileLocation = config.suggestAeronLogLocation(logger)
OS.isMacOsX() -> {
logger.info("It is recommended to create a RAM drive for best performance. For example\n" +
"\$ diskutil erasevolume HFS+ \"DevShm\" `hdiutil attach -nomount ram://\$((2048 * 2048))`\n" +
"\t After this, set config.aeronLogDirectory = \"/Volumes/DevShm\"")
File(System.getProperty("java.io.tmpdir"))
}
OS.isLinux() -> {
// this is significantly faster for linux than using the temp dir
File(System.getProperty("/dev/shm/"))
}
else -> {
File(System.getProperty("java.io.tmpdir"))
}
}
val baseName = "aeron-" + type.simpleName val aeronLogDirectory = File(baseFileLocation, "aeron-" + type.simpleName)
val aeronLogDirectory = File(baseFile, baseName) aeronDirAlreadyExists = aeronLogDirectory.exists()
if (aeronLogDirectory.exists()) {
logger.info("Aeron log directory already exists! This might not be what you want!")
}
logger.debug("Aeron log directory: $aeronLogDirectory")
config.aeronLogDirectory = aeronLogDirectory config.aeronLogDirectory = aeronLogDirectory
} }
logger.info("Aeron log directory: ${config.aeronLogDirectory}")
if (aeronDirAlreadyExists) {
logger.info("Aeron log directory already exists! This might not be what you want!")
}
val threadFactory = NamedThreadFactory("Aeron", false) val threadFactory = NamedThreadFactory("Aeron", false)
// LOW-LATENCY SETTINGS // LOW-LATENCY SETTINGS
@ -254,7 +246,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
mediaDriverContext = MediaDriver.Context() mediaDriverContext = MediaDriver.Context()
.publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW) .publicationReservedSessionIdLow(RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH) .publicationReservedSessionIdHigh(RESERVED_SESSION_ID_HIGH)
.dirDeleteOnStart(true) // TODO: FOR NOW? .dirDeleteOnStart(true)
.dirDeleteOnShutdown(true) .dirDeleteOnShutdown(true)
.conductorThreadFactory(threadFactory) .conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory) .receiverThreadFactory(threadFactory)
@ -304,7 +296,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
crypto = CryptoManagement(logger, settingsStore, type, config) crypto = CryptoManagement(logger, settingsStore, type, config)
// we are done with initial configuration, now finish serialization // we are done with initial configuration, now finish serialization
serialization.finishInit(type) runBlocking {
serialization.finishInit(type)
}
} }
abstract fun newException(message: String, cause: Throwable? = null): Throwable abstract fun newException(message: String, cause: Throwable? = null): Throwable
@ -313,7 +307,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* Returns the property store used by this endpoint. The property store can store via properties, * Returns the property store used by this endpoint. The property store can store via properties,
* a database, etc, or can be a "null" property store, which does nothing * a database, etc, or can be a "null" property store, which does nothing
*/ */
fun <S : SettingsStore> getPropertyStore(): S { fun <S : SettingsStore> getStorage(): S {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
return settingsStore as S return settingsStore as S
} }
@ -321,16 +315,16 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
/** /**
* This method allows the connections used by the client/server to be subclassed (with custom implementations). * This method allows the connections used by the client/server to be subclassed (with custom implementations).
* *
* As this is for the network stack, the new connection MUST subclass [ConnectionImpl] * As this is for the network stack, the new connection MUST subclass [Connection]
* *
* The parameters are ALL NULL when getting the base class, as this instance is just thrown away. * The parameters are ALL NULL when getting the base class, as this instance is just thrown away.
* *
* @return a new network connection * @return a new network connection
*/ */
@Suppress("MemberVisibilityCanBePrivate") @Suppress("MemberVisibilityCanBePrivate")
open fun newConnection(endPoint: EndPoint<CONNECTION>, mediaDriverConnection: MediaDriverConnection): CONNECTION { open fun newConnection(connectionParameters: ConnectionParams<CONNECTION>): CONNECTION {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
return Connection(endPoint, mediaDriverConnection) as CONNECTION return Connection(connectionParameters) as CONNECTION
} }
/** /**
@ -341,14 +335,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
return RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch) return RmiSupportConnection(logger, rmiGlobalSupport, serialization, actionDispatch)
} }
/**
* Disables remote endpoint public key validation when the connection is established. This is not recommended as it is a security risk
*/
fun disableRemoteKeyValidation() {
logger.info("WARNING: Disabling remote key validation is a security risk!!")
crypto.disableRemoteKeyValidation = true
}
/** /**
* Adds an IP+subnet rule that defines if that IP+subnet is allowed/denied connectivity to this server. * Adds an IP+subnet rule that defines if that IP+subnet is allowed/denied connectivity to this server.
* *
@ -439,7 +425,9 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* Runs an action for each connection inside of a read-lock * Runs an action for each connection inside of a read-lock
*/ */
suspend fun forEachConnection(function: suspend (connection: CONNECTION) -> Unit) { suspend fun forEachConnection(function: suspend (connection: CONNECTION) -> Unit) {
handshake.forEachConnectionDoRead(function) connections.forEach {
function(it)
}
} }
internal suspend fun writeHandshakeMessage(publication: Publication, message: Any) { internal suspend fun writeHandshakeMessage(publication: Publication, message: Any) {
@ -475,7 +463,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
return return
} }
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error serializing message $message", e) listenerManager.notifyError(newException("Error serializing message $message", e))
} finally { } finally {
sendIdleStrategy.reset() sendIdleStrategy.reset()
serialization.returnKryo(kryo) serialization.returnKryo(kryo)
@ -489,12 +477,12 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
* *
* @return A string * @return A string
*/ */
internal fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? { internal suspend fun readHandshakeMessage(buffer: DirectBuffer, offset: Int, length: Int, header: Header): Any? {
val kryo: KryoExtra = serialization.takeKryo() val kryo: KryoExtra = serialization.takeKryo()
try { try {
val message = kryo.read(buffer, offset, length) val message = kryo.read(buffer, offset, length)
logger.trace { logger.trace {
"[${header.sessionId()}] received: $message" "[${header.sessionId()}] received handshake: $message"
} }
return message return message
@ -529,7 +517,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
try { try {
message = kryo.read(buffer, offset, length, connection) message = kryo.read(buffer, offset, length, connection)
logger.trace { logger.trace {
"[${sessionId}] received: ${message}" "[${sessionId}] received: $message"
} }
} catch (e: Exception) { } catch (e: Exception) {
listenerManager.notifyError(newException("[${sessionId}] Error de-serializing message", e)) listenerManager.notifyError(newException("[${sessionId}] Error de-serializing message", e))
@ -537,10 +525,6 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
serialization.returnKryo(kryo) serialization.returnKryo(kryo)
} }
val data = ByteArray(length)
buffer.getBytes(offset, data)
when (message) { when (message) {
is PingMessage -> { is PingMessage -> {
// the ping listener (internal use only!) // the ping listener (internal use only!)
@ -561,33 +545,20 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
is Any -> { is Any -> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
listenerManager.notifyOnMessage(connection as CONNECTION, message) var hasListeners = listenerManager.notifyOnMessage(connection as CONNECTION, message)
// each connection registers, and is polled INDEPENDENTLY for messages.
hasListeners = hasListeners or connection.notifyOnMessage(message)
if (!hasListeners) {
listenerManager.notifyError(connection, MessageNotRegisteredException("No message callbacks found for ${message::class.java.simpleName}"))
}
} }
else -> { else -> {
// do nothing, there were problems with the message // do nothing, there were problems with the message
} }
} }
} }
override fun toString(): String { override fun toString(): String {
return "EndPoint [${type.simpleName}]" return "EndPoint [${type.simpleName}]"
@ -611,7 +582,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
return false return false
} }
other as EndPoint<CONNECTION> other as EndPoint<*>
return crypto == other.crypto return crypto == other.crypto
} }
@ -637,6 +608,15 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
} }
autoClosableObjects.clear() autoClosableObjects.clear()
runBlocking {
// don't need anything fast or fancy here, because this method will only be called once
connections.forEach {
it.close()
listenerManager.notifyDisconnect(it)
}
}
connections.close()
actionDispatch.cancel() actionDispatch.cancel()
shutdownLatch.countDown() shutdownLatch.countDown()
} }

View File

@ -15,7 +15,7 @@ import net.jodah.typetools.TypeResolver
/** /**
* Manages all of the different connect/disconnect/etc listeners * Manages all of the different connect/disconnect/etc listeners
*/ */
internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogger, private val exceptionGetter: (String, Throwable?) -> Throwable) { internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogger) {
companion object { companion object {
/** /**
* Specifies the load-factor for the IdentityMap used to manage keeping track of the number of connections + listeners * Specifies the load-factor for the IdentityMap used to manage keeping track of the number of connections + listeners
@ -175,7 +175,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
* *
* This method should not block for long periods as other network activity will not be processed until it returns. * This method should not block for long periods as other network activity will not be processed until it returns.
*/ */
suspend fun <MESSAGE : Any> onMessage(function: suspend (CONNECTION, MESSAGE) -> Unit) { suspend fun <MESSAGE> onMessage(function: suspend (CONNECTION, MESSAGE) -> Unit) {
onMessageMutex.withLock { onMessageMutex.withLock {
// we have to follow the single-writer principle! // we have to follow the single-writer principle!
@ -264,8 +264,6 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
// return true // return true
// } // }
//
// //
// for (i in 0 until size) { // for (i in 0 until size) {
// val rule = ipFilterRules[i] ?: continue // val rule = ipFilterRules[i] ?: continue
@ -299,38 +297,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
try { try {
it(connection) it(connection)
} catch (t: Throwable) { } catch (t: Throwable) {
// // NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace // NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
// val throwable = result as Throwable cleanStackTrace(t)
// val reversedList = throwable.stackTrace.reversed().toMutableList()
//
// // we have to remove kotlin stuff from the stacktrace
// var reverseIter = reversedList.iterator()
// while (reverseIter.hasNext()) {
// val stackName = reverseIter.next().className
// if (stackName.startsWith("kotlinx.coroutines") || stackName.startsWith("kotlin.coroutines")) {
// // cleanup the stack elements which create the stacktrace
// reverseIter.remove()
// } else {
// // done cleaning up the tail from kotlin
// break
// }
// }
//
// // remove dorkbox network stuff
// reverseIter = reversedList.iterator()
// while (reverseIter.hasNext()) {
// val stackName = reverseIter.next().className
// if (stackName.startsWith("dorkbox.network")) {
// // cleanup the stack elements which create the stacktrace
// reverseIter.remove()
// } else {
// // done cleaning up the tail from network
// break
// }
// }
//
// throwable.stackTrace = reversedList.reversed().toTypedArray()
notifyError(connection, t) notifyError(connection, t)
} }
} }
@ -343,8 +311,10 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
onDisconnectList.value.forEach { onDisconnectList.value.forEach {
try { try {
it(connection) it(connection)
} catch (e: Throwable) { } catch (t: Throwable) {
notifyError(connection, exceptionGetter("Error during notifyDisconnect", e)) // NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
cleanStackTrace(t)
notifyError(connection, t)
} }
} }
} }
@ -371,12 +341,12 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
} }
} }
/** /**
* Invoked when a message object was received from a remote peer. * Invoked when a message object was received from a remote peer.
*
* @return true if there were listeners assigned for this message type
*/ */
suspend fun notifyOnMessage(connection: CONNECTION, message: Any) { suspend fun notifyOnMessage(connection: CONNECTION, message: Any): Boolean {
val messageClass: Class<*> = message.javaClass val messageClass: Class<*> = message.javaClass
// have to save the types + hierarchy (note: duplicates are OK, since they will just be overwritten) // have to save the types + hierarchy (note: duplicates are OK, since they will just be overwritten)
@ -394,8 +364,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
// this is EXPLICITLY listed as a "Don't" via the documentation. The ****ONLY**** reason this is actually OK is because // this is EXPLICITLY listed as a "Don't" via the documentation. The ****ONLY**** reason this is actually OK is because
// we are following the "single-writer principle", so only ONE THREAD can modify this at a time. // we are following the "single-writer principle", so only ONE THREAD can modify this at a time.
// cache the lookup (because we don't care about race conditions, since the object hierarchy will be ALREADY established at this // cache the lookup
// exact moment // we don't care about race conditions, since the object hierarchy will be ALREADY established at this exact moment
val tempMap = onMessageMap.value val tempMap = onMessageMap.value
var hasListeners = false var hasListeners = false
hierarchy.forEach { clazz -> hierarchy.forEach { clazz ->
@ -413,207 +383,6 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
} }
} }
return hasListeners
// foundListener |= onMessageReceivedManager.notifyReceived((C) connection, message, shutdown);
// now have to account for additional connection listener managers (non-global).
// access a snapshot of the managers (single-writer-principle)
// val localManager = localManagers[connection as C]
// if (localManager != null) {
// // if we found a listener during THIS method call, we need to let the NEXT method call know,
// // so it doesn't spit out error for not handling a message (since that message MIGHT have
// // been found in this method).
// foundListener = foundListener or localManager.notifyOnMessage0(connection, message, foundListener)
// }
if (!hasListeners) {
logger.error("----------- MESSAGE CALLBACK NOT REGISTERED FOR {}", messageClass.simpleName)
}
} }
//
// override fun remove(listener: OnConnected<C>): Listeners<C> {
// return this
// }
// /**
// * Adds a listener to this connection/endpoint to be notified of connect/disconnect/idle/receive(object) events.
// * <p/>
// * When called by a server, NORMALLY listeners are added at the GLOBAL level (meaning, I add one listener, and ALL connections are
// * notified of that listener.
// * <p/>
// * It is POSSIBLE to add a server connection ONLY (ie, not global) listener (via connection.addListener), meaning that ONLY that
// * listener attached to the connection is notified on that event (ie, admin type listeners)
// *
// *
// * // TODO: When converting to kotlin, use reified! to get the listener types
// * https://kotlinlang.org/docs/reference/inline-functions.html
// */
// @Override
// public final
// Listeners add(final Listener listener) {
// if (listener == null) {
// throw new IllegalArgumentException("listener cannot be null.");
// }
//
// // this is the connection generic parameter for the listener, works for lambda expressions as well
// Class<?> genericClass = ClassHelper.getGenericParameterAsClassForSuperClass(Listener.class, listener.getClass(), 0);
//
// // if we are null, it means that we have no generics specified for our listener!
// if (genericClass == this.baseClass || genericClass == TypeResolver.Unknown.class || genericClass == null) {
// // we are the base class, so we are fine.
// addListener0(listener);
// return this;
//
// }
// else if (ClassHelper.hasInterface(Connection.class, genericClass) && !ClassHelper.hasParentClass(this.baseClass, genericClass)) {
// // now we must make sure that the PARENT class is NOT the base class. ONLY the base class is allowed!
// addListener0(listener);
// return this;
// }
//
// // didn't successfully add the listener.
// throw new IllegalArgumentException("Unable to add incompatible connection type as a listener! : " + this.baseClass);
// }
//
// /**
// * INTERNAL USE ONLY
// */
// private
// void addListener0(final Listener listener) {
// boolean found = false;
// if (listener instanceof OnConnected) {
// onConnectedManager.add((Listener.OnConnected<C>) listener);
// found = true;
// }
// if (listener instanceof Listener.OnDisconnected) {
// onDisconnectedManager.add((Listener.OnDisconnected<C>) listener);
// found = true;
// }
//
// if (listener instanceof Listener.OnMessageReceived) {
// onMessageReceivedManager.add((Listener.OnMessageReceived) listener);
// found = true;
// }
//
// if (found) {
// hasAtLeastOneListener.set(true);
//
// if (logger.isTraceEnabled()) {
// logger.trace("listener added: {}",
// listener.getClass()
// .getName());
// }
// }
// else {
// logger.error("No matching listener types. Unable to add listener: {}",
// listener.getClass()
// .getName());
// }
// }
//
// /**
// * Removes a listener from this connection/endpoint to NO LONGER be notified of connect/disconnect/idle/receive(object) events.
// * <p/>
// * When called by a server, NORMALLY listeners are added at the GLOBAL level (meaning, I add one listener, and ALL connections are
// * notified of that listener.
// * <p/>
// * It is POSSIBLE to remove a server-connection 'non-global' listener (via connection.removeListener), meaning that ONLY that listener
// * attached to the connection is removed
// */
// @Override
// public final
// Listeners remove(final Listener listener) {
// if (listener == null) {
// throw new IllegalArgumentException("listener cannot be null.");
// }
//
// if (logger.isTraceEnabled()) {
// logger.trace("listener removed: {}",
// listener.getClass()
// .getName());
// }
//
// boolean found = false;
// int remainingListeners = 0;
//
// if (listener instanceof Listener.OnConnected) {
// int size = onConnectedManager.removeWithSize((OnConnected<C>) listener);
// if (size >= 0) {
// remainingListeners += size;
// found = true;
// }
// }
// if (listener instanceof Listener.OnDisconnected) {
// int size = onDisconnectedManager.removeWithSize((Listener.OnDisconnected<C>) listener);
// if (size >= 0) {
// remainingListeners += size;
// found |= true;
// }
// }
// if (listener instanceof Listener.OnMessageReceived) {
// int size = onMessageReceivedManager.removeWithSize((Listener.OnMessageReceived) listener);
// if (size >= 0) {
// remainingListeners += size;
// found |= true;
// }
// }
//
// if (found) {
// if (remainingListeners == 0) {
// hasAtLeastOneListener.set(false);
// }
// }
// else {
// logger.error("No matching listener types. Unable to remove listener: {}",
// listener.getClass()
// .getName());
//
// }
//
// return this;
// }
// /**
// * Removes all registered listeners from this connection/endpoint to NO LONGER be notified of connect/disconnect/idle/receive(object)
// * events.
// */
// override fun removeAll(): Listeners<C> {
// // onConnectedManager.clear();
// // onDisconnectedManager.clear();
// // onMessageReceivedManager.clear();
// logger.error("ALL listeners removed !!")
// return this
// }
// /**
// * Removes all registered listeners (of the object type) from this
// * connection/endpoint to NO LONGER be notified of
// * connect/disconnect/idle/receive(object) events.
// */
// override fun removeAll(classType: Class<*>): Listeners<C> {
// val logger2 = logger
// // if (onMessageReceivedManager.removeAll(classType)) {
// // if (logger2.isTraceEnabled()) {
// // logger2.trace("All listeners removed for type: {}",
// // classType.getClass()
// // .getName());
// // }
// // } else {
// // logger2.warn("No listeners found to remove for type: {}",
// // classType.getClass()
// // .getName());
// // }
// return this
// }
} }

View File

@ -1,104 +0,0 @@
/*
* Copyright 2010 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
/**
* Generic types are in place to make sure that users of the application do not
* accidentally add an incompatible connection type.
*/
interface Listeners<C : Connection> {
/**
* Adds a function that will be called BEFORE a client/server "connects" with
* each other, and used to determine if a connection should be allowed
*
* If the function returns TRUE, then the connection will continue to connect.
* If the function returns FALSE, then the other end of the connection will
* receive a connection error
*
* For a server, this function will be called for ALL clients.
*/
fun filter(function: (C) -> Boolean): Int
/**
* Adds a function that will be called when a client/server "connects" with
* each other
*
* For a server, this function will be called for ALL clients.
*/
fun onConnect(function: (C) -> Unit): Int
/**
* Adds a function that will be called when a client/server "disconnects" with
* each other
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
fun onDisconnect(function: (C) -> Unit): Int
/**
* Adds a function that will be called when a client/server encounters an error
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
fun onError(function: (C, throwable: Throwable) -> Unit): Int
/**
* Adds a function that will be called when a client/server receives a message
*
* For a server, this function will be called for ALL clients.
*
* It is POSSIBLE to add a server CONNECTION only (ie, not global) listener
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
fun <M : Any> onMessage(function: (C, M) -> Unit): Int
/**
* Removes a listener from this connection/endpoint to NO LONGER be notified
* of connect/disconnect/idle/receive(object) events.
*
*
* When called by a server, NORMALLY listeners are added at the GLOBAL level
* (meaning, I add one listener, and ALL connections are notified of that
* listener.
*
*
* It is POSSIBLE to remove a server-connection 'non-global' listener (via
* connection.removeListener), meaning that ONLY that listener attached to
* the connection is removed
*/
fun remove(listenerId: Int)
/**
* Removes all registered listeners from this connection/endpoint to NO
* LONGER be notified of connect/disconnect/idle/receive(object) events.
*/
fun removeAll()
}

View File

@ -34,14 +34,13 @@ interface MediaDriverConnection : AutoCloseable {
/** /**
* For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER * For a client, the ports specified here MUST be manually flipped because they are in the perspective of the SERVER
*/ */
class UdpMediaDriverConnection( class UdpMediaDriverConnection(override val address: String,
override val address: String, override val publicationPort: Int,
override val subscriptionPort: Int, override val subscriptionPort: Int,
override val publicationPort: Int, override val streamId: Int,
override val streamId: Int, override val sessionId: Int,
override val sessionId: Int, private val connectionTimeoutMS: Long = 0,
private val connectionTimeoutMS: Long = 0, override val isReliable: Boolean = true) : MediaDriverConnection {
override val isReliable: Boolean = true) : MediaDriverConnection {
override lateinit var subscription: Subscription override lateinit var subscription: Subscription
override lateinit var publication: Publication override lateinit var publication: Publication
@ -152,7 +151,7 @@ class UdpMediaDriverConnection(
return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
"Connecting to $address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)" "Connecting to $address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
} else { } else {
"Connecting to $address [$subscriptionPort|$publicationPort] [$streamId] (reliable:$isReliable)" "Connecting to $address [$subscriptionPort|$publicationPort] [$streamId|*] (reliable:$isReliable)"
} }
} }
@ -160,7 +159,7 @@ class UdpMediaDriverConnection(
return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) { return if (sessionId != EndPoint.RESERVED_SESSION_ID_INVALID) {
"Listening on $address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)" "Listening on $address [$subscriptionPort|$publicationPort] [$streamId|$sessionId] (reliable:$isReliable)"
} else { } else {
"Listening on $address [$subscriptionPort|$publicationPort] [$streamId] (reliable:$isReliable)" "Listening on $address [$subscriptionPort|$publicationPort] [$streamId|*] (reliable:$isReliable)"
} }
} }

View File

@ -0,0 +1,6 @@
package dorkbox.network.connection
/**
* thrown when a message is received, and does not have any registered 'onMessage' handlers.
*/
class MessageNotRegisteredException(errorMessage: String) : Exception(errorMessage)

View File

@ -0,0 +1,5 @@
package dorkbox.network.connection
enum class PublicKeyValidationState {
VALID, INVALID, TAMPERED
}