diff --git a/src/dorkbox/network/connection/RegistrationWrapper.java b/src/dorkbox/network/connection/RegistrationWrapper.java index 398a9f74..01406b9a 100644 --- a/src/dorkbox/network/connection/RegistrationWrapper.java +++ b/src/dorkbox/network/connection/RegistrationWrapper.java @@ -207,10 +207,13 @@ class RegistrationWrapper { } public - void abortRegistrationIfClient() { - if (this.endPoint instanceof EndPointClient) { - ((EndPointClient) this.endPoint).abortRegistration(); - } + boolean verifyKryoRegistration(byte[] bytes) { + return this.endPoint.getSerialization().verifyKryoRegistration(bytes); + } + + public + byte[] getKryoRegistrationDetails() { + return this.endPoint.getSerialization().getKryoRegistrationDetails(); } public @@ -341,4 +344,6 @@ class RegistrationWrapper { channel.close(); } } + + } diff --git a/src/dorkbox/network/connection/registration/MetaChannel.java b/src/dorkbox/network/connection/registration/MetaChannel.java index 9263c63a..9188188b 100644 --- a/src/dorkbox/network/connection/registration/MetaChannel.java +++ b/src/dorkbox/network/connection/registration/MetaChannel.java @@ -48,11 +48,14 @@ class MetaChannel { public volatile byte[] aesKey; public volatile byte[] aesIV; - // indicates if the remote ECC key has changed for an IP address. If the client detects this, it will not connect. // If the server detects this, it has the option for additional security (two-factor auth, perhaps?) public volatile boolean changedRemoteKey = false; + public volatile byte remainingFragments; + public volatile byte[] fragmentedRegistrationDetails; + + public MetaChannel(final int sessionId) { this.sessionId = sessionId; diff --git a/src/dorkbox/network/connection/registration/Registration.java b/src/dorkbox/network/connection/registration/Registration.java index 600c494f..306fb9ca 100644 --- a/src/dorkbox/network/connection/registration/Registration.java +++ b/src/dorkbox/network/connection/registration/Registration.java @@ -36,6 +36,7 @@ class Registration { public boolean hasMore; // true when we are ready to setup the connection (hasMore will always be false if this is true). False when we are ready to connect + // ALSO used if there are fragmented frames for registration data (since we have to split it up to fit inside a single UDP packet without fragmentation) public boolean upgrade; // true when we are fully upgraded diff --git a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java index 6b7f0f1d..4bd567b4 100644 --- a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java +++ b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerClient.java @@ -211,9 +211,56 @@ class RegistrationRemoteHandlerClient extends RegistrationRemoteHandler { // setup the pipeline with the real connection upgradePipeline(metaChannel, remoteAddress); - // tell the server we are upgraded (it will bounce back telling us to connect) - registration.upgraded = true; - channel.writeAndFlush(registration); + // we don't verify anything on the CLIENT. We only verify on the server. + // we don't support registering NEW classes after the client starts. + byte[] details = registrationWrapper.getKryoRegistrationDetails(); + + int length = details.length; + if (length > 480) { + // it is too large to send in a single packet + + // child arrays have index 0 also as their 'index' and 1 is the total number of fragments + byte[][] fragments = divideArray(details, 480); + if (fragments == null) { + logger.error("Too many classes have been registered for Serialization. Please report this issue"); + // abort if something messed up! + shutdown(channel, registration.sessionID); + return; + } + + int allButLast = fragments.length - 1; + + for (int i = 0; i < allButLast; i++) { + final byte[] fragment = fragments[i]; + Registration fragmentedRegistration = new Registration(registration.sessionID); + fragmentedRegistration.payload = fragment; + + // tell the server we are fragmented + fragmentedRegistration.upgrade = true; + + // tell the server we are upgraded (it will bounce back telling us to connect) + fragmentedRegistration.upgraded = true; + channel.write(fragmentedRegistration); + } + + // now tell the server we are done with the fragments + Registration fragmentedRegistration = new Registration(registration.sessionID); + fragmentedRegistration.payload = fragments[allButLast]; + + // tell the server we are fragmented + fragmentedRegistration.upgrade = true; + + // tell the server we are upgraded (it will bounce back telling us to connect) + fragmentedRegistration.upgraded = true; + channel.writeAndFlush(fragmentedRegistration); + } else { + registration.payload = details; + + // tell the server we are upgraded (it will bounce back telling us to connect) + registration.upgraded = true; + channel.writeAndFlush(registration); + } + return; } @@ -229,6 +276,7 @@ class RegistrationRemoteHandlerClient extends RegistrationRemoteHandler { return; } + // remove the ConnectionWrapper (that was used to upgrade the connection) and cleanup the pipeline // always wait until AFTER the server calls "onConnect", then we do this cleanupPipeline(metaChannel, new Runnable() { @@ -275,4 +323,42 @@ class RegistrationRemoteHandlerClient extends RegistrationRemoteHandler { } }); } + + /** + * Split array into chunks, max of 256 chunks. + * byte[0] = chunk ID + * byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte) + */ + private static + byte[][] divideArray(byte[] source, int chunksize) { + + int fragments = (int) Math.ceil(source.length / ((double) chunksize + 2)); + if (fragments > 127) { + // cannot allow more than 127 + return null; + } + + // pre-allocate the memory + byte[][] splitArray = new byte[fragments][chunksize + 2]; + int start = 0; + + for (int i = 0; i < splitArray.length; i++) { + int length; + + if (start + chunksize > source.length) { + length = source.length - start; + } + else { + length = chunksize; + } + splitArray[i] = new byte[length+2]; + splitArray[i][0] = (byte) i; + splitArray[i][1] = (byte) fragments; + System.arraycopy(source, start, splitArray[i], 2, length); + + start += chunksize; + } + + return splitArray; + } } diff --git a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java index b002e97a..f292039b 100644 --- a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java +++ b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandlerServer.java @@ -175,12 +175,12 @@ class RegistrationRemoteHandlerServer extends RegistrationRemoteHandler { return; } - // ALWAYS upgrade the connection at this point. - // IN: upgraded=false if we haven't upgraded to encryption yet (this will always be the case right after encryption is setup) - // NOTE: if we have more registrations, we will "bounce back" that status so the client knows what to do. // IN: hasMore=true if we have more registrations to do, false otherwise + // ALWAYS upgrade the connection at this point. + // IN: upgraded=false if we haven't upgraded to encryption yet (this will always be the case right after encryption is setup) + if (!registration.upgraded) { // upgrade the connection to an encrypted connection registration.upgrade = true; @@ -203,6 +203,50 @@ class RegistrationRemoteHandlerServer extends RegistrationRemoteHandler { return; } + // the client will send their class registration data + if (registration.upgrade) { + byte[] fragment = registration.payload; + + // this means that the registrations are FRAGMENTED! + // max size of ALL fragments is 480 * 127 + if (metaChannel.fragmentedRegistrationDetails == null) { + metaChannel.remainingFragments = fragment[1]; + metaChannel.fragmentedRegistrationDetails = new byte[480 * fragment[1]]; + } + + System.arraycopy(fragment, 2, metaChannel.fragmentedRegistrationDetails, fragment[0] * 480, fragment.length - 2); + metaChannel.remainingFragments--; + + + if (fragment[0] + 1 == fragment[1]) { + // this is the last fragment in the in byte array (but NOT necessarily the last fragment to arrive) + int correctSize = (480 * (fragment[1] - 1)) + (fragment.length - 2); + byte[] correctlySized = new byte[correctSize]; + System.arraycopy(metaChannel.fragmentedRegistrationDetails, 0, correctlySized, 0, correctSize); + metaChannel.fragmentedRegistrationDetails = correctlySized; + } + + if (metaChannel.remainingFragments == 0) { + // there are no more fragments available + byte[] details = metaChannel.fragmentedRegistrationDetails; + metaChannel.fragmentedRegistrationDetails = null; + + if (!registrationWrapper.verifyKryoRegistration(details)) { + shutdown(channel, registration.sessionID); + return; + } + } else { + // wait for more fragments + return; + } + } + else { + if (!registrationWrapper.verifyKryoRegistration(registration.payload)) { + shutdown(channel, registration.sessionID); + return; + } + } + // // diff --git a/src/dorkbox/network/serialization/ClassRegistration.java b/src/dorkbox/network/serialization/ClassRegistration.java new file mode 100644 index 00000000..62ff6d54 --- /dev/null +++ b/src/dorkbox/network/serialization/ClassRegistration.java @@ -0,0 +1,44 @@ +/* + * Copyright 2019 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dorkbox.network.serialization; + +import org.slf4j.Logger; + +import dorkbox.network.connection.CryptoConnection; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.rmi.RemoteObjectSerializer; + +class ClassRegistration { + Class clazz; + int id; + + ClassRegistration(final Class clazz) { + this.clazz = clazz; + } + + void register(final KryoExtra kryo, final RemoteObjectSerializer remoteObjectSerializer) { + if (clazz.isInterface()) { + id = kryo.register(clazz, remoteObjectSerializer).getId(); + } + else { + id = kryo.register(clazz).getId(); + } + } + + void log(final Logger logger) { + logger.trace("Registered {} -> {}", id, clazz.getName()); + } +} diff --git a/src/dorkbox/network/serialization/ClassSerializer.java b/src/dorkbox/network/serialization/ClassSerializer.java new file mode 100644 index 00000000..1e828960 --- /dev/null +++ b/src/dorkbox/network/serialization/ClassSerializer.java @@ -0,0 +1,41 @@ +/* + * Copyright 2019 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dorkbox.network.serialization; + +import org.slf4j.Logger; + +import com.esotericsoftware.kryo.Serializer; + +import dorkbox.network.connection.CryptoConnection; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.rmi.RemoteObjectSerializer; + +class ClassSerializer extends ClassRegistration { + final Serializer serializer; + + ClassSerializer(final Class clazz, final Serializer serializer) { + super(clazz); + this.serializer = serializer; + } + + void register(final KryoExtra kryo, final RemoteObjectSerializer remoteObjectSerializer) { + id = kryo.register(clazz, serializer).getId(); + } + + void log(final Logger logger) { + logger.trace("Registered {} -> {} using {}", id, clazz.getName(), serializer.getClass().getName()); + } +} diff --git a/src/dorkbox/network/serialization/ClassSerializer1.java b/src/dorkbox/network/serialization/ClassSerializer1.java new file mode 100644 index 00000000..36939eb0 --- /dev/null +++ b/src/dorkbox/network/serialization/ClassSerializer1.java @@ -0,0 +1,38 @@ +/* + * Copyright 2019 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dorkbox.network.serialization; + +import org.slf4j.Logger; + +import dorkbox.network.connection.CryptoConnection; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.rmi.RemoteObjectSerializer; + +class ClassSerializer1 extends ClassRegistration { + + ClassSerializer1(final Class clazz, final int id) { + super(clazz); + this.id = id; + } + + void register(final KryoExtra kryo, final RemoteObjectSerializer remoteObjectSerializer) { + kryo.register(clazz, id); + } + + void log(final Logger logger) { + logger.trace("Registered {} -> (specified) {}", id, clazz.getName()); + } +} diff --git a/src/dorkbox/network/serialization/ClassSerializer2.java b/src/dorkbox/network/serialization/ClassSerializer2.java new file mode 100644 index 00000000..a2577003 --- /dev/null +++ b/src/dorkbox/network/serialization/ClassSerializer2.java @@ -0,0 +1,42 @@ +/* + * Copyright 2019 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dorkbox.network.serialization; + +import org.slf4j.Logger; + +import com.esotericsoftware.kryo.Serializer; + +import dorkbox.network.connection.CryptoConnection; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.rmi.RemoteObjectSerializer; + +class ClassSerializer2 extends ClassRegistration { + final Serializer serializer; + + ClassSerializer2(final Class clazz, final Serializer serializer, final int id) { + super(clazz); + this.serializer = serializer; + this.id = id; + } + + void register(final KryoExtra kryo, final RemoteObjectSerializer remoteObjectSerializer) { + kryo.register(clazz, serializer, id); + } + + void log(final Logger logger) { + logger.trace("Registered {} -> (specified) {} using {}", id, clazz.getName(), serializer.getClass().getName()); + } +} diff --git a/src/dorkbox/network/serialization/ClassSerializerRmi.java b/src/dorkbox/network/serialization/ClassSerializerRmi.java new file mode 100644 index 00000000..8795af15 --- /dev/null +++ b/src/dorkbox/network/serialization/ClassSerializerRmi.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dorkbox.network.serialization; + +import org.slf4j.Logger; + +import dorkbox.network.connection.CryptoConnection; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.rmi.RemoteObjectSerializer; + +class ClassSerializerRmi extends ClassRegistration { + final Class implClass; + + ClassSerializerRmi(final Class ifaceClass, final Class implClass) { + super(ifaceClass); + this.implClass = implClass; + } + + void register(final KryoExtra kryo, final RemoteObjectSerializer remoteObjectSerializer) { + id = kryo.register(clazz, remoteObjectSerializer).getId(); + } + + void log(final Logger logger) { + logger.trace("Registered {} -> (RMI) {}", id, implClass.getName()); + } +} diff --git a/src/dorkbox/network/serialization/RmiSerializationManager.java b/src/dorkbox/network/serialization/RmiSerializationManager.java index de5cb348..32700885 100644 --- a/src/dorkbox/network/serialization/RmiSerializationManager.java +++ b/src/dorkbox/network/serialization/RmiSerializationManager.java @@ -95,6 +95,17 @@ interface RmiSerializationManager extends SerializationManager { */ void returnKryo(KryoExtra kryo); + /** + * @return true if the remote kryo registration are the same as our own + */ + boolean verifyKryoRegistration(byte[] bytes); + + /** + * @return the details of all registration IDs -> Class name used by kryo + */ + byte[] getKryoRegistrationDetails(); + + /** * Gets the RMI implementation based on the specified interface * diff --git a/src/dorkbox/network/serialization/Serialization.java b/src/dorkbox/network/serialization/Serialization.java index 6e90d798..825e9196 100644 --- a/src/dorkbox/network/serialization/Serialization.java +++ b/src/dorkbox/network/serialization/Serialization.java @@ -28,7 +28,6 @@ import org.bouncycastle.crypto.params.IESParameters; import org.bouncycastle.crypto.params.IESWithCipherParameters; import org.slf4j.Logger; -import com.esotericsoftware.kryo.ClassResolver; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoException; import com.esotericsoftware.kryo.Serializer; @@ -46,6 +45,8 @@ import com.esotericsoftware.kryo.util.Util; import dorkbox.network.connection.CryptoConnection; import dorkbox.network.connection.KryoExtra; import dorkbox.network.connection.ping.PingMessage; +import dorkbox.network.pipeline.ByteBufInput; +import dorkbox.network.pipeline.ByteBufOutput; import dorkbox.network.rmi.CachedMethod; import dorkbox.network.rmi.InvocationHandlerSerializer; import dorkbox.network.rmi.InvocationResultSerializer; @@ -68,6 +69,7 @@ import dorkbox.util.serialization.UnmodifiableCollectionsSerializer; import io.netty.bootstrap.DatagramCloseMessage; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; /** * Threads reading/writing, it messes up a single instance. it is possible to use a single kryo with the use of synchronize, however - that @@ -78,7 +80,7 @@ import io.netty.buffer.ByteBufUtil; */ @SuppressWarnings({"unused", "StaticNonFinalField"}) public -class Serialization implements CryptoSerializationManager, RmiSerializationManager { +class Serialization implements CryptoSerializationManager { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Serialization.class.getSimpleName()); @@ -89,55 +91,6 @@ class Serialization implements CryptoSerializationMa @Property public static boolean useUnsafeMemory = false; - - private static - class ClassSerializer { - final Class clazz; - final Serializer serializer; - - ClassSerializer(final Class clazz, final Serializer serializer) { - this.clazz = clazz; - this.serializer = serializer; - } - } - - - private static - class ClassSerializer1 { - final Class clazz; - final int id; - - ClassSerializer1(final Class clazz, final int id) { - this.clazz = clazz; - this.id = id; - } - } - - - private static - class ClassSerializer2 { - final Class clazz; - final Serializer serializer; - final int id; - - ClassSerializer2(final Class clazz, final Serializer serializer, final int id) { - this.clazz = clazz; - this.serializer = serializer; - this.id = id; - } - } - - private static - class RmiClassSerializer { - private final Class ifaceClass; - private final Class implClass; - - RmiClassSerializer(final Class ifaceClass, final Class implClass) { - this.ifaceClass = ifaceClass; - this.implClass = implClass; - } - } - public static Serialization DEFAULT() { return DEFAULT(true, true, null); @@ -179,6 +132,7 @@ class Serialization implements CryptoSerializationMa serialization.register(IESWithCipherParameters.class, new IesWithCipherParametersSerializer()); serialization.register(ECPublicKeyParameters.class, new EccPublicKeySerializer()); serialization.register(ECPrivateKeyParameters.class, new EccPrivateKeySerializer()); + serialization.register(ClassRegistration.class); serialization.register(dorkbox.network.connection.registration.Registration.class); // must use full package name! // necessary for the transport of exceptions. @@ -188,8 +142,7 @@ class Serialization implements CryptoSerializationMa // extra serializers //noinspection ArraysAsListWithZeroOrOneArgument - serialization.register(Arrays.asList("") - .getClass(), new ArraysAsListSerializer()); + serialization.register(Arrays.asList("").getClass(), new ArraysAsListSerializer()); UnmodifiableCollectionsSerializer.registerSerializers(serialization); @@ -223,10 +176,14 @@ class Serialization implements CryptoSerializationMa private boolean initialized = false; private final ObjectPool> kryoPool; + private final boolean registrationRequired; + // used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class) // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // Object checking is performed during actual registration. - private final List classesToRegister = new ArrayList(); + private final List classesToRegister = new ArrayList(); + private ClassRegistration[] mergedRegistrations; + private byte[] savedRegistrationDetails; private boolean usesRmi = false; @@ -275,6 +232,8 @@ class Serialization implements CryptoSerializationMa final boolean registrationRequired, final SerializerFactory factory) { + this.registrationRequired = registrationRequired; + this.kryoPool = ObjectPool.NonBlockingSoftReference(new PoolableObject>() { @Override public @@ -304,38 +263,8 @@ class Serialization implements CryptoSerializationMa // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // additionally, if a class registered is an INTERFACE, then we register it as an RMI class. - for (Object clazz : classesToRegister) { - if (clazz instanceof Class) { - // CAN register for RMI in this case - Class aClass = (Class) clazz; - - if (aClass.isInterface()) { - kryo.register(aClass, remoteObjectSerializer); - } - else { - kryo.register(aClass); - } - } - else if (clazz instanceof ClassSerializer) { - // cannot register for RMI in this case - ClassSerializer classSerializer = (ClassSerializer) clazz; - kryo.register(classSerializer.clazz, classSerializer.serializer); - } - else if (clazz instanceof ClassSerializer1) { - // cannot register for RMI in this case - ClassSerializer1 classSerializer = (ClassSerializer1) clazz; - kryo.register(classSerializer.clazz, classSerializer.id); - } - else if (clazz instanceof ClassSerializer2) { - // cannot register for RMI in this case - ClassSerializer2 classSerializer = (ClassSerializer2) clazz; - kryo.register(classSerializer.clazz, classSerializer.serializer, classSerializer.id); - } - else if (clazz instanceof RmiClassSerializer) { - // CAN register for RMI in this case - RmiClassSerializer rmiClass = (RmiClassSerializer) clazz; - kryo.register(rmiClass.ifaceClass, remoteObjectSerializer); - } + for (ClassRegistration clazz : classesToRegister) { + clazz.register(kryo, remoteObjectSerializer); } if (factory != null) { @@ -395,7 +324,7 @@ class Serialization implements CryptoSerializationMa usesRmi = true; } - classesToRegister.add(clazz); + classesToRegister.add(new ClassRegistration(clazz)); } return this; @@ -510,7 +439,7 @@ class Serialization implements CryptoSerializationMa } usesRmi = true; - classesToRegister.add(new RmiClassSerializer(ifaceClass, implClass)); + classesToRegister.add(new ClassSerializerRmi(ifaceClass, implClass)); // rmiIfaceToImpl tells us, "the server" how to create a (requested) remote object // this MUST BE UNIQUE otherwise unexpected and BAD things can happen. @@ -566,79 +495,169 @@ class Serialization implements CryptoSerializationMa // initialize the kryo pool with at least 1 kryo instance. This ALSO makes sure that all of our class registration is done // correctly and (if not) we are are notified on the initial thread (instead of on the network update thread) - KryoExtra kryo = null; + KryoExtra kryo = kryoPool.take(); try { - kryo = kryoPool.take(); - ClassResolver classResolver = kryo.getClassResolver(); + // now MERGE all of the registrations (since we can have registrations overwrite newer/specific registrations - boolean traceEnabled = logger.isTraceEnabled(); + int size = classesToRegister.size(); + ArrayList mergedRegistrations = new ArrayList(); - // now initialize the RMI cached methods, so that they are "final" when the network threads need access to it. - for (Object clazz : classesToRegister) { - // LOG CLASS ID REGISTRATIONS (if trace logging is enabled) - if (traceEnabled) { - if (clazz instanceof Class) { - Class aClass = (Class) clazz; + for (ClassRegistration registration : classesToRegister) { + Class clazz = registration.clazz; + int id = registration.id; - int id = classResolver.getRegistration(aClass).getId(); - logger.trace("Registered {} -> {}", id, aClass.getName()); - } - else if (clazz instanceof ClassSerializer) { - ClassSerializer classSerializer = (ClassSerializer) clazz; - - int id = classResolver.getRegistration(classSerializer.clazz).getId(); - logger.trace("Registered {} -> {} using {}", id, classSerializer.clazz.getName(), classSerializer.serializer.getClass().getName()); - } - else if (clazz instanceof ClassSerializer1) { - ClassSerializer1 classSerializer = (ClassSerializer1) clazz; - - logger.trace("Registered {} -> (specified) {}", classSerializer.id, classSerializer.clazz); - } - else if (clazz instanceof ClassSerializer2) { - ClassSerializer2 classSerializer = (ClassSerializer2) clazz; - - logger.trace("Registered {} -> (specified) {} using {}", classSerializer.id, classSerializer.clazz.getName(), classSerializer.serializer.getClass().getName()); - } - else if (clazz instanceof RmiClassSerializer) { - RmiClassSerializer remoteImplClass = (RmiClassSerializer) clazz; - - int id = classResolver.getRegistration(remoteImplClass.ifaceClass).getId(); - logger.trace("Registered {} -> (RMI) {}", id, remoteImplClass.ifaceClass.getName()); + // if we ALREADY contain this registration (based ONLY on ID), then overwrite the existing one and REMOVE the current one + boolean found = false; + for (int index = 0; index < mergedRegistrations.size(); index++) { + final ClassRegistration existingRegistration = mergedRegistrations.get(index); + if (existingRegistration.id == id) { + mergedRegistrations.set(index, registration); + found = true; + break; } } - if (clazz instanceof Class) { - Class aClass = (Class) clazz; - - if (aClass.isInterface()) { - int id = classResolver.getRegistration(aClass).getId(); - - CachedMethod[] cachedMethods = RmiUtils.getCachedMethods(Serialization.logger, kryo, useAsm, - aClass, - null, - id); - methodCache.put(id, cachedMethods); - } - } - else if (clazz instanceof RmiClassSerializer) { - // this is done on the endpoint that will HOST the remote object. The other endpoint will access this object via RMI objects - RmiClassSerializer rmiClass = (RmiClassSerializer) clazz; - int id = classResolver.getRegistration(rmiClass.ifaceClass).getId(); - - CachedMethod[] cachedMethods = RmiUtils.getCachedMethods(Serialization.logger, kryo, useAsm, - rmiClass.ifaceClass, - rmiClass.implClass, - id); - methodCache.put(id, cachedMethods); + if (!found) { + mergedRegistrations.add(registration); } } + // now all of the registrations are IN ORDER and MERGED + this.mergedRegistrations = mergedRegistrations.toArray(new ClassRegistration[0]); + + Object[][] registrationDetails = new Object[mergedRegistrations.size()][2]; + + for (int i = 0; i < mergedRegistrations.size(); i++) { + final ClassRegistration registration = mergedRegistrations.get(i); + registration.log(logger); + + // now save all of the registration IDs for quick verification/access + registrationDetails[i] = new Object[] {registration.id, registration.clazz.getName()}; + + + + // now we have to manage caching methods (only as necessary) + if (registration.clazz.isInterface()) { + // can be a normal class or an RMI class... + Class implClass = null; + + if (registration instanceof ClassSerializerRmi) { + implClass = ((ClassSerializerRmi) registration).implClass; + } + + CachedMethod[] cachedMethods = RmiUtils.getCachedMethods(Serialization.logger, kryo, useAsm, registration.clazz, implClass, registration.id); + methodCache.put(registration.id, cachedMethods); + } + } + + + // save this as a byte array (so registration is faster) + ByteBuf buffer = Unpooled.buffer(); + ByteBufOutput writer = new ByteBufOutput(); + writer.setBuffer(buffer); + + kryo.setRegistrationRequired(false); + kryo.writeObject(writer, registrationDetails); + + savedRegistrationDetails = new byte[buffer.writerIndex()]; + buffer.getBytes(0, savedRegistrationDetails); + + buffer.release(); + } finally { + if (registrationRequired) { + kryo.setRegistrationRequired(true); + } + + kryoPool.put(kryo); + } + } + + /** + * @return true if kryo registration is required for all classes sent over the wire + */ + @Override + public + boolean verifyKryoRegistration(byte[] otherRegistrationData) { + // verify the registration IDs if necessary with our own. The CLIENT does not verify anything, only the server! + byte[] kryoRegistrationDetails = savedRegistrationDetails; + boolean equals = java.util.Arrays.equals(kryoRegistrationDetails, otherRegistrationData); + if (equals) { + return true; + } + + + // now we need to figure out WHAT was screwed up so we know what to fix + KryoExtra kryo = takeKryo(); + + ByteBuf byteBuf = Unpooled.wrappedBuffer(otherRegistrationData); + + try { + ByteBufInput reader = new ByteBufInput(); + reader.setBuffer(byteBuf); + + kryo.setRegistrationRequired(false); + Object[][] classRegistrations = kryo.readObject(reader, Object[][].class); + + + int lengthOrg = mergedRegistrations.length; + int lengthNew = classRegistrations.length; + int index = 0; + + // list all of the registrations that are mis-matched between the server/client + for (; index < lengthOrg; index++) { + final ClassRegistration classOrg = mergedRegistrations[index]; + + if (index >= lengthNew) { + logger.error("Missing client registration for {} -> {}", classOrg.id, classOrg.clazz.getName()); + } + else { + Object[] classNew = classRegistrations[index]; + int idNew = (Integer) classNew[0]; + String nameNew = (String) classNew[1]; + + int idOrg = classOrg.id; + String nameOrg = classOrg.clazz.getName(); + + if (idNew != idOrg || !nameOrg.equals(nameNew)) { + logger.error("Server registration : {} -> {}", idOrg, nameOrg); + logger.error("Client registration : {} -> {}", idNew, nameNew); + } + } + } + + // list all of the registrations that are missing on the server + if (index < lengthNew) { + for (; index < lengthNew; index++) { + Object[] holderClass = classRegistrations[index]; + int id = (Integer) holderClass[0]; + String name = (String) holderClass[1]; + + logger.error("Missing server registration : {} -> {}", id, name); + } + } + } catch(Exception e) { + logger.error("{} during registration validation", e.getMessage()); + } finally { - if (kryo != null) { - kryoPool.put(kryo); + if (registrationRequired) { + kryo.setRegistrationRequired(true); } + + returnKryo(kryo); + byteBuf.release(); } + + return false; + } + + /** + * @return the details of all registration IDs -> Class name used by kryo + */ + @Override + public + byte[] getKryoRegistrationDetails() { + return savedRegistrationDetails; } /**