diff --git a/src/dorkbox/network/Client.java b/src/dorkbox/network/Client.java index ff860c62..d4ce9ad6 100644 --- a/src/dorkbox/network/Client.java +++ b/src/dorkbox/network/Client.java @@ -43,6 +43,8 @@ import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.epoll.EpollDatagramChannel; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.kqueue.KQueueDatagramChannel; +import io.netty.channel.kqueue.KQueueSocketChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.nio.NioEventLoopGroup; @@ -167,6 +169,10 @@ class Client extends EndPointClient implements Connecti // JNI network stack is MUCH faster (but only on linux) tcpBootstrap.channel(EpollSocketChannel.class); } + else if (OS.isMacOsX()) { + // JNI network stack is MUCH faster (but only on macosx) + tcpBootstrap.channel(KQueueSocketChannel.class); + } else { tcpBootstrap.channel(NioSocketChannel.class); } @@ -197,7 +203,12 @@ class Client extends EndPointClient implements Connecti // JNI network stack is MUCH faster (but only on linux) udpBootstrap.channel(EpollDatagramChannel.class); } + else if (OS.isMacOsX()) { + // JNI network stack is MUCH faster (but only on macosx) + udpBootstrap.channel(KQueueDatagramChannel.class); + } else { + // windows udpBootstrap.channel(NioDatagramChannel.class); } @@ -241,7 +252,7 @@ class Client extends EndPointClient implements Connecti * if the client is unable to reconnect in the requested time */ public - void reconnect(int connectionTimeout) throws IOException { + void reconnect(final int connectionTimeout) throws IOException { // close out all old connections closeConnections(); @@ -291,23 +302,8 @@ class Client extends EndPointClient implements Connecti } } - // have to start the registration process - registerNextProtocol(); - - // have to BLOCK - // don't want the client to run before registration is complete - synchronized (registrationLock) { - if (!registrationComplete) { - try { - registrationLock.wait(connectionTimeout); - } catch (InterruptedException e) { - throw new IOException("Unable to complete registration within '" + connectionTimeout + "' milliseconds", e); - } - } - } - - // RMI methods are usually created during the connection phase. We should wait until they are finished - waitForRmi(connectionTimeout); + // have to start the registration process. This will wait until registration is complete and RMI methods are initialized + startRegistration(); } @Override diff --git a/src/dorkbox/network/Server.java b/src/dorkbox/network/Server.java index cda523e5..c21e63cd 100644 --- a/src/dorkbox/network/Server.java +++ b/src/dorkbox/network/Server.java @@ -36,10 +36,11 @@ import io.netty.channel.ChannelOption; import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; import io.netty.channel.WriteBufferWaterMark; -import io.netty.channel.epoll.EpollChannelOption; import io.netty.channel.epoll.EpollDatagramChannel; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.kqueue.KQueueDatagramChannel; +import io.netty.channel.kqueue.KQueueServerSocketChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalServerChannel; import io.netty.channel.nio.NioEventLoopGroup; @@ -48,6 +49,7 @@ import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.oio.OioDatagramChannel; import io.netty.channel.socket.oio.OioServerSocketChannel; +import io.netty.channel.unix.UnixChannelOption; /** * The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the @@ -200,6 +202,10 @@ class Server extends EndPointServer { // JNI network stack is MUCH faster (but only on linux) tcpBootstrap.channel(EpollServerSocketChannel.class); } + else if (OS.isMacOsX()) { + // JNI network stack is MUCH faster (but only on macosx) + tcpBootstrap.channel(KQueueServerSocketChannel.class); + } else { tcpBootstrap.channel(NioServerSocketChannel.class); } @@ -235,14 +241,21 @@ class Server extends EndPointServer { if (udpBootstrap != null) { if (OS.isAndroid()) { // android ONLY supports OIO (not NIO) - udpBootstrap.channel(OioDatagramChannel.class); + udpBootstrap.channel(OioDatagramChannel.class) + .option(UnixChannelOption.SO_REUSEPORT, true); } else if (OS.isLinux()) { // JNI network stack is MUCH faster (but only on linux) udpBootstrap.channel(EpollDatagramChannel.class) - .option(EpollChannelOption.SO_REUSEPORT, true); + .option(UnixChannelOption.SO_REUSEPORT, true); + } + else if (OS.isMacOsX()) { + // JNI network stack is MUCH faster (but only on macosx) + udpBootstrap.channel(KQueueDatagramChannel.class) + .option(UnixChannelOption.SO_REUSEPORT, true); } else { + // windows udpBootstrap.channel(NioDatagramChannel.class); } diff --git a/src/dorkbox/network/connection/ConnectionImpl.java b/src/dorkbox/network/connection/ConnectionImpl.java index 6023bd5b..7de26e7b 100644 --- a/src/dorkbox/network/connection/ConnectionImpl.java +++ b/src/dorkbox/network/connection/ConnectionImpl.java @@ -20,6 +20,7 @@ import java.lang.reflect.Field; import java.util.LinkedList; import java.util.Map; import java.util.WeakHashMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -93,7 +94,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn // while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error. private boolean remoteKeyChanged; - private final EndPointBase endPointBaseConnection; + private final EndPointBase endPoint; // when true, the connection will be closed (either as RMI or as 'normal' listener execution) when the thread execution returns control // back to the network stack @@ -107,6 +108,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn // // RMI fields // + protected CountDownLatch rmi; private final RmiBridge rmiBridge; private final Map proxyIdCache = new WeakHashMap(8); private final IntMap rmiRegistrationCallbacks = new IntMap(); @@ -117,9 +119,9 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn */ @SuppressWarnings({"rawtypes", "unchecked"}) public - ConnectionImpl(final Logger logger, final EndPointBase endPointBaseConnection, final RmiBridge rmiBridge) { + ConnectionImpl(final Logger logger, final EndPointBase endPoint, final RmiBridge rmiBridge) { this.logger = logger; - this.endPointBaseConnection = endPointBaseConnection; + this.endPoint = endPoint; this.rmiBridge = rmiBridge; } @@ -215,7 +217,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn @Override public EndPointBase getEndPoint() { - return this.endPointBaseConnection; + return this.endPoint; } /** @@ -328,7 +330,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn */ private void controlBackPressure(ConnectionPoint c) { - while (!c.isWritable()) { + while (!closeInProgress.get() && !c.isWritable()) { needsLock.set(true); writeSignalNeeded.set(true); @@ -590,6 +592,10 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn .toString()); } + if (this.endPoint instanceof EndPointClient) { + ((EndPointClient) this.endPoint).abortRegistration(); + } + // our master channels are TCP/LOCAL (which are mutually exclusive). Only key disconnect events based on the status of them. if (isTCP || channelClass == LocalChannel.class) { // this is because channelInactive can ONLY happen when netty shuts down the channel. @@ -614,7 +620,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn void close() { // only close if we aren't already in the middle of closing. if (this.closeInProgress.compareAndSet(false, true)) { - int idleTimeoutMs = this.endPointBaseConnection.getIdleTimeout(); + int idleTimeoutMs = this.endPoint.getIdleTimeout(); if (idleTimeoutMs == 0) { // default is 2 second timeout, in milliseconds. idleTimeoutMs = 2000; @@ -714,7 +720,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn @Override public final Listeners add(Listener listener) { - if (this.endPointBaseConnection instanceof EndPointServer) { + if (this.endPoint instanceof EndPointServer) { // when we are a server, NORMALLY listeners are added at the GLOBAL level // meaning -- // I add one listener, and ALL connections are notified of that listener. @@ -726,15 +732,15 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn // is empty, we can remove it from this connection. synchronized (this) { if (this.localListenerManager == null) { - this.localListenerManager = ((EndPointServer) this.endPointBaseConnection).addListenerManager(this); + this.localListenerManager = ((EndPointServer) this.endPoint).addListenerManager(this); } this.localListenerManager.add(listener); } } else { - this.endPointBaseConnection.listeners() - .add(listener); + this.endPoint.listeners() + .add(listener); } return this; @@ -756,7 +762,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn @Override public final Listeners remove(Listener listener) { - if (this.endPointBaseConnection instanceof EndPointServer) { + if (this.endPoint instanceof EndPointServer) { // when we are a server, NORMALLY listeners are added at the GLOBAL level // meaning -- // I add one listener, and ALL connections are notified of that listener. @@ -771,14 +777,14 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn this.localListenerManager.remove(listener); if (!this.localListenerManager.hasListeners()) { - ((EndPointServer) this.endPointBaseConnection).removeListenerManager(this); + ((EndPointServer) this.endPoint).removeListenerManager(this); } } } } else { - this.endPointBaseConnection.listeners() - .remove(listener); + this.endPoint.listeners() + .remove(listener); } return this; @@ -791,7 +797,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn @Override public final Listeners removeAll() { - if (this.endPointBaseConnection instanceof EndPointServer) { + if (this.endPoint instanceof EndPointServer) { // when we are a server, NORMALLY listeners are added at the GLOBAL level // meaning -- // I add one listener, and ALL connections are notified of that listener. @@ -806,13 +812,13 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn this.localListenerManager.removeAll(); this.localListenerManager = null; - ((EndPointServer) this.endPointBaseConnection).removeListenerManager(this); + ((EndPointServer) this.endPoint).removeListenerManager(this); } } } else { - this.endPointBaseConnection.listeners() - .removeAll(); + this.endPoint.listeners() + .removeAll(); } return this; @@ -826,7 +832,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn @Override public final Listeners removeAll(Class classType) { - if (this.endPointBaseConnection instanceof EndPointServer) { + if (this.endPoint instanceof EndPointServer) { // when we are a server, NORMALLY listeners are added at the GLOBAL level // meaning -- // I add one listener, and ALL connections are notified of that listener. @@ -842,14 +848,14 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn if (!this.localListenerManager.hasListeners()) { this.localListenerManager = null; - ((EndPointServer) this.endPointBaseConnection).removeListenerManager(this); + ((EndPointServer) this.endPoint).removeListenerManager(this); } } } } else { - this.endPointBaseConnection.listeners() - .removeAll(classType); + this.endPoint.listeners() + .removeAll(classType); } return this; @@ -899,62 +905,10 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn // RMI methods // - /** - * Internal call CLIENT ONLY. - *

- * RMI methods are usually created during the connection phase. We should wait until they are finished - */ - void waitForRmi(final int connectionTimeout) { - synchronized (rmiRegistrationCallbacks) { - try { - rmiRegistrationCallbacks.wait(connectionTimeout); - } catch (InterruptedException e) { - logger.error("Interrupted waiting for RMI to finish.", e); - } - } - } - - /** - * Internal call CLIENT ONLY. - *

- * RMI methods are usually created during the connection phase. If there are none, we should unblock the waiting client.connect(). - */ - boolean rmiCallbacksIsEmpty() { - synchronized (rmiRegistrationCallbacks) { - return rmiRegistrationCallbacks.size == 0; - } - } - - /** - * Internal call CLIENT ONLY. - *

- * RMI methods are usually created during the connection phase. If there are none, we should unblock the waiting client.connect(). - */ - void rmiCallbacksNotify() { - synchronized (rmiRegistrationCallbacks) { - rmiRegistrationCallbacks.notify(); - } - } - - /** - * Internal call CLIENT ONLY. - *

- * RMI methods are usually created during the connection phase. If there are none, we should unblock the waiting client.connect(). - */ - private - void rmiCallbacksNotifyIfEmpty() { - synchronized (rmiRegistrationCallbacks) { - if (rmiRegistrationCallbacks.size == 0) { - rmiRegistrationCallbacks.notify(); - } - } - } - - @SuppressWarnings({"UnnecessaryLocalVariable", "unchecked", "Duplicates"}) @Override public final - void getRemoteObject(final Class interfaceClass, final RemoteObjectCallback callback) throws IOException { + void getRemoteObject(final Class interfaceClass, final RemoteObjectCallback callback) { if (!interfaceClass.isInterface()) { throw new IllegalArgumentException("Cannot create a proxy for RMI access. It must be an interface."); } @@ -977,7 +931,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn @SuppressWarnings({"UnnecessaryLocalVariable", "unchecked", "Duplicates"}) @Override public final - void getRemoteObject(final int objectId, final RemoteObjectCallback callback) throws IOException { + void getRemoteObject(final int objectId, final RemoteObjectCallback callback) { RmiRegistration message; synchronized (rmiRegistrationCallbacks) { @@ -1079,12 +1033,14 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn } } else { - // THIS IS ON THE LOCAL CONNECTION SIDE, which is the side that called 'getRemoteObject()' + // THIS IS ON THE LOCAL CONNECTION SIDE, which is the side that called 'getRemoteObject()' This can be Server or Client. // this will be null if there was an error Object remoteObject = remoteRegistration.remoteObject; + boolean noMoreRmiRemaining ; RemoteObjectCallback callback; + synchronized (rmiRegistrationCallbacks) { callback = rmiRegistrationCallbacks.remove(remoteRegistration.rmiID); } @@ -1095,9 +1051,6 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn } catch (Exception e) { logger.error("Error getting remote object " + remoteObject.getClass() + ", ID: " + rmiID, e); } - - // tell the client that we are finished with all RMI callbacks - rmiCallbacksNotifyIfEmpty(); } } @@ -1111,7 +1064,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn public int getRegisteredId(final T object) { // always check local before checking global, because less contention on the synchronization - RmiBridge globalRmiBridge = endPointBaseConnection.globalRmiBridge; + RmiBridge globalRmiBridge = endPoint.globalRmiBridge; if (globalRmiBridge == null) { throw new NullPointerException("Unable to call 'getRegisteredId' when the globalRmiBridge is null!"); @@ -1155,7 +1108,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn public Object getImplementationObject(final int objectID) { if (RmiBridge.isGlobal(objectID)) { - RmiBridge globalRmiBridge = endPointBaseConnection.globalRmiBridge; + RmiBridge globalRmiBridge = endPoint.globalRmiBridge; if (globalRmiBridge == null) { throw new NullPointerException("Unable to call 'getRegisteredId' when the gloablRmiBridge is null!"); diff --git a/src/dorkbox/network/connection/EndPointBase.java b/src/dorkbox/network/connection/EndPointBase.java index debbac51..c6c633b1 100644 --- a/src/dorkbox/network/connection/EndPointBase.java +++ b/src/dorkbox/network/connection/EndPointBase.java @@ -17,7 +17,6 @@ package dorkbox.network.connection; import java.io.IOException; import java.security.SecureRandom; -import java.util.Collection; import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @@ -277,7 +276,7 @@ class EndPointBase extends EndPoint { */ public void setIdleTimeout(int idleTimeoutMs) { - idleTimeoutMs = idleTimeoutMs; + this.idleTimeoutMs = idleTimeoutMs; } /** @@ -308,8 +307,8 @@ class EndPointBase extends EndPoint { * @return a new network connection */ protected - ConnectionImpl newConnection(final Logger logger, final EndPointBase endPointBaseConnection, final RmiBridge rmiBridge) { - return new ConnectionImpl(logger, endPointBaseConnection, rmiBridge); + ConnectionImpl newConnection(final Logger logger, final EndPointBase endPoint, final RmiBridge rmiBridge) { + return new ConnectionImpl(logger, endPoint, rmiBridge); } /** @@ -401,15 +400,6 @@ class EndPointBase extends EndPoint { return connectionManager.getConnections(); } - /** - * Returns a non-modifiable list of active connections - */ - @SuppressWarnings("unchecked") - public - Collection getConnectionsAs() { - return connectionManager.getConnections(); - } - /** * Expose methods to send objects to a destination. */ diff --git a/src/dorkbox/network/connection/EndPointClient.java b/src/dorkbox/network/connection/EndPointClient.java index cf9cbee9..bc4ee49f 100644 --- a/src/dorkbox/network/connection/EndPointClient.java +++ b/src/dorkbox/network/connection/EndPointClient.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import dorkbox.network.Client; import dorkbox.network.Configuration; @@ -32,19 +34,18 @@ import io.netty.channel.ChannelOption; * This serves the purpose of making sure that specific methods are not available to the end user. */ public -class EndPointClient extends EndPointBase implements Runnable { +class EndPointClient extends EndPointBase { protected C connection; - protected final Object registrationLock = new Object(); + private CountDownLatch registration; - protected final Object bootstrapLock = new Object(); + private final Object bootstrapLock = new Object(); protected List bootstraps = new LinkedList(); - protected Iterator bootstrapIterator; + private Iterator bootstrapIterator; + + protected volatile int connectionTimeout = 5000; // default is 5 seconds - protected volatile int connectionTimeout = 5000; // default - protected volatile boolean registrationComplete = false; - private volatile boolean rmiInitializationComplete = false; private volatile ConnectionBridge connectionBridgeFlushAlways; @@ -55,36 +56,36 @@ class EndPointClient extends EndPointBase implements Ru } protected - void registerNextProtocol() { - // always reset everything. - registrationComplete = false; - bootstrapIterator = bootstraps.iterator(); + void startRegistration() throws IOException { + synchronized (bootstrapLock) { + // always reset everything. + registration = new CountDownLatch(1); - startProtocolRegistration(); + bootstrapIterator = bootstraps.iterator(); + } + + doRegistration(); + + // have to BLOCK + // don't want the client to run before registration is complete + try { + if (!registration.await(connectionTimeout, TimeUnit.MILLISECONDS)) { + throw new IOException("Unable to complete registration within '" + connectionTimeout + "' milliseconds"); + } + } catch (InterruptedException e) { + throw new IOException("Unable to complete registration within '" + connectionTimeout + "' milliseconds", e); + } } - private - void startProtocolRegistration() { - new Thread(this, "Bootstrap registration").start(); - } - - // protected by bootstrapLock private boolean isRegistrationComplete() { return !bootstrapIterator.hasNext(); } - @SuppressWarnings("AutoBoxing") - @Override - public - void run() { - // NOTE: Throwing exceptions in this method is pointless, since it runs from it's own thread + // this is called by 2 threads. The startup thread, and the registration-in-progress thread + private void doRegistration() { synchronized (bootstrapLock) { - if (isRegistrationComplete()) { - return; - } - BootstrapWrapper bootstrapWrapper = bootstrapIterator.next(); ChannelFuture future; @@ -130,13 +131,14 @@ class EndPointClient extends EndPointBase implements Ru return; } - if (logger.isTraceEnabled()) { - logger.trace("Waiting for registration from server."); - } + logger.trace("Waiting for registration from server."); + manageForShutdown(future); } } + + /** * Internal call by the pipeline to notify the client to continue registering the different session protocols. * @@ -145,20 +147,16 @@ class EndPointClient extends EndPointBase implements Ru @Override protected boolean registerNextProtocol0() { + boolean registrationComplete; + synchronized (bootstrapLock) { registrationComplete = isRegistrationComplete(); if (!registrationComplete) { - startProtocolRegistration(); + doRegistration(); } - - // we're done with registration, so no need to keep this around - bootstrapIterator = null; } - - if (logger.isTraceEnabled()) { - logger.trace("Registered protocol from server."); - } + logger.trace("Registered protocol from server."); // only let us continue with connections (this starts up the client/server implementations) once ALL of the // bootstraps have connected @@ -172,9 +170,6 @@ class EndPointClient extends EndPointBase implements Ru @Override final void connectionConnected0(final ConnectionImpl connection) { - // invokes the listener.connection() method, and initialize the connection channels with whatever extra info they might need. - super.connectionConnected0(connection); - connectionBridgeFlushAlways = new ConnectionBridge() { @Override public @@ -217,27 +212,24 @@ class EndPointClient extends EndPointBase implements Ru //noinspection unchecked this.connection = (C) connection; - // check if there were any RMI callbacks during the connect phase. - rmiInitializationComplete = connection.rmiCallbacksIsEmpty(); - - // notify the registration we are done! - synchronized (registrationLock) { - registrationLock.notify(); + synchronized (bootstrapLock) { + // we're done with registration, so no need to keep this around + bootstrapIterator = null; + registration.countDown(); } + + // invokes the listener.connection() method, and initialize the connection channels with whatever extra info they might need. + // This will also start the RMI (if necessary) initialization/creation of objects + super.connectionConnected0(connection); } - /** - * Internal call. - *

- * RMI methods are usually created during the connection phase. We should wait until they are finished, but ONLY if there is - * something we need to wait for. - * - * This is called AFTER registration is finished. - */ - protected - void waitForRmi(final int connectionTimeout) { - if (!rmiInitializationComplete && connection instanceof ConnectionImpl) { - ((ConnectionImpl) connection).waitForRmi(connectionTimeout); + private + void registrationCompleted() { + // make sure we're not waiting on registration + synchronized (bootstrapLock) { + // we're done with registration, so no need to keep this around + bootstrapIterator = null; + registration.countDown(); } } @@ -263,34 +255,18 @@ class EndPointClient extends EndPointBase implements Ru void closeConnections() { super.closeConnections(); + // make sure we're not waiting on registration + registrationCompleted(); + // for the CLIENT only, we clear these connections! (the server only clears them on shutdown) shutdownChannels(); - - // make sure we're not waiting on registration - registrationComplete = true; - synchronized (registrationLock) { - registrationLock.notify(); - } - registrationComplete = false; - - // Always unblock the waiting client.connect(). - if (connection instanceof ConnectionImpl) { - ((ConnectionImpl) connection).rmiCallbacksNotify(); - } } /** * Internal call to abort registration if the shutdown command is issued during channel registration. */ void abortRegistration() { - synchronized (registrationLock) { - registrationLock.notify(); - } - - // Always unblock the waiting client.connect(). - if (connection instanceof ConnectionImpl) { - ((ConnectionImpl) connection).rmiCallbacksNotify(); - } - stop(); + // make sure we're not waiting on registration + registrationCompleted(); } } diff --git a/src/dorkbox/network/serialization/SerializationManager.java b/src/dorkbox/network/serialization/SerializationManager.java index 55f635a9..887ae6c0 100644 --- a/src/dorkbox/network/serialization/SerializationManager.java +++ b/src/dorkbox/network/serialization/SerializationManager.java @@ -84,11 +84,11 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati public static SerializationManager DEFAULT() { - return DEFAULT(true, true); + return DEFAULT(true, true, true); } public static - SerializationManager DEFAULT(final boolean references, final boolean registrationRequired) { + SerializationManager DEFAULT(final boolean references, final boolean registrationRequired, final boolean forbidInterfaceRegistration) { // ignore fields that have the "@IgnoreSerialization" annotation. Collection> marks = new ArrayList>(); marks.add(IgnoreSerialization.class); @@ -96,6 +96,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati final SerializationManager serializationManager = new SerializationManager(references, registrationRequired, + forbidInterfaceRegistration, disregardingFactory); serializationManager.register(PingMessage.class); @@ -170,9 +171,14 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati } } + + private boolean initialized = false; private final ObjectPool kryoPool; + // used to determine if we should forbid interface registration OUTSIDE of RMI registration. + private final boolean forbidInterfaceRegistration; + // used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class) // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // Object checking is performed during actual registration. @@ -207,6 +213,15 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati * Registered classes are serialized as an int id, avoiding the overhead of serializing the class name, but have the * drawback of needing to know the classes to be serialized up front. *

+ * @param forbidInterfaceRegistration + * If true, interfaces are not permitted to be registered, outside of the {@link #registerRmiInterface(Class)} and + * {@link #registerRmiImplementation(Class, Class)} methods. If false, then interfaces can also be registered. + *

+ * Enabling interface registration permits matching a different RMI client/server serialization scheme, since + * interfaces are generally in a "common" package, accessible to both the RMI client and server. + *

+ * Generally, one should not register interfaces, because they have no meaning (ignoring "default" implementations in + * newer versions of java...) * @param factory * Sets the serializer factory to use when no {@link Kryo#addDefaultSerializer(Class, Class) default serializers} match * an object's type. Default is {@link ReflectionSerializerFactory} with {@link FieldSerializer}. @see @@ -214,8 +229,9 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati *

*/ public - SerializationManager(final boolean references, final boolean registrationRequired, final SerializerFactory factory) { - kryoPool = ObjectPool.NonBlockingSoftReference(new PoolableObject() { + SerializationManager(final boolean references, final boolean registrationRequired, final boolean forbidInterfaceRegistration, final SerializerFactory factory) { + this.forbidInterfaceRegistration = forbidInterfaceRegistration; + this.kryoPool = ObjectPool.NonBlockingSoftReference(new PoolableObject() { @Override public KryoExtra create() { @@ -316,7 +332,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati if (initialized) { logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class) call."); } - else if (clazz.isInterface()) { + else if (forbidInterfaceRegistration && clazz.isInterface()) { throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); } else { classesToRegister.add(clazz); @@ -342,7 +358,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati if (initialized) { logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, int) call."); } - else if (clazz.isInterface()) { + else if (forbidInterfaceRegistration && clazz.isInterface()) { throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); } else { @@ -367,7 +383,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati if (initialized) { logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, Serializer) call."); } - else if (clazz.isInterface()) { + else if (forbidInterfaceRegistration && clazz.isInterface()) { throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); } else { @@ -394,7 +410,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati if (initialized) { logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, Serializer, int) call."); } - else if (clazz.isInterface()) { + else if (forbidInterfaceRegistration && clazz.isInterface()) { throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); } else { @@ -451,6 +467,13 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati return this; } + if (!ifaceClass.isInterface()) { + throw new IllegalArgumentException("Cannot register an implementation for RMI access. It must be an interface."); + } + if (implClass.isInterface()) { + throw new IllegalArgumentException("Cannot register an interface for RMI implementations. It must be an implementation."); + } + usesRmi = true; classesToRegister.add(new RemoteImplClass(ifaceClass, implClass)); diff --git a/test/dorkbox/network/IdleTest.java b/test/dorkbox/network/IdleTest.java index 9c90ad9c..384552ba 100644 --- a/test/dorkbox/network/IdleTest.java +++ b/test/dorkbox/network/IdleTest.java @@ -60,7 +60,7 @@ class IdleTest extends BaseTest { Configuration configuration = new Configuration(); configuration.tcpPort = tcpPort; configuration.host = host; - configuration.serialization = SerializationManager.DEFAULT(false, false); + configuration.serialization = SerializationManager.DEFAULT(false, false, true); streamSpecificType(largeDataSize, configuration, ConnectionType.TCP); @@ -70,7 +70,7 @@ class IdleTest extends BaseTest { configuration.tcpPort = tcpPort; configuration.udpPort = udpPort; configuration.host = host; - configuration.serialization = SerializationManager.DEFAULT(false, false); + configuration.serialization = SerializationManager.DEFAULT(false, false, true); streamSpecificType(largeDataSize, configuration, ConnectionType.UDP); } diff --git a/test/dorkbox/network/UnregisteredClassTest.java b/test/dorkbox/network/UnregisteredClassTest.java index c6b254a0..50b99209 100644 --- a/test/dorkbox/network/UnregisteredClassTest.java +++ b/test/dorkbox/network/UnregisteredClassTest.java @@ -54,7 +54,7 @@ class UnregisteredClassTest extends BaseTest { configuration.tcpPort = tcpPort; configuration.udpPort = udpPort; configuration.host = host; - configuration.serialization = SerializationManager.DEFAULT(false, false); + configuration.serialization = SerializationManager.DEFAULT(false, false, true); System.err.println("Running test " + this.tries + " times, please wait for it to finish.");