diff --git a/src/dorkbox/network/connection/RegistrationWrapper.java b/src/dorkbox/network/connection/RegistrationWrapper.java index 01406b9a..d8a28883 100644 --- a/src/dorkbox/network/connection/RegistrationWrapper.java +++ b/src/dorkbox/network/connection/RegistrationWrapper.java @@ -26,6 +26,7 @@ import org.bouncycastle.crypto.params.ECPublicKeyParameters; import org.slf4j.Logger; import dorkbox.network.connection.registration.MetaChannel; +import dorkbox.network.connection.registration.Registration; import dorkbox.network.pipeline.tcp.KryoEncoder; import dorkbox.network.pipeline.tcp.KryoEncoderCrypto; import dorkbox.network.pipeline.udp.KryoDecoderUdp; @@ -47,6 +48,9 @@ import io.netty.channel.Channel; */ public class RegistrationWrapper { + public + enum STATE { ERROR, WAIT, CONTINUE } + private final org.slf4j.Logger logger; public final KryoEncoder kryoTcpEncoder; @@ -206,15 +210,6 @@ class RegistrationWrapper { } } - public - boolean verifyKryoRegistration(byte[] bytes) { - return this.endPoint.getSerialization().verifyKryoRegistration(bytes); - } - - public - byte[] getKryoRegistrationDetails() { - return this.endPoint.getSerialization().getKryoRegistrationDetails(); - } public boolean isClient() { @@ -223,10 +218,6 @@ class RegistrationWrapper { - - - - /** * MetaChannel allow access to the same "session" across TCP/UDP/etc *

@@ -346,4 +337,140 @@ class RegistrationWrapper { } + public + boolean initClassRegistration(final Channel channel, final Registration registration) { + byte[] details = this.endPoint.getSerialization().getKryoRegistrationDetails(); + + int length = details.length; + if (length > 480) { + // it is too large to send in a single packet + + // child arrays have index 0 also as their 'index' and 1 is the total number of fragments + byte[][] fragments = divideArray(details, 480); + if (fragments == null) { + logger.error("Too many classes have been registered for Serialization. Please report this issue"); + + return false; + } + + int allButLast = fragments.length - 1; + + for (int i = 0; i < allButLast; i++) { + final byte[] fragment = fragments[i]; + Registration fragmentedRegistration = new Registration(registration.sessionID); + fragmentedRegistration.payload = fragment; + + // tell the server we are fragmented + fragmentedRegistration.upgrade = true; + + // tell the server we are upgraded (it will bounce back telling us to connect) + fragmentedRegistration.upgraded = true; + channel.writeAndFlush(fragmentedRegistration); + } + + // now tell the server we are done with the fragments + Registration fragmentedRegistration = new Registration(registration.sessionID); + fragmentedRegistration.payload = fragments[allButLast]; + + // tell the server we are fragmented + fragmentedRegistration.upgrade = true; + + // tell the server we are upgraded (it will bounce back telling us to connect) + fragmentedRegistration.upgraded = true; + channel.writeAndFlush(fragmentedRegistration); + } else { + registration.payload = details; + + // tell the server we are upgraded (it will bounce back telling us to connect) + registration.upgraded = true; + channel.writeAndFlush(registration); + } + + return true; + } + + public + STATE verifyClassRegistration(final MetaChannel metaChannel, final Registration registration) { + if (registration.upgrade) { + byte[] fragment = registration.payload; + + // this means that the registrations are FRAGMENTED! + // max size of ALL fragments is 480 * 127 + if (metaChannel.fragmentedRegistrationDetails == null) { + metaChannel.remainingFragments = fragment[1]; + metaChannel.fragmentedRegistrationDetails = new byte[480 * fragment[1]]; + } + + System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * 480, fragment.length - 2); + metaChannel.remainingFragments--; + + + if (fragment[0] + 1 == fragment[1]) { + // this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive) + int correctSize = (480 * (fragment[1] - 1)) + (fragment.length - 2); + byte[] correctlySized = new byte[correctSize]; + System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize); + metaChannel.fragmentedRegistrationDetails = correctlySized; + } + + if (metaChannel.remainingFragments == 0) { + // there are no more fragments available + byte[] details = metaChannel.fragmentedRegistrationDetails; + metaChannel.fragmentedRegistrationDetails = null; + + if (!this.endPoint.getSerialization().verifyKryoRegistration(details)) { + // error + return STATE.ERROR; + } + } else { + // wait for more fragments + return STATE.WAIT; + } + } + else { + if (!this.endPoint.getSerialization().verifyKryoRegistration(registration.payload)) { + return STATE.ERROR; + } + } + + return STATE.CONTINUE; + } + + /** + * Split array into chunks, max of 256 chunks. + * byte[0] = chunk ID + * byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte) + */ + private static + byte[][] divideArray(byte[] source, int chunksize) { + + int fragments = (int) Math.ceil(source.length / ((double) chunksize + 2)); + if (fragments > 127) { + // cannot allow more than 127 + return null; + } + + // pre-allocate the memory + byte[][] splitArray = new byte[fragments][chunksize + 2]; + int start = 0; + + for (int i = 0; i < splitArray.length; i++) { + int length; + + if (start + chunksize > source.length) { + length = source.length - start; + } + else { + length = chunksize; + } + splitArray[i] = new byte[length+2]; + splitArray[i][0] = (byte) i; + splitArray[i][1] = (byte) fragments; + System.arraycopy(source, start, splitArray[i], 2, length); + + start += chunksize; + } + + return splitArray; + } } diff --git a/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerClient.java b/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerClient.java index 6a31a154..9462a6ad 100644 --- a/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerClient.java +++ b/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerClient.java @@ -47,12 +47,22 @@ class RegistrationLocalHandlerClient extends RegistrationLocalHandler { channel.remoteAddress()); // client starts the registration process - channel.writeAndFlush(new Registration(0)); + Registration registration = new Registration(0); + + // ALSO make sure to verify registration details + + // we don't verify anything on the CLIENT. We only verify on the server. + // we don't support registering NEW classes after the client starts. + if (!registrationWrapper.initClassRegistration(channel, registration)) { + // abort if something messed up! + shutdown(channel, registration.sessionID); + } } @Override public void channelRead(ChannelHandlerContext context, Object message) throws Exception { + // the "server" bounces back the registration message when it's valid. ReferenceCountUtil.release(message); Channel channel = context.channel(); diff --git a/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerServer.java b/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerServer.java index 0779cb49..8ee213e1 100644 --- a/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerServer.java +++ b/src/dorkbox/network/connection/registration/local/RegistrationLocalHandlerServer.java @@ -17,7 +17,9 @@ package dorkbox.network.connection.registration.local; import dorkbox.network.connection.ConnectionImpl; import dorkbox.network.connection.RegistrationWrapper; +import dorkbox.network.connection.RegistrationWrapper.STATE; import dorkbox.network.connection.registration.MetaChannel; +import dorkbox.network.connection.registration.Registration; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -26,6 +28,7 @@ import io.netty.util.ReferenceCountUtil; public class RegistrationLocalHandlerServer extends RegistrationLocalHandler { + public RegistrationLocalHandlerServer(String name, RegistrationWrapper registrationWrapper) { super(name, registrationWrapper); @@ -53,25 +56,59 @@ class RegistrationLocalHandlerServer extends RegistrationLocalHandler { Channel channel = context.channel(); ChannelPipeline pipeline = channel.pipeline(); + if (!(message instanceof Registration)) { + logger.error("Expected registration message was [{}] instead!", message.getClass()); + shutdown(channel, 0); + ReferenceCountUtil.release(message); + return; + } + + MetaChannel metaChannel = channel.attr(META_CHANNEL).get(); + + + if (metaChannel == null) { + logger.error("Server MetaChannel was null. It shouldn't be."); + shutdown(channel, 0); + ReferenceCountUtil.release(message); + return; + } + + Registration registration = (Registration) message; + + // verify the class ID registration details. + // the client will send their class registration data. VERIFY IT IS CORRECT! + STATE state = registrationWrapper.verifyClassRegistration(metaChannel, registration); + if (state == STATE.ERROR) { + // abort! There was an error + shutdown(channel, 0); + return; + } + else if (state == STATE.WAIT) { + return; + } + // else, continue. + + + // have to remove the pipeline FIRST, since if we don't, and we expect to receive a message --- when we REMOVE "this" from the pipeline, // we will ALSO REMOVE all it's messages, which we want to receive! pipeline.remove(this); - channel.writeAndFlush(message); + registration.payload = null; - ReferenceCountUtil.release(message); + // we no longer need the meta channel, so remove it + channel.attr(META_CHANNEL).set(null); + channel.writeAndFlush(registration); + + ReferenceCountUtil.release(registration); logger.trace("Sent registration"); - MetaChannel metaChannel = channel.attr(META_CHANNEL) - .getAndSet(null); - if (metaChannel != null) { - ConnectionImpl connection = registrationWrapper.connection0(metaChannel, null); + ConnectionImpl connection = registrationWrapper.connection0(metaChannel, null); - if (connection != null) { - // have to setup connection handler - pipeline.addLast(CONNECTION_HANDLER, connection); - registrationWrapper.connectionConnected0(connection); - } + if (connection != null) { + // have to setup connection handler + pipeline.addLast(CONNECTION_HANDLER, connection); + registrationWrapper.connectionConnected0(connection); } } } diff --git a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java index 4bd567b4..6ed73870 100644 --- a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java +++ b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java @@ -213,52 +213,9 @@ class RegistrationRemoteHandlerClient extends RegistrationRemoteHandler { // we don't verify anything on the CLIENT. We only verify on the server. // we don't support registering NEW classes after the client starts. - byte[] details = registrationWrapper.getKryoRegistrationDetails(); - - int length = details.length; - if (length > 480) { - // it is too large to send in a single packet - - // child arrays have index 0 also as their 'index' and 1 is the total number of fragments - byte[][] fragments = divideArray(details, 480); - if (fragments == null) { - logger.error("Too many classes have been registered for Serialization. Please report this issue"); - // abort if something messed up! - shutdown(channel, registration.sessionID); - return; - } - - int allButLast = fragments.length - 1; - - for (int i = 0; i < allButLast; i++) { - final byte[] fragment = fragments[i]; - Registration fragmentedRegistration = new Registration(registration.sessionID); - fragmentedRegistration.payload = fragment; - - // tell the server we are fragmented - fragmentedRegistration.upgrade = true; - - // tell the server we are upgraded (it will bounce back telling us to connect) - fragmentedRegistration.upgraded = true; - channel.write(fragmentedRegistration); - } - - // now tell the server we are done with the fragments - Registration fragmentedRegistration = new Registration(registration.sessionID); - fragmentedRegistration.payload = fragments[allButLast]; - - // tell the server we are fragmented - fragmentedRegistration.upgrade = true; - - // tell the server we are upgraded (it will bounce back telling us to connect) - fragmentedRegistration.upgraded = true; - channel.writeAndFlush(fragmentedRegistration); - } else { - registration.payload = details; - - // tell the server we are upgraded (it will bounce back telling us to connect) - registration.upgraded = true; - channel.writeAndFlush(registration); + if (!registrationWrapper.initClassRegistration(channel, registration)) { + // abort if something messed up! + shutdown(channel, registration.sessionID); } return; @@ -323,42 +280,4 @@ class RegistrationRemoteHandlerClient extends RegistrationRemoteHandler { } }); } - - /** - * Split array into chunks, max of 256 chunks. - * byte[0] = chunk ID - * byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte) - */ - private static - byte[][] divideArray(byte[] source, int chunksize) { - - int fragments = (int) Math.ceil(source.length / ((double) chunksize + 2)); - if (fragments > 127) { - // cannot allow more than 127 - return null; - } - - // pre-allocate the memory - byte[][] splitArray = new byte[fragments][chunksize + 2]; - int start = 0; - - for (int i = 0; i < splitArray.length; i++) { - int length; - - if (start + chunksize > source.length) { - length = source.length - start; - } - else { - length = chunksize; - } - splitArray[i] = new byte[length+2]; - splitArray[i][0] = (byte) i; - splitArray[i][1] = (byte) fragments; - System.arraycopy(source, start, splitArray[i], 2, length); - - start += chunksize; - } - - return splitArray; - } } diff --git a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java index f292039b..4fd8a2b3 100644 --- a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java +++ b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java @@ -34,6 +34,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import dorkbox.network.connection.RegistrationWrapper; +import dorkbox.network.connection.RegistrationWrapper.STATE; import dorkbox.network.connection.registration.MetaChannel; import dorkbox.network.connection.registration.Registration; import dorkbox.util.crypto.CryptoECC; @@ -203,49 +204,18 @@ class RegistrationRemoteHandlerServer extends RegistrationRemoteHandler { return; } - // the client will send their class registration data - if (registration.upgrade) { - byte[] fragment = registration.payload; - - // this means that the registrations are FRAGMENTED! - // max size of ALL fragments is 480 * 127 - if (metaChannel.fragmentedRegistrationDetails == null) { - metaChannel.remainingFragments = fragment[1]; - metaChannel.fragmentedRegistrationDetails = new byte[480 * fragment[1]]; - } - - System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * 480, fragment.length - 2); - metaChannel.remainingFragments--; - - - if (fragment[0] + 1 == fragment[1]) { - // this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive) - int correctSize = (480 * (fragment[1] - 1)) + (fragment.length - 2); - byte[] correctlySized = new byte[correctSize]; - System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize); - metaChannel.fragmentedRegistrationDetails = correctlySized; - } - - if (metaChannel.remainingFragments == 0) { - // there are no more fragments available - byte[] details = metaChannel.fragmentedRegistrationDetails; - metaChannel.fragmentedRegistrationDetails = null; - - if (!registrationWrapper.verifyKryoRegistration(details)) { - shutdown(channel, registration.sessionID); - return; - } - } else { - // wait for more fragments - return; - } + // the client will send their class registration data. VERIFY IT IS CORRECT! + STATE state = registrationWrapper.verifyClassRegistration(metaChannel, registration); + if (state == STATE.ERROR) { + // abort! There was an error + shutdown(channel, registration.sessionID); + return; } - else { - if (!registrationWrapper.verifyKryoRegistration(registration.payload)) { - shutdown(channel, registration.sessionID); - return; - } + else if (state == STATE.WAIT) { + return; } + // else, continue. + //