From da864b3263110e2171e543c63bbb140013a83212 Mon Sep 17 00:00:00 2001 From: nathan Date: Mon, 20 Jul 2015 23:30:46 +0200 Subject: [PATCH] Added "global" RMI objects, which are per endpoint, and are accessed via their ID. Odd=global, Even=local. Also added a unit test --- .../network/connection/Connection.java | 36 +- .../network/connection/ConnectionImpl.java | 127 ++++--- .../dorkbox/network/connection/EndPoint.java | 32 +- .../KryoCryptoSerializationManager.java | 10 +- .../connection/ObjectRegistrationLatch.java | 9 + .../network/rmi/RemoteObjectSerializer.java | 2 +- .../src/dorkbox/network/rmi/RmiBridge.java | 91 +++-- .../dorkbox/network/rmi/RmiRegistration.java | 9 + .../dorkbox/network/rmi/RmiGlobalTest.java | 331 ++++++++++++++++++ .../network/rmi/RmiSendObjectTest.java | 27 +- .../test/dorkbox/network/rmi/RmiTest.java | 44 ++- 11 files changed, 625 insertions(+), 93 deletions(-) create mode 100644 Dorkbox-Network/src/dorkbox/network/connection/ObjectRegistrationLatch.java create mode 100644 Dorkbox-Network/test/dorkbox/network/rmi/RmiGlobalTest.java diff --git a/Dorkbox-Network/src/dorkbox/network/connection/Connection.java b/Dorkbox-Network/src/dorkbox/network/connection/Connection.java index 0e4e2e95..170f205f 100644 --- a/Dorkbox-Network/src/dorkbox/network/connection/Connection.java +++ b/Dorkbox-Network/src/dorkbox/network/connection/Connection.java @@ -15,13 +15,14 @@ */ package dorkbox.network.connection; +import org.bouncycastle.crypto.params.ParametersWithIV; + import dorkbox.network.connection.bridge.ConnectionBridge; import dorkbox.network.connection.idle.IdleBridge; import dorkbox.network.connection.idle.IdleSender; import dorkbox.network.rmi.RemoteObject; import dorkbox.network.rmi.TimeoutException; import dorkbox.util.exceptions.NetException; -import org.bouncycastle.crypto.params.ParametersWithIV; @SuppressWarnings("unused") public @@ -112,8 +113,8 @@ interface Connection { * Returns a new proxy object implements the specified interface. Methods invoked on the proxy object will be * invoked remotely on the object with the specified ID in the ObjectSpace for the current connection. *

- * This will request a registration ID from the remote endpoint, and will block until the registration - * ID has been returned. + * This will request a registration ID from the remote endpoint, and will block until the object + * has been returned. *

* Methods that return a value will throw {@link TimeoutException} if the * response is not received with the @@ -131,6 +132,31 @@ interface Connection { * * @see RemoteObject */ - Iface createRemoteObject(final Class remoteImplementationInterface, - final Class remoteImplementationClass) throws NetException; + Iface createRemoteObject(final Class remoteImplementationClass) throws NetException; + + + /** + * Returns a new proxy object implements the specified interface. Methods invoked on the proxy object will be + * invoked remotely on the object with the specified ID in the ObjectSpace for the current connection. + *

+ * This will REUSE a registration ID from the remote endpoint, and will block until the object + * has been returned. + *

+ * Methods that return a value will throw {@link TimeoutException} if the + * response is not received with the + * {@link RemoteObject#setResponseTimeout(int) response timeout}. + *

+ * If {@link RemoteObject#setNonBlocking(boolean) non-blocking} is false + * (the default), then methods that return a value must not be called from + * the update thread for the connection. An exception will be thrown if this + * occurs. Methods with a void return value can be called on the update + * thread. + *

+ * If a proxy returned from this method is part of an object graph sent over + * the network, the object graph on the receiving side will have the proxy + * object replaced with the registered (non-proxy) object. + * + * @see RemoteObject + */ + Iface getRemoteObject(final int objectId) throws NetException; } diff --git a/Dorkbox-Network/src/dorkbox/network/connection/ConnectionImpl.java b/Dorkbox-Network/src/dorkbox/network/connection/ConnectionImpl.java index defebd0e..290ea028 100644 --- a/Dorkbox-Network/src/dorkbox/network/connection/ConnectionImpl.java +++ b/Dorkbox-Network/src/dorkbox/network/connection/ConnectionImpl.java @@ -25,6 +25,7 @@ import dorkbox.network.connection.ping.PingTuple; import dorkbox.network.connection.wrapper.ChannelNetworkWrapper; import dorkbox.network.connection.wrapper.ChannelNull; import dorkbox.network.connection.wrapper.ChannelWrapper; +import dorkbox.network.rmi.RemoteObject; import dorkbox.network.rmi.RemoteProxy; import dorkbox.network.rmi.RmiBridge; import dorkbox.network.rmi.RmiRegistration; @@ -50,10 +51,8 @@ import java.io.IOException; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.util.LinkedList; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; /** @@ -86,7 +85,10 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, private final EndPoint endPoint; - public final RmiBridge rmiBridge; + + private volatile ObjectRegistrationLatch objectRegistrationLatch; + private final Object remoteObjectLock = new Object(); + private final RmiBridge rmiBridge; /** @@ -779,33 +781,21 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, // RMI methods // - volatile RegistrationLatch registrationLatch; - - class RegistrationLatch { - final CountDownLatch latch = new CountDownLatch(1); - Object remoteObject; - boolean hasError = false; - } - - - private final AtomicInteger rmiObjectIdCounter = new AtomicInteger(0); - @SuppressWarnings({"UnnecessaryLocalVariable", "unchecked"}) @Override - public - Iface createRemoteObject(final Class remoteImplementationInterface, - final Class remoteImplementationClass) throws NetException { - + public final + Iface createRemoteObject(final Class remoteImplementationClass) throws NetException { // only one register can happen at a time - synchronized (rmiObjectIdCounter) { - registrationLatch = new RegistrationLatch(); + synchronized (remoteObjectLock) { + objectRegistrationLatch = new ObjectRegistrationLatch(); // since this synchronous, we want to wait for the response before we continue + // this means we are creating a NEW object on the server, bound access to only this connection TCP(new RmiRegistration(remoteImplementationClass.getName())).flush(); try { - if (!registrationLatch.latch.await(2, TimeUnit.SECONDS)) { + if (!objectRegistrationLatch.latch.await(2, TimeUnit.SECONDS)) { final String errorMessage = "Timed out getting registration ID for: " + remoteImplementationClass; logger.error(errorMessage); throw new NetException(errorMessage); @@ -817,7 +807,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, } // local var to prevent double hit on volatile field - final RegistrationLatch latch = registrationLatch; + final ObjectRegistrationLatch latch = objectRegistrationLatch; if (latch.hasError) { final String errorMessage = "Error getting registration ID for: " + remoteImplementationClass; logger.error(errorMessage); @@ -828,12 +818,49 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, } } + @SuppressWarnings({"UnnecessaryLocalVariable", "unchecked"}) + @Override + public final + Iface getRemoteObject(final int objectId) throws NetException { + // only one register can happen at a time + synchronized (remoteObjectLock) { + objectRegistrationLatch = new ObjectRegistrationLatch(); + + // since this synchronous, we want to wait for the response before we continue + // this means that we are ACCESSING a remote object on the server, the server checks GLOBAL, then LOCAL for this object + TCP(new RmiRegistration(objectId)).flush(); + + try { + if (!objectRegistrationLatch.latch.await(2, TimeUnit.SECONDS)) { + final String errorMessage = "Timed out getting registration for ID: " + objectId; + logger.error(errorMessage); + throw new NetException(errorMessage); + } + } catch (InterruptedException e) { + final String errorMessage = "Error getting registration for ID: " + objectId; + logger.error(errorMessage, e); + throw new NetException(errorMessage, e); + } + + // local var to prevent double hit on volatile field + final ObjectRegistrationLatch latch = objectRegistrationLatch; + if (latch.hasError) { + final String errorMessage = "Error getting registration for ID: " + objectId; + logger.error(errorMessage); + throw new NetException(errorMessage); + } + + return (Iface) latch.remoteObject; + } + } + void registerInternal(final ConnectionImpl connection, final RmiRegistration remoteRegistration) { final String implementationClassName = remoteRegistration.remoteImplementationClass; if (implementationClassName != null) { // THIS IS ON THE SERVER SIDE + // // create a new ID, and register the ID and new object (must create a new one) in the object maps Class implementationClass; @@ -848,7 +875,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, try { final Object remotePrimaryObject = implementationClass.newInstance(); - rmiBridge.register(rmiObjectIdCounter.getAndIncrement(), remotePrimaryObject); + rmiBridge.register(rmiBridge.nextObjectId(), remotePrimaryObject); LinkedList remoteClasses = new LinkedList(); remoteClasses.add(new ClassObject(implementationClass, remotePrimaryObject)); @@ -868,7 +895,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, field.setAccessible(prev); final Class type = field.getType(); - rmiBridge.register(rmiObjectIdCounter.getAndIncrement(), o); + rmiBridge.register(rmiBridge.nextObjectId(), o); remoteClasses.offerLast(new ClassObject(type, o)); } @@ -877,17 +904,29 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, } } -// connection.TCP(new RmiRegistration()).flush(); connection.TCP(new RmiRegistration(remotePrimaryObject)).flush(); } catch (Exception e) { logger.error("Error registering RMI class " + implementationClassName, e); connection.TCP(new RmiRegistration()).flush(); } - } else { + } + else if (remoteRegistration.remoteObjectId > RmiBridge.INVALID_RMI) { + // THIS IS ON THE SERVER SIDE + // + // Get a LOCAL rmi object, if none get a specific, GLOBAL rmi object (objects that are not bound to a single connection). + Object object = getRegisteredObject(remoteRegistration.remoteObjectId); + + if (object != null) { + connection.TCP(new RmiRegistration(object)).flush(); + } else { + connection.TCP(new RmiRegistration()).flush(); + } + } + else { // THIS IS ON THE CLIENT SIDE // the next two use a local var, so that there isn't a double hit for volatile access - final RegistrationLatch latch = this.registrationLatch; + final ObjectRegistrationLatch latch = this.objectRegistrationLatch; latch.hasError = remoteRegistration.hasError; if (!remoteRegistration.hasError) { @@ -895,24 +934,32 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements Connection, } // notify the original register that it may continue. We access the volatile field directly, so that it's members are updated - registrationLatch.latch.countDown(); + objectRegistrationLatch.latch.countDown(); } } + public + int getRegisteredId(final T object) { + // always check local before checking global, because less contention on the synchronization + int object1 = endPoint.globalRmiBridge.getRegisteredId(object); + if (object1 == Integer.MAX_VALUE) { + return rmiBridge.getRegisteredId(object); + } else { + return object1; + } + } + + public + RemoteObject getRemoteObject(final int objectID, final Class type) { + return RmiBridge.getRemoteObject(this, objectID, type); + } - /** - * Returns the object registered with the specified ID. - */ public Object getRegisteredObject(final int objectID) { - return rmiBridge.getRegisteredObject(objectID); - } - - /** - * Returns the ID registered for the specified object, or Integer.MAX_VALUE if not found. - */ - public - int getRegisteredId(final Object object) { - return rmiBridge.getRegisteredId(object); + if (RmiBridge.isGlobal(objectID)) { + return endPoint.globalRmiBridge.getRegisteredObject(objectID); + } else { + return rmiBridge.getRegisteredObject(objectID); + } } } diff --git a/Dorkbox-Network/src/dorkbox/network/connection/EndPoint.java b/Dorkbox-Network/src/dorkbox/network/connection/EndPoint.java index 83a15482..c2a7c7d5 100644 --- a/Dorkbox-Network/src/dorkbox/network/connection/EndPoint.java +++ b/Dorkbox-Network/src/dorkbox/network/connection/EndPoint.java @@ -27,15 +27,15 @@ import dorkbox.network.pipeline.KryoEncoderCrypto; import dorkbox.network.rmi.RmiBridge; import dorkbox.network.util.CryptoSerializationManager; import dorkbox.network.util.EndPointTool; -import dorkbox.util.entropy.Entropy; -import dorkbox.util.exceptions.InitializationException; -import dorkbox.util.exceptions.SecurityException; import dorkbox.network.util.store.NullSettingsStore; import dorkbox.network.util.store.SettingsStore; import dorkbox.util.Sys; import dorkbox.util.collections.IntMap; import dorkbox.util.collections.IntMap.Entries; import dorkbox.util.crypto.Crypto; +import dorkbox.util.entropy.Entropy; +import dorkbox.util.exceptions.InitializationException; +import dorkbox.util.exceptions.SecurityException; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.EventLoopGroup; @@ -137,7 +137,7 @@ class EndPoint { final ECPrivateKeyParameters privateKey; final ECPublicKeyParameters publicKey; final SecureRandom secureRandom; - + final RmiBridge globalRmiBridge; private final CountDownLatch blockUntilDone = new CountDownLatch(1); private final Executor rmiExecutor; private final boolean rmiEnabled; @@ -145,12 +145,16 @@ class EndPoint { private final List eventLoopGroups = new ArrayList(8); private final List shutdownChannelList = new ArrayList(); private final ConcurrentHashMap, EndPointTool> toolMap = new ConcurrentHashMap, EndPointTool>(); + // make sure that the endpoint is closed on JVM shutdown (if it's still open at that point in time) protected Thread shutdownHook; + protected AtomicBoolean stopCalled = new AtomicBoolean(false); protected AtomicBoolean isConnected = new AtomicBoolean(false); + SettingsStore propertyStore; boolean disableRemoteKeyValidation; + /** * in milliseconds. default is disabled! */ @@ -284,6 +288,10 @@ class EndPoint { if (this.rmiEnabled) { // these register the listener for registering a class implementation for RMI (internal use only) this.connectionManager.add(new RegisterRmiSystemListener()); + this.globalRmiBridge = new RmiBridge(logger, options.rmiExecutor, true); + } + else { + this.globalRmiBridge = null; } } @@ -426,7 +434,7 @@ class EndPoint { RmiBridge rmiBridge = null; if (metaChannel != null && rmiEnabled) { - rmiBridge = new RmiBridge(logger, rmiExecutor); + rmiBridge = new RmiBridge(logger, rmiExecutor, false); } // setup the extras needed by the network connection. @@ -455,7 +463,8 @@ class EndPoint { if (rmiBridge != null) { // notify our remote object space that it is able to receive method calls. - connection.listeners().add(rmiBridge.getListener()); + connection.listeners() + .add(rmiBridge.getListener()); } } else { @@ -812,4 +821,15 @@ class EndPoint { String getName() { return this.type.getSimpleName(); } + + /** + * Creates a "global" RMI object for use by multiple connections. + * @return the ID assigned to this RMI object + */ + public + int createGlobalObject(final T globalObject) { + int globalObjectId = globalRmiBridge.nextObjectId(); + globalRmiBridge.register(globalObjectId, globalObject); + return globalObjectId; + } } diff --git a/Dorkbox-Network/src/dorkbox/network/connection/KryoCryptoSerializationManager.java b/Dorkbox-Network/src/dorkbox/network/connection/KryoCryptoSerializationManager.java index 57540190..00959d2f 100644 --- a/Dorkbox-Network/src/dorkbox/network/connection/KryoCryptoSerializationManager.java +++ b/Dorkbox-Network/src/dorkbox/network/connection/KryoCryptoSerializationManager.java @@ -26,19 +26,12 @@ import com.esotericsoftware.kryo.util.MapReferenceResolver; import dorkbox.network.connection.ping.PingMessage; import dorkbox.network.rmi.*; import dorkbox.network.util.CryptoSerializationManager; -import dorkbox.util.serialization.ArraysAsListSerializer; -import dorkbox.util.serialization.FieldAnnotationAwareSerializer; -import dorkbox.util.serialization.IgnoreSerialization; -import dorkbox.util.serialization.UnmodifiableCollectionsSerializer; import dorkbox.util.crypto.Crypto; -import dorkbox.util.serialization.EccPrivateKeySerializer; -import dorkbox.util.serialization.EccPublicKeySerializer; -import dorkbox.util.serialization.IesParametersSerializer; -import dorkbox.util.serialization.IesWithCipherParametersSerializer; import dorkbox.util.exceptions.NetException; import dorkbox.util.objectPool.ObjectPool; import dorkbox.util.objectPool.ObjectPoolFactory; import dorkbox.util.objectPool.PoolableObject; +import dorkbox.util.serialization.*; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.handler.codec.compression.CompressionException; @@ -137,6 +130,7 @@ class KryoCryptoSerializationManager implements CryptoSerializationManager { serializationManager.register(StackTraceElement[].class); // extra serializers + //noinspection ArraysAsListWithZeroOrOneArgument serializationManager.register(Arrays.asList("").getClass(), new ArraysAsListSerializer()); UnmodifiableCollectionsSerializer.registerSerializers(serializationManager); diff --git a/Dorkbox-Network/src/dorkbox/network/connection/ObjectRegistrationLatch.java b/Dorkbox-Network/src/dorkbox/network/connection/ObjectRegistrationLatch.java new file mode 100644 index 00000000..e95a9307 --- /dev/null +++ b/Dorkbox-Network/src/dorkbox/network/connection/ObjectRegistrationLatch.java @@ -0,0 +1,9 @@ +package dorkbox.network.connection; + +import java.util.concurrent.CountDownLatch; + +class ObjectRegistrationLatch { + final CountDownLatch latch = new CountDownLatch(1); + Object remoteObject; + boolean hasError = false; +} diff --git a/Dorkbox-Network/src/dorkbox/network/rmi/RemoteObjectSerializer.java b/Dorkbox-Network/src/dorkbox/network/rmi/RemoteObjectSerializer.java index 83f3aafa..a15f4266 100644 --- a/Dorkbox-Network/src/dorkbox/network/rmi/RemoteObjectSerializer.java +++ b/Dorkbox-Network/src/dorkbox/network/rmi/RemoteObjectSerializer.java @@ -75,6 +75,6 @@ class RemoteObjectSerializer extends Serializer { KryoExtra kryoExtra = (KryoExtra) kryo; int objectID = input.readInt(true); final ConnectionImpl connection = kryoExtra.connection; - return (T) connection.rmiBridge.getRemoteObject(connection, objectID, type); + return (T) connection.getRemoteObject(objectID, type); } } diff --git a/Dorkbox-Network/src/dorkbox/network/rmi/RmiBridge.java b/Dorkbox-Network/src/dorkbox/network/rmi/RmiBridge.java index 4a44efe6..8c4ea174 100644 --- a/Dorkbox-Network/src/dorkbox/network/rmi/RmiBridge.java +++ b/Dorkbox-Network/src/dorkbox/network/rmi/RmiBridge.java @@ -34,23 +34,27 @@ */ package dorkbox.network.rmi; -import com.esotericsoftware.kryo.util.IntMap; -import dorkbox.network.connection.Connection; -import dorkbox.network.connection.EndPoint; -import dorkbox.network.connection.ListenerRaw; -import dorkbox.util.exceptions.NetException; -import dorkbox.util.collections.ObjectIntMap; -import dorkbox.util.objectPool.ObjectPool; -import dorkbox.util.objectPool.ObjectPoolFactory; -import org.slf4j.Logger; - import java.lang.reflect.Proxy; import java.util.Arrays; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock.ReadLock; import java.util.concurrent.locks.ReentrantReadWriteLock.WriteLock; +import org.slf4j.Logger; + +import com.esotericsoftware.kryo.util.IntMap; + +import dorkbox.network.connection.Connection; +import dorkbox.network.connection.ConnectionImpl; +import dorkbox.network.connection.EndPoint; +import dorkbox.network.connection.ListenerRaw; +import dorkbox.util.collections.ObjectIntMap; +import dorkbox.util.exceptions.NetException; +import dorkbox.util.objectPool.ObjectPool; +import dorkbox.util.objectPool.ObjectPoolFactory; + /** * Allows methods on objects to be invoked remotely over TCP, UDP, or UDT. Objects are * {@link dorkbox.network.util.RMISerializationManager#registerRemote(Class, Class)}, and endpoint connections @@ -70,10 +74,26 @@ class RmiBridge { static final int returnExceptionMask = 1 << 6; static final int responseIdMask = 0xFF & ~returnValueMask & ~returnExceptionMask; + // global RMI objects -> ODD in range 1-16380 (max 2 bytes) throws error on outside of range + // connection local RMI -> EVEN in range 1-16380 (max 2 bytes) throws error on outside of range + private static final int MAX_RMI_VALUE = 16380; + public static final int INVALID_RMI = 0; + + /** + * @return true if the objectId is a "global" id (it's odd) otherwise, false (it's connection local) + */ + public static + boolean isGlobal(final int objectId) { + return (objectId & 1) != 0; + } + // the name of who created this RmiBridge private final org.slf4j.Logger logger; + // we start at 1, because 0 (INVALID_RMI) means we access connection only objects + private final AtomicInteger rmiObjectIdCounter; + // can be accessed by DIFFERENT threads. private final ReentrantReadWriteLock objectLock = new ReentrantReadWriteLock(); private final IntMap idToObject = new IntMap(); @@ -82,24 +102,21 @@ class RmiBridge { private final Executor executor; // 4096 concurrent method invocations max - private final ObjectPool invokeMethodPool = ObjectPoolFactory.create(new InvokeMethodPoolable(), 4096); + private static final ObjectPool invokeMethodPool = ObjectPoolFactory.create(new InvokeMethodPoolable(), 4096); - private final ListenerRaw invokeListener = new ListenerRaw() { + private final ListenerRaw invokeListener = new ListenerRaw() { @Override public - void received(final Connection connection, final InvokeMethod invokeMethod) { - ReadLock readLock = RmiBridge.this.objectLock.readLock(); - readLock.lock(); - - final Object target = RmiBridge.this.idToObject.get(invokeMethod.objectID); - - readLock.unlock(); + void received(final ConnectionImpl connection, final InvokeMethod invokeMethod) { + int objectID = invokeMethod.objectID; + // have to make sure to get the correct object (global vs local) + final Object target = connection.getRegisteredObject(objectID); if (target == null) { Logger logger2 = RmiBridge.this.logger; if (logger2.isWarnEnabled()) { - logger2.warn("Ignoring remote invocation request for unknown object ID: {}", invokeMethod.objectID); + logger2.warn("Ignoring remote invocation request for unknown object ID: {}", objectID); } return; @@ -129,11 +146,19 @@ class RmiBridge { * @param executor Sets the executor used to invoke methods when an invocation is received * from a remote endpoint. By default, no executor is set and invocations * occur on the network thread, which should not be blocked for long, May be null. + * @param isGlobal specify if this RmiBridge is a "global" bridge, meaning connections will prefer + * objects from this bridge instead of the connection-local bridge. */ public - RmiBridge(final org.slf4j.Logger logger, final Executor executor) { + RmiBridge(final org.slf4j.Logger logger, final Executor executor, final boolean isGlobal) { this.logger = logger; this.executor = executor; + + if (isGlobal) { + rmiObjectIdCounter = new AtomicInteger(1); + } else { + rmiObjectIdCounter = new AtomicInteger(2); + } } /** @@ -237,6 +262,20 @@ class RmiBridge { // logger.error("{} sent data: {} with id ({})", connection, result, invokeMethod.responseID); } + public + int nextObjectId() { + // always increment by 2 + // global RMI objects -> ODD in range 1-16380 (max 2 bytes) throws error on outside of range + // connection local RMI -> EVEN in range 1-16380 (max 2 bytes) throws error on outside of range + int value = rmiObjectIdCounter.getAndAdd(2); + if (value > MAX_RMI_VALUE) { + rmiObjectIdCounter.set(MAX_RMI_VALUE); // prevent wrapping by spammy callers + throw new NetException("RMI next value has exceeded maximum limits."); + } + return value; + } + + /** * Registers an object to allow the remote end of the RmiBridge connections to access it using the specified ID. * @@ -336,7 +375,7 @@ class RmiBridge { * * @see RemoteObject */ - public + public static RemoteObject getRemoteObject(Connection connection, int objectID, Class iface) { if (connection == null) { throw new IllegalArgumentException("connection cannot be null."); @@ -351,11 +390,9 @@ class RmiBridge { return (RemoteObject) Proxy.newProxyInstance(RmiBridge.class.getClassLoader(), temp, - new RemoteInvocationHandler(this.invokeMethodPool, connection, objectID)); + new RemoteInvocationHandler(invokeMethodPool, connection, objectID)); } - - /** * Returns the object registered with the specified ID. */ @@ -375,7 +412,7 @@ class RmiBridge { * Returns the ID registered for the specified object, or Integer.MAX_VALUE if not found. */ public - int getRegisteredId(final Object object) { + int getRegisteredId(final T object) { // Find an ID with the object. ReadLock readLock = this.objectLock.readLock(); @@ -385,4 +422,6 @@ class RmiBridge { return id; } + + } diff --git a/Dorkbox-Network/src/dorkbox/network/rmi/RmiRegistration.java b/Dorkbox-Network/src/dorkbox/network/rmi/RmiRegistration.java index da868cb3..8b0c7aec 100644 --- a/Dorkbox-Network/src/dorkbox/network/rmi/RmiRegistration.java +++ b/Dorkbox-Network/src/dorkbox/network/rmi/RmiRegistration.java @@ -24,6 +24,9 @@ class RmiRegistration { public String remoteImplementationClass; public boolean hasError; + // this is used to get specific, GLOBAL rmi objects (objects that are not bound to a single connection) + public int remoteObjectId; + public RmiRegistration() { hasError = true; @@ -40,4 +43,10 @@ class RmiRegistration { this.remoteObject = remoteObject; hasError = false; } + + public + RmiRegistration(final int remoteObjectId) { + this.remoteObjectId = remoteObjectId; + hasError = false; + } } diff --git a/Dorkbox-Network/test/dorkbox/network/rmi/RmiGlobalTest.java b/Dorkbox-Network/test/dorkbox/network/rmi/RmiGlobalTest.java new file mode 100644 index 00000000..05975a24 --- /dev/null +++ b/Dorkbox-Network/test/dorkbox/network/rmi/RmiGlobalTest.java @@ -0,0 +1,331 @@ +package dorkbox.network.rmi; + + +import dorkbox.network.BaseTest; +import dorkbox.network.Client; +import dorkbox.network.Configuration; +import dorkbox.network.Server; +import dorkbox.network.connection.Connection; +import dorkbox.network.connection.ConnectionImpl; +import dorkbox.network.connection.KryoCryptoSerializationManager; +import dorkbox.network.connection.Listener; +import dorkbox.network.util.CryptoSerializationManager; +import dorkbox.util.exceptions.InitializationException; +import dorkbox.util.exceptions.SecurityException; +import org.junit.Test; + +import java.io.IOException; +import java.io.Serializable; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public +class RmiGlobalTest extends BaseTest { + + private int CLIENT_GLOBAL_OBJECT_ID = 0; + private int SERVER_GLOBAL_OBJECT_ID = 0; + + private TestObject globalRemoteServerObject = new TestObjectImpl(); + private TestObject globalRemoteClientObject = new TestObjectImpl(); + + private static + void runTest(final Connection connection, final Object remoteObject, final int remoteObjectID) { + new Thread() { + @Override + public + void run() { + TestObject test = connection.getRemoteObject(remoteObjectID); + + System.err.println("Starting test for: " + remoteObjectID); + + //TestObject test = connection.getRemoteObject(id, TestObject.class); + assertEquals(remoteObject.hashCode(), test.hashCode()); + RemoteObject remoteObject = (RemoteObject) test; + + // Default behavior. RMI is transparent, method calls behave like normal + // (return values and exceptions are returned, call is synchronous) + System.err.println("hashCode: " + test.hashCode()); + System.err.println("toString: " + test); + test.moo(); + test.moo("Cow"); + assertEquals(remoteObjectID, test.id()); + + + // UDP calls that ignore the return value + remoteObject.setUDP(true); + test.moo("Meow"); + assertEquals(0, test.id()); + remoteObject.setUDP(false); + + + // Test that RMI correctly waits for the remotely invoked method to exit + remoteObject.setResponseTimeout(5000); + test.moo("You should see this two seconds before...", 2000); + System.out.println("...This"); + remoteObject.setResponseTimeout(3000); + + // Try exception handling + boolean caught = false; + try { + test.throwException(); + } catch (UnsupportedOperationException ex) { + System.err.println("\tExpected."); + caught = true; + } + assertTrue(caught); + + // Return values are ignored, but exceptions are still dealt with properly + + remoteObject.setTransmitReturnValue(false); + test.moo("Baa"); + test.id(); + caught = false; + try { + test.throwException(); + } catch (UnsupportedOperationException ex) { + caught = true; + } + assertTrue(caught); + + // Non-blocking call that ignores the return value + remoteObject.setNonBlocking(true); + remoteObject.setTransmitReturnValue(false); + test.moo("Meow"); + assertEquals(0, test.id()); + + // Non-blocking call that returns the return value + remoteObject.setTransmitReturnValue(true); + test.moo("Foo"); + + assertEquals(0, test.id()); + // wait for the response to id() + assertEquals(remoteObjectID, remoteObject.waitForLastResponse()); + + assertEquals(0, test.id()); + byte responseID = remoteObject.getLastResponseID(); + // wait for the response to id() + assertEquals(remoteObjectID, remoteObject.waitForResponse(responseID)); + + // Non-blocking call that errors out + remoteObject.setTransmitReturnValue(false); + test.throwException(); + assertEquals(remoteObject.waitForLastResponse() + .getClass(), UnsupportedOperationException.class); + + // Call will time out if non-blocking isn't working properly + remoteObject.setTransmitExceptions(false); + test.moo("Mooooooooo", 3000); + + + // should wait for a small time + remoteObject.setTransmitReturnValue(true); + remoteObject.setNonBlocking(false); + remoteObject.setResponseTimeout(6000); + System.out.println("You should see this 2 seconds before"); + float slow = test.slow(); + System.out.println("...This"); + assertEquals(123f, slow, .0001D); + + + // Test sending a reference to a remote object. + MessageWithTestObject m = new MessageWithTestObject(); + m.number = 678; + m.text = "sometext"; + m.testObject = test; + connection.send() + .TCP(m) + .flush(); + } + }.start(); + } + + + + public static + void register(CryptoSerializationManager kryoMT) { + kryoMT.register(Object.class); // Needed for Object#toString, hashCode, etc. + + kryoMT.registerRemote(TestObject.class, TestObjectImpl.class); + kryoMT.register(MessageWithTestObject.class); + + kryoMT.register(UnsupportedOperationException.class); + } + + @Test + public + void rmi() throws InitializationException, SecurityException, IOException { + KryoCryptoSerializationManager.DEFAULT = KryoCryptoSerializationManager.DEFAULT(); + register(KryoCryptoSerializationManager.DEFAULT); + + + Configuration configuration = new Configuration(); + configuration.tcpPort = tcpPort; + configuration.udpPort = udpPort; + configuration.host = host; + configuration.rmiEnabled = true; + + final Server server = new Server(configuration); + server.disableRemoteKeyValidation(); + server.setIdleTimeout(0); + + register(server.getSerialization()); + + // register this object as a global object that the client will get + SERVER_GLOBAL_OBJECT_ID = server.createGlobalObject(globalRemoteServerObject); + + addEndPoint(server); + server.bind(false); + + server.listeners() + .add(new Listener() { + @Override + public + void connected(final Connection connection) { + RmiGlobalTest.runTest(connection, globalRemoteClientObject, CLIENT_GLOBAL_OBJECT_ID); + } + + @Override + public + void received(Connection connection, MessageWithTestObject m) { + TestObject object = m.testObject; + final int id = object.id(); + assertEquals(1, id); + System.err.println("Client/Server Finished!"); + + stopEndPoints(2000); + } + + }); + + + // ---- + + final Client client = new Client(configuration); + client.setIdleTimeout(0); + client.disableRemoteKeyValidation(); + + // register this object as a global object that the server will get + CLIENT_GLOBAL_OBJECT_ID = client.createGlobalObject(globalRemoteClientObject); + + addEndPoint(client); + + client.listeners() + .add(new Listener() { + @Override + public + void received(Connection connection, MessageWithTestObject m) { + TestObject object = m.testObject; + final int id = object.id(); + assertEquals(1, id); + System.err.println("Server/Client Finished!"); + + // normally this is in the 'connected', but we do it here, so that it's more linear and easier to debug + runTest(connection, globalRemoteServerObject, SERVER_GLOBAL_OBJECT_ID); + } + }); + + client.connect(5000); + waitForThreads(); + } + + public + interface TestObject extends Serializable { + void throwException(); + + void moo(); + + void moo(String value); + + void moo(String value, long delay); + + int id(); + + float slow(); + } + + + public static class ConnectionAware { + private + ConnectionImpl connection; + + public + ConnectionImpl getConnection() { + return connection; + } + + public + void setConnection(final ConnectionImpl connection) { + this.connection = connection; + } + } + + public static + class TestObjectImpl extends ConnectionAware implements TestObject { + public long value = System.currentTimeMillis(); + public int moos; + private final int id = 1; + + public + TestObjectImpl() { + } + + @Override + public + void throwException() { + throw new UnsupportedOperationException("Why would I do that?"); + } + + @Override + public + void moo() { + this.moos++; + System.out.println("Moo!"); + } + + @Override + public + void moo(String value) { + this.moos += 2; + System.out.println("Moo: " + value); + } + + @Override + public + void moo(String value, long delay) { + this.moos += 4; + System.out.println("Moo: " + value + " (" + delay + ")"); + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + @Override + public + int id() { + return id; + } + + @Override + public + float slow() { + System.out.println("Slowdown!!"); + try { + Thread.sleep(2000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + return 123f; + } + } + + + public static + class MessageWithTestObject implements RmiMessages { + public int number; + public String text; + public TestObject testObject; + } +} diff --git a/Dorkbox-Network/test/dorkbox/network/rmi/RmiSendObjectTest.java b/Dorkbox-Network/test/dorkbox/network/rmi/RmiSendObjectTest.java index 7167bd15..fa98a5ee 100644 --- a/Dorkbox-Network/test/dorkbox/network/rmi/RmiSendObjectTest.java +++ b/Dorkbox-Network/test/dorkbox/network/rmi/RmiSendObjectTest.java @@ -9,9 +9,11 @@ import dorkbox.network.connection.KryoCryptoSerializationManager; import dorkbox.network.connection.Listener; import dorkbox.util.exceptions.InitializationException; import dorkbox.util.exceptions.SecurityException; +import dorkbox.util.serialization.IgnoreSerialization; import org.junit.Test; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -75,7 +77,7 @@ class RmiSendObjectTest extends BaseTest { @Override public void run() { - TestObject test = connection.createRemoteObject(TestObject.class, TestObjectImpl.class); + TestObject test = connection.createRemoteObject(TestObjectImpl.class); test.setOther(43.21f); // Normal remote method call. assertEquals(43.21f, test.other(), .0001f); @@ -119,10 +121,16 @@ class RmiSendObjectTest extends BaseTest { } + private static final AtomicInteger idCounter = new AtomicInteger(); + + public static class TestObjectImpl implements TestObject { + @IgnoreSerialization + private final int ID = idCounter.getAndIncrement(); + @RemoteProxy - private OtherObject otherObject = new OtherObjectImpl(); + private final OtherObject otherObject = new OtherObjectImpl(); private float aFloat; @@ -143,11 +151,20 @@ class RmiSendObjectTest extends BaseTest { OtherObject getOtherObject() { return this.otherObject; } + + @Override + public + int hashCode() { + return ID; + } } public static class OtherObjectImpl implements OtherObject { + @IgnoreSerialization + private final int ID = idCounter.getAndIncrement(); + private float aFloat; @Override @@ -161,5 +178,11 @@ class RmiSendObjectTest extends BaseTest { float value() { return aFloat; } + + @Override + public + int hashCode() { + return ID; + } } } diff --git a/Dorkbox-Network/test/dorkbox/network/rmi/RmiTest.java b/Dorkbox-Network/test/dorkbox/network/rmi/RmiTest.java index dc121cf4..28b1a4ec 100644 --- a/Dorkbox-Network/test/dorkbox/network/rmi/RmiTest.java +++ b/Dorkbox-Network/test/dorkbox/network/rmi/RmiTest.java @@ -5,6 +5,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.io.Serializable; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; @@ -29,7 +30,7 @@ class RmiTest extends BaseTest { @Override public void run() { - TestObject test = connection.createRemoteObject(TestObject.class, TestObjectImpl.class); + TestObject test = connection.createRemoteObject(TestObjectImpl.class); System.err.println("Starting test for: " + remoteObjectID); @@ -110,6 +111,15 @@ class RmiTest extends BaseTest { remoteObject.setTransmitExceptions(false); test.moo("Mooooooooo", 3000); + // should wait for a small time + remoteObject.setTransmitReturnValue(true); + remoteObject.setNonBlocking(false); + remoteObject.setResponseTimeout(6000); + System.out.println("You should see this 2 seconds before"); + float slow = test.slow(); + System.out.println("...This"); + assertEquals(slow, 123, .0001D); + // Test sending a reference to a remote object. MessageWithTestObject m = new MessageWithTestObject(); m.number = 678; @@ -205,7 +215,7 @@ class RmiTest extends BaseTest { } public - interface TestObject { + interface TestObject extends Serializable { void throwException(); void moo(); @@ -274,11 +284,35 @@ class RmiTest extends BaseTest { @Override public float slow() { + System.out.println("Slowdown!!"); try { - Thread.sleep(300); - } catch (InterruptedException ignored) { + Thread.sleep(2000); + } catch (InterruptedException e) { + e.printStackTrace(); } - return 666; + return 123f; + } + + @Override + public + boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final TestObjectImpl that = (TestObjectImpl) o; + + return id == that.id; + + } + + @Override + public + int hashCode() { + return id; } }