Code polish. Removed waiting for RMI to startup. Changed how

abortRegistration works, added MACOSX native event loops, Cleaned up how
 waiting for registration works (it's now a CountDownLatch. It's much
 more stable now.) Removed duplicate/dead methods. Serialization can now
  accept/reject interfaces for object serialization (as a way to permit
  "stubs" so class IDs match)
This commit is contained in:
nathan 2018-01-14 23:01:09 +01:00
parent 1d3cd06130
commit 758a93d1b9
8 changed files with 158 additions and 207 deletions

View File

@ -43,6 +43,8 @@ import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.epoll.EpollDatagramChannel; import io.netty.channel.epoll.EpollDatagramChannel;
import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel; import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.kqueue.KQueueDatagramChannel;
import io.netty.channel.kqueue.KQueueSocketChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
@ -167,6 +169,10 @@ class Client<C extends Connection> extends EndPointClient<C> implements Connecti
// JNI network stack is MUCH faster (but only on linux) // JNI network stack is MUCH faster (but only on linux)
tcpBootstrap.channel(EpollSocketChannel.class); tcpBootstrap.channel(EpollSocketChannel.class);
} }
else if (OS.isMacOsX()) {
// JNI network stack is MUCH faster (but only on macosx)
tcpBootstrap.channel(KQueueSocketChannel.class);
}
else { else {
tcpBootstrap.channel(NioSocketChannel.class); tcpBootstrap.channel(NioSocketChannel.class);
} }
@ -197,7 +203,12 @@ class Client<C extends Connection> extends EndPointClient<C> implements Connecti
// JNI network stack is MUCH faster (but only on linux) // JNI network stack is MUCH faster (but only on linux)
udpBootstrap.channel(EpollDatagramChannel.class); udpBootstrap.channel(EpollDatagramChannel.class);
} }
else if (OS.isMacOsX()) {
// JNI network stack is MUCH faster (but only on macosx)
udpBootstrap.channel(KQueueDatagramChannel.class);
}
else { else {
// windows
udpBootstrap.channel(NioDatagramChannel.class); udpBootstrap.channel(NioDatagramChannel.class);
} }
@ -241,7 +252,7 @@ class Client<C extends Connection> extends EndPointClient<C> implements Connecti
* if the client is unable to reconnect in the requested time * if the client is unable to reconnect in the requested time
*/ */
public public
void reconnect(int connectionTimeout) throws IOException { void reconnect(final int connectionTimeout) throws IOException {
// close out all old connections // close out all old connections
closeConnections(); closeConnections();
@ -291,23 +302,8 @@ class Client<C extends Connection> extends EndPointClient<C> implements Connecti
} }
} }
// have to start the registration process // have to start the registration process. This will wait until registration is complete and RMI methods are initialized
registerNextProtocol(); startRegistration();
// have to BLOCK
// don't want the client to run before registration is complete
synchronized (registrationLock) {
if (!registrationComplete) {
try {
registrationLock.wait(connectionTimeout);
} catch (InterruptedException e) {
throw new IOException("Unable to complete registration within '" + connectionTimeout + "' milliseconds", e);
}
}
}
// RMI methods are usually created during the connection phase. We should wait until they are finished
waitForRmi(connectionTimeout);
} }
@Override @Override

View File

@ -36,10 +36,11 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.epoll.EpollChannelOption;
import io.netty.channel.epoll.EpollDatagramChannel; import io.netty.channel.epoll.EpollDatagramChannel;
import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel; import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.kqueue.KQueueDatagramChannel;
import io.netty.channel.kqueue.KQueueServerSocketChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
@ -48,6 +49,7 @@ import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.oio.OioDatagramChannel; import io.netty.channel.socket.oio.OioDatagramChannel;
import io.netty.channel.socket.oio.OioServerSocketChannel; import io.netty.channel.socket.oio.OioServerSocketChannel;
import io.netty.channel.unix.UnixChannelOption;
/** /**
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the * The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
@ -200,6 +202,10 @@ class Server<C extends Connection> extends EndPointServer<C> {
// JNI network stack is MUCH faster (but only on linux) // JNI network stack is MUCH faster (but only on linux)
tcpBootstrap.channel(EpollServerSocketChannel.class); tcpBootstrap.channel(EpollServerSocketChannel.class);
} }
else if (OS.isMacOsX()) {
// JNI network stack is MUCH faster (but only on macosx)
tcpBootstrap.channel(KQueueServerSocketChannel.class);
}
else { else {
tcpBootstrap.channel(NioServerSocketChannel.class); tcpBootstrap.channel(NioServerSocketChannel.class);
} }
@ -235,14 +241,21 @@ class Server<C extends Connection> extends EndPointServer<C> {
if (udpBootstrap != null) { if (udpBootstrap != null) {
if (OS.isAndroid()) { if (OS.isAndroid()) {
// android ONLY supports OIO (not NIO) // android ONLY supports OIO (not NIO)
udpBootstrap.channel(OioDatagramChannel.class); udpBootstrap.channel(OioDatagramChannel.class)
.option(UnixChannelOption.SO_REUSEPORT, true);
} }
else if (OS.isLinux()) { else if (OS.isLinux()) {
// JNI network stack is MUCH faster (but only on linux) // JNI network stack is MUCH faster (but only on linux)
udpBootstrap.channel(EpollDatagramChannel.class) udpBootstrap.channel(EpollDatagramChannel.class)
.option(EpollChannelOption.SO_REUSEPORT, true); .option(UnixChannelOption.SO_REUSEPORT, true);
}
else if (OS.isMacOsX()) {
// JNI network stack is MUCH faster (but only on macosx)
udpBootstrap.channel(KQueueDatagramChannel.class)
.option(UnixChannelOption.SO_REUSEPORT, true);
} }
else { else {
// windows
udpBootstrap.channel(NioDatagramChannel.class); udpBootstrap.channel(NioDatagramChannel.class);
} }

View File

@ -20,6 +20,7 @@ import java.lang.reflect.Field;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.Map; import java.util.Map;
import java.util.WeakHashMap; import java.util.WeakHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -93,7 +94,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
// while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error. // while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error.
private boolean remoteKeyChanged; private boolean remoteKeyChanged;
private final EndPointBase<Connection> endPointBaseConnection; private final EndPointBase endPoint;
// when true, the connection will be closed (either as RMI or as 'normal' listener execution) when the thread execution returns control // when true, the connection will be closed (either as RMI or as 'normal' listener execution) when the thread execution returns control
// back to the network stack // back to the network stack
@ -107,6 +108,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
// //
// RMI fields // RMI fields
// //
protected CountDownLatch rmi;
private final RmiBridge rmiBridge; private final RmiBridge rmiBridge;
private final Map<Integer, RemoteObject> proxyIdCache = new WeakHashMap<Integer, RemoteObject>(8); private final Map<Integer, RemoteObject> proxyIdCache = new WeakHashMap<Integer, RemoteObject>(8);
private final IntMap<RemoteObjectCallback> rmiRegistrationCallbacks = new IntMap<RemoteObjectCallback>(); private final IntMap<RemoteObjectCallback> rmiRegistrationCallbacks = new IntMap<RemoteObjectCallback>();
@ -117,9 +119,9 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
*/ */
@SuppressWarnings({"rawtypes", "unchecked"}) @SuppressWarnings({"rawtypes", "unchecked"})
public public
ConnectionImpl(final Logger logger, final EndPointBase endPointBaseConnection, final RmiBridge rmiBridge) { ConnectionImpl(final Logger logger, final EndPointBase endPoint, final RmiBridge rmiBridge) {
this.logger = logger; this.logger = logger;
this.endPointBaseConnection = endPointBaseConnection; this.endPoint = endPoint;
this.rmiBridge = rmiBridge; this.rmiBridge = rmiBridge;
} }
@ -215,7 +217,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
@Override @Override
public public
EndPointBase<Connection> getEndPoint() { EndPointBase<Connection> getEndPoint() {
return this.endPointBaseConnection; return this.endPoint;
} }
/** /**
@ -328,7 +330,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
*/ */
private private
void controlBackPressure(ConnectionPoint c) { void controlBackPressure(ConnectionPoint c) {
while (!c.isWritable()) { while (!closeInProgress.get() && !c.isWritable()) {
needsLock.set(true); needsLock.set(true);
writeSignalNeeded.set(true); writeSignalNeeded.set(true);
@ -590,6 +592,10 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
.toString()); .toString());
} }
if (this.endPoint instanceof EndPointClient) {
((EndPointClient) this.endPoint).abortRegistration();
}
// our master channels are TCP/LOCAL (which are mutually exclusive). Only key disconnect events based on the status of them. // our master channels are TCP/LOCAL (which are mutually exclusive). Only key disconnect events based on the status of them.
if (isTCP || channelClass == LocalChannel.class) { if (isTCP || channelClass == LocalChannel.class) {
// this is because channelInactive can ONLY happen when netty shuts down the channel. // this is because channelInactive can ONLY happen when netty shuts down the channel.
@ -614,7 +620,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
void close() { void close() {
// only close if we aren't already in the middle of closing. // only close if we aren't already in the middle of closing.
if (this.closeInProgress.compareAndSet(false, true)) { if (this.closeInProgress.compareAndSet(false, true)) {
int idleTimeoutMs = this.endPointBaseConnection.getIdleTimeout(); int idleTimeoutMs = this.endPoint.getIdleTimeout();
if (idleTimeoutMs == 0) { if (idleTimeoutMs == 0) {
// default is 2 second timeout, in milliseconds. // default is 2 second timeout, in milliseconds.
idleTimeoutMs = 2000; idleTimeoutMs = 2000;
@ -714,7 +720,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
@Override @Override
public final public final
Listeners add(Listener listener) { Listeners add(Listener listener) {
if (this.endPointBaseConnection instanceof EndPointServer) { if (this.endPoint instanceof EndPointServer) {
// when we are a server, NORMALLY listeners are added at the GLOBAL level // when we are a server, NORMALLY listeners are added at the GLOBAL level
// meaning -- // meaning --
// I add one listener, and ALL connections are notified of that listener. // I add one listener, and ALL connections are notified of that listener.
@ -726,15 +732,15 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
// is empty, we can remove it from this connection. // is empty, we can remove it from this connection.
synchronized (this) { synchronized (this) {
if (this.localListenerManager == null) { if (this.localListenerManager == null) {
this.localListenerManager = ((EndPointServer<Connection>) this.endPointBaseConnection).addListenerManager(this); this.localListenerManager = ((EndPointServer) this.endPoint).addListenerManager(this);
} }
this.localListenerManager.add(listener); this.localListenerManager.add(listener);
} }
} }
else { else {
this.endPointBaseConnection.listeners() this.endPoint.listeners()
.add(listener); .add(listener);
} }
return this; return this;
@ -756,7 +762,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
@Override @Override
public final public final
Listeners remove(Listener listener) { Listeners remove(Listener listener) {
if (this.endPointBaseConnection instanceof EndPointServer) { if (this.endPoint instanceof EndPointServer) {
// when we are a server, NORMALLY listeners are added at the GLOBAL level // when we are a server, NORMALLY listeners are added at the GLOBAL level
// meaning -- // meaning --
// I add one listener, and ALL connections are notified of that listener. // I add one listener, and ALL connections are notified of that listener.
@ -771,14 +777,14 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
this.localListenerManager.remove(listener); this.localListenerManager.remove(listener);
if (!this.localListenerManager.hasListeners()) { if (!this.localListenerManager.hasListeners()) {
((EndPointServer<Connection>) this.endPointBaseConnection).removeListenerManager(this); ((EndPointServer<Connection>) this.endPoint).removeListenerManager(this);
} }
} }
} }
} }
else { else {
this.endPointBaseConnection.listeners() this.endPoint.listeners()
.remove(listener); .remove(listener);
} }
return this; return this;
@ -791,7 +797,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
@Override @Override
public final public final
Listeners removeAll() { Listeners removeAll() {
if (this.endPointBaseConnection instanceof EndPointServer) { if (this.endPoint instanceof EndPointServer) {
// when we are a server, NORMALLY listeners are added at the GLOBAL level // when we are a server, NORMALLY listeners are added at the GLOBAL level
// meaning -- // meaning --
// I add one listener, and ALL connections are notified of that listener. // I add one listener, and ALL connections are notified of that listener.
@ -806,13 +812,13 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
this.localListenerManager.removeAll(); this.localListenerManager.removeAll();
this.localListenerManager = null; this.localListenerManager = null;
((EndPointServer<Connection>) this.endPointBaseConnection).removeListenerManager(this); ((EndPointServer<Connection>) this.endPoint).removeListenerManager(this);
} }
} }
} }
else { else {
this.endPointBaseConnection.listeners() this.endPoint.listeners()
.removeAll(); .removeAll();
} }
return this; return this;
@ -826,7 +832,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
@Override @Override
public final public final
Listeners removeAll(Class<?> classType) { Listeners removeAll(Class<?> classType) {
if (this.endPointBaseConnection instanceof EndPointServer) { if (this.endPoint instanceof EndPointServer) {
// when we are a server, NORMALLY listeners are added at the GLOBAL level // when we are a server, NORMALLY listeners are added at the GLOBAL level
// meaning -- // meaning --
// I add one listener, and ALL connections are notified of that listener. // I add one listener, and ALL connections are notified of that listener.
@ -842,14 +848,14 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
if (!this.localListenerManager.hasListeners()) { if (!this.localListenerManager.hasListeners()) {
this.localListenerManager = null; this.localListenerManager = null;
((EndPointServer<Connection>) this.endPointBaseConnection).removeListenerManager(this); ((EndPointServer<Connection>) this.endPoint).removeListenerManager(this);
} }
} }
} }
} }
else { else {
this.endPointBaseConnection.listeners() this.endPoint.listeners()
.removeAll(classType); .removeAll(classType);
} }
return this; return this;
@ -899,62 +905,10 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
// RMI methods // RMI methods
// //
/**
* Internal call CLIENT ONLY.
* <p>
* RMI methods are usually created during the connection phase. We should wait until they are finished
*/
void waitForRmi(final int connectionTimeout) {
synchronized (rmiRegistrationCallbacks) {
try {
rmiRegistrationCallbacks.wait(connectionTimeout);
} catch (InterruptedException e) {
logger.error("Interrupted waiting for RMI to finish.", e);
}
}
}
/**
* Internal call CLIENT ONLY.
* <p>
* RMI methods are usually created during the connection phase. If there are none, we should unblock the waiting client.connect().
*/
boolean rmiCallbacksIsEmpty() {
synchronized (rmiRegistrationCallbacks) {
return rmiRegistrationCallbacks.size == 0;
}
}
/**
* Internal call CLIENT ONLY.
* <p>
* RMI methods are usually created during the connection phase. If there are none, we should unblock the waiting client.connect().
*/
void rmiCallbacksNotify() {
synchronized (rmiRegistrationCallbacks) {
rmiRegistrationCallbacks.notify();
}
}
/**
* Internal call CLIENT ONLY.
* <p>
* RMI methods are usually created during the connection phase. If there are none, we should unblock the waiting client.connect().
*/
private
void rmiCallbacksNotifyIfEmpty() {
synchronized (rmiRegistrationCallbacks) {
if (rmiRegistrationCallbacks.size == 0) {
rmiRegistrationCallbacks.notify();
}
}
}
@SuppressWarnings({"UnnecessaryLocalVariable", "unchecked", "Duplicates"}) @SuppressWarnings({"UnnecessaryLocalVariable", "unchecked", "Duplicates"})
@Override @Override
public final public final
<Iface> void getRemoteObject(final Class<Iface> interfaceClass, final RemoteObjectCallback<Iface> callback) throws IOException { <Iface> void getRemoteObject(final Class<Iface> interfaceClass, final RemoteObjectCallback<Iface> callback) {
if (!interfaceClass.isInterface()) { if (!interfaceClass.isInterface()) {
throw new IllegalArgumentException("Cannot create a proxy for RMI access. It must be an interface."); throw new IllegalArgumentException("Cannot create a proxy for RMI access. It must be an interface.");
} }
@ -977,7 +931,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
@SuppressWarnings({"UnnecessaryLocalVariable", "unchecked", "Duplicates"}) @SuppressWarnings({"UnnecessaryLocalVariable", "unchecked", "Duplicates"})
@Override @Override
public final public final
<Iface> void getRemoteObject(final int objectId, final RemoteObjectCallback<Iface> callback) throws IOException { <Iface> void getRemoteObject(final int objectId, final RemoteObjectCallback<Iface> callback) {
RmiRegistration message; RmiRegistration message;
synchronized (rmiRegistrationCallbacks) { synchronized (rmiRegistrationCallbacks) {
@ -1079,12 +1033,14 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
} }
} }
else { else {
// THIS IS ON THE LOCAL CONNECTION SIDE, which is the side that called 'getRemoteObject()' // THIS IS ON THE LOCAL CONNECTION SIDE, which is the side that called 'getRemoteObject()' This can be Server or Client.
// this will be null if there was an error // this will be null if there was an error
Object remoteObject = remoteRegistration.remoteObject; Object remoteObject = remoteRegistration.remoteObject;
boolean noMoreRmiRemaining ;
RemoteObjectCallback callback; RemoteObjectCallback callback;
synchronized (rmiRegistrationCallbacks) { synchronized (rmiRegistrationCallbacks) {
callback = rmiRegistrationCallbacks.remove(remoteRegistration.rmiID); callback = rmiRegistrationCallbacks.remove(remoteRegistration.rmiID);
} }
@ -1095,9 +1051,6 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
} catch (Exception e) { } catch (Exception e) {
logger.error("Error getting remote object " + remoteObject.getClass() + ", ID: " + rmiID, e); logger.error("Error getting remote object " + remoteObject.getClass() + ", ID: " + rmiID, e);
} }
// tell the client that we are finished with all RMI callbacks
rmiCallbacksNotifyIfEmpty();
} }
} }
@ -1111,7 +1064,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
public public
<T> int getRegisteredId(final T object) { <T> int getRegisteredId(final T object) {
// always check local before checking global, because less contention on the synchronization // always check local before checking global, because less contention on the synchronization
RmiBridge globalRmiBridge = endPointBaseConnection.globalRmiBridge; RmiBridge globalRmiBridge = endPoint.globalRmiBridge;
if (globalRmiBridge == null) { if (globalRmiBridge == null) {
throw new NullPointerException("Unable to call 'getRegisteredId' when the globalRmiBridge is null!"); throw new NullPointerException("Unable to call 'getRegisteredId' when the globalRmiBridge is null!");
@ -1155,7 +1108,7 @@ class ConnectionImpl extends ChannelInboundHandlerAdapter implements ICryptoConn
public public
Object getImplementationObject(final int objectID) { Object getImplementationObject(final int objectID) {
if (RmiBridge.isGlobal(objectID)) { if (RmiBridge.isGlobal(objectID)) {
RmiBridge globalRmiBridge = endPointBaseConnection.globalRmiBridge; RmiBridge globalRmiBridge = endPoint.globalRmiBridge;
if (globalRmiBridge == null) { if (globalRmiBridge == null) {
throw new NullPointerException("Unable to call 'getRegisteredId' when the gloablRmiBridge is null!"); throw new NullPointerException("Unable to call 'getRegisteredId' when the gloablRmiBridge is null!");

View File

@ -17,7 +17,6 @@ package dorkbox.network.connection;
import java.io.IOException; import java.io.IOException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -277,7 +276,7 @@ class EndPointBase<C extends Connection> extends EndPoint {
*/ */
public public
void setIdleTimeout(int idleTimeoutMs) { void setIdleTimeout(int idleTimeoutMs) {
idleTimeoutMs = idleTimeoutMs; this.idleTimeoutMs = idleTimeoutMs;
} }
/** /**
@ -308,8 +307,8 @@ class EndPointBase<C extends Connection> extends EndPoint {
* @return a new network connection * @return a new network connection
*/ */
protected protected
ConnectionImpl newConnection(final Logger logger, final EndPointBase<C> endPointBaseConnection, final RmiBridge rmiBridge) { ConnectionImpl newConnection(final Logger logger, final EndPointBase endPoint, final RmiBridge rmiBridge) {
return new ConnectionImpl(logger, endPointBaseConnection, rmiBridge); return new ConnectionImpl(logger, endPoint, rmiBridge);
} }
/** /**
@ -401,15 +400,6 @@ class EndPointBase<C extends Connection> extends EndPoint {
return connectionManager.getConnections(); return connectionManager.getConnections();
} }
/**
* Returns a non-modifiable list of active connections
*/
@SuppressWarnings("unchecked")
public
Collection<C> getConnectionsAs() {
return connectionManager.getConnections();
}
/** /**
* Expose methods to send objects to a destination. * Expose methods to send objects to a destination.
*/ */

View File

@ -19,6 +19,8 @@ import java.io.IOException;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import dorkbox.network.Client; import dorkbox.network.Client;
import dorkbox.network.Configuration; import dorkbox.network.Configuration;
@ -32,19 +34,18 @@ import io.netty.channel.ChannelOption;
* This serves the purpose of making sure that specific methods are not available to the end user. * This serves the purpose of making sure that specific methods are not available to the end user.
*/ */
public public
class EndPointClient<C extends Connection> extends EndPointBase<C> implements Runnable { class EndPointClient<C extends Connection> extends EndPointBase<C> {
protected C connection; protected C connection;
protected final Object registrationLock = new Object(); private CountDownLatch registration;
protected final Object bootstrapLock = new Object(); private final Object bootstrapLock = new Object();
protected List<BootstrapWrapper> bootstraps = new LinkedList<BootstrapWrapper>(); protected List<BootstrapWrapper> bootstraps = new LinkedList<BootstrapWrapper>();
protected Iterator<BootstrapWrapper> bootstrapIterator; private Iterator<BootstrapWrapper> bootstrapIterator;
protected volatile int connectionTimeout = 5000; // default is 5 seconds
protected volatile int connectionTimeout = 5000; // default
protected volatile boolean registrationComplete = false;
private volatile boolean rmiInitializationComplete = false;
private volatile ConnectionBridge connectionBridgeFlushAlways; private volatile ConnectionBridge connectionBridgeFlushAlways;
@ -55,36 +56,36 @@ class EndPointClient<C extends Connection> extends EndPointBase<C> implements Ru
} }
protected protected
void registerNextProtocol() { void startRegistration() throws IOException {
// always reset everything. synchronized (bootstrapLock) {
registrationComplete = false; // always reset everything.
bootstrapIterator = bootstraps.iterator(); registration = new CountDownLatch(1);
startProtocolRegistration(); bootstrapIterator = bootstraps.iterator();
}
doRegistration();
// have to BLOCK
// don't want the client to run before registration is complete
try {
if (!registration.await(connectionTimeout, TimeUnit.MILLISECONDS)) {
throw new IOException("Unable to complete registration within '" + connectionTimeout + "' milliseconds");
}
} catch (InterruptedException e) {
throw new IOException("Unable to complete registration within '" + connectionTimeout + "' milliseconds", e);
}
} }
private
void startProtocolRegistration() {
new Thread(this, "Bootstrap registration").start();
}
// protected by bootstrapLock // protected by bootstrapLock
private private
boolean isRegistrationComplete() { boolean isRegistrationComplete() {
return !bootstrapIterator.hasNext(); return !bootstrapIterator.hasNext();
} }
@SuppressWarnings("AutoBoxing") // this is called by 2 threads. The startup thread, and the registration-in-progress thread
@Override private void doRegistration() {
public
void run() {
// NOTE: Throwing exceptions in this method is pointless, since it runs from it's own thread
synchronized (bootstrapLock) { synchronized (bootstrapLock) {
if (isRegistrationComplete()) {
return;
}
BootstrapWrapper bootstrapWrapper = bootstrapIterator.next(); BootstrapWrapper bootstrapWrapper = bootstrapIterator.next();
ChannelFuture future; ChannelFuture future;
@ -130,13 +131,14 @@ class EndPointClient<C extends Connection> extends EndPointBase<C> implements Ru
return; return;
} }
if (logger.isTraceEnabled()) { logger.trace("Waiting for registration from server.");
logger.trace("Waiting for registration from server.");
}
manageForShutdown(future); manageForShutdown(future);
} }
} }
/** /**
* Internal call by the pipeline to notify the client to continue registering the different session protocols. * Internal call by the pipeline to notify the client to continue registering the different session protocols.
* *
@ -145,20 +147,16 @@ class EndPointClient<C extends Connection> extends EndPointBase<C> implements Ru
@Override @Override
protected protected
boolean registerNextProtocol0() { boolean registerNextProtocol0() {
boolean registrationComplete;
synchronized (bootstrapLock) { synchronized (bootstrapLock) {
registrationComplete = isRegistrationComplete(); registrationComplete = isRegistrationComplete();
if (!registrationComplete) { if (!registrationComplete) {
startProtocolRegistration(); doRegistration();
} }
// we're done with registration, so no need to keep this around
bootstrapIterator = null;
} }
logger.trace("Registered protocol from server.");
if (logger.isTraceEnabled()) {
logger.trace("Registered protocol from server.");
}
// only let us continue with connections (this starts up the client/server implementations) once ALL of the // only let us continue with connections (this starts up the client/server implementations) once ALL of the
// bootstraps have connected // bootstraps have connected
@ -172,9 +170,6 @@ class EndPointClient<C extends Connection> extends EndPointBase<C> implements Ru
@Override @Override
final final
void connectionConnected0(final ConnectionImpl connection) { void connectionConnected0(final ConnectionImpl connection) {
// invokes the listener.connection() method, and initialize the connection channels with whatever extra info they might need.
super.connectionConnected0(connection);
connectionBridgeFlushAlways = new ConnectionBridge() { connectionBridgeFlushAlways = new ConnectionBridge() {
@Override @Override
public public
@ -217,27 +212,24 @@ class EndPointClient<C extends Connection> extends EndPointBase<C> implements Ru
//noinspection unchecked //noinspection unchecked
this.connection = (C) connection; this.connection = (C) connection;
// check if there were any RMI callbacks during the connect phase. synchronized (bootstrapLock) {
rmiInitializationComplete = connection.rmiCallbacksIsEmpty(); // we're done with registration, so no need to keep this around
bootstrapIterator = null;
// notify the registration we are done! registration.countDown();
synchronized (registrationLock) {
registrationLock.notify();
} }
// invokes the listener.connection() method, and initialize the connection channels with whatever extra info they might need.
// This will also start the RMI (if necessary) initialization/creation of objects
super.connectionConnected0(connection);
} }
/** private
* Internal call. void registrationCompleted() {
* <p> // make sure we're not waiting on registration
* RMI methods are usually created during the connection phase. We should wait until they are finished, but ONLY if there is synchronized (bootstrapLock) {
* something we need to wait for. // we're done with registration, so no need to keep this around
* bootstrapIterator = null;
* This is called AFTER registration is finished. registration.countDown();
*/
protected
void waitForRmi(final int connectionTimeout) {
if (!rmiInitializationComplete && connection instanceof ConnectionImpl) {
((ConnectionImpl) connection).waitForRmi(connectionTimeout);
} }
} }
@ -263,34 +255,18 @@ class EndPointClient<C extends Connection> extends EndPointBase<C> implements Ru
void closeConnections() { void closeConnections() {
super.closeConnections(); super.closeConnections();
// make sure we're not waiting on registration
registrationCompleted();
// for the CLIENT only, we clear these connections! (the server only clears them on shutdown) // for the CLIENT only, we clear these connections! (the server only clears them on shutdown)
shutdownChannels(); shutdownChannels();
// make sure we're not waiting on registration
registrationComplete = true;
synchronized (registrationLock) {
registrationLock.notify();
}
registrationComplete = false;
// Always unblock the waiting client.connect().
if (connection instanceof ConnectionImpl) {
((ConnectionImpl) connection).rmiCallbacksNotify();
}
} }
/** /**
* Internal call to abort registration if the shutdown command is issued during channel registration. * Internal call to abort registration if the shutdown command is issued during channel registration.
*/ */
void abortRegistration() { void abortRegistration() {
synchronized (registrationLock) { // make sure we're not waiting on registration
registrationLock.notify(); registrationCompleted();
}
// Always unblock the waiting client.connect().
if (connection instanceof ConnectionImpl) {
((ConnectionImpl) connection).rmiCallbacksNotify();
}
stop();
} }
} }

View File

@ -84,11 +84,11 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
public static public static
SerializationManager DEFAULT() { SerializationManager DEFAULT() {
return DEFAULT(true, true); return DEFAULT(true, true, true);
} }
public static public static
SerializationManager DEFAULT(final boolean references, final boolean registrationRequired) { SerializationManager DEFAULT(final boolean references, final boolean registrationRequired, final boolean forbidInterfaceRegistration) {
// ignore fields that have the "@IgnoreSerialization" annotation. // ignore fields that have the "@IgnoreSerialization" annotation.
Collection<Class<? extends Annotation>> marks = new ArrayList<Class<? extends Annotation>>(); Collection<Class<? extends Annotation>> marks = new ArrayList<Class<? extends Annotation>>();
marks.add(IgnoreSerialization.class); marks.add(IgnoreSerialization.class);
@ -96,6 +96,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
final SerializationManager serializationManager = new SerializationManager(references, final SerializationManager serializationManager = new SerializationManager(references,
registrationRequired, registrationRequired,
forbidInterfaceRegistration,
disregardingFactory); disregardingFactory);
serializationManager.register(PingMessage.class); serializationManager.register(PingMessage.class);
@ -170,9 +171,14 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
} }
} }
private boolean initialized = false; private boolean initialized = false;
private final ObjectPool<KryoExtra> kryoPool; private final ObjectPool<KryoExtra> kryoPool;
// used to determine if we should forbid interface registration OUTSIDE of RMI registration.
private final boolean forbidInterfaceRegistration;
// used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class) // used by operations performed during kryo initialization, which are by default package access (since it's an anon-inner class)
// All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems. // All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems.
// Object checking is performed during actual registration. // Object checking is performed during actual registration.
@ -207,6 +213,15 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
* Registered classes are serialized as an int id, avoiding the overhead of serializing the class name, but have the * Registered classes are serialized as an int id, avoiding the overhead of serializing the class name, but have the
* drawback of needing to know the classes to be serialized up front. * drawback of needing to know the classes to be serialized up front.
* <p> * <p>
* @param forbidInterfaceRegistration
* If true, interfaces are not permitted to be registered, outside of the {@link #registerRmiInterface(Class)} and
* {@link #registerRmiImplementation(Class, Class)} methods. If false, then interfaces can also be registered.
* <p>
* Enabling interface registration permits matching a different RMI client/server serialization scheme, since
* interfaces are generally in a "common" package, accessible to both the RMI client and server.
* <p>
* Generally, one should not register interfaces, because they have no meaning (ignoring "default" implementations in
* newer versions of java...)
* @param factory * @param factory
* Sets the serializer factory to use when no {@link Kryo#addDefaultSerializer(Class, Class) default serializers} match * Sets the serializer factory to use when no {@link Kryo#addDefaultSerializer(Class, Class) default serializers} match
* an object's type. Default is {@link ReflectionSerializerFactory} with {@link FieldSerializer}. @see * an object's type. Default is {@link ReflectionSerializerFactory} with {@link FieldSerializer}. @see
@ -214,8 +229,9 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
* <p> * <p>
*/ */
public public
SerializationManager(final boolean references, final boolean registrationRequired, final SerializerFactory factory) { SerializationManager(final boolean references, final boolean registrationRequired, final boolean forbidInterfaceRegistration, final SerializerFactory factory) {
kryoPool = ObjectPool.NonBlockingSoftReference(new PoolableObject<KryoExtra>() { this.forbidInterfaceRegistration = forbidInterfaceRegistration;
this.kryoPool = ObjectPool.NonBlockingSoftReference(new PoolableObject<KryoExtra>() {
@Override @Override
public public
KryoExtra create() { KryoExtra create() {
@ -316,7 +332,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
if (initialized) { if (initialized) {
logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class) call."); logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class) call.");
} }
else if (clazz.isInterface()) { else if (forbidInterfaceRegistration && clazz.isInterface()) {
throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation.");
} else { } else {
classesToRegister.add(clazz); classesToRegister.add(clazz);
@ -342,7 +358,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
if (initialized) { if (initialized) {
logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, int) call."); logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, int) call.");
} }
else if (clazz.isInterface()) { else if (forbidInterfaceRegistration && clazz.isInterface()) {
throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation.");
} }
else { else {
@ -367,7 +383,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
if (initialized) { if (initialized) {
logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, Serializer) call."); logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, Serializer) call.");
} }
else if (clazz.isInterface()) { else if (forbidInterfaceRegistration && clazz.isInterface()) {
throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation.");
} }
else { else {
@ -394,7 +410,7 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
if (initialized) { if (initialized) {
logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, Serializer, int) call."); logger.warn("Serialization manager already initialized. Ignoring duplicate register(Class, Serializer, int) call.");
} }
else if (clazz.isInterface()) { else if (forbidInterfaceRegistration && clazz.isInterface()) {
throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation."); throw new IllegalArgumentException("Cannot register an interface for serialization. It must be an implementation.");
} }
else { else {
@ -451,6 +467,13 @@ class SerializationManager implements CryptoSerializationManager, RmiSerializati
return this; return this;
} }
if (!ifaceClass.isInterface()) {
throw new IllegalArgumentException("Cannot register an implementation for RMI access. It must be an interface.");
}
if (implClass.isInterface()) {
throw new IllegalArgumentException("Cannot register an interface for RMI implementations. It must be an implementation.");
}
usesRmi = true; usesRmi = true;
classesToRegister.add(new RemoteImplClass(ifaceClass, implClass)); classesToRegister.add(new RemoteImplClass(ifaceClass, implClass));

View File

@ -60,7 +60,7 @@ class IdleTest extends BaseTest {
Configuration configuration = new Configuration(); Configuration configuration = new Configuration();
configuration.tcpPort = tcpPort; configuration.tcpPort = tcpPort;
configuration.host = host; configuration.host = host;
configuration.serialization = SerializationManager.DEFAULT(false, false); configuration.serialization = SerializationManager.DEFAULT(false, false, true);
streamSpecificType(largeDataSize, configuration, ConnectionType.TCP); streamSpecificType(largeDataSize, configuration, ConnectionType.TCP);
@ -70,7 +70,7 @@ class IdleTest extends BaseTest {
configuration.tcpPort = tcpPort; configuration.tcpPort = tcpPort;
configuration.udpPort = udpPort; configuration.udpPort = udpPort;
configuration.host = host; configuration.host = host;
configuration.serialization = SerializationManager.DEFAULT(false, false); configuration.serialization = SerializationManager.DEFAULT(false, false, true);
streamSpecificType(largeDataSize, configuration, ConnectionType.UDP); streamSpecificType(largeDataSize, configuration, ConnectionType.UDP);
} }

View File

@ -54,7 +54,7 @@ class UnregisteredClassTest extends BaseTest {
configuration.tcpPort = tcpPort; configuration.tcpPort = tcpPort;
configuration.udpPort = udpPort; configuration.udpPort = udpPort;
configuration.host = host; configuration.host = host;
configuration.serialization = SerializationManager.DEFAULT(false, false); configuration.serialization = SerializationManager.DEFAULT(false, false, true);
System.err.println("Running test " + this.tries + " times, please wait for it to finish."); System.err.println("Running test " + this.tries + " times, please wait for it to finish.");