Optimized types and methods for the Registration wrapper

This commit is contained in:
nathan 2019-06-14 20:35:36 +02:00
parent bbd5aa0f59
commit 78d42e985a
3 changed files with 300 additions and 259 deletions

View File

@ -25,8 +25,6 @@ import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.slf4j.Logger;
import dorkbox.network.connection.registration.MetaChannel;
import dorkbox.network.connection.registration.Registration;
import dorkbox.network.connection.registration.UpgradeType;
import dorkbox.network.pipeline.tcp.KryoEncoderTcp;
import dorkbox.network.pipeline.tcp.KryoEncoderTcpCompression;
import dorkbox.network.pipeline.tcp.KryoEncoderTcpCrypto;
@ -40,27 +38,29 @@ import dorkbox.network.pipeline.udp.KryoEncoderUdpCompression;
import dorkbox.network.pipeline.udp.KryoEncoderUdpCrypto;
import dorkbox.network.pipeline.udp.KryoEncoderUdpNone;
import dorkbox.network.serialization.NetworkSerializationManager;
import dorkbox.network.serialization.Serialization;
import dorkbox.util.RandomUtil;
import dorkbox.util.collections.IntMap.Values;
import dorkbox.util.collections.LockFreeIntMap;
import dorkbox.util.crypto.CryptoECC;
import dorkbox.util.exceptions.SecurityException;
import io.netty.channel.Channel;
import io.netty.util.NetUtil;
/**
* Just wraps common/needed methods of the client/server endpoint by the registration stage/handshake.
* <p/>
* This is in the connection package, so it can access the endpoint methods that it needs to without having to publicly expose them
*/
public
public abstract
class RegistrationWrapper {
public
enum STATE { ERROR, WAIT, CONTINUE }
private final org.slf4j.Logger logger;
final org.slf4j.Logger logger;
final EndPoint endPoint;
// keeps track of connections/sessions (TCP/UDP/Local). The session ID '0' is reserved to mean "no session ID yet"
final LockFreeIntMap<MetaChannel> sessionMap = new LockFreeIntMap<MetaChannel>(32, ConnectionManager.LOAD_FACTOR);
public final KryoEncoderTcp kryoTcpEncoder;
public final KryoEncoderTcpNone kryoTcpEncoderNone;
@ -77,10 +77,6 @@ class RegistrationWrapper {
public final KryoDecoderUdpCompression kryoUdpDecoderCompression;
public final KryoDecoderUdpCrypto kryoUdpDecoderCrypto;
private final EndPoint endPoint;
// keeps track of connections/sessions (TCP/UDP/Local). The session ID '0' is reserved to mean "no session ID yet"
private final LockFreeIntMap<MetaChannel> sessionMap = new LockFreeIntMap<MetaChannel>(32, ConnectionManager.LOAD_FACTOR);
public
RegistrationWrapper(final EndPoint endPoint,
@ -119,25 +115,6 @@ class RegistrationWrapper {
return this.endPoint.getIdleTimeout();
}
/**
* Internal call by the pipeline to check if the client has more protocol registrations to complete.
*
* @return true if there are more registrations to process, false if we are 100% done with all types to register (TCP/UDP/etc)
*/
public
boolean hasMoreRegistrations() {
return this.endPoint.hasMoreRegistrations();
}
/**
* Internal call by the pipeline to notify the client to continue registering the different session protocols. The server does not use
* this.
*/
public
void startNextProtocolRegistration() {
this.endPoint.startNextProtocolRegistration();
}
/**
* Internal call by the pipeline to notify the "Connection" object that it has "connected".
*/
@ -169,23 +146,7 @@ class RegistrationWrapper {
/**
* Only called by the server!
*
* If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic.
*/
public
byte getConnectionUpgradeType(final InetSocketAddress remoteAddress) {
if (isClient()) {
throw new IllegalArgumentException("This should never be called by the client!");
}
if (remoteAddress.getAddress().equals(NetUtil.LOCALHOST)) {
return UpgradeType.COMPRESS;
}
return ((EndPointServer) this.endPoint).getConnectionUpgradeType(remoteAddress);
}
/**
* If the key does not match AND we have disabled remote key validation, then metachannel.changedRemoteKey = true. OTHERWISE, key validation is REQUIRED!
@ -229,72 +190,8 @@ class RegistrationWrapper {
return true;
}
public
void removeRegisteredServerKey(final byte[] hostAddress) throws SecurityException {
ECPublicKeyParameters savedPublicKey = this.endPoint.propertyStore.getRegisteredServerKey(hostAddress);
if (savedPublicKey != null) {
Logger logger2 = this.logger;
if (logger2.isDebugEnabled()) {
logger2.debug("Deleting remote IP address key {}.{}.{}.{}",
hostAddress[0],
hostAddress[1],
hostAddress[2],
hostAddress[3]);
}
this.endPoint.propertyStore.removeRegisteredServerKey(hostAddress);
}
}
public
boolean isClient() {
return (this.endPoint instanceof EndPointClient);
}
/**
* MetaChannel allow access to the same "session" across TCP/UDP/etc
* <p>
* The connection ID '0' is reserved to mean "no channel ID yet"
*/
public
MetaChannel createSessionClient(int sessionId) {
MetaChannel metaChannel = new MetaChannel(sessionId);
sessionMap.put(sessionId, metaChannel);
return metaChannel;
}
/**
* MetaChannel allow access to the same "session" across TCP/UDP/etc.
* <p>
* The connection ID '0' is reserved to mean "no channel ID yet"
*/
public
MetaChannel createSessionServer() {
int sessionId = RandomUtil.int_();
while (sessionId == 0 && sessionMap.containsKey(sessionId)) {
sessionId = RandomUtil.int_();
}
MetaChannel metaChannel;
synchronized (sessionMap) {
// one final check, but slower...
while (sessionId == 0 && sessionMap.containsKey(sessionId)) {
sessionId = RandomUtil.int_();
}
metaChannel = new MetaChannel(sessionId);
sessionMap.put(sessionId, metaChannel);
// TODO: clean out sessions that are stale!
}
return metaChannel;
}
/**
* The session ID '0' is reserved to mean "no session ID yet"
@ -304,17 +201,6 @@ class RegistrationWrapper {
return sessionMap.get(sessionId);
}
/**
* @return the first session we have available. This is for the CLIENT to track sessions (between TCP/UDP) to a server
*/
public MetaChannel getFirstSession() {
Values<MetaChannel> values = sessionMap.values();
if (values.hasNext) {
return values.next();
}
return null;
}
/**
* The SERVER AND CLIENT will stop tracking a session once the session is complete.
*/
@ -327,7 +213,7 @@ class RegistrationWrapper {
}
/**
* The SERVER will stop tracking a session if there are errors
* The CLIENT/SERVER will stop tracking a session if there are errors
*/
public
void closeSession(final int sessionId) {
@ -371,141 +257,4 @@ class RegistrationWrapper {
channel.close();
}
}
public
boolean initClassRegistration(final Channel channel, final Registration registration) {
byte[] details = this.endPoint.getSerialization().getKryoRegistrationDetails();
int length = details.length;
if (length > Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) {
// it is too large to send in a single packet
// child arrays have index 0 also as their 'index' and 1 is the total number of fragments
byte[][] fragments = divideArray(details, Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE);
if (fragments == null) {
logger.error("Too many classes have been registered for Serialization. Please report this issue");
return false;
}
int allButLast = fragments.length - 1;
for (int i = 0; i < allButLast; i++) {
final byte[] fragment = fragments[i];
Registration fragmentedRegistration = new Registration(registration.sessionID);
fragmentedRegistration.payload = fragment;
// tell the server we are fragmented
fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED;
// tell the server we are upgraded (it will bounce back telling us to connect)
fragmentedRegistration.upgraded = true;
channel.writeAndFlush(fragmentedRegistration);
}
// now tell the server we are done with the fragments
Registration fragmentedRegistration = new Registration(registration.sessionID);
fragmentedRegistration.payload = fragments[allButLast];
// tell the server we are fragmented
fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED;
// tell the server we are upgraded (it will bounce back telling us to connect)
fragmentedRegistration.upgraded = true;
channel.writeAndFlush(fragmentedRegistration);
} else {
registration.payload = details;
// tell the server we are upgraded (it will bounce back telling us to connect)
registration.upgraded = true;
channel.writeAndFlush(registration);
}
return true;
}
public
STATE verifyClassRegistration(final MetaChannel metaChannel, final Registration registration) {
if (registration.upgradeType == UpgradeType.FRAGMENTED) {
byte[] fragment = registration.payload;
// this means that the registrations are FRAGMENTED!
// max size of ALL fragments is xxx * 127
if (metaChannel.fragmentedRegistrationDetails == null) {
metaChannel.remainingFragments = fragment[1];
metaChannel.fragmentedRegistrationDetails = new byte[Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * fragment[1]];
}
System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE, fragment.length - 2);
metaChannel.remainingFragments--;
if (fragment[0] + 1 == fragment[1]) {
// this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive)
int correctSize = (Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * (fragment[1] - 1)) + (fragment.length - 2);
byte[] correctlySized = new byte[correctSize];
System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize);
metaChannel.fragmentedRegistrationDetails = correctlySized;
}
if (metaChannel.remainingFragments == 0) {
// there are no more fragments available
byte[] details = metaChannel.fragmentedRegistrationDetails;
metaChannel.fragmentedRegistrationDetails = null;
if (!this.endPoint.getSerialization().verifyKryoRegistration(details)) {
// error
return STATE.ERROR;
}
} else {
// wait for more fragments
return STATE.WAIT;
}
}
else {
if (!this.endPoint.getSerialization().verifyKryoRegistration(registration.payload)) {
return STATE.ERROR;
}
}
return STATE.CONTINUE;
}
/**
* Split array into chunks, max of 256 chunks.
* byte[0] = chunk ID
* byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte)
*/
private static
byte[][] divideArray(byte[] source, int chunksize) {
int fragments = (int) Math.ceil(source.length / ((double) chunksize + 2));
if (fragments > 127) {
// cannot allow more than 127
return null;
}
// pre-allocate the memory
byte[][] splitArray = new byte[fragments][chunksize + 2];
int start = 0;
for (int i = 0; i < splitArray.length; i++) {
int length;
if (start + chunksize > source.length) {
length = source.length - start;
}
else {
length = chunksize;
}
splitArray[i] = new byte[length+2];
splitArray[i][0] = (byte) i;
splitArray[i][1] = (byte) fragments;
System.arraycopy(source, start, splitArray[i], 2, length);
start += chunksize;
}
return splitArray;
}
}

View File

@ -0,0 +1,178 @@
package dorkbox.network.connection;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.slf4j.Logger;
import dorkbox.network.connection.registration.MetaChannel;
import dorkbox.network.connection.registration.Registration;
import dorkbox.network.connection.registration.UpgradeType;
import dorkbox.network.serialization.Serialization;
import dorkbox.util.collections.IntMap.Values;
import dorkbox.util.exceptions.SecurityException;
import io.netty.channel.Channel;
/**
*
*/
public
class RegistrationWrapperClient extends RegistrationWrapper {
public
RegistrationWrapperClient(final EndPoint endPoint, final Logger logger) {
super(endPoint, logger);
}
/**
* MetaChannel allow access to the same "session" across TCP/UDP/etc
* <p>
* The connection ID '0' is reserved to mean "no channel ID yet"
*/
public
MetaChannel createSession(int sessionId) {
MetaChannel metaChannel = new MetaChannel(sessionId);
sessionMap.put(sessionId, metaChannel);
return metaChannel;
}
/**
* @return the first session we have available. This is for the CLIENT to track sessions (between TCP/UDP) to a server
*/
public MetaChannel getFirstSession() {
Values<MetaChannel> values = sessionMap.values();
if (values.hasNext) {
return values.next();
}
return null;
}
public
boolean isClient() {
return true;
}
/**
* Internal call by the pipeline to check if the client has more protocol registrations to complete.
*
* @return true if there are more registrations to process, false if we are 100% done with all types to register (TCP/UDP/etc)
*/
public
boolean hasMoreRegistrations() {
return this.endPoint.hasMoreRegistrations();
}
/**
* Internal call by the pipeline to notify the client to continue registering the different session protocols. The server does not use
* this.
*/
public
void startNextProtocolRegistration() {
this.endPoint.startNextProtocolRegistration();
}
public
void removeRegisteredServerKey(final byte[] hostAddress) throws SecurityException {
ECPublicKeyParameters savedPublicKey = this.endPoint.propertyStore.getRegisteredServerKey(hostAddress);
if (savedPublicKey != null) {
Logger logger2 = this.logger;
if (logger2.isDebugEnabled()) {
logger2.debug("Deleting remote IP address key {}.{}.{}.{}",
hostAddress[0],
hostAddress[1],
hostAddress[2],
hostAddress[3]);
}
this.endPoint.propertyStore.removeRegisteredServerKey(hostAddress);
}
}
public
boolean initClassRegistration(final Channel channel, final Registration registration) {
byte[] details = this.endPoint.getSerialization().getKryoRegistrationDetails();
int length = details.length;
if (length > Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) {
// it is too large to send in a single packet
// child arrays have index 0 also as their 'index' and 1 is the total number of fragments
byte[][] fragments = divideArray(details, Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE);
if (fragments == null) {
logger.error("Too many classes have been registered for Serialization. Please report this issue");
return false;
}
int allButLast = fragments.length - 1;
for (int i = 0; i < allButLast; i++) {
final byte[] fragment = fragments[i];
Registration fragmentedRegistration = new Registration(registration.sessionID);
fragmentedRegistration.payload = fragment;
// tell the server we are fragmented
fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED;
// tell the server we are upgraded (it will bounce back telling us to connect)
fragmentedRegistration.upgraded = true;
channel.writeAndFlush(fragmentedRegistration);
}
// now tell the server we are done with the fragments
Registration fragmentedRegistration = new Registration(registration.sessionID);
fragmentedRegistration.payload = fragments[allButLast];
// tell the server we are fragmented
fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED;
// tell the server we are upgraded (it will bounce back telling us to connect)
fragmentedRegistration.upgraded = true;
channel.writeAndFlush(fragmentedRegistration);
} else {
registration.payload = details;
// tell the server we are upgraded (it will bounce back telling us to connect)
registration.upgraded = true;
channel.writeAndFlush(registration);
}
return true;
}
/**
* Split array into chunks, max of 256 chunks.
* byte[0] = chunk ID
* byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte)
*/
private static
byte[][] divideArray(byte[] source, int chunksize) {
int fragments = (int) Math.ceil(source.length / ((double) chunksize + 2));
if (fragments > 127) {
// cannot allow more than 127
return null;
}
// pre-allocate the memory
byte[][] splitArray = new byte[fragments][chunksize + 2];
int start = 0;
for (int i = 0; i < splitArray.length; i++) {
int length;
if (start + chunksize > source.length) {
length = source.length - start;
}
else {
length = chunksize;
}
splitArray[i] = new byte[length+2];
splitArray[i][0] = (byte) i;
splitArray[i][1] = (byte) fragments;
System.arraycopy(source, start, splitArray[i], 2, length);
start += chunksize;
}
return splitArray;
}
}

View File

@ -0,0 +1,114 @@
package dorkbox.network.connection;
import java.net.InetSocketAddress;
import org.slf4j.Logger;
import dorkbox.network.connection.registration.MetaChannel;
import dorkbox.network.connection.registration.Registration;
import dorkbox.network.connection.registration.UpgradeType;
import dorkbox.network.serialization.Serialization;
import dorkbox.util.RandomUtil;
/**
*
*/
public
class RegistrationWrapperServer extends RegistrationWrapper {
public
RegistrationWrapperServer(final EndPoint endPoint, final Logger logger) {
super(endPoint, logger);
}
/**
* MetaChannel allow access to the same "session" across TCP/UDP/etc.
* <p>
* The connection ID '0' is reserved to mean "no channel ID yet"
*/
public
MetaChannel createSession() {
int sessionId = RandomUtil.int_();
while (sessionId == 0 && sessionMap.containsKey(sessionId)) {
sessionId = RandomUtil.int_();
}
MetaChannel metaChannel;
synchronized (sessionMap) {
// one final check, but slower...
while (sessionId == 0 && sessionMap.containsKey(sessionId)) {
sessionId = RandomUtil.int_();
}
metaChannel = new MetaChannel(sessionId);
sessionMap.put(sessionId, metaChannel);
// TODO: clean out sessions that are stale!
}
return metaChannel;
}
public
boolean acceptRemoteConnection(final InetSocketAddress remoteAddress) {
return ((EndPointServer) this.endPoint).acceptRemoteConnection(remoteAddress);
}
/**
* Only called by the server!
*
* If we are loopback or the client is a specific IP/CIDR address, then we do things differently. The LOOPBACK address will never encrypt or compress the traffic.
*/
public
byte getConnectionUpgradeType(final InetSocketAddress remoteAddress) {
return ((EndPointServer) this.endPoint).getConnectionUpgradeType(remoteAddress);
}
public
STATE verifyClassRegistration(final MetaChannel metaChannel, final Registration registration) {
if (registration.upgradeType == UpgradeType.FRAGMENTED) {
byte[] fragment = registration.payload;
// this means that the registrations are FRAGMENTED!
// max size of ALL fragments is xxx * 127
if (metaChannel.fragmentedRegistrationDetails == null) {
metaChannel.remainingFragments = fragment[1];
metaChannel.fragmentedRegistrationDetails = new byte[Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * fragment[1]];
}
System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE, fragment.length - 2);
metaChannel.remainingFragments--;
if (fragment[0] + 1 == fragment[1]) {
// this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive)
int correctSize = (Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE * (fragment[1] - 1)) + (fragment.length - 2);
byte[] correctlySized = new byte[correctSize];
System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize);
metaChannel.fragmentedRegistrationDetails = correctlySized;
}
if (metaChannel.remainingFragments == 0) {
// there are no more fragments available
byte[] details = metaChannel.fragmentedRegistrationDetails;
metaChannel.fragmentedRegistrationDetails = null;
if (!this.endPoint.getSerialization().verifyKryoRegistration(details)) {
// error
return STATE.ERROR;
}
} else {
// wait for more fragments
return STATE.WAIT;
}
}
else {
if (!this.endPoint.getSerialization().verifyKryoRegistration(registration.payload)) {
return STATE.ERROR;
}
}
return STATE.CONTINUE;
}
}