diff --git a/src/dorkbox/network/connection/ConnectionImpl.java b/src/dorkbox/network/connection/ConnectionImpl.java index 061540e5..c0498cc7 100644 --- a/src/dorkbox/network/connection/ConnectionImpl.java +++ b/src/dorkbox/network/connection/ConnectionImpl.java @@ -16,20 +16,12 @@ package dorkbox.network.connection; import java.io.IOException; -import java.lang.reflect.Field; -import java.lang.reflect.Proxy; -import java.util.AbstractMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import org.bouncycastle.crypto.params.ParametersWithIV; -import org.slf4j.Logger; import dorkbox.network.Client; import dorkbox.network.connection.bridge.ConnectionBridge; @@ -42,21 +34,12 @@ import dorkbox.network.connection.ping.PingTuple; import dorkbox.network.connection.wrapper.ChannelNetworkWrapper; import dorkbox.network.connection.wrapper.ChannelNull; import dorkbox.network.connection.wrapper.ChannelWrapper; -import dorkbox.network.rmi.InvokeMethod; -import dorkbox.network.rmi.InvokeMethodResult; +import dorkbox.network.rmi.ConnectionRmiSupport; +import dorkbox.network.rmi.ConnectionSupport; import dorkbox.network.rmi.RemoteObject; import dorkbox.network.rmi.RemoteObjectCallback; -import dorkbox.network.rmi.Rmi; -import dorkbox.network.rmi.RmiBridge; -import dorkbox.network.rmi.RmiMessage; import dorkbox.network.rmi.RmiObjectHandler; -import dorkbox.network.rmi.RmiProxyHandler; -import dorkbox.network.rmi.RmiRegistration; import dorkbox.network.rmi.TimeoutException; -import dorkbox.network.serialization.CryptoSerializationManager; -import dorkbox.util.collections.LockFreeHashMap; -import dorkbox.util.collections.LockFreeIntMap; -import dorkbox.util.generics.ClassHelper; import io.netty.bootstrap.DatagramSessionChannel; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler.Sharable; @@ -119,8 +102,8 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne private final Object messageInProgressLock = new Object(); private final AtomicBoolean messageInProgress = new AtomicBoolean(false); - private ISessionManager sessionManager; - private ChannelWrapper channelWrapper; + private final ISessionManager sessionManager; + private final ChannelWrapper channelWrapper; private volatile PingFuture pingFuture = null; @@ -146,73 +129,72 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne private CountDownLatch closeLatch; - // - // RMI fields - // - private final RmiBridge rmiBridge; - private final Map proxyIdCache; - private final List> proxyListeners; - - private final LockFreeIntMap rmiRegistrationCallbacks; - private volatile int rmiCallbackId = 0; - + // RMI support for this connection + private final ConnectionSupport rmiSupport; /** * All of the parameters can be null, when metaChannel wants to get the base class type */ public - ConnectionImpl(final Logger logger, final EndPoint endPoint, final RmiBridge rmiBridge) { - this.logger = logger; + ConnectionImpl(final EndPoint endPoint, final ChannelWrapper channelWrapper) { this.endPoint = endPoint; - this.rmiBridge = rmiBridge; - if (endPoint != null && endPoint.globalRmiBridge != null) { - // rmi is enabled. - // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent - // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections - proxyIdCache = new LockFreeHashMap(); - proxyListeners = new CopyOnWriteArrayList>(); - rmiRegistrationCallbacks = new LockFreeIntMap(); + if (endPoint != null) { + this.channelWrapper = channelWrapper; + this.logger = endPoint.logger; + this.sessionManager = endPoint.connectionManager; + + boolean isNetworkChannel = this.channelWrapper instanceof ChannelNetworkWrapper; + + if (endPoint.rmiEnabled) { + + RmiObjectHandler handler; + if (isNetworkChannel) { + handler = endPoint.rmiNetworkHandler; + } + else { + handler = endPoint.rmiLocalHandler; + } + + // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent access, but + // there WILL be issues with thread visibility because a different worker thread can be called for different connections + this.rmiSupport = new ConnectionRmiSupport(endPoint.rmiGlobalBridge, handler); + } else { + this.rmiSupport = new ConnectionSupport(); + } + + + + if (isNetworkChannel) { + this.remoteKeyChanged = ((ChannelNetworkWrapper) channelWrapper).remoteKeyChanged(); + + int count = 0; + if (channelWrapper.tcp() != null) { + count++; + } + + if (channelWrapper.udp() != null) { + count++; + } + + // when closing this connection, HOW MANY endpoints need to be closed? + this.closeLatch = new CountDownLatch(count); + } + else { + this.remoteKeyChanged = false; + + // when closing this connection, HOW MANY endpoints need to be closed? + this.closeLatch = new CountDownLatch(1); + } + } else { - proxyIdCache = null; - proxyListeners = null; - rmiRegistrationCallbacks = null; + this.logger = null; + this.sessionManager = null; + this.channelWrapper = null; + this.rmiSupport = new ConnectionSupport(); } } - /** - * Initialize the connection with any extra info that is needed but was unavailable at the channel construction. - */ - final - void init(final ChannelWrapper channelWrapper, final ISessionManager sessionManager) { - this.sessionManager = sessionManager; - this.channelWrapper = channelWrapper; - - //noinspection SimplifiableIfStatement - if (this.channelWrapper instanceof ChannelNetworkWrapper) { - this.remoteKeyChanged = ((ChannelNetworkWrapper) this.channelWrapper).remoteKeyChanged(); - - int count = 0; - if (channelWrapper.tcp() != null) { - count++; - } - - if (channelWrapper.udp() != null) { - count++; - } - - // when closing this connection, HOW MANY endpoints need to be closed? - closeLatch = new CountDownLatch(count); - } - else { - this.remoteKeyChanged = false; - - // when closing this connection, HOW MANY endpoints need to be closed? - closeLatch = new CountDownLatch(1); - } - } - - /** * @return a threadlocal AES key + IV. key=32 byte, iv=12 bytes (AES-GCM implementation). This is a threadlocal * because multiple protocols can be performing crypto AT THE SAME TIME, and so we have to make sure that operations don't @@ -777,14 +759,9 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne removeAll(); } - // proxy listeners are cleared in the removeAll() call - if (proxyIdCache != null) { - proxyIdCache.clear(); - } - if (rmiRegistrationCallbacks != null) { - rmiRegistrationCallbacks.clear(); - } + // remove all RMI listeners + rmiSupport.close(); } } @@ -925,9 +902,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne @Override public final Listeners removeAll() { - if (proxyListeners != null) { - proxyListeners.clear(); - } + rmiSupport.removeAllListeners(); if (this.endPoint instanceof EndPointServer) { // when we are a server, NORMALLY listeners are added at the GLOBAL level @@ -1042,48 +1017,13 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne @Override public final void createRemoteObject(final Class interfaceClass, final RemoteObjectCallback callback) { - if (!interfaceClass.isInterface()) { - throw new IllegalArgumentException("Cannot create a proxy for RMI access. It must be an interface."); - } - - // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent - // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections - //noinspection NonAtomicOperationOnVolatileField - int nextRmiCallbackId = rmiCallbackId++; - rmiRegistrationCallbacks.put(nextRmiCallbackId, callback); - RmiRegistration message = new RmiRegistration(interfaceClass, RmiBridge.INVALID_RMI, nextRmiCallbackId); - - // We use a callback to notify us when the object is ready. We can't "create this on the fly" because we - // have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here. - - // this means we are creating a NEW object on the server, bound access to only this connection - send(message).flush(); + rmiSupport.createRemoteObject(this, interfaceClass, callback); } @Override public final void getRemoteObject(final int objectId, final RemoteObjectCallback callback) { - if (objectId < 0) { - throw new IllegalStateException("Object ID cannot be < 0"); - } - if (objectId >= RmiBridge.INVALID_RMI) { - throw new IllegalStateException("Object ID cannot be >= " + RmiBridge.INVALID_RMI); - } - - Class iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(RemoteObjectCallback.class, callback.getClass(), 0); - - // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent - // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections - //noinspection NonAtomicOperationOnVolatileField - int nextRmiCallbackId = rmiCallbackId++; - rmiRegistrationCallbacks.put(nextRmiCallbackId, callback); - RmiRegistration message = new RmiRegistration(iFaceClass, objectId, nextRmiCallbackId); - - // We use a callback to notify us when the object is ready. We can't "create this on the fly" because we - // have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here. - - // this means we are creating a NEW object on the server, bound access to only this connection - send(message).flush(); + rmiSupport.getRemoteObject(this, objectId, callback); } @@ -1093,25 +1033,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne * @return true if there was RMI stuff done, false if the message was "normal" and nothing was done */ boolean manageRmi(final Object message) { - if (message instanceof RmiMessage) { - RmiObjectHandler rmiObjectHandler = channelWrapper.manageRmi(); - - if (message instanceof InvokeMethod) { - rmiObjectHandler.invoke(this, (InvokeMethod) message, rmiBridge.getListener()); - } - else if (message instanceof InvokeMethodResult) { - for (Listener.OnMessageReceived proxyListener : proxyListeners) { - proxyListener.received(this, (InvokeMethodResult) message); - } - } - else if (message instanceof RmiRegistration) { - rmiObjectHandler.registration(this, (RmiRegistration) message); - } - - return true; - } - - return false; + return rmiSupport.manage(this, message); } /** @@ -1119,159 +1041,31 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne */ Object fixupRmi(final Object message) { // "local RMI" objects have to be modified, this part does that - RmiObjectHandler rmiObjectHandler = channelWrapper.manageRmi(); - return rmiObjectHandler.normalMessages(this, message); + return rmiSupport.fixupRmi(this, message); } /** - * This will remove the invoke and invoke response listeners for this the remote object + * This will remove the invoke and invoke response listeners for this remote object */ public void removeRmiListeners(final int objectID, final Listener listener) { - } - - /** - * For network connections, the interface class kryo ID == implementation class kryo ID, so they switch automatically. - * For local connections, we have to switch it appropriately in the LocalRmiProxy - */ - public final - RmiRegistration createNewRmiObject(final Class interfaceClass, final Class implementationClass, final int callbackId) { - CryptoSerializationManager manager = getEndPoint().serializationManager; - - KryoExtra kryo = null; - Object object = null; - int rmiId = 0; - - try { - kryo = manager.takeKryo(); - - // because the INTERFACE is what is registered with kryo (not the impl) we have to temporarily permit unregistered classes (which have an ID of -1) - // so we can cache the instantiator for this class. - boolean registrationRequired = kryo.isRegistrationRequired(); - - kryo.setRegistrationRequired(false); - - // this is what creates a new instance of the impl class, and stores it as an ID. - object = kryo.newInstance(implementationClass); - - if (registrationRequired) { - // only if it's different should we call this again. - kryo.setRegistrationRequired(true); - } - - - rmiId = rmiBridge.register(object); - - if (rmiId == RmiBridge.INVALID_RMI) { - // this means that there are too many RMI ids (either global or connection specific!) - object = null; - } - else { - // if we are invalid, skip going over fields that might also be RMI objects, BECAUSE our object will be NULL! - - // the @Rmi annotation allows an RMI object to have fields with objects that are ALSO RMI objects - LinkedList, Object>> classesToCheck = new LinkedList, Object>>(); - classesToCheck.add(new AbstractMap.SimpleEntry, Object>(implementationClass, object)); - - - Map.Entry, Object> remoteClassObject; - while (!classesToCheck.isEmpty()) { - remoteClassObject = classesToCheck.removeFirst(); - - // we have to check the IMPLEMENTATION for any additional fields that will have proxy information. - // we use getDeclaredFields() + walking the object hierarchy, so we get ALL the fields possible (public + private). - for (Field field : remoteClassObject.getKey() - .getDeclaredFields()) { - if (field.getAnnotation(Rmi.class) != null) { - final Class type = field.getType(); - - if (!type.isInterface()) { - // the type must be an interface, otherwise RMI cannot create a proxy object - logger.error("Error checking RMI fields for: {}.{} -- It is not an interface!", - remoteClassObject.getKey(), - field.getName()); - continue; - } - - - boolean prev = field.isAccessible(); - field.setAccessible(true); - final Object o; - try { - o = field.get(remoteClassObject.getValue()); - - rmiBridge.register(o); - classesToCheck.add(new AbstractMap.SimpleEntry, Object>(type, o)); - } catch (IllegalAccessException e) { - logger.error("Error checking RMI fields for: {}.{}", remoteClassObject.getKey(), field.getName(), e); - } finally { - field.setAccessible(prev); - } - } - } - - - // have to check the object hierarchy as well - Class superclass = remoteClassObject.getKey() - .getSuperclass(); - if (superclass != null && superclass != Object.class) { - classesToCheck.add(new AbstractMap.SimpleEntry, Object>(superclass, remoteClassObject.getValue())); - } - } - } - } catch (Exception e) { - logger.error("Error registering RMI class " + implementationClass, e); - } finally { - if (kryo != null) { - // we use kryo to create a new instance - so only return it on error or when it's done creating a new instance - manager.returnKryo(kryo); - } - } - - return new RmiRegistration(interfaceClass, rmiId, callbackId, object); - } - - public final - RmiRegistration getExistingRmiObject(final Class interfaceClass, final int rmiId, final int callbackId) { - Object object = getImplementationObject(rmiId); - - return new RmiRegistration(interfaceClass, rmiId, callbackId, object); + rmiSupport.removeAllListeners(); //? this is called from close(), when the "RMI" object is closed. TODO: REMOVE THIS? } public final void runRmiCallback(final Class interfaceClass, final int callbackId, final Object remoteObject) { - RemoteObjectCallback callback = rmiRegistrationCallbacks.remove(callbackId); - - try { - //noinspection unchecked - callback.created(remoteObject); - } catch (Exception e) { - logger.error("Error getting or creating the remote object " + interfaceClass, e); - } + rmiSupport.runCallback(interfaceClass, callbackId, remoteObject, logger); } /** * Used by RMI by the LOCAL side when setting up the to fetch an object for the REMOTE side * - * @return the registered ID for a specific object. + * @return the registered ID for a specific object, or RmiBridge.INVALID_RMI if there was no ID. */ @Override public int getRegisteredId(final T object) { - // always check global before checking local, because less contention on the synchronization - RmiBridge globalRmiBridge = endPoint.globalRmiBridge; - - if (globalRmiBridge == null) { - throw new NullPointerException("Unable to call 'getRegisteredId' when the globalRmiBridge is null!"); - } - - int objectId = globalRmiBridge.getRegisteredId(object); - if (objectId != RmiBridge.INVALID_RMI) { - return objectId; - } - else { - return rmiBridge.getRegisteredId(object); - } + return rmiSupport.getRegisteredId(object); } /** @@ -1299,51 +1093,17 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements CryptoConne @Override public RemoteObject getProxyObject(final int rmiId, final Class iFace) { - if (iFace == null) { - throw new IllegalArgumentException("iface cannot be null."); - } - if (!iFace.isInterface()) { - throw new IllegalArgumentException("iface must be an interface."); - } - - // we want to have a connection specific cache of IDs - // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent - // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections - RemoteObject remoteObject = proxyIdCache.get(rmiId); - - if (remoteObject == null) { - // duplicates are fine, as they represent the same object (as specified by the ID) on the remote side. - - // the ACTUAL proxy is created in the connection impl. - RmiProxyHandler proxyObject = new RmiProxyHandler(this, rmiId, iFace); - proxyListeners.add(proxyObject.getListener()); - - Class[] temp = new Class[2]; - temp[0] = RemoteObject.class; - temp[1] = iFace; - - remoteObject = (RemoteObject) Proxy.newProxyInstance(RmiBridge.class.getClassLoader(), temp, proxyObject); - - proxyIdCache.put(rmiId, remoteObject); - } - - return remoteObject; + return rmiSupport.getProxyObject(this, rmiId, iFace); } /** * This is used by RMI for the REMOTE side, to get the implementation + * + * @param objectId this is the RMI object ID */ @Override public - Object getImplementationObject(final int objectID) { - if (RmiBridge.isGlobal(objectID)) { - RmiBridge globalRmiBridge = endPoint.globalRmiBridge; - - assert globalRmiBridge != null; - - return globalRmiBridge.getRegisteredObject(objectID); - } else { - return rmiBridge.getRegisteredObject(objectID); - } + Object getImplementationObject(final int objectId) { + return rmiSupport.getImplementationObject(objectId); } } diff --git a/src/dorkbox/network/connection/EndPoint.java b/src/dorkbox/network/connection/EndPoint.java index f787e483..fb94849b 100644 --- a/src/dorkbox/network/connection/EndPoint.java +++ b/src/dorkbox/network/connection/EndPoint.java @@ -20,7 +20,6 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.SecureRandom; import java.util.List; -import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import org.bouncycastle.crypto.AsymmetricCipherKeyPair; @@ -36,7 +35,6 @@ import dorkbox.network.connection.wrapper.ChannelLocalWrapper; import dorkbox.network.connection.wrapper.ChannelNetworkWrapper; import dorkbox.network.connection.wrapper.ChannelWrapper; import dorkbox.network.rmi.RmiBridge; -import dorkbox.network.rmi.RmiObjectHandler; import dorkbox.network.rmi.RmiObjectLocalHandler; import dorkbox.network.rmi.RmiObjectNetworkHandler; import dorkbox.network.serialization.Serialization; @@ -157,6 +155,8 @@ class EndPoint extends Shutdownable { @Property public static int udpMaxSize = 508; + protected final Configuration config; + protected final ConnectionManager connectionManager; protected final dorkbox.network.serialization.CryptoSerializationManager serializationManager; protected final RegistrationWrapper registrationWrapper; @@ -166,14 +166,13 @@ class EndPoint extends Shutdownable { final SecureRandom secureRandom; - // we only want one instance of these created. These will be called appropriately - private final RmiObjectHandler rmiHandler; - private final RmiObjectLocalHandler localRmiHandler; - private final RmiObjectNetworkHandler networkRmiHandler; - final RmiBridge globalRmiBridge; + final boolean rmiEnabled; + + // we only want one instance of these created. These will be called appropriately + final RmiObjectLocalHandler rmiLocalHandler; + final RmiObjectNetworkHandler rmiNetworkHandler; + final RmiBridge rmiGlobalBridge; - private final Executor rmiExecutor; - private final boolean rmiEnabled; SettingsStore propertyStore; boolean disableRemoteKeyValidation; @@ -196,6 +195,7 @@ class EndPoint extends Shutdownable { public EndPoint(Class type, final Configuration config) throws SecurityException { super(type); + this.config = config; // make sure that 'localhost' is ALWAYS our specific loopback IP address if (config.host != null && (config.host.equals("localhost") || config.host.startsWith("127."))) { @@ -212,7 +212,6 @@ class EndPoint extends Shutdownable { // setup our RMI serialization managers. Can only be called once rmiEnabled = serializationManager.initRmiSerialization(); - rmiExecutor = config.rmiExecutor; // The registration wrapper permits the registration process to access protected/package fields/methods, that we don't want @@ -283,16 +282,14 @@ class EndPoint extends Shutdownable { connectionManager = new ConnectionManager(type.getSimpleName(), connection0(null, null).getClass()); if (rmiEnabled) { - rmiHandler = null; - localRmiHandler = new RmiObjectLocalHandler(logger); - networkRmiHandler = new RmiObjectNetworkHandler(logger); - globalRmiBridge = new RmiBridge(logger, config.rmiExecutor, true); + rmiLocalHandler = new RmiObjectLocalHandler(logger); + rmiNetworkHandler = new RmiObjectNetworkHandler(logger); + rmiGlobalBridge = new RmiBridge(logger, true); } else { - rmiHandler = new RmiObjectHandler(); - localRmiHandler = null; - networkRmiHandler = null; - globalRmiBridge = null; + rmiLocalHandler = null; + rmiNetworkHandler = null; + rmiGlobalBridge = null; } Logger readLogger = LoggerFactory.getLogger(type.getSimpleName() + ".READ"); @@ -380,8 +377,8 @@ class EndPoint extends Shutdownable { * @return a new network connection */ protected - ConnectionImpl newConnection(final Logger logger, final E endPoint, final RmiBridge rmiBridge) { - return new ConnectionImpl(logger, endPoint, rmiBridge); + ConnectionImpl newConnection(final E endPoint, final ChannelWrapper wrapper) { + return new ConnectionImpl(endPoint, wrapper); } /** @@ -392,42 +389,24 @@ class EndPoint extends Shutdownable { * @param metaChannel can be NULL (when getting the baseClass) * @param remoteAddress be NULL (when getting the baseClass or when creating a local channel) */ - protected final + final ConnectionImpl connection0(final MetaChannel metaChannel, final InetSocketAddress remoteAddress) { ConnectionImpl connection; - RmiBridge rmiBridge = null; - if (metaChannel != null && rmiEnabled) { - rmiBridge = new RmiBridge(logger, rmiExecutor, false); - } - // setup the extras needed by the network connection. // These properties are ASSIGNED in the same thread that CREATED the object. Only the AES info needs to be // volatile since it is the only thing that changes. if (metaChannel != null) { ChannelWrapper wrapper; - connection = newConnection(logger, this, rmiBridge); - if (metaChannel.localChannel != null) { - if (rmiEnabled) { - wrapper = new ChannelLocalWrapper(metaChannel, localRmiHandler); - } - else { - wrapper = new ChannelLocalWrapper(metaChannel, rmiHandler); - } + wrapper = new ChannelLocalWrapper(metaChannel); } else { - if (rmiEnabled) { - wrapper = new ChannelNetworkWrapper(metaChannel, remoteAddress, networkRmiHandler); - } - else { - wrapper = new ChannelNetworkWrapper(metaChannel, remoteAddress, rmiHandler); - } + wrapper = new ChannelNetworkWrapper(metaChannel, remoteAddress); } - // now initialize the connection channels with whatever extra info they might need. - connection.init(wrapper, connectionManager); + connection = newConnection(this, wrapper); isConnected.set(true); connectionManager.addConnection(connection); @@ -436,7 +415,7 @@ class EndPoint extends Shutdownable { // getting the connection baseClass // have to add the networkAssociate to a map of "connected" computers - connection = newConnection(null, null, null); + connection = newConnection(null, null); } return connection; @@ -568,7 +547,7 @@ class EndPoint extends Shutdownable { */ public int createGlobalObject(final T globalObject) { - return globalRmiBridge.register(globalObject); + return rmiGlobalBridge.register(globalObject); } /** @@ -581,6 +560,6 @@ class EndPoint extends Shutdownable { @SuppressWarnings("unchecked") public T getGlobalObject(final int objectRmiId) { - return (T) globalRmiBridge.getRegisteredObject(objectRmiId); + return (T) rmiGlobalBridge.getRegisteredObject(objectRmiId); } } diff --git a/src/dorkbox/network/connection/KryoExtra.java b/src/dorkbox/network/connection/KryoExtra.java index 6b9c8e26..499724aa 100644 --- a/src/dorkbox/network/connection/KryoExtra.java +++ b/src/dorkbox/network/connection/KryoExtra.java @@ -122,6 +122,14 @@ class KryoExtra extends Kryo { return readClassAndObject(reader); // this properly sets the readerIndex, but only if it's the correct buffer } + /** + * This is NOT ENCRYPTED (and is only done on the loopback connection!) + */ + public synchronized + void writeCompressed(final ByteBuf buffer, final Object message) throws IOException { + writeCompressed(null, buffer, message); + } + /** * This is NOT ENCRYPTED (and is only done on the loopback connection!) */ @@ -215,6 +223,14 @@ class KryoExtra extends Kryo { buffer.writeBytes(inputArray, inputOffset, compressedLength + lengthLength); } + /** + * This is NOT ENCRYPTED (and is only done on the loopback connection!) + */ + public + Object readCompressed(final ByteBuf buffer, int length) throws IOException { + return readCompressed(null, buffer, length); + } + /** * This is NOT ENCRYPTED (and is only done on the loopback connection!) */ diff --git a/src/dorkbox/network/connection/registration/ConnectionWrapper.java b/src/dorkbox/network/connection/registration/ConnectionRegistrationImpl.java similarity index 97% rename from src/dorkbox/network/connection/registration/ConnectionWrapper.java rename to src/dorkbox/network/connection/registration/ConnectionRegistrationImpl.java index 9706f73f..86acea8b 100644 --- a/src/dorkbox/network/connection/registration/ConnectionWrapper.java +++ b/src/dorkbox/network/connection/registration/ConnectionRegistrationImpl.java @@ -36,11 +36,11 @@ import io.netty.channel.ChannelHandlerContext; * This is to prevent race conditions where onMessage() can happen BEFORE a "connection" is "connected" */ public -class ConnectionWrapper implements CryptoConnection, ChannelHandler { +class ConnectionRegistrationImpl implements CryptoConnection, ChannelHandler { public final ConnectionImpl connection; public - ConnectionWrapper(final ConnectionImpl connection) { + ConnectionRegistrationImpl(final ConnectionImpl connection) { this.connection = connection; } @@ -59,7 +59,6 @@ class ConnectionWrapper implements CryptoConnection, ChannelHandler { void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception { } - @Override public long getNextGcmSequence() { diff --git a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandler.java b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandler.java index 5006a3cb..8133c1d6 100644 --- a/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandler.java +++ b/src/dorkbox/network/connection/registration/remote/RegistrationRemoteHandler.java @@ -26,7 +26,7 @@ import org.bouncycastle.jce.spec.ECParameterSpec; import dorkbox.network.connection.ConnectionImpl; import dorkbox.network.connection.EndPoint; import dorkbox.network.connection.RegistrationWrapper; -import dorkbox.network.connection.registration.ConnectionWrapper; +import dorkbox.network.connection.registration.ConnectionRegistrationImpl; import dorkbox.network.connection.registration.MetaChannel; import dorkbox.network.connection.registration.Registration; import dorkbox.network.connection.registration.RegistrationHandler; @@ -285,7 +285,7 @@ class RegistrationRemoteHandler extends RegistrationHandler { // add the "connected"/"normal" handler now that we have established a "new" connection. // This will have state, etc. for this connection. THIS MUST BE 100% TCP/UDP created, otherwise it will break connections! ConnectionImpl connection = this.registrationWrapper.connection0(metaChannel, remoteAddress); - metaChannel.connection = new ConnectionWrapper(connection); + metaChannel.connection = new ConnectionRegistrationImpl(connection); // Now setup our meta-channel to migrate to the correct connection handler for all regular data. @@ -334,7 +334,7 @@ class RegistrationRemoteHandler extends RegistrationHandler { try { // REMOVE our channel wrapper (only used for encryption) with the actual connection - ChannelHandler handler = metaChannel.connection = ((ConnectionWrapper) metaChannel.connection).connection; + ChannelHandler handler = metaChannel.connection = ((ConnectionRegistrationImpl) metaChannel.connection).connection; Channel channel; if (metaChannel.tcpChannel != null) { @@ -410,7 +410,7 @@ class RegistrationRemoteHandler extends RegistrationHandler { } } - pipeline.remove(ConnectionWrapper.class); + pipeline.remove(ConnectionRegistrationImpl.class); if (idleTimeout > 0) { pipeline.replace(IDLE_HANDLER, IDLE_HANDLER_FULL, new IdleStateHandler(0, 0, idleTimeout, TimeUnit.MILLISECONDS)); diff --git a/src/dorkbox/network/connection/wrapper/ChannelLocalWrapper.java b/src/dorkbox/network/connection/wrapper/ChannelLocalWrapper.java index 5d15d92d..06a09e28 100644 --- a/src/dorkbox/network/connection/wrapper/ChannelLocalWrapper.java +++ b/src/dorkbox/network/connection/wrapper/ChannelLocalWrapper.java @@ -24,7 +24,6 @@ import dorkbox.network.connection.ConnectionPoint; import dorkbox.network.connection.EndPoint; import dorkbox.network.connection.ISessionManager; import dorkbox.network.connection.registration.MetaChannel; -import dorkbox.network.rmi.RmiObjectHandler; import io.netty.channel.Channel; import io.netty.channel.local.LocalAddress; import io.netty.util.concurrent.Promise; @@ -33,15 +32,13 @@ public class ChannelLocalWrapper implements ChannelWrapper, ConnectionPoint { private final Channel channel; - private final RmiObjectHandler rmiObjectHandler; private final AtomicBoolean shouldFlush = new AtomicBoolean(false); private String remoteAddress; public - ChannelLocalWrapper(MetaChannel metaChannel, final RmiObjectHandler rmiObjectHandler) { + ChannelLocalWrapper(MetaChannel metaChannel) { this.channel = metaChannel.localChannel; - this.rmiObjectHandler = rmiObjectHandler; this.remoteAddress = ((LocalAddress) this.channel.remoteAddress()).id(); } @@ -108,12 +105,6 @@ class ChannelLocalWrapper implements ChannelWrapper, ConnectionPoint { return true; } - @Override - public - RmiObjectHandler manageRmi() { - return rmiObjectHandler; - } - @Override public final String getRemoteHost() { diff --git a/src/dorkbox/network/connection/wrapper/ChannelNetworkWrapper.java b/src/dorkbox/network/connection/wrapper/ChannelNetworkWrapper.java index effe011b..5eeca347 100644 --- a/src/dorkbox/network/connection/wrapper/ChannelNetworkWrapper.java +++ b/src/dorkbox/network/connection/wrapper/ChannelNetworkWrapper.java @@ -25,7 +25,6 @@ import dorkbox.network.connection.ConnectionPoint; import dorkbox.network.connection.EndPoint; import dorkbox.network.connection.ISessionManager; import dorkbox.network.connection.registration.MetaChannel; -import dorkbox.network.rmi.RmiObjectHandler; import dorkbox.util.FastThreadLocal; import io.netty.bootstrap.DatagramCloseMessage; import io.netty.util.NetUtil; @@ -49,16 +48,11 @@ class ChannelNetworkWrapper implements ChannelWrapper { private final byte[] aesIV; // AES-GCM requires 12 bytes private final FastThreadLocal cryptoParameters; - private final RmiObjectHandler rmiObjectHandler; - /** - * @param rmiObjectHandler is a no-op handler if RMI is disabled, otherwise handles RMI object registration - */ public - ChannelNetworkWrapper(final MetaChannel metaChannel, final InetSocketAddress remoteAddress, final RmiObjectHandler rmiObjectHandler) { + ChannelNetworkWrapper(final MetaChannel metaChannel, final InetSocketAddress remoteAddress) { this.sessionId = metaChannel.sessionId; - this.rmiObjectHandler = rmiObjectHandler; this.isLoopback = remoteAddress.getAddress().equals(NetUtil.LOCALHOST); if (metaChannel.tcpChannel != null) { @@ -140,12 +134,6 @@ class ChannelNetworkWrapper implements ChannelWrapper { return isLoopback; } - @Override - public - RmiObjectHandler manageRmi() { - return rmiObjectHandler; - } - @Override public String getRemoteHost() { diff --git a/src/dorkbox/network/connection/wrapper/ChannelWrapper.java b/src/dorkbox/network/connection/wrapper/ChannelWrapper.java index c18fb5b3..22d4eb67 100644 --- a/src/dorkbox/network/connection/wrapper/ChannelWrapper.java +++ b/src/dorkbox/network/connection/wrapper/ChannelWrapper.java @@ -20,7 +20,6 @@ import org.bouncycastle.crypto.params.ParametersWithIV; import dorkbox.network.connection.ConnectionImpl; import dorkbox.network.connection.ConnectionPoint; import dorkbox.network.connection.ISessionManager; -import dorkbox.network.rmi.RmiObjectHandler; public interface ChannelWrapper { @@ -46,8 +45,6 @@ interface ChannelWrapper { */ boolean isLoopback(); - RmiObjectHandler manageRmi(); - /** * @return the remote host (can be local, tcp, udp) */ diff --git a/src/dorkbox/network/rmi/ConnectionRmiSupport.java b/src/dorkbox/network/rmi/ConnectionRmiSupport.java new file mode 100644 index 00000000..4cfa7ba4 --- /dev/null +++ b/src/dorkbox/network/rmi/ConnectionRmiSupport.java @@ -0,0 +1,396 @@ +package dorkbox.network.rmi; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Proxy; +import java.util.AbstractMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.slf4j.Logger; + +import dorkbox.network.connection.Connection; +import dorkbox.network.connection.ConnectionImpl; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.connection.Listener; +import dorkbox.network.connection.Listener.OnMessageReceived; +import dorkbox.network.serialization.CryptoSerializationManager; +import dorkbox.util.collections.LockFreeHashMap; +import dorkbox.util.collections.LockFreeIntMap; +import dorkbox.util.generics.ClassHelper; + +public +class ConnectionRmiSupport extends ConnectionSupport { + private final RmiBridge rmiGlobalBridge; + private final RmiBridge rmiLocalBridge; + private final RmiObjectHandler rmiHandler; + + private final Map proxyIdCache; + private final List> proxyListeners; + + private final LockFreeIntMap rmiRegistrationCallbacks; + private final Logger logger; + private volatile int rmiCallbackId = 0; + + + public + ConnectionRmiSupport(final RmiBridge rmiGlobalBridge, final RmiObjectHandler rmiHandler) { + if (rmiGlobalBridge == null || rmiHandler == null) { + throw new NullPointerException("RMI cannot be null if using RMI support!"); + } + + this.rmiGlobalBridge = rmiGlobalBridge; + this.rmiHandler = rmiHandler; + + logger = rmiGlobalBridge.logger; + + // * @param executor + // * Sets the executor used to invoke methods when an invocation is received from a remote endpoint. By default, no + // * executor is set and invocations occur on the network thread, which should not be blocked for long, May be null. + + rmiLocalBridge = new RmiBridge(logger, false); + + + proxyIdCache = new LockFreeHashMap(); + proxyListeners = new CopyOnWriteArrayList>(); + rmiRegistrationCallbacks = new LockFreeIntMap(); + } + + public + void close() { + // proxy listeners are cleared in the removeAll() call (which happens BEFORE close) + proxyIdCache.clear(); + + rmiRegistrationCallbacks.clear(); + } + + public + void removeAllListeners() { + proxyListeners.clear(); + } + + public + void createRemoteObject(final ConnectionImpl connection, final Class interfaceClass, final RemoteObjectCallback callback) { + if (!interfaceClass.isInterface()) { + throw new IllegalArgumentException("Cannot create a proxy for RMI access. It must be an interface."); + } + + // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent + // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections + //noinspection NonAtomicOperationOnVolatileField + int nextRmiCallbackId = rmiCallbackId++; + rmiRegistrationCallbacks.put(nextRmiCallbackId, callback); + RmiRegistration message = new RmiRegistration(interfaceClass, RmiBridge.INVALID_RMI, nextRmiCallbackId); + + // We use a callback to notify us when the object is ready. We can't "create this on the fly" because we + // have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here. + + // this means we are creating a NEW object on the server, bound access to only this connection + connection.send(message).flush(); + } + + public + void getRemoteObject(final ConnectionImpl connection, final int objectId, final RemoteObjectCallback callback) { + if (objectId < 0) { + throw new IllegalStateException("Object ID cannot be < 0"); + } + if (objectId >= RmiBridge.INVALID_RMI) { + throw new IllegalStateException("Object ID cannot be >= " + RmiBridge.INVALID_RMI); + } + + Class iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(RemoteObjectCallback.class, callback.getClass(), 0); + + // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent + // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections + //noinspection NonAtomicOperationOnVolatileField + int nextRmiCallbackId = rmiCallbackId++; + rmiRegistrationCallbacks.put(nextRmiCallbackId, callback); + RmiRegistration message = new RmiRegistration(iFaceClass, objectId, nextRmiCallbackId); + + // We use a callback to notify us when the object is ready. We can't "create this on the fly" because we + // have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here. + + // this means we are getting an EXISTING object on the server, bound access to only this connection + connection.send(message).flush(); + } + + /** + * Manages the RMI stuff for a connection. + */ + public + boolean manage(final ConnectionImpl connection, final Object message) { + if (message instanceof InvokeMethod) { + CryptoSerializationManager serialization = connection.getEndPoint().getSerialization(); + + InvokeMethod invokeMethod = rmiHandler.getInvokeMethod(serialization, connection, (InvokeMethod) message); + + + + int objectID = invokeMethod.objectID; + + // have to make sure to get the correct object (global vs local) + // This is what is overridden when registering interfaces/classes for RMI. + // objectID is the interface ID, and this returns the implementation ID. + final Object target = getImplementationObject(objectID); + + if (target == null) { + logger.warn("Ignoring remote invocation request for unknown object ID: {}", objectID); + + return true; // maybe false? + } + + // Executor executor2 = RmiBridge.this.executor; + // if (executor2 == null) { + try { + RmiBridge.invoke(connection, target, invokeMethod, logger); + } catch (IOException e) { + logger.error("Unable to invoke method.", e); + } + // } + // else { + // executor2.execute(new Runnable() { + // @Override + // public + // void run() { + // try { + // RmiBridge.invoke(connection, target, invokeMethod, logger); + // } catch (IOException e) { + // logger.error("Unable to invoke method.", e); + // } + // } + // }); + // } + return true; + } + else if (message instanceof InvokeMethodResult) { + for (Listener.OnMessageReceived proxyListener : proxyListeners) { + proxyListener.received(connection, (InvokeMethodResult) message); + } + return true; + } + else if (message instanceof RmiRegistration) { + rmiHandler.registration(this, connection, (RmiRegistration) message); + return true; + } + + // not the correct type + return false; + } + + /** + * For network connections, the interface class kryo ID == implementation class kryo ID, so they switch automatically. + * For local connections, we have to switch it appropriately in the LocalRmiProxy + */ + public + RmiRegistration createNewRmiObject(final CryptoSerializationManager serialization, + final Class interfaceClass, + final Class implementationClass, + final int callbackId, + final Logger logger) { + KryoExtra kryo = null; + Object object = null; + int rmiId = 0; + + try { + kryo = serialization.takeKryo(); + + // because the INTERFACE is what is registered with kryo (not the impl) we have to temporarily permit unregistered classes (which have an ID of -1) + // so we can cache the instantiator for this class. + boolean registrationRequired = kryo.isRegistrationRequired(); + + kryo.setRegistrationRequired(false); + + // this is what creates a new instance of the impl class, and stores it as an ID. + object = kryo.newInstance(implementationClass); + + if (registrationRequired) { + // only if it's different should we call this again. + kryo.setRegistrationRequired(true); + } + + + rmiId = rmiLocalBridge.register(object); + + if (rmiId == RmiBridge.INVALID_RMI) { + // this means that there are too many RMI ids (either global or connection specific!) + object = null; + } + else { + // if we are invalid, skip going over fields that might also be RMI objects, BECAUSE our object will be NULL! + + // the @Rmi annotation allows an RMI object to have fields with objects that are ALSO RMI objects + LinkedList, Object>> classesToCheck = new LinkedList, Object>>(); + classesToCheck.add(new AbstractMap.SimpleEntry, Object>(implementationClass, object)); + + + Map.Entry, Object> remoteClassObject; + while (!classesToCheck.isEmpty()) { + remoteClassObject = classesToCheck.removeFirst(); + + // we have to check the IMPLEMENTATION for any additional fields that will have proxy information. + // we use getDeclaredFields() + walking the object hierarchy, so we get ALL the fields possible (public + private). + for (Field field : remoteClassObject.getKey() + .getDeclaredFields()) { + if (field.getAnnotation(Rmi.class) != null) { + final Class type = field.getType(); + + if (!type.isInterface()) { + // the type must be an interface, otherwise RMI cannot create a proxy object + logger.error("Error checking RMI fields for: {}.{} -- It is not an interface!", + remoteClassObject.getKey(), + field.getName()); + continue; + } + + + boolean prev = field.isAccessible(); + field.setAccessible(true); + final Object o; + try { + o = field.get(remoteClassObject.getValue()); + + rmiLocalBridge.register(o); + classesToCheck.add(new AbstractMap.SimpleEntry, Object>(type, o)); + } catch (IllegalAccessException e) { + logger.error("Error checking RMI fields for: {}.{}", remoteClassObject.getKey(), field.getName(), e); + } finally { + field.setAccessible(prev); + } + } + } + + + // have to check the object hierarchy as well + Class superclass = remoteClassObject.getKey() + .getSuperclass(); + if (superclass != null && superclass != Object.class) { + classesToCheck.add(new AbstractMap.SimpleEntry, Object>(superclass, remoteClassObject.getValue())); + } + } + } + } catch (Exception e) { + logger.error("Error registering RMI class " + implementationClass, e); + } finally { + if (kryo != null) { + // we use kryo to create a new instance - so only return it on error or when it's done creating a new instance + serialization.returnKryo(kryo); + } + } + + return new RmiRegistration(interfaceClass, rmiId, callbackId, object); + } + + public + void runCallback(final Class interfaceClass, final int callbackId, final Object remoteObject, final Logger logger) { + RemoteObjectCallback callback = rmiRegistrationCallbacks.remove(callbackId); + + try { + //noinspection unchecked + callback.created(remoteObject); + } catch (Exception e) { + logger.error("Error getting or creating the remote object " + interfaceClass, e); + } + } + + public + int getRegisteredId(final T object) { + // always check global before checking local, because less contention on the synchronization + int objectId = rmiGlobalBridge.getRegisteredId(object); + if (objectId != RmiBridge.INVALID_RMI) { + return objectId; + } + else { + // might return RmiBridge.INVALID_RMI; + return rmiLocalBridge.getRegisteredId(object); + } + } + + public + Object getImplementationObject(final int objectID) { + if (RmiBridge.isGlobal(objectID)) { + return rmiGlobalBridge.getRegisteredObject(objectID); + } else { + return rmiLocalBridge.getRegisteredObject(objectID); + } + } + + public + RemoteObject getProxyObject(final ConnectionImpl connection, final int rmiId, final Class iFace) { + if (iFace == null) { + throw new IllegalArgumentException("iface cannot be null."); + } + if (!iFace.isInterface()) { + throw new IllegalArgumentException("iface must be an interface."); + } + + // we want to have a connection specific cache of IDs + // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent + // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections + RemoteObject remoteObject = proxyIdCache.get(rmiId); + + if (remoteObject == null) { + // duplicates are fine, as they represent the same object (as specified by the ID) on the remote side. + + // the ACTUAL proxy is created in the connection impl. + RmiProxyNetworkHandler proxyObject = new RmiProxyNetworkHandler(connection, rmiId, iFace); + proxyListeners.add(proxyObject.getListener()); + + // This is the interface inheritance by the proxy object + Class[] temp = new Class[2]; + temp[0] = RemoteObject.class; + temp[1] = iFace; + + remoteObject = (RemoteObject) Proxy.newProxyInstance(RmiBridge.class.getClassLoader(), temp, proxyObject); + + proxyIdCache.put(rmiId, remoteObject); + } + + return remoteObject; + } + + public + RemoteObject getLocalProxyObject(final ConnectionImpl connection, final int rmiId, final Class iFace, final Object object) { + if (iFace == null) { + throw new IllegalArgumentException("iface cannot be null."); + } + if (!iFace.isInterface()) { + throw new IllegalArgumentException("iface must be an interface."); + } + if (object == null) { + throw new IllegalArgumentException("object cannot be null."); + } + + + // we want to have a connection specific cache of IDs + // because this is PER CONNECTION, there is no need for synchronize(), since there will not be any issues with concurrent + // access, but there WILL be issues with thread visibility because a different worker thread can be called for different connections + RemoteObject remoteObject = proxyIdCache.get(rmiId); + + if (remoteObject == null) { + // duplicates are fine, as they represent the same object (as specified by the ID) on the remote side. + + // the ACTUAL proxy is created in the connection impl. + RmiProxyLocalHandler proxyObject = new RmiProxyLocalHandler(connection, rmiId, iFace, object); + proxyListeners.add(proxyObject.getListener()); + + Class[] temp = new Class[2]; + temp[0] = RemoteObject.class; + temp[1] = iFace; + + remoteObject = (RemoteObject) Proxy.newProxyInstance(RmiBridge.class.getClassLoader(), temp, proxyObject); + + proxyIdCache.put(rmiId, remoteObject); + } + + return remoteObject; + } + + public + Object fixupRmi(final ConnectionImpl connection, final Object message) { + // "local RMI" objects have to be modified, this part does that + return rmiHandler.normalMessages(this, message); + } +} diff --git a/src/dorkbox/network/rmi/ConnectionSupport.java b/src/dorkbox/network/rmi/ConnectionSupport.java new file mode 100644 index 00000000..b6d5a862 --- /dev/null +++ b/src/dorkbox/network/rmi/ConnectionSupport.java @@ -0,0 +1,61 @@ +package dorkbox.network.rmi; + +import org.slf4j.Logger; + +import dorkbox.network.connection.ConnectionImpl; + +/** + * + */ +public +class ConnectionSupport { + public + void close() { + } + + public + void removeAllListeners() { + } + + public + void createRemoteObject(final ConnectionImpl connection, final Class interfaceClass, final RemoteObjectCallback callback) { + } + + public + void getRemoteObject(final ConnectionImpl connection, final int objectId, final RemoteObjectCallback callback) { + } + + public + boolean manage(final ConnectionImpl connection, final Object message) { + return false; + } + + public + Object fixupRmi(final ConnectionImpl connection, final Object message) { + return message; + } + + public + int getRegisteredId(final T object) { + return RmiBridge.INVALID_RMI; + } + + public + void runCallback(final Class interfaceClass, final int callbackId, final Object remoteObject, final Logger logger) { + } + + public + RemoteObject getProxyObject(final ConnectionImpl connection, final int rmiId, final Class iFace) { + return null; + } + + public + RemoteObject getLocalProxyObject(final ConnectionImpl connection, final int rmiId, final Class iFace, final Object object) { + return null; + } + + public + Object getImplementationObject(final int objectId) { + return null; + } +} diff --git a/src/dorkbox/network/rmi/InvocationHandlerSerializer.java b/src/dorkbox/network/rmi/InvocationHandlerSerializer.java index 66c55ccb..a6872a02 100644 --- a/src/dorkbox/network/rmi/InvocationHandlerSerializer.java +++ b/src/dorkbox/network/rmi/InvocationHandlerSerializer.java @@ -38,7 +38,7 @@ class InvocationHandlerSerializer extends Serializer { @Override public void write(Kryo kryo, Output output, Object object) { - RmiProxyHandler handler = (RmiProxyHandler) Proxy.getInvocationHandler(object); + RmiProxyNetworkHandler handler = (RmiProxyNetworkHandler) Proxy.getInvocationHandler(object); output.writeInt(handler.rmiObjectId, true); } diff --git a/src/dorkbox/network/rmi/RmiBridge.java b/src/dorkbox/network/rmi/RmiBridge.java index 8ad6fcae..eaaa7a59 100644 --- a/src/dorkbox/network/rmi/RmiBridge.java +++ b/src/dorkbox/network/rmi/RmiBridge.java @@ -36,15 +36,12 @@ package dorkbox.network.rmi; import java.io.IOException; import java.util.Arrays; -import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import dorkbox.network.connection.Connection; -import dorkbox.network.connection.ConnectionImpl; import dorkbox.network.connection.EndPoint; -import dorkbox.network.connection.Listener; import dorkbox.network.serialization.RmiSerializationManager; import dorkbox.util.Property; import dorkbox.util.collections.LockFreeIntBiMap; @@ -111,9 +108,7 @@ class RmiBridge { } // the name of who created this RmiBridge - private final org.slf4j.Logger logger; - - private final Executor executor; + final org.slf4j.Logger logger; // we start at 1, because 0 (INVALID_RMI) means we access connection only objects @@ -122,66 +117,17 @@ class RmiBridge { // this is the ID -> Object RMI map. The RMI ID is used (not the kryo ID) private final LockFreeIntBiMap objectMap = new LockFreeIntBiMap(INVALID_MAP_ID); - private final Listener.OnMessageReceived invokeListener = new Listener.OnMessageReceived() { - @SuppressWarnings("AutoBoxing") - @Override - public - void received(final ConnectionImpl connection, final InvokeMethod invokeMethod) { - int objectID = invokeMethod.objectID; - - // have to make sure to get the correct object (global vs local) - // This is what is overridden when registering interfaces/classes for RMI. - // objectID is the interface ID, and this returns the implementation ID. - final Object target = connection.getImplementationObject(objectID); - - if (target == null) { - Logger logger2 = RmiBridge.this.logger; - if (logger2.isWarnEnabled()) { - logger2.warn("Ignoring remote invocation request for unknown object ID: {}", objectID); - } - - return; - } - - Executor executor2 = RmiBridge.this.executor; - if (executor2 == null) { - try { - invoke(connection, target, invokeMethod); - } catch (IOException e) { - logger.error("Unable to invoke method.", e); - } - } - else { - executor2.execute(new Runnable() { - @Override - public - void run() { - try { - invoke(connection, target, invokeMethod); - } catch (IOException e) { - logger.error("Unable to invoke method.", e); - } - } - }); - } - } - }; - /** * Creates an RmiBridge with no connections. Connections must be {@link RmiBridge#register(int, Object)} added to allow the remote end * of the connections to access objects in this ObjectSpace. * - * @param executor - * Sets the executor used to invoke methods when an invocation is received from a remote endpoint. By default, no - * executor is set and invocations occur on the network thread, which should not be blocked for long, May be null. * @param isGlobal * specify if this RmiBridge is a "global" bridge, meaning connections will prefer objects from this bridge instead of * the connection-local bridge. */ public - RmiBridge(final org.slf4j.Logger logger, final Executor executor, final boolean isGlobal) { + RmiBridge(final org.slf4j.Logger logger, final boolean isGlobal) { this.logger = logger; - this.executor = executor; if (isGlobal) { rmiObjectIdCounter = new AtomicInteger(0); @@ -191,15 +137,6 @@ class RmiBridge { } } - /** - * @return the invocation listener - */ - @SuppressWarnings("rawtypes") - public - Listener.OnMessageReceived getListener() { - return this.invokeListener; - } - /** * Invokes the method on the object and, if necessary, sends the result back to the connection that made the invocation request. This * method is invoked on the update thread of the {@link EndPoint} for this RmiBridge and unless an executor has been set. @@ -210,12 +147,11 @@ class RmiBridge { * The remote side of this connection requested the invocation. */ @SuppressWarnings("NumericCastThatLosesPrecision") - protected - void invoke(final Connection connection, final Object target, final InvokeMethod invokeMethod) throws IOException { + protected static + void invoke(final Connection connection, final Object target, final InvokeMethod invokeMethod, final Logger logger) throws IOException { CachedMethod cachedMethod = invokeMethod.cachedMethod; - Logger logger2 = this.logger; - if (logger2.isTraceEnabled()) { + if (logger.isTraceEnabled()) { String argString = ""; if (invokeMethod.args != null) { argString = Arrays.deepToString(invokeMethod.args); @@ -240,7 +176,7 @@ class RmiBridge { // did we override our cached method? This is not common. stringBuilder.append(" [Connection method override]"); } - logger2.trace(stringBuilder.toString()); + logger.trace(stringBuilder.toString()); } byte responseData = invokeMethod.responseData; diff --git a/src/dorkbox/network/rmi/RmiObjectHandler.java b/src/dorkbox/network/rmi/RmiObjectHandler.java index f030c996..98dcc973 100644 --- a/src/dorkbox/network/rmi/RmiObjectHandler.java +++ b/src/dorkbox/network/rmi/RmiObjectHandler.java @@ -16,25 +16,14 @@ package dorkbox.network.rmi; import dorkbox.network.connection.ConnectionImpl; -import dorkbox.network.connection.Listener; +import dorkbox.network.serialization.CryptoSerializationManager; public -class RmiObjectHandler { +interface RmiObjectHandler { - public - RmiObjectHandler() { - } + InvokeMethod getInvokeMethod(final CryptoSerializationManager serialization, final ConnectionImpl connection, final InvokeMethod invokeMethod); - public - void invoke(final ConnectionImpl connection, final InvokeMethod message, final Listener.OnMessageReceived rmiInvokeListener) { - } + void registration(final ConnectionRmiSupport rmiSupport, final ConnectionImpl connection, final RmiRegistration message); - public - void registration(final ConnectionImpl connection, final RmiRegistration message) { - } - - public - Object normalMessages(final ConnectionImpl connection, final Object message) { - return message; - } + Object normalMessages(final ConnectionRmiSupport connection, final Object message); } diff --git a/src/dorkbox/network/rmi/RmiObjectLocalHandler.java b/src/dorkbox/network/rmi/RmiObjectLocalHandler.java index 1cf48146..6d2251a2 100644 --- a/src/dorkbox/network/rmi/RmiObjectLocalHandler.java +++ b/src/dorkbox/network/rmi/RmiObjectLocalHandler.java @@ -29,7 +29,6 @@ import com.esotericsoftware.kryo.util.IdentityMap; import dorkbox.network.connection.ConnectionImpl; import dorkbox.network.connection.EndPoint; import dorkbox.network.connection.KryoExtra; -import dorkbox.network.connection.Listener; import dorkbox.network.serialization.CryptoSerializationManager; /** @@ -40,7 +39,7 @@ import dorkbox.network.serialization.CryptoSerializationManager; * This is for a LOCAL connection (same-JVM) */ public -class RmiObjectLocalHandler extends RmiObjectHandler { +class RmiObjectLocalHandler implements RmiObjectHandler { private static final boolean ENABLE_PROXY_OBJECTS = RmiBridge.ENABLE_PROXY_OBJECTS; private static final Field[] NO_REMOTE_FIELDS = new Field[0]; @@ -64,17 +63,12 @@ class RmiObjectLocalHandler extends RmiObjectHandler { this.logger = logger; } - @Override public - void invoke(final ConnectionImpl connection, final InvokeMethod invokeMethod, final Listener.OnMessageReceived rmiInvokeListener) { + InvokeMethod getInvokeMethod(final CryptoSerializationManager serialization, final ConnectionImpl connection, final InvokeMethod invokeMethod) { int methodClassID = invokeMethod.cachedMethod.methodClassID; int methodIndex = invokeMethod.cachedMethod.methodIndex; // have to replace the cached methods with the correct (remote) version, otherwise the wrong methods CAN BE invoked. - CryptoSerializationManager serialization = connection.getEndPoint() - .getSerialization(); - - CachedMethod cachedMethod; try { cachedMethod = serialization.getMethods(methodClassID)[methodIndex]; @@ -122,13 +116,12 @@ class RmiObjectLocalHandler extends RmiObjectHandler { invokeMethod.cachedMethod = cachedMethod; invokeMethod.args = args; - // default action, now that we have swapped out things - rmiInvokeListener.received(connection, invokeMethod); + return invokeMethod; } @Override public - void registration(final ConnectionImpl connection, final RmiRegistration registration) { + void registration(final ConnectionRmiSupport rmiSupport, final ConnectionImpl connection, final RmiRegistration registration) { // manage creating/getting/notifying this RMI object // these fields are ALWAYS present! @@ -148,7 +141,7 @@ class RmiObjectLocalHandler extends RmiObjectHandler { Class rmiImpl = serialization.getRmiImpl(registration.interfaceClass); - RmiRegistration registrationResult = connection.createNewRmiObject(interfaceClass, rmiImpl, callbackId); + RmiRegistration registrationResult = rmiSupport.createNewRmiObject(serialization, interfaceClass, rmiImpl, callbackId, logger); connection.send(registrationResult); // connection transport is flushed in calling method (don't need to do it here) } @@ -156,8 +149,8 @@ class RmiObjectLocalHandler extends RmiObjectHandler { // Check if we are getting an already existing REMOTE object. This check is always AFTER the check to create a new object else { // GET a LOCAL rmi object, if none get a specific, GLOBAL rmi object (objects that are not bound to a single connection). - RmiRegistration registrationResult = connection.getExistingRmiObject(interfaceClass, registration.rmiId, callbackId); - connection.send(registrationResult); + Object implementationObject = rmiSupport.getImplementationObject(registration.rmiId); + connection.send(new RmiRegistration(interfaceClass, registration.rmiId, callbackId, implementationObject)); // connection transport is flushed in calling method (don't need to do it here) } } @@ -182,7 +175,7 @@ class RmiObjectLocalHandler extends RmiObjectHandler { else { // override the implementation object with the proxy. This is required because RMI must be the same between "network" and "local" // connections -- even if this "slows down" the speed/performance of what "local" connections offer. - proxyObject = connection.getProxyObject(registration.rmiId, interfaceClass); + proxyObject = rmiSupport.getLocalProxyObject(connection, registration.rmiId, interfaceClass, registration.remoteObject); if (proxyObject != null && registration.remoteObject != null) { // have to save A and B so we can correctly switch as necessary @@ -205,7 +198,7 @@ class RmiObjectLocalHandler extends RmiObjectHandler { @SuppressWarnings("unchecked") @Override public - Object normalMessages(final ConnectionImpl connection, final Object message) { + Object normalMessages(final ConnectionRmiSupport rmiSupport, final Object message) { // else, this was "just a local message" // because we NORMALLY pass around just the object (there is no serialization going on...) we have to explicitly check to see @@ -250,10 +243,10 @@ class RmiObjectLocalHandler extends RmiObjectHandler { o = field.get(message); if (o instanceof RemoteObject) { - RmiProxyHandler handler = (RmiProxyHandler) Proxy.getInvocationHandler(o); + RmiProxyLocalHandler handler = (RmiProxyLocalHandler) Proxy.getInvocationHandler(o); int id = handler.rmiObjectId; - field.set(message, connection.getImplementationObject(id)); + field.set(message, rmiSupport.getImplementationObject(id)); fields.add(field); } else { @@ -283,7 +276,7 @@ class RmiObjectLocalHandler extends RmiObjectHandler { array = NO_REMOTE_FIELDS; } else { - array = fields.toArray(new Field[fields.size()]); + array = fields.toArray(new Field[0]); } //noinspection SynchronizeOnNonFinalField @@ -303,10 +296,10 @@ class RmiObjectLocalHandler extends RmiObjectHandler { o = field.get(message); if (o instanceof RemoteObject) { - RmiProxyHandler handler = (RmiProxyHandler) Proxy.getInvocationHandler(o); + RmiProxyNetworkHandler handler = (RmiProxyNetworkHandler) Proxy.getInvocationHandler(o); int id = handler.rmiObjectId; - field.set(message, connection.getImplementationObject(id)); + field.set(message, rmiSupport.getImplementationObject(id)); } else { // is a field supposed to be a proxy? diff --git a/src/dorkbox/network/rmi/RmiObjectNetworkHandler.java b/src/dorkbox/network/rmi/RmiObjectNetworkHandler.java index d9390b00..4827fca8 100644 --- a/src/dorkbox/network/rmi/RmiObjectNetworkHandler.java +++ b/src/dorkbox/network/rmi/RmiObjectNetworkHandler.java @@ -18,10 +18,10 @@ package dorkbox.network.rmi; import org.slf4j.Logger; import dorkbox.network.connection.ConnectionImpl; -import dorkbox.network.connection.Listener; +import dorkbox.network.serialization.CryptoSerializationManager; public -class RmiObjectNetworkHandler extends RmiObjectHandler { +class RmiObjectNetworkHandler implements RmiObjectHandler { private final Logger logger; @@ -30,16 +30,15 @@ class RmiObjectNetworkHandler extends RmiObjectHandler { this.logger = logger; } - @Override public - void invoke(final ConnectionImpl connection, final InvokeMethod message, final Listener.OnMessageReceived rmiInvokeListener) { - // default, nothing fancy - rmiInvokeListener.received(connection, message); + InvokeMethod getInvokeMethod(final CryptoSerializationManager serialization, final ConnectionImpl connection, final InvokeMethod invokeMethod) { + // everything is fine, there is nothing necessary to fix + return invokeMethod; } @Override public - void registration(final ConnectionImpl connection, final RmiRegistration registration) { + void registration(final ConnectionRmiSupport rmiSupport, final ConnectionImpl connection, final RmiRegistration registration) { // manage creating/getting/notifying this RMI object // these fields are ALWAYS present! @@ -55,11 +54,13 @@ class RmiObjectNetworkHandler extends RmiObjectHandler { // CREATE a new ID, and register the ID and new object (must create a new one) in the object maps // have to lookup the implementation class - Class rmiImpl = connection.getEndPoint().getSerialization().getRmiImpl(interfaceClass); + CryptoSerializationManager serialization = connection.getEndPoint().getSerialization(); + + Class rmiImpl = serialization.getRmiImpl(interfaceClass); // For network connections, the interface class kryo ID == implementation class kryo ID, so they switch automatically. - RmiRegistration registrationResult = connection.createNewRmiObject(interfaceClass, rmiImpl, callbackId); + RmiRegistration registrationResult = rmiSupport.createNewRmiObject(serialization, interfaceClass, rmiImpl, callbackId, logger); connection.send(registrationResult); // connection transport is flushed in calling method (don't need to do it here) } @@ -69,8 +70,8 @@ class RmiObjectNetworkHandler extends RmiObjectHandler { // THIS IS ON THE REMOTE CONNECTION (where the object implementation will really exist) // // GET a LOCAL rmi object, if none get a specific, GLOBAL rmi object (objects that are not bound to a single connection). - RmiRegistration registrationResult = connection.getExistingRmiObject(interfaceClass, registration.rmiId, callbackId); - connection.send(registrationResult); + Object implementationObject = rmiSupport.getImplementationObject(registration.rmiId); + connection.send(new RmiRegistration(interfaceClass, registration.rmiId, callbackId, implementationObject)); // connection transport is flushed in calling method (don't need to do it here) } } @@ -84,4 +85,10 @@ class RmiObjectNetworkHandler extends RmiObjectHandler { connection.runRmiCallback(interfaceClass, callbackId, registration.remoteObject); } } + + @Override + public + Object normalMessages(final ConnectionRmiSupport connection, final Object message) { + return message; + } } diff --git a/src/dorkbox/network/rmi/RmiProxyLocalHandler.java b/src/dorkbox/network/rmi/RmiProxyLocalHandler.java new file mode 100644 index 00000000..bfe2e19a --- /dev/null +++ b/src/dorkbox/network/rmi/RmiProxyLocalHandler.java @@ -0,0 +1,441 @@ +/* + * Copyright 2010 dorkbox, llc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Copyright (c) 2008, Nathan Sweet + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided with the distribution. + * - Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, + * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + * SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package dorkbox.network.rmi; + + +import java.io.IOException; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import dorkbox.network.connection.Connection; +import dorkbox.network.connection.ConnectionImpl; +import dorkbox.network.connection.EndPoint; +import dorkbox.network.connection.KryoExtra; +import dorkbox.network.connection.Listener; +import dorkbox.network.serialization.RmiSerializationManager; + +/** + * Handles network communication when methods are invoked on a proxy. + *

+ * If the method return type is 'void', then we don't have to explicitly set 'transmitReturnValue' to false + *

+ * If there are no checked exceptions thrown, then we don't have to explicitly set 'transmitExceptions' to false + */ +public +class RmiProxyLocalHandler implements InvocationHandler { + private final Logger logger; + + private final ReentrantLock lock = new ReentrantLock(); + private final Condition responseCondition = this.lock.newCondition(); + + private final InvokeMethodResult[] responseTable = new InvokeMethodResult[64]; + private final boolean[] pendingResponses = new boolean[64]; + + private final ConnectionImpl connection; + public final int rmiObjectId; // this is the RMI id + public final int ID; // this is the KRYO id + + + private final String proxyString; + private final + Listener.OnMessageReceived responseListener; + + private int timeoutMillis = 3000; + private boolean isAsync = false; + + // if the return type is 'void', then this has no meaning. + private boolean transmitReturnValue = false; + + // if there are no checked exceptions thrown, then this has no meaning + private boolean transmitExceptions = false; + + private boolean enableToString; + + private boolean udp; + + private Byte lastResponseID; + private byte nextResponseId = (byte) 1; + + /** + * @param connection this is really the network client -- there is ONLY ever 1 connection + * @param rmiId this is the remote object ID (assigned by RMI). This is NOT the kryo registration ID + * @param iFace this is the RMI interface + * @param object + */ + public + RmiProxyLocalHandler(final ConnectionImpl connection, final int rmiId, final Class iFace, final Object object) { + super(); + + this.connection = connection; + this.rmiObjectId = rmiId; + this.proxyString = ""; + + EndPoint endPointConnection = this.connection.getEndPoint(); + final RmiSerializationManager serializationManager = endPointConnection.getSerialization(); + + KryoExtra kryoExtra = null; + try { + kryoExtra = serializationManager.takeKryo(); + this.ID = kryoExtra.getRegistration(iFace).getId(); + } finally { + if (kryoExtra != null) { + serializationManager.returnKryo(kryoExtra); + } + } + + this.logger = LoggerFactory.getLogger(connection.getEndPoint().getName() + ":" + this.getClass().getSimpleName()); + + this.responseListener = new Listener.OnMessageReceived() { + @Override + public + void received(Connection connection, InvokeMethodResult invokeMethodResult) { + byte responseID = invokeMethodResult.responseId; + + if (invokeMethodResult.rmiObjectId != rmiId) { + return; + } + + synchronized (this) { + if (RmiProxyLocalHandler.this.pendingResponses[responseID]) { + RmiProxyLocalHandler.this.responseTable[responseID] = invokeMethodResult; + } + } + + RmiProxyLocalHandler.this.lock.lock(); + try { + RmiProxyLocalHandler.this.responseCondition.signalAll(); + } finally { + RmiProxyLocalHandler.this.lock.unlock(); + } + } + }; + } + + public + Listener.OnMessageReceived getListener() { + return responseListener; + } + + @SuppressWarnings({"AutoUnboxing", "AutoBoxing", "NumericCastThatLosesPrecision", "IfCanBeSwitch"}) + @Override + public + Object invoke(final Object proxy, final Method method, final Object[] args) throws Exception { + final Class declaringClass = method.getDeclaringClass(); + if (declaringClass == RemoteObject.class) { + // manage all of the RemoteObject proxy methods + + String name = method.getName(); + if (name.equals("close")) { + connection.removeRmiListeners(rmiObjectId, getListener()); + return null; + } + else if (name.equals("setResponseTimeout")) { + this.timeoutMillis = (Integer) args[0]; + return null; + } + else if (name.equals("getResponseTimeout")) { + return this.timeoutMillis; + } + else if (name.equals("setAsync")) { + this.isAsync = (Boolean) args[0]; + return null; + } + else if (name.equals("setTransmitReturnValue")) { + this.transmitReturnValue = (Boolean) args[0]; + return null; + } + else if (name.equals("setTransmitExceptions")) { + this.transmitExceptions = (Boolean) args[0]; + return null; + } + else if (name.equals("setTCP")) { + this.udp = false; + return null; + } + else if (name.equals("setUDP")) { + this.udp = true; + return null; + } + else if (name.equals("enableToString")) { + this.enableToString = (Boolean) args[0]; + return null; + } + else if (name.equals("waitForLastResponse")) { + if (this.lastResponseID == null) { + throw new IllegalStateException("There is no last response to wait for."); + } + return waitForResponse(this.lastResponseID); + } + else if (name.equals("getLastResponseID")) { + if (this.lastResponseID == null) { + throw new IllegalStateException("There is no last response ID."); + } + return this.lastResponseID; + } + else if (name.equals("waitForResponse")) { + if (!this.transmitReturnValue && !this.transmitExceptions && this.isAsync) { + throw new IllegalStateException("This RemoteObject is currently set to ignore all responses."); + } + return waitForResponse((Byte) args[0]); + } + + // Should never happen, for debugging purposes only! + throw new Exception("Invocation handler could not find RemoteObject method for " + name); + } + else if (!this.enableToString && declaringClass == Object.class && method.getName() + .equals("toString")) { + return proxyString; + } + + InvokeMethod invokeMethod = new InvokeMethod(); + invokeMethod.objectID = this.rmiObjectId; + invokeMethod.args = args; + + // which method do we access? We always want to access the IMPLEMENTATION (if available!) + CachedMethod[] cachedMethods = connection.getEndPoint() + .getSerialization() + .getMethods(ID); + + for (int i = 0, n = cachedMethods.length; i < n; i++) { + CachedMethod cachedMethod = cachedMethods[i]; + Method checkMethod = cachedMethod.method; + + if (checkMethod.equals(method)) { + invokeMethod.cachedMethod = cachedMethod; + break; + } + } + + if (invokeMethod.cachedMethod == null) { + String msg = "Method not found: " + method; + logger.error(msg); + return msg; + } + + + byte responseID = (byte) 0; + Class returnType = method.getReturnType(); + + // If the method return type is 'void', then we don't have to explicitly set 'transmitReturnValue' to false + boolean shouldReturnValue = returnType != void.class || this.transmitReturnValue; + + // If there are no checked exceptions thrown, then we don't have to explicitly set 'transmitExceptions' to false + boolean shouldTransmitExceptions = (method.getExceptionTypes().length != 0 || method.getGenericExceptionTypes().length != 0) || this.transmitExceptions; + + // If we are async (but still have a return type or throw checked exceptions) then we ignore the response + // If we are 'void' return type and do not throw checked exceptions then we ignore the response + boolean ignoreResponse = (this.isAsync || returnType == void.class) && !(shouldReturnValue || shouldTransmitExceptions); + + if (ignoreResponse) { + invokeMethod.responseData = (byte) 0; // 0 means do not respond. + } + else { + synchronized (this) { + // Increment the response counter and put it into the low bits of the responseID. + responseID = this.nextResponseId++; + if (this.nextResponseId > RmiBridge.responseIdMask) { + this.nextResponseId = (byte) 1; + } + this.pendingResponses[responseID] = true; + } + // Pack other data into the high bits. + byte responseData = responseID; + if (shouldReturnValue) { + responseData |= (byte) RmiBridge.returnValueMask; + } + if (shouldTransmitExceptions) { + responseData |= (byte) RmiBridge.returnExceptionMask; + } + invokeMethod.responseData = responseData; + } + + byte lastResponseID = (byte) (invokeMethod.responseData & RmiBridge.responseIdMask); + this.lastResponseID = lastResponseID; + + // Sends our invokeMethod to the remote connection, which the RmiBridge listens for + if (this.udp) { + // flush is necessary in case this is called outside of a network worker thread + this.connection.UDP(invokeMethod).flush(); + } + else { + // flush is necessary in case this is called outside of a network worker thread + this.connection.send(invokeMethod).flush(); + } + + if (logger.isTraceEnabled()) { + String argString = ""; + if (args != null) { + argString = Arrays.deepToString(args); + argString = argString.substring(1, argString.length() - 1); + } + logger.trace(this.connection + " sent: " + method.getDeclaringClass() + .getSimpleName() + + "#" + method.getName() + "(" + argString + ")"); + } + + // MUST use 'waitForLastResponse()' or 'waitForResponse'('getLastResponseID()') to get the response + // If we are async then we return immediately + // If we are 'void' return type and do not throw checked exceptions then we return immediately + boolean respondImmediately = this.isAsync || (returnType == void.class) && !(shouldReturnValue || shouldTransmitExceptions); + if (respondImmediately) { + if (returnType.isPrimitive()) { + if (returnType == int.class) { + return 0; + } + if (returnType == boolean.class) { + return Boolean.FALSE; + } + if (returnType == float.class) { + return 0.0f; + } + if (returnType == char.class) { + return (char) 0; + } + if (returnType == long.class) { + return 0L; + } + if (returnType == short.class) { + return (short) 0; + } + if (returnType == byte.class) { + return (byte) 0; + } + if (returnType == double.class) { + return 0.0d; + } + } + return null; + } + + try { + Object result = waitForResponse(lastResponseID); + if (result instanceof Exception) { + throw (Exception) result; + } + else { + return result; + } + } catch (TimeoutException ex) { + throw new TimeoutException("Response timed out: " + method.getDeclaringClass() + .getName() + "." + method.getName()); + } finally { + synchronized (this) { + this.pendingResponses[responseID] = false; + this.responseTable[responseID] = null; + } + } + } + + /** + * A timeout of 0 means that we want to disable waiting, otherwise - it waits in milliseconds + */ + private + Object waitForResponse(final byte responseID) throws IOException { + // if timeout == 0, we wait "forever" + long remaining; + long endTime; + + if (this.timeoutMillis != 0) { + remaining = this.timeoutMillis; + endTime = System.currentTimeMillis() + remaining; + } else { + // not forever, but close enough + remaining = Long.MAX_VALUE; + endTime = Long.MAX_VALUE; + } + + // wait for the specified time + while (remaining > 0) { + InvokeMethodResult invokeMethodResult; + synchronized (this) { + invokeMethodResult = this.responseTable[responseID]; + } + + if (invokeMethodResult != null) { + this.lastResponseID = null; + return invokeMethodResult.result; + } + else { + this.lock.lock(); + try { + this.responseCondition.await(remaining, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread() + .interrupt(); + throw new IOException("Response timed out.", e); + } finally { + this.lock.unlock(); + } + } + + remaining = endTime - System.currentTimeMillis(); + } + + // only get here if we timeout + throw new TimeoutException("Response timed out."); + } + + @Override + public + int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + this.rmiObjectId; + return result; + } + + @Override + public + boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + RmiProxyLocalHandler other = (RmiProxyLocalHandler) obj; + return this.rmiObjectId == other.rmiObjectId; + } +} diff --git a/src/dorkbox/network/rmi/RmiProxyHandler.java b/src/dorkbox/network/rmi/RmiProxyNetworkHandler.java similarity index 96% rename from src/dorkbox/network/rmi/RmiProxyHandler.java rename to src/dorkbox/network/rmi/RmiProxyNetworkHandler.java index 27cba0e2..efb874e3 100644 --- a/src/dorkbox/network/rmi/RmiProxyHandler.java +++ b/src/dorkbox/network/rmi/RmiProxyNetworkHandler.java @@ -61,7 +61,7 @@ import dorkbox.network.serialization.RmiSerializationManager; * If there are no checked exceptions thrown, then we don't have to explicitly set 'transmitExceptions' to false */ public -class RmiProxyHandler implements InvocationHandler { +class RmiProxyNetworkHandler implements InvocationHandler { private final Logger logger; private final ReentrantLock lock = new ReentrantLock(); @@ -101,7 +101,7 @@ class RmiProxyHandler implements InvocationHandler { * @param iFace this is the RMI interface */ public - RmiProxyHandler(final ConnectionImpl connection, final int rmiId, final Class iFace) { + RmiProxyNetworkHandler(final ConnectionImpl connection, final int rmiId, final Class iFace) { super(); this.connection = connection; @@ -134,16 +134,16 @@ class RmiProxyHandler implements InvocationHandler { } synchronized (this) { - if (RmiProxyHandler.this.pendingResponses[responseID]) { - RmiProxyHandler.this.responseTable[responseID] = invokeMethodResult; + if (RmiProxyNetworkHandler.this.pendingResponses[responseID]) { + RmiProxyNetworkHandler.this.responseTable[responseID] = invokeMethodResult; } } - RmiProxyHandler.this.lock.lock(); + RmiProxyNetworkHandler.this.lock.lock(); try { - RmiProxyHandler.this.responseCondition.signalAll(); + RmiProxyNetworkHandler.this.responseCondition.signalAll(); } finally { - RmiProxyHandler.this.lock.unlock(); + RmiProxyNetworkHandler.this.lock.unlock(); } } }; @@ -434,7 +434,7 @@ class RmiProxyHandler implements InvocationHandler { if (getClass() != obj.getClass()) { return false; } - RmiProxyHandler other = (RmiProxyHandler) obj; + RmiProxyNetworkHandler other = (RmiProxyNetworkHandler) obj; return this.rmiObjectId == other.rmiObjectId; } } diff --git a/test/dorkbox/network/ListenerTest.java b/test/dorkbox/network/ListenerTest.java index 500e88d2..5e9e0410 100644 --- a/test/dorkbox/network/ListenerTest.java +++ b/test/dorkbox/network/ListenerTest.java @@ -19,17 +19,23 @@ */ package dorkbox.network; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.io.IOException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; -import org.slf4j.Logger; -import dorkbox.network.connection.*; -import dorkbox.network.rmi.RmiBridge; +import dorkbox.network.connection.Connection; +import dorkbox.network.connection.ConnectionImpl; +import dorkbox.network.connection.EndPoint; +import dorkbox.network.connection.Listener; +import dorkbox.network.connection.Listeners; +import dorkbox.network.connection.wrapper.ChannelWrapper; import dorkbox.util.exceptions.InitializationException; import dorkbox.util.exceptions.SecurityException; @@ -57,8 +63,8 @@ class ListenerTest extends BaseTest { // quick and dirty test to also test connection sub-classing class TestConnectionA extends ConnectionImpl { public - TestConnectionA(final Logger logger, final EndPoint endPointConnection, final RmiBridge rmiBridge) { - super(logger, endPointConnection, rmiBridge); + TestConnectionA(final EndPoint endPointConnection, final ChannelWrapper wrapper) { + super(endPointConnection, wrapper); } public @@ -70,8 +76,8 @@ class ListenerTest extends BaseTest { class TestConnectionB extends TestConnectionA { public - TestConnectionB(final Logger logger, final EndPoint endPointConnection, final RmiBridge rmiBridge) { - super(logger, endPointConnection, rmiBridge); + TestConnectionB(final EndPoint endPointConnection, final ChannelWrapper wrapper) { + super(endPointConnection, wrapper); } @Override @@ -102,8 +108,8 @@ class ListenerTest extends BaseTest { Server server = new Server(configuration) { @Override public - TestConnectionA newConnection(final Logger logger, final EndPoint endPoint, final RmiBridge rmiBridge) { - return new TestConnectionA(logger, endPoint, rmiBridge); + TestConnectionA newConnection(final EndPoint endPoint, final ChannelWrapper wrapper) { + return new TestConnectionA(endPoint, wrapper); } }; diff --git a/test/dorkbox/network/rmi/RmiSendObjectTest.java b/test/dorkbox/network/rmi/RmiSendObjectTest.java index 98c0acd4..f505136c 100644 --- a/test/dorkbox/network/rmi/RmiSendObjectTest.java +++ b/test/dorkbox/network/rmi/RmiSendObjectTest.java @@ -175,7 +175,7 @@ class RmiSendObjectTest extends BaseTest { } }); - client.connect(5000); + client.connect(0); waitForThreads(); } diff --git a/test/dorkbox/network/rmi/RmiTest.java b/test/dorkbox/network/rmi/RmiTest.java index 807f025a..90f37453 100644 --- a/test/dorkbox/network/rmi/RmiTest.java +++ b/test/dorkbox/network/rmi/RmiTest.java @@ -330,7 +330,7 @@ class RmiTest extends BaseTest { } }); - client.connect(5000); + client.connect(0); waitForThreads(); }