Compare commits

...

307 Commits

Author SHA1 Message Date
Robinson 3f77af5894
More responsive shutdown logic during handshake and fixed crashes when forcing a shutdown 2024-02-19 10:09:03 +01:00
Robinson c197a2f627
logging 2024-02-05 21:58:00 +01:00
Robinson 0500de29c8
updated copyright 2024-02-05 21:57:40 +01:00
Robinson 4fc0cb7541
fixed comments 2024-01-15 10:39:58 +01:00
Robinson e4721b4c7c
fixed typo 2023-12-14 22:50:42 +01:00
Robinson bd6476059b
ConnectionCounts are now threadsafe 2023-12-12 13:31:07 +01:00
Robinson e80cc93c72
Use ThreadSafe publication instead of exclusive publication 2023-12-12 13:25:01 +01:00
Robinson a706cdb228
If we have "max connections" specified, then obey the limit 2023-12-12 13:24:33 +01:00
Robinson 98d8321902
updated config comments 2023-12-12 13:24:01 +01:00
Robinson 331f921df0
Added extra information for handshake/connection timeout potential issues 2023-12-12 13:23:51 +01:00
Robinson 9c2fa4b65b
updated logging 2023-12-12 13:22:28 +01:00
Robinson db15b62c8c
Added more unit tests, increased multi-threading for mutli-client unit test 2023-12-04 13:37:18 +01:00
Robinson 6f50618040
Now properly waits for event dispatcher to shutdown in unit tests 2023-12-04 13:36:47 +01:00
Robinson 6bbf62f886
Added more unit tests for aeron connectivity 2023-12-04 11:10:52 +01:00
Robinson cdc056f3a1
logging tweak 2023-12-04 10:47:51 +01:00
Robinson cb7f8b2990
Send all buffered messages at once, instead of 1-at-a-time. 2023-12-04 10:47:41 +01:00
Robinson b9f7a552f0
update license 2023-11-28 21:39:14 +01:00
Robinson 2c9d3c119c
Updated build deps 2023-11-28 21:08:53 +01:00
Robinson 3a5295efe8
Cleaned logging 2023-11-28 20:53:37 +01:00
Robinson 5c4d64f3f1
fixed up comments 2023-11-28 20:52:58 +01:00
Robinson 41b3acf147
Fixed api name 2023-11-27 14:40:09 +01:00
Robinson cf2e7ffc77
Removed exception stacktrace during reconnect 2023-11-27 13:11:47 +01:00
Robinson bf0cd3f0e6
Added support for PER-CONNECTION buffering of messages (default is enabled) 2023-11-27 11:14:52 +01:00
Robinson f1a06fd8fd
Better comments/docs 2023-11-27 11:13:54 +01:00
Robinson 13e8501255
Commented out remaining logic for connection rules (not impl yet) 2023-11-27 11:12:53 +01:00
Robinson a1db866375
updated deps 2023-11-22 20:40:05 +01:00
Robinson b496f83e64
100 concurrent connections in a unit tests kills the machine. 2023-11-22 20:39:06 +01:00
Robinson 2cfc2e41e6
Make sure now that errors during unit tests are properly failing (or ignoring) as appropriate the test. 2023-11-22 20:38:46 +01:00
Robinson 76f42c900c
Guarantee that connect occurs AFTER the current close events are finished running before redispatching on the connect dispatcher 2023-11-22 09:18:17 +01:00
Robinson 88bac6ef84
Version 6.15 2023-11-16 12:08:30 +01:00
Robinson f0493beca1
updated minlog 2023-11-13 22:31:04 +01:00
Robinson d4fd773ea0
spaces 2023-11-13 22:30:43 +01:00
Robinson 644d28ea70
Any exceptions will cause a unit test failure now 2023-11-13 18:45:17 +01:00
Robinson 35020adac9
Updated to 100 concurrent connections (on 50 separate threads) 2023-11-13 18:44:59 +01:00
Robinson bae5b41d1c
disconnect period is as short as possible to improve unit test performance 2023-11-13 14:15:13 +01:00
Robinson 8e32e0980c
Shutdown is now atomic instead of volatile 2023-11-13 14:10:19 +01:00
Robinson cbfe51f746
Added Handshake dispatch (was required, and must be single threaded) 2023-11-13 14:10:00 +01:00
Robinson 2aebbe6116
Added more logging 2023-11-08 12:44:15 +01:00
Robinson 4bd77515d8
Increased connection timeouts. 2023-11-07 16:48:05 +01:00
Robinson a5286899b7
When creating publications and handshaking, we CANNOT do this on the main processing thread 2023-11-03 18:21:14 +01:00
Robinson af19049519
Added multi-dispatch, for on the server when conducting handshakes (and waiting for a connection to complete). Under load, we cannot block the main thread 2023-11-03 18:15:34 +01:00
Robinson f40e8cf14d
More detailed error message 2023-11-03 18:14:48 +01:00
Robinson 2162131b17
updated shadowclass file 2023-11-02 22:37:09 +01:00
Robinson 91deea8b1a
Added support for callbacks on a message, so there can be 'happens-before' logic. 2023-11-02 22:36:50 +01:00
Robinson 58535a923b
Added support for connection tags (so the client can set a name for its connection, and the server will get that name). This is usefull for identifying different connections (and doing different things) based on their tag name. 2023-10-28 20:55:49 +02:00
Robinson fe98763712
All connections are now buffered - in the event there is a network issue, or a quick reconnect, and messages are sent DURING this disconnected phase, these messages will be resent on the new connection once it is connected 2023-10-28 20:54:40 +02:00
Robinson 27b4b0421e
disabled test debug 2023-10-26 21:19:50 +02:00
Robinson 0c4c442b3a
Fixed/cleaned up connection polling and restarts 2023-10-26 21:19:36 +02:00
Robinson ba57447169
fixed connect notify 2023-10-26 21:19:07 +02:00
Robinson 1b235e21aa
waiting for endpoint to shutdown better supports restarts 2023-10-26 21:14:05 +02:00
Robinson 046ece160f
code polish 2023-10-26 21:13:35 +02:00
Robinson 737b68549c
Wrapped potential RMI errors in exception catching 2023-10-26 21:13:21 +02:00
Robinson f531f61a53
Better support for polling and sending dc message 2023-10-26 21:12:58 +02:00
Robinson 7f2ad97aa7
code cleanup and comments 2023-10-26 14:57:48 +02:00
Robinson d7884c4d8d
Updated logs 2023-10-26 12:02:09 +02:00
Robinson 9d303beade
Better session management + logs 2023-10-26 09:30:11 +02:00
Robinson 70825708a3
More careful event dispatch (no longer global, but per endpoint) 2023-10-26 08:09:47 +02:00
Robinson 4b58a63dc1
Added extra (general) log message when a network error occurs 2023-10-24 20:38:20 +02:00
Robinson 495cb954d8
cleanAllStackTrace() returns itself 2023-10-24 15:14:39 +02:00
Robinson 59d17ea367
Better logic for unit test 2023-10-24 13:47:10 +02:00
Robinson 8c2b6b39cd
wait for close is not explicitly necessary 2023-10-24 13:46:46 +02:00
Robinson 60a26202b4
Added extra debug info 2023-10-24 13:46:28 +02:00
Robinson 044ce8771f
Fixed sigint close command issues 2023-10-24 13:46:16 +02:00
Robinson 90d087054e
Added for remote server testing 2023-10-24 13:45:58 +02:00
Robinson 4906e94aef
Code cleanup 2023-10-24 13:13:06 +02:00
Robinson 14544d3296
Removed delayed close from event poller 2023-10-24 12:06:55 +02:00
Robinson 2270b815b4
better logging 2023-10-23 23:27:55 +02:00
Robinson 706cf5b3e8
Fixed edge case with session connections and sending data on a publication that is not connected (either yet, or is an old one) 2023-10-23 23:24:57 +02:00
Robinson 01ab0bf1d8
Properly cleanup the remote object storage/cache 2023-10-23 23:23:32 +02:00
Robinson d40c080311
API parameter clarification 2023-10-19 23:42:36 +02:00
Robinson 83a9a5762d
Connection timeout is based from when connection is created 2023-10-19 23:42:08 +02:00
Robinson 0e16747dc2
logging 2023-10-19 23:41:43 +02:00
Robinson a15478c535
modified delayLingerTimeout() to be more intelligent with regards to the delay linger timeout 2023-10-18 20:00:54 +02:00
Robinson cfc08a2f4b
SessionManager expiration now using the correct expire time 2023-10-18 19:55:28 +02:00
Robinson c38fa13f11
More careful return values when adding data to aeron buffer 2023-10-18 19:54:33 +02:00
Robinson 7cacc63dca
renamed function: isConnected -> isClosedWithTimeout 2023-10-18 19:54:07 +02:00
Robinson d51c878a65
Fixed cast exception with sessions 2023-10-18 19:52:50 +02:00
Robinson 6fb5dbb833
added sendDisconnectMessage to API when closing 2023-10-18 19:52:31 +02:00
Robinson 2245f0bfc5
more comments 2023-10-18 19:50:04 +02:00
Robinson d4e3e2e41d
Better/easier checking if we are a session or not 2023-10-18 19:49:33 +02:00
Robinson 2a8ac38e55
ResponseManager now uses a special TimeoutException instead of generic exception. 2023-10-18 19:47:02 +02:00
Robinson 46cb174183
more logging 2023-10-18 19:46:34 +02:00
Robinson de6d22f808
Added logs when closing storage 2023-10-18 19:33:28 +02:00
Robinson 099f9de834
Enhanced logging for session connection type 2023-10-18 19:33:13 +02:00
Robinson 047d938386
Safely attempt to close the settings store (permissions might not allow it) 2023-10-17 22:54:34 +02:00
Robinson 4d09999f0a
Updated logging 2023-10-17 16:48:31 +02:00
Robinson ee296c602a
updated gradle 2023-10-17 16:48:02 +02:00
Robinson 3a69d1a525
updated build deps 2023-10-17 16:47:52 +02:00
Robinson 53f7cd8cf1
Simplified connection log info for debug output 2023-10-17 16:47:40 +02:00
Robinson c62016dad9
version 6.14 2023-10-05 13:19:10 +02:00
Robinson decec8641b
updated license 2023-10-05 13:18:40 +02:00
Robinson 66a922b6b5
Increased default macos devShm virtual drive 2023-10-05 13:18:28 +02:00
Robinson 7eac9699c9
Fixed session/connection lateinit errors 2023-10-05 13:18:00 +02:00
Robinson 8f9ee52b36
Fixed null pointer 2023-10-05 13:17:28 +02:00
Robinson 48325ee846
version 6.13 2023-10-03 22:11:48 +02:00
Robinson 1fded5575b
updated build deps 2023-10-03 22:11:40 +02:00
Robinson 1287eb8c6e
AeronDriverInternal will now restart the network if there is an `unexpected close of heartbeat timestamp counter` error 2023-10-03 22:01:54 +02:00
Robinson c69a33f1a9
version 6.12 2023-09-28 01:55:14 +02:00
Robinson 0825274bd0
Cleaned up ordering of connection initialization 2023-09-28 01:55:03 +02:00
Robinson b55168a3eb
Now safely try to close a connection when it's not possible (just log, don't throw exception) 2023-09-26 19:53:27 +02:00
Robinson 653236a7e2
version 6.11 2023-09-25 14:00:52 +02:00
Robinson b45826da80
Code cleanup + better unit tests 2023-09-25 13:59:46 +02:00
Robinson 78374e4dfc
More clearly defined session management. Fixed problem when reconnecting + RMI create callbacks. 2023-09-25 13:59:26 +02:00
Robinson b2217f66ee
Added additional cast method 2023-09-25 13:58:32 +02:00
Robinson 1b6880bf7d
Fixed issues when deleting RMI objects/proxies 2023-09-25 13:58:24 +02:00
Robinson babee06372
Fixed issues surrounding RMI + timeouts 2023-09-25 13:58:02 +02:00
Robinson 00d444bde7
Fixed out-of-order signalling 2023-09-25 13:57:04 +02:00
Robinson 8dd70e9e0e
Cleaned up comments 2023-09-25 13:56:23 +02:00
Robinson 72a7121762
More clear logging 2023-09-25 13:56:02 +02:00
Robinson e399f4948d
Stronger checks for RMI ID minimum values 2023-09-25 13:54:37 +02:00
Robinson 2ab8d7b3bd
Added comments 2023-09-25 13:54:11 +02:00
Robinson be78d498dc
Add connection before init, so init happens before polling of events 2023-09-25 13:53:37 +02:00
Robinson 8bbaa6df18
Cleaned up sessions 2023-09-22 15:54:01 +02:00
Robinson c0227fee06
Updated deps, now pure support for jpms 2023-09-21 12:55:09 +02:00
Robinson 1125785b21
added unit test for sessions 2023-09-21 12:54:42 +02:00
Robinson 19d6d6ebaf
Code cleanup 2023-09-21 12:54:07 +02:00
Robinson e6b4cbd386
Added support for sessions 2023-09-21 12:53:47 +02:00
Robinson 3e9109b4c7
Code cleanup 2023-09-21 11:22:36 +02:00
Robinson c8047987b4
removed unnecessary dir 2023-09-19 22:44:48 +02:00
Robinson 851fcfd1fc
Code cleanup 2023-09-17 02:39:00 +02:00
Robinson b2d349d17c
Merge branch 'session' 2023-09-17 02:37:36 +02:00
Robinson f269684ea5
newConnection method reverted back to function, which allows for easier extension of class types 2023-09-17 02:36:45 +02:00
Robinson e32ecda7ac
updated dep 2023-09-15 20:09:29 +02:00
Robinson 5b43d44b22
Initial work on cache 2023-09-15 20:09:09 +02:00
Robinson 67b4443ade
Ensure the aeron context is always closed and the driver is cleaned up, even if it hasn't been started 2023-09-14 13:39:45 +02:00
Robinson 00a1f4c66b
Limit logging messages for RMI when args are long. 2023-09-14 11:47:30 +02:00
Robinson b2b6cfdc10
Removed KotlinLogging (it has a niche usage that did not apply) 2023-09-13 17:04:25 +02:00
Robinson 560b5bc743
Fixed issue with trace logging + RMI when arguments were large 2023-09-13 16:01:54 +02:00
Robinson 5021eb5136
made read accessible 2023-09-13 16:01:32 +02:00
Robinson 81380fe633
Wrapped logger.debug/trace into if statements to prevent the JVM from creating unnecessary lambdas 2023-09-13 16:01:14 +02:00
Robinson d895e04af5
Fixed issues with streaming for RMI and added another streaming test 2023-09-13 15:57:54 +02:00
Robinson 3fff69757c
Only notify exceptions when message-send-during-close, when we did not explicitly close the connection 2023-09-13 13:52:45 +02:00
Robinson 8e9e0441ed
Fixed issue with with RMI sync/async. 2023-09-13 13:49:13 +02:00
Robinson 3abbdf8825
removed stacktrace output 2023-09-08 14:18:48 +02:00
Robinson 3d8c5275ac
Better error details during connect phase 2023-09-08 14:18:35 +02:00
Robinson c69512eda4
pollerClosedLatch is now only created once we've fully started (prevent blocking forever when shutting down) 2023-09-08 13:21:12 +02:00
Robinson dafcc97eac
cleanStackTrace() now returns itself. 2023-09-08 13:20:24 +02:00
Robinson 78ae3a38e4
Added unit test for closing endpoint while they haven't fully started 2023-09-08 13:19:56 +02:00
Robinson f9b30012b1
RMI fix/cleanup 2023-09-08 02:49:52 +02:00
Robinson e3a565f291
Fixed issues with streaming (it MUST be the aeron thread) 2023-09-08 02:49:35 +02:00
Robinson 50f212b834
updated deps 2023-09-07 18:36:03 +02:00
Robinson d772088eed
Streaming data now goes onto its own context instead of on the aeron polling thread 2023-09-07 18:35:58 +02:00
Robinson c2c45b9ffe
updated deps 2023-09-07 18:32:34 +02:00
Robinson 2aef58b507
version 6.10 2023-09-07 18:15:20 +02:00
Robinson daf289c7b7
updated deps 2023-09-07 18:15:10 +02:00
Robinson df11e40222
Moved all `app` non-unit test code into the app package 2023-09-07 18:08:39 +02:00
Robinson fa03be5e89
various steps to optimize RMI calls (they are ~1.5x as slow as standard message passing 2023-09-07 18:08:22 +02:00
Robinson 56a42e5b7f
comments 2023-09-07 17:26:37 +02:00
Robinson 2e8382eb2f
removed debug code 2023-09-07 17:23:27 +02:00
Robinson a85c647598
Updated collections to use LongMap 2023-09-07 12:24:10 +02:00
Robinson 8428d9899d
Enable to dynamically enable IPC when explicitly called 2023-09-07 12:23:47 +02:00
Robinson 9d0d8efdc0
Initial value of threadID is 0, so we don't have to initialize the poller in order to close it 2023-09-07 12:23:23 +02:00
Robinson ba4df9b33b
Comments 2023-09-07 10:36:35 +02:00
Robinson 94b5226a5a
Fixed issues with RMI not throwing exceptions properly 2023-09-07 10:36:22 +02:00
Robinson 1bb052fed4
comments fixed 2023-09-07 01:02:17 +02:00
Robinson 63dd14015c
Converted to using threads instead of coroutines 2023-09-07 01:01:55 +02:00
Robinson 0173ef7b91
GC performance optimization 2023-09-07 01:01:36 +02:00
Robinson e11287b31e
Converted to executors. 2023-09-07 01:01:22 +02:00
Robinson 94ae22716d
Fixed issues with heap garbage generation and performance (suspend is better than blocking, but only with short execution stacks) 2023-09-07 01:00:53 +02:00
Robinson 7ac284bc1b
Code cleanup 2023-09-07 00:59:57 +02:00
Robinson 9e20a20bbb
Streaming data now supports random placement 2023-09-07 00:45:40 +02:00
Robinson cbb5038eb6
code cleanup 2023-09-06 16:44:40 +02:00
Robinson 9b1650ae31
Code cleanup 2023-09-06 16:24:45 +02:00
Robinson 2a485bd097
code cleanup 2023-09-06 16:20:23 +02:00
Robinson c5b9691bb1
Moved dispatcher to EventDispatcher 2023-09-06 16:20:15 +02:00
Robinson 54eab9d6c8
code cleanup 2023-09-06 16:19:40 +02:00
Robinson 0bd725b2d8
Removed withKryo{} lambda (was causing heap issues) 2023-09-06 16:17:36 +02:00
Robinson b30b024849
Added readKryos for streaming 2023-09-06 16:17:03 +02:00
Robinson 464fbadbd1
Removed coroutine trampoline from JPMS 2023-09-06 12:05:55 +02:00
Robinson 26e6da555b
Now catch and watch `Throwable` instead of just `Exception` 2023-09-05 23:39:40 +02:00
Robinson ae5a48b309
Added back a/syncSuspend 2023-09-05 23:38:53 +02:00
Robinson 48f1555ace
Removed unnecessary Suspend Trampoline 2023-09-05 23:38:32 +02:00
Robinson 3704ae25e7
RmiUtils now accepts `Throwable` instead of `Exception` 2023-09-05 23:38:13 +02:00
Robinson 7c326d180c
Converted the RMI response manager to use blocking instead of suspending calls. 2023-09-05 12:59:49 +02:00
Robinson 3e9a8f9c74
Moved ping to the connection object 2023-09-05 12:58:38 +02:00
Robinson effed36faf
Server Handshake has its own dispatcher, and its in the HandshakePollers object 2023-09-05 12:58:21 +02:00
Robinson 1b2487daec
Error notifications have their own dispatcher now (and it's in the ListenerManager) 2023-09-05 12:57:41 +02:00
Robinson d64a4bb1e1
Moved ping() to the connection object 2023-09-05 12:57:16 +02:00
Robinson 769bad6aac
tweaked function names and timeouts 2023-09-05 12:56:14 +02:00
Robinson 2d061220f5
Cleaned up how errors are managed 2023-09-05 12:55:58 +02:00
Robinson 8b62dbb063
Removed more coroutine, simplified methods 2023-09-04 14:23:06 +02:00
Robinson e185f496ec
Removed coroutines/suspending calls 2023-09-04 00:48:00 +02:00
Robinson 6291e1aa77
ResponseManager uses its own, internal dispatcher for events 2023-09-04 00:01:27 +02:00
Robinson e7999d3095
WIP - removing heap allocations 2023-09-03 21:17:37 +02:00
Robinson ac2cf56fb9
code cleanup 2023-09-03 21:10:25 +02:00
Robinson 4d2ee10c02
fixed error in logic for unit test 2023-09-03 21:10:16 +02:00
Robinson 620e74a506
Added package-info.java 2023-09-03 21:09:54 +02:00
Robinson f631dea046
If reconnect is called on a client WITHOUT being first closed, it will close first. 2023-08-30 12:02:04 +02:00
Robinson 364b29fd0c
version 6.9.1 2023-08-21 19:53:10 +02:00
Robinson b639ec1372
UDP frame size information moved to startup. 2023-08-21 19:52:19 +02:00
Robinson 0747802f0d
updated deps, version 6.9 2023-08-21 02:20:41 +02:00
Robinson 87173af0b7
version 6.8 2023-08-20 14:28:12 +02:00
Robinson 6df290cfd3
updated deps 2023-08-20 13:56:44 +02:00
Robinson bdfc293167
updated gradle 2023-08-20 13:55:11 +02:00
Robinson c95e811fde
version 6.7 2023-08-11 16:23:49 -06:00
Robinson c856c23e3c
Server connections are checked for isConnected() status during poll events 2023-08-11 10:02:57 -06:00
Robinson b39db65027
code polish 2023-08-11 09:57:38 -06:00
Robinson e8724ea4c5
Client connections are checked for isConnected() status during poll events 2023-08-11 09:57:24 -06:00
Robinson 95b1b44890
Only set the send/recv buffer sizes if they have been configured 2023-08-10 23:34:27 -06:00
Robinson 8e7c47abcc
code cleanup 2023-08-10 20:11:23 -06:00
Robinson 96f5406ae6
More careful checks when closing endpoints during restart
code polish
2023-08-10 20:05:49 -06:00
Robinson ad9771263c
driver endpoint list is now concurrent 2023-08-10 20:04:23 -06:00
Robinson 466363901c
code cleanup 2023-08-10 20:03:06 -06:00
Robinson 77d56b8804
Direct access to critical error now instead of proxy 2023-08-10 20:02:38 -06:00
Robinson a36947af5b
updated deps 2023-08-09 22:35:41 -06:00
Robinson 91aed612cc
updated license 2023-08-09 22:35:30 -06:00
Robinson b8a6f5436d
Updated API for unittests 2023-08-09 22:35:15 -06:00
Robinson 50ab7fc72f
config.id -> mediaDriverId() 2023-08-09 22:13:55 -06:00
Robinson def935214f
comment cleanup 2023-08-09 22:13:39 -06:00
Robinson 07e1da3660
Tweaked how waiting for close works 2023-08-09 22:13:12 -06:00
Robinson 2b5e943369
can optionally notifyDisconnect when closing a connection 2023-08-09 22:12:52 -06:00
Robinson 1ded010b89
code cleanup 2023-08-09 22:12:29 -06:00
Robinson 19b36bde9f
driver.start/close are now reentrant 2023-08-09 22:12:18 -06:00
Robinson 90d218637c
Cleaned up how new aeron drivers are created 2023-08-09 22:11:57 -06:00
Robinson 96cd987238
config.id -> mediaDriverId() 2023-08-09 22:09:44 -06:00
Robinson 4d73d4802c
cleaned up imports 2023-08-09 22:06:22 -06:00
Robinson e2b5f522e0
AddError is no longer suspending 2023-08-09 21:47:29 -06:00
Robinson ce311fea86
Added more support for criticalDriverErrors 2023-08-09 21:47:06 -06:00
Robinson d9bac748f8
added endpoint to inUse() 2023-08-09 21:45:54 -06:00
Robinson 6e76160c83
changed config.id -> mediaDriverId() 2023-08-09 21:45:28 -06:00
Robinson 836c8abce6
Cleaned up/tweaked endpoint.close() 2023-08-09 21:35:40 -06:00
Robinson 3852677feb
connect event dispatch check only redispatches when it's ON the EDT, but NOT in the correct one 2023-08-09 21:31:00 -06:00
Robinson e9f7172b62
code polish for event poller 2023-08-09 21:30:17 -06:00
Robinson 4c3135028a
driver.close method cleanup 2023-08-09 21:29:28 -06:00
Robinson a7533d2c91
closed check is now volatile 2023-08-09 21:28:03 -06:00
Robinson db385d0c1a
inUse check now uses the endpoint for extra checks 2023-08-09 21:23:52 -06:00
Robinson 28d170c25c
Added support for detecting critical driver errors 2023-08-09 21:18:47 -06:00
Robinson 3dcd2af495
Moved aeron.send() logic to the driver 2023-08-09 21:17:10 -06:00
Robinson 9fcbabd061
cleaned up logging 2023-08-09 21:10:41 -06:00
Robinson eaafc0f0c4
reset the endpoint config (not the initial config) when resetting. 2023-08-09 16:37:09 -06:00
Robinson 9a30c031ef
Added extra checks when adding pub/sub for when there is an ERRORED state 2023-08-07 22:30:53 -06:00
Robinson 296c600245
Added more detailed info when reconnecing 2023-08-07 19:56:38 -06:00
Robinson 8aa919b28a
simplified connect redispatch logic 2023-08-07 19:56:14 -06:00
Robinson 59bc934dc1
moved checks to earlier in the connect process 2023-08-07 19:55:51 -06:00
Robinson 4d2de085a5
Added data success checks when streaming messages.
Expanded exceptions when thrown
2023-08-07 19:54:49 -06:00
Robinson 6dc7e6bc41
If we close the event poller WHILE ON the event poller, re-dispatch the close event to the CLOSE dispatch 2023-08-07 19:53:59 -06:00
Robinson 08d58fd6fd
Fixed issues with recursive aeron directory name 2023-08-07 00:09:14 -06:00
Robinson ac42a8be7e
updated deps 2023-08-06 01:11:14 -06:00
Robinson 342abd495d
updated deps 2023-08-06 01:00:29 -06:00
Robinson 1c8b9d5023
removed upstream dependency (no longer needed) 2023-08-05 18:41:39 -06:00
Robinson b7f4a09f46
Updated classutils 2023-08-05 13:24:29 -06:00
Robinson ae08ff2c2f
Updated classutils 2023-08-05 13:24:21 -06:00
Robinson 72b4c93206
Removed moshi, updated deps 2023-08-04 23:32:44 -06:00
Robinson 00dffa78e0
Updated for new config project 2023-08-04 23:32:16 -06:00
Robinson 4a80c2c0b8
Reconnect now can have a specified timeout 2023-08-04 23:32:00 -06:00
Robinson 16c8386ae1
Updated API for collections 2023-08-04 23:31:44 -06:00
Robinson 2e904b8ac5
Tweaks for testing performance 2023-07-24 02:03:04 +02:00
Robinson 53cd6ac382
Tweaks for testing performance 2023-07-24 02:02:28 +02:00
Robinson e5786550a6
Updated version 2023-07-24 02:00:03 +02:00
Robinson 8da5215455
enhanced the basic performance test tool 2023-07-24 01:43:21 +02:00
Robinson 3016618b1c
Added support for also changing the aeron driver idle strategies 2023-07-24 01:42:27 +02:00
Robinson 57480735c3
By default, create dev/shm for macos (ram drive). Windows still uses the disk. 2023-07-23 23:36:36 +02:00
Robinson 15c7fb2a3d
Code cleanup 2023-07-23 23:02:29 +02:00
Robinson ccf7a37d3c
Commented out unnecessary code 2023-07-23 16:05:07 +02:00
Robinson fa04185234
Fixed equals 2023-07-23 15:49:25 +02:00
Robinson a140c844db
Code cleanup 2023-07-23 13:41:29 +02:00
Robinson 7f6550f1c1
Now use defaults for idle strategies 2023-07-23 13:40:42 +02:00
Robinson 6754e35c61
Code cleanup 2023-07-23 13:40:16 +02:00
Robinson 06b5f30948
Only check if a connection is closed now. We now wait for pub+sub to be "connected" before continuing to build the connection object (so it will always be in the connected state) 2023-07-23 13:39:27 +02:00
Robinson ad3fdfc64d
Updated text names for idle strategies 2023-07-23 13:30:38 +02:00
Robinson daec762e30
More specific errors when connection is closed during poll event 2023-07-23 01:20:16 +02:00
Robinson 949a863aca
Fixed instance assignment 2023-07-23 01:17:37 +02:00
Robinson a087dfa9bd
Faster startup when aeron is already running and we force-allow a driver to be running on startup (usually we don't want this) 2023-07-23 01:17:09 +02:00
Robinson 936a5e2d67
Added ability for subscription to wait for a publication to connect 2023-07-23 01:16:23 +02:00
Robinson 2d8956c78c
Client waits for server publication to connect before continuing. 2023-07-22 14:18:38 +02:00
Robinson 781d530294
added comments 2023-07-21 22:46:41 +02:00
Robinson ed2ddb239d
Tweaked aeron idle strategies 2023-07-21 22:46:30 +02:00
Robinson ee558e666d
Added additional idle strategies 2023-07-21 21:16:49 +02:00
Robinson ed89b634a2
Added comments 2023-07-21 21:16:37 +02:00
Robinson 4e232aa18e
Fixed unit tests shutdown lifecycle ordering 2023-07-21 00:20:47 +02:00
Robinson c4129f25fa
Updated test app 2023-07-21 00:20:23 +02:00
Robinson 2620a06409
Cleaned up how kryo's are used
Changed idleStrategy
StreamingManager no longer copies bytes (it just uses a pooled kryo instance)
2023-07-21 00:19:31 +02:00
Robinson 7ed474111a
Code cleanup 2023-07-20 22:31:27 +02:00
Robinson 081ee42a2e
removed dead code 2023-07-20 20:41:33 +02:00
Robinson 916ddb857f
fixed comment typo 2023-07-20 20:41:23 +02:00
Robinson 2c0680b513
Simplified setting aeron initial window length 2023-07-20 20:41:13 +02:00
Robinson 0e37689c2c
better logs when retrying the connect sequence 2023-07-20 20:40:43 +02:00
Robinson 7bd653db2a
Better IPC checking 2023-07-20 20:40:17 +02:00
Robinson e2a4887a19
Changed order of cleanup when done with handshake 2023-07-20 20:39:58 +02:00
Robinson 0e3cc803b2
More detailed logging when in debug mode 2023-07-20 20:39:26 +02:00
Robinson 80d77f2f51
Better lock-step and checks when closing an endpoint 2023-07-20 20:39:12 +02:00
Robinson d787045149
Updated version 2023-07-16 14:58:01 +02:00
Robinson 94048cfe8f
Added file transport to streaming manager 2023-07-16 14:57:04 +02:00
Robinson bb026f377b
WIP compression/crypto 2023-07-15 13:12:59 +02:00
Robinson 411a4c54b8
Better error checking. added kryo-exception checking/failures for unittests 2023-07-15 13:12:25 +02:00
Robinson c4eda86bfe
Optimized how we send data (we use our own stream/block data structures for fragmentation/reassembly. 2023-07-15 13:11:50 +02:00
Robinson 93a7c9008d
Code cleanup and fixed issues when sending non-perfect multiples of our data limit. 2023-07-15 13:07:17 +02:00
Robinson 85d716e572
Changed which data structure is evaluated when saving data 2023-07-15 13:05:57 +02:00
Robinson 5583948961
Added AeronWriter size initialization 2023-07-15 13:05:19 +02:00
Robinson 307b8f558f
Updated comments/dependencies 2023-07-14 13:47:59 +02:00
Robinson 2f8c78ddee
Added ipcMTU to aeron config (it will be the same as the network MTU). This must be the same value, since our internal read/write serialization buffers) 2023-07-14 13:39:08 +02:00
Robinson 215ed20056
Changed wording of chunk -> block 2023-07-14 13:33:36 +02:00
Robinson 290c5bd768
Cleaned up comments 2023-07-12 14:20:04 +02:00
Robinson 09748326c9
Cleaned up crypto management, removed dead code 2023-07-12 14:19:54 +02:00
Robinson 6bf870bd7b
Split kryo TYPES into read/write types, so usage is very clear. Now use a kryo pool for concurrent serialization 2023-07-12 14:08:46 +02:00
Robinson d3c3bf50d6
Changed applicationId -> appId 2023-07-12 14:04:32 +02:00
Robinson 2f7a365f75
Moved `errorCodeName` into the driver 2023-07-11 11:50:48 +02:00
Robinson 90830128e6
Removed dead code 2023-07-11 11:50:32 +02:00
Robinson f1ebd076bf
Removed dependency on aeron-aal (which was only for samples) 2023-07-11 09:48:26 +02:00
Robinson 87b65d061a
Now supports JPMS (kotlin-only 9+ projects must use a workaround) 2023-07-11 00:27:39 +02:00
Robinson c4ddfe8675
Fixed unnecessary non-null assertions 2023-07-11 00:23:08 +02:00
Robinson 990652288e
Code cleanup 2023-07-11 00:12:09 +02:00
Robinson 4797b7e816
Added port1/2 settings to server + client.
Fixed relevant unit tests
2023-07-11 00:11:58 +02:00
Robinson ebad4d234b
Fixed driver liveliness checks 2023-07-11 00:02:40 +02:00
Robinson cba66a6959
Removed SigInt catch. It should be managed by the application, not the library. 2023-07-05 12:53:21 +02:00
Robinson 897db748e7
Removed dead code 2023-07-05 12:52:36 +02:00
Robinson ce6ffec197
When an endpoint restarts too quickly, wait a more appropriate timeout 2023-07-05 12:52:23 +02:00
160 changed files with 10887 additions and 8211 deletions

1651
LICENSE

File diff suppressed because it is too large Load Diff

View File

@ -72,7 +72,7 @@ Maven Info
<dependency>
<groupId>com.dorkbox</groupId>
<artifactId>Network</artifactId>
<version>6.4</version>
<version>6.15</version>
</dependency>
</dependencies>
```
@ -82,7 +82,7 @@ Gradle Info
```
dependencies {
...
implementation("com.dorkbox:Network:6.4")
implementation("com.dorkbox:Network:6.15")
}
```

View File

@ -23,21 +23,22 @@
gradle.startParameter.showStacktrace = ShowStacktrace.ALWAYS // always show the stacktrace!
plugins {
id("com.dorkbox.GradleUtils") version "3.17"
id("com.dorkbox.Licensing") version "2.24"
id("com.dorkbox.GradleUtils") version "3.18"
id("com.dorkbox.Licensing") version "2.28"
id("com.dorkbox.VersionUpdate") version "2.8"
id("com.dorkbox.GradlePublish") version "1.18"
id("com.dorkbox.GradlePublish") version "1.22"
id("com.github.johnrengelman.shadow") version "7.1.2"
id("com.github.johnrengelman.shadow") version "8.1.1"
kotlin("jvm") version "1.8.0"
kotlin("jvm") version "1.9.0"
}
@Suppress("ConstPropertyName")
object Extras {
// set for the project
const val description = "High-performance, event-driven/reactive network stack for Java 11+"
const val group = "com.dorkbox"
const val version = "6.4"
const val version = "6.15"
// set as project.ext
const val name = "Network"
@ -52,6 +53,12 @@ object Extras {
///////////////////////////////
GradleUtils.load("$projectDir/../../gradle.properties", Extras)
GradleUtils.defaults()
// because of the api changes for stacktrace stuff, it's best for us to ONLY support 11+
GradleUtils.compileConfiguration(JavaVersion.VERSION_11) {
// see: https://kotlinlang.org/docs/reference/using-gradle.html
// enable the use of inline classes. see https://kotlinlang.org/docs/reference/inline-classes.html
freeCompilerArgs = listOf("-Xinline-classes")
}
//val kotlin = project.extensions.getByType(org.jetbrains.kotlin.gradle.dsl.KotlinJvmProjectExtension::class.java).sourceSets.getByName("main").kotlin
@ -60,18 +67,10 @@ GradleUtils.defaults()
// include("**/*.kt") // want to include kotlin files for the source. 'setSrcDirs' resets includes...
//}
// TODO: driver name resolution: https://github.com/real-logic/aeron/wiki/Driver-Name-Resolution
// this keeps us from having to restart the media driver when a connection changes IP addresses
// because of the api changes for stacktrace stuff, it's best for us to ONLY support 11+
GradleUtils.compileConfiguration(JavaVersion.VERSION_11) {
// see: https://kotlinlang.org/docs/reference/using-gradle.html
// enable the use of inline classes. see https://kotlinlang.org/docs/reference/inline-classes.html
freeCompilerArgs = listOf("-Xinline-classes")
}
//GradleUtils.jpms(JavaVersion.VERSION_11)
//NOTE: we do not support JPMS yet, as there are some libraries missing support for it still, notably kotlin!
// TODO: virtual threads in java21 for polling?
// if we are sending a SMALL byte array, then we SEND IT DIRECTLY in a more optimized manner (because we can count size info!)
// other side has to be able to parse/know that this was sent directly as bytes. It could be game state data, or voice data, etc.
@ -84,14 +83,14 @@ GradleUtils.compileConfiguration(JavaVersion.VERSION_11) {
// --- this remote outputStream is a file, raw??? this is setup by createInputStream() on the remote end
// - state-machine for kryo class registrations?
// ratelimiter, "other" package
// ratelimiter, "other" package, send-on-idle
// rest of unit tests
// getConnectionUpgradeType
// ability to send with a function callback (using RMI waiter type stuff for callbacks)
// java 14 is faster with aeron!
// NOTE: now using aeron instead of netty
// todo: remove BC! use or native java? (if possible. we are java 11 now, instead of 1.6)
licensing {
@ -152,7 +151,7 @@ shadowJar.apply {
manifest.inheritFrom(tasks.jar.get().manifest)
manifest.attributes.apply {
put("Main-Class", "dorkboxTest.network.AeronRmiClientServer")
put("Main-Class", "dorkboxTest.network.app.AeronClientServerForever")
}
mergeServiceFiles()
@ -167,66 +166,66 @@ shadowJar.apply {
dependencies {
api("org.jetbrains.kotlinx:atomicfu:0.21.0")
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.1")
api("org.jetbrains.kotlinx:atomicfu:0.23.0")
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")
// https://github.com/dorkbox
api("com.dorkbox:ByteUtilities:1.10")
api("com.dorkbox:Collections:1.6")
api("com.dorkbox:MinLog:2.5")
api("com.dorkbox:NetworkDNS:2.9")
api("com.dorkbox:NetworkUtils:2.22")
api("com.dorkbox:OS:1.6")
api("com.dorkbox:ByteUtilities:2.1")
api("com.dorkbox:ClassUtils:1.3")
api("com.dorkbox:Collections:2.7")
api("com.dorkbox:HexUtilities:1.1")
api("com.dorkbox:JNA:1.4")
api("com.dorkbox:MinLog:2.7")
api("com.dorkbox:NetworkDNS:2.16")
api("com.dorkbox:NetworkUtils:2.23")
api("com.dorkbox:OS:1.11")
api("com.dorkbox:Serializers:2.9")
api("com.dorkbox:Storage:1.5")
api("com.dorkbox:Storage:1.11")
api("com.dorkbox:Updates:1.1")
api("com.dorkbox:Utilities:1.42")
api("com.dorkbox:Utilities:1.48")
// how we bypass using reflection/jpms to access fields for java17+
api("org.javassist:javassist:3.29.2-GA")
api("com.dorkbox:JNA:1.0")
val jnaVersion = "5.12.1"
val jnaVersion = "5.13.0"
api("net.java.dev.jna:jna-jpms:${jnaVersion}")
api("net.java.dev.jna:jna-platform-jpms:${jnaVersion}")
// we include ALL of aeron, in case we need to debug aeron behavior
// https://github.com/real-logic/aeron
val aeronVer = "1.41.4"
api("io.aeron:aeron-all:$aeronVer")
// api("org.agrona:agrona:1.16.0") // sources for this isn't included in aeron-all!
val aeronVer = "1.42.1"
api("io.aeron:aeron-driver:$aeronVer")
// ALL of aeron, in case we need to debug aeron behavior
// api("io.aeron:aeron-all:$aeronVer")
// api("org.agrona:agrona:1.18.2") // sources for this aren't included in aeron-all!
// https://github.com/EsotericSoftware/kryo
api("com.esotericsoftware:kryo:5.5.0") {
exclude("com.esotericsoftware", "minlog") // we use our own minlog, that logs to SLF4j instead
}
// https://github.com/jpountz/lz4-java
// implementation("net.jpountz.lz4:lz4:1.3.0")
// this is NOT the same thing as LMAX disruptor.
// This is just a slightly faster queue than LMAX. (LMAX is a fast queue + other things w/ a difficult DSL)
// https://github.com/conversant/disruptor_benchmark
// https://www.youtube.com/watch?v=jVMOgQgYzWU
//api("com.conversantmedia:disruptor:1.2.19")
// https://github.com/lz4/lz4-java
api("org.lz4:lz4-java:1.8.0")
// https://github.com/jhalterman/typetools
api("net.jodah:typetools:0.6.3")
// Expiring Map (A high performance thread-safe map that expires entries)
// https://github.com/jhalterman/expiringmap
api("net.jodah:expiringmap:0.5.10")
api("net.jodah:expiringmap:0.5.11")
// https://github.com/MicroUtils/kotlin-logging
api("io.github.microutils:kotlin-logging:3.0.5")
api("org.slf4j:slf4j-api:2.0.7")
// api("io.github.microutils:kotlin-logging:3.0.5")
implementation("org.slf4j:slf4j-api:2.0.9")
testImplementation("junit:junit:4.13.2")
testImplementation("ch.qos.logback:logback-classic:1.4.5")
testImplementation("io.aeron:aeron-all:$aeronVer")
testImplementation("com.dorkbox:Config:2.1")
}
publishToSonatype {

Binary file not shown.

View File

@ -1,5 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-all.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-all.zip
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

25
gradlew vendored
View File

@ -55,7 +55,7 @@
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
@ -80,13 +80,11 @@ do
esac
done
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
APP_NAME="Gradle"
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
@ -133,22 +131,29 @@ location of your Java installation."
fi
else
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
if ! command -v java >/dev/null 2>&1
then
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
@ -193,6 +198,10 @@ if "$cygwin" || "$msys" ; then
done
fi
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in

1
gradlew.bat vendored
View File

@ -26,6 +26,7 @@ if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%

View File

@ -1,293 +0,0 @@
package dorkbox.network.other
import java.math.BigInteger
import java.security.GeneralSecurityException
import java.security.KeyFactory
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.PrivateKey
import java.security.SecureRandom
import java.security.interfaces.ECPrivateKey
import java.security.interfaces.ECPublicKey
import java.security.spec.ECField
import java.security.spec.ECFieldFp
import java.security.spec.ECParameterSpec
import java.security.spec.ECPoint
import java.security.spec.ECPublicKeySpec
import java.security.spec.EllipticCurve
import java.security.spec.NamedParameterSpec
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
import javax.crypto.Cipher
/**
*
*/
private object CryptoEccNative {
// see: https://openjdk.java.net/jeps/324
const val curve25519 = "curve25519"
const val default_curve = curve25519
const val macSize = 512
// on NIST vs 25519 vs Brainpool, see:
// - http://ogryb.blogspot.de/2014/11/why-i-dont-trust-nist-p-256.html
// - http://credelius.com/credelius/?p=97
// - http://safecurves.cr.yp.to/
// we should be using 25519, because NIST and brainpool are "unsafe". Brainpool is "more random" than 25519, but is still not considered safe.
// more info about ECC from:
// http://www.johannes-bauer.com/compsci/ecc/?menuid=4
// http://stackoverflow.com/questions/7419183/problems-implementing-ecdh-on-android-using-bouncycastle
// http://tools.ietf.org/html/draft-jivsov-openpgp-ecc-06#page-4
// http://www.nsa.gov/ia/programs/suiteb_cryptography/
// https://github.com/nelenkov/ecdh-kx/blob/master/src/org/nick/ecdhkx/Crypto.java
// http://nelenkov.blogspot.com/2011/12/using-ecdh-on-android.html
// http://www.secg.org/collateral/sec1_final.pdf
// More info about 25519 key types (ed25519 and X25519)
// https://blog.filippo.io/using-ed25519-keys-for-encryption/
fun createKeyPair(secureRandom: SecureRandom): KeyPair {
val kpg: KeyPairGenerator = KeyPairGenerator.getInstance("XDH")
kpg.initialize(NamedParameterSpec.X25519, secureRandom)
return kpg.generateKeyPair()
// println("--- Public Key ---")
// val publicKey = kp.public
//
// System.out.println(publicKey.algorithm) // XDH
// System.out.println(publicKey.format) // X.509
//
// // save this public key
// val pubKey = publicKey.encoded
//
// println("---")
//
// println("--- Private Key ---")
// val privateKey = kp.private
//
// System.out.println(privateKey.algorithm); // XDH
// System.out.println(privateKey.format); // PKCS#8
//
// // save this private key
// val priKey = privateKey.encoded
// val kf: KeyFactory = KeyFactory.getInstance("XDH");
// //BigInteger u = ...
// val pubSpec: XECPublicKeySpec = XECPublicKeySpec(paramSpec, u);
// val pubKey: PublicKey = kf.generatePublic(pubSpec);
// //
//
// val ka: KeyAgreement = KeyAgreement.getInstance("XDH");
// ka.init(kp.private);
//ka.doPhase(pubKey, true);
//byte[] secret = ka.generateSecret();
}
private val FieldP_2: BigInteger = BigInteger.TWO // constant for scalar operations
private val FieldP_3: BigInteger = BigInteger.valueOf(3) // constant for scalar operations
private const val byteVal1 = 1.toByte()
@Throws(GeneralSecurityException::class)
fun getPublicKey(pk: ECPrivateKey): ECPublicKey? {
val params: ECParameterSpec = pk.params
val w: ECPoint = scalmultNew(params, params.generator, pk.s)
//final ECPoint w = scalmult(params.getCurve(), pk.getParams().getGenerator(), pk.getS());
val kg: KeyFactory = KeyFactory.getInstance("EC")
return kg.generatePublic(ECPublicKeySpec(w, params)) as ECPublicKey
}
private fun scalmultNew(params: ECParameterSpec, g: ECPoint, kin: BigInteger): ECPoint {
val curve = params.curve
val field = curve.field
if (field !is ECFieldFp) throw java.lang.UnsupportedOperationException(field::class.java.canonicalName)
val p = field.p
val a = curve.a
var R = ECPoint.POINT_INFINITY
// value only valid for curve secp256k1, code taken from https://www.secg.org/sec2-v2.pdf,
// see "Finally the order n of G and the cofactor are: n = "FF.."
val SECP256K1_Q = params.order
//BigInteger SECP256K1_Q = new BigInteger("00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",16);
var k = kin.mod(SECP256K1_Q) // uses this !
// BigInteger k = kin.mod(p); // do not use this ! wrong as per comment from President James Moveon Polk
val length = k.bitLength()
val binarray = ByteArray(length)
for (i in 0..length - 1) {
binarray[i] = k.mod(FieldP_2).byteValueExact()
k = k.shiftRight(1)
}
for (i in length - 1 downTo 0) {
R = doublePoint(p, a, R)
if (binarray[i] == byteVal1) R = addPoint(p, a, R, g)
}
return R
}
fun scalmultOrg(curve: EllipticCurve, g: ECPoint, kin: BigInteger): ECPoint {
val field: ECField = curve.getField()
if (field !is ECFieldFp) throw UnsupportedOperationException(field::class.java.canonicalName)
val p: BigInteger = (field as ECFieldFp).getP()
val a: BigInteger = curve.getA()
var R = ECPoint.POINT_INFINITY
// value only valid for curve secp256k1, code taken from https://www.secg.org/sec2-v2.pdf,
// see "Finally the order n of G and the cofactor are: n = "FF.."
val SECP256K1_Q = BigInteger("00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16)
var k = kin.mod(SECP256K1_Q) // uses this !
// wrong as per comment from President James Moveon Polk
// BigInteger k = kin.mod(p); // do not use this !
println(" SECP256K1_Q: $SECP256K1_Q")
println(" p: $p")
System.out.println("curve: " + curve.toString())
val length = k.bitLength()
val binarray = ByteArray(length)
for (i in 0..length - 1) {
binarray[i] = k.mod(FieldP_2).byteValueExact()
k = k.shiftRight(1)
}
for (i in length - 1 downTo 0) {
R = doublePoint(p, a, R)
if (binarray[i] == byteVal1) R = addPoint(p, a, R, g)
}
return R
}
// scalar operations for native java
// https://stackoverflow.com/a/42797410/8166854
// written by author: SkateScout
private fun doublePoint(p: BigInteger, a: BigInteger, R: ECPoint): ECPoint? {
if (R == ECPoint.POINT_INFINITY) return R
var slope = R.affineX.pow(2).multiply(FieldP_3)
slope = slope.add(a)
slope = slope.multiply(R.affineY.multiply(FieldP_2).modInverse(p))
val Xout = slope.pow(2).subtract(R.affineX.multiply(FieldP_2)).mod(p)
val Yout = R.affineY.negate().add(slope.multiply(R.affineX.subtract(Xout))).mod(p)
return ECPoint(Xout, Yout)
}
private fun addPoint(p: BigInteger, a: BigInteger, r: ECPoint, g: ECPoint): ECPoint? {
if (r == ECPoint.POINT_INFINITY) return g
if (g == ECPoint.POINT_INFINITY) return r
if (r == g || r == g) return doublePoint(p, a, r)
val gX = g.affineX
val sY = g.affineY
val rX = r.affineX
val rY = r.affineY
val slope = rY.subtract(sY).multiply(rX.subtract(gX).modInverse(p)).mod(p)
val Xout = slope.modPow(FieldP_2, p).subtract(rX).subtract(gX).mod(p)
var Yout = sY.negate().mod(p)
Yout = Yout.add(slope.multiply(gX.subtract(Xout))).mod(p)
return ECPoint(Xout, Yout)
}
private fun byteArrayToHexString(a: ByteArray): String {
val sb = StringBuilder(a.size * 2)
for (b in a) sb.append(String.format("%02X", b))
return sb.toString()
}
fun hexStringToByteArray(s: String): ByteArray {
val len = s.length
val data = ByteArray(len / 2)
var i = 0
while (i < len) {
data[i / 2] = ((Character.digit(s[i], 16) shl 4)
+ Character.digit(s[i + 1], 16)).toByte()
i += 2
}
return data
}
@Throws(GeneralSecurityException::class)
@JvmStatic
fun main(args: Array<String>) {
val cryptoText = "i23j4jh234kjh234kjh23lkjnfa9s8egfuypuh325"
// NOTE: THIS IS NOT 25519!!
println("Generate ECPublicKey from PrivateKey (String) for curve secp256k1 (final)")
println("Check keys with https://gobittest.appspot.com/Address")
// https://gobittest.appspot.com/Address
val privateKey = "D12D2FACA9AD92828D89683778CB8DFCCDBD6C9E92F6AB7D6065E8AACC1FF6D6"
val publicKeyExpected = "04661BA57FED0D115222E30FE7E9509325EE30E7E284D3641E6FB5E67368C2DB185ADA8EFC5DC43AF6BF474A41ED6237573DC4ED693D49102C42FFC88510500799"
println("\nprivatekey given : $privateKey")
println("publicKeyExpected: $publicKeyExpected")
// // routine with bouncy castle
// println("\nGenerate PublicKey from PrivateKey with BouncyCastle")
// val spec: ECNamedCurveParameterSpec = ECNamedCurveTable.getParameterSpec("secp256k1") // this ec curve is used for bitcoin operations
// val pointQ: org.bouncycastle.math.ec.ECPoint = spec.getG().multiply(BigInteger(1, ch.qos.logback.core.encoder.ByteArrayUtil.hexStringToByteArray(privateKey)))
// val publickKeyByte = pointQ.getEncoded(false)
// val publicKeyBc: String = byteArrayToHexString(publickKeyByte)
// println("publicKeyExpected: $publicKeyExpected")
// println("publicKey BC : $publicKeyBc")
// println("publicKeys match : " + publicKeyBc.contentEquals(publicKeyExpected))
// regeneration of ECPublicKey with java native starts here
println("\nGenerate PublicKey from PrivateKey with Java native routines")
// the preset "303E.." only works for elliptic curve secp256k1
// see answer by user dave_thompson_085
// https://stackoverflow.com/questions/48832170/generate-ec-public-key-from-byte-array-private-key-in-native-java-7
val privateKeyFull = "303E020100301006072A8648CE3D020106052B8104000A042730250201010420" + privateKey
val privateKeyFullByte: ByteArray = hexStringToByteArray(privateKeyFull)
println("privateKey full : $privateKeyFull")
val keyFactory = KeyFactory.getInstance("EC")
val privateKeyNative: PrivateKey = keyFactory.generatePrivate(PKCS8EncodedKeySpec(privateKeyFullByte))
val ecPrivateKeyNative = privateKeyNative as ECPrivateKey
val ecPublicKeyNative = getPublicKey(ecPrivateKeyNative)
val ecPublicKeyNativeByte = ecPublicKeyNative!!.encoded
val testPubKey = keyFactory.generatePublic(X509EncodedKeySpec(ecPublicKeyNativeByte)) as ECPublicKey
val equal = ecPublicKeyNativeByte.contentEquals(testPubKey.encoded)
val publicKeyNativeFull: String = byteArrayToHexString(ecPublicKeyNativeByte)
val publicKeyNativeHeader = publicKeyNativeFull.substring(0, 46)
val publicKeyNativeKey = publicKeyNativeFull.substring(46, 176)
println("ecPublicKeyFull : $publicKeyNativeFull")
println("ecPublicKeyHeader: $publicKeyNativeHeader")
println("ecPublicKeyKey : $publicKeyNativeKey")
println("publicKeyExpected: $publicKeyExpected")
println("publicKeys match : " + publicKeyNativeKey.contentEquals(publicKeyExpected))
// encrypt
val encryptCipher: Cipher = Cipher.getInstance("RSA")
encryptCipher.init(Cipher.ENCRYPT_MODE, ecPublicKeyNative)
val cipherText: ByteArray = encryptCipher.doFinal(cryptoText.toByteArray())
// decrypt
val decryptCipher = Cipher.getInstance("RSA");
decryptCipher.init(Cipher.DECRYPT_MODE, ecPrivateKeyNative);
val outputBytes = decryptCipher.doFinal(cipherText)
println("Crypto round passed: ${String(outputBytes) == cryptoText}")
}
}

View File

@ -1,159 +0,0 @@
/* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* - Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,
* BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
* SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
import static org.junit.Assert.fail;
import java.io.IOException;
import java.security.SecureRandom;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import dorkbox.network.connection.Connection;
import dorkbox.network.connection.Listener;
import dorkbox.network.serialization.Serialization;
import dorkbox.util.exceptions.SecurityException;
import dorkbox.util.serialization.SerializationManager;
public
class LargeResizeBufferTest extends BaseTest {
private static final int OBJ_SIZE = 1024 * 100;
private volatile int finalCheckAmount = 0;
private volatile int serverCheck = -1;
private volatile int clientCheck = -1;
@Test
public
void manyLargeMessages() throws SecurityException, IOException {
final int messageCount = 1024;
Configuration configuration = new Configuration();
configuration.tcpPort = tcpPort;
configuration.udpPort = udpPort;
configuration.host = host;
register(configuration.serialization);
Server server = new Server(configuration);
addEndPoint(server);
server.bind(false);
server.listeners()
.add(new Listener.OnMessageReceived<Connection, LargeMessage>() {
AtomicInteger received = new AtomicInteger();
AtomicInteger receivedBytes = new AtomicInteger();
@Override
public
void received(Connection connection, LargeMessage object) {
// System.err.println("Server ack message: " + received.get());
connection.send()
.TCP(object);
this.receivedBytes.addAndGet(object.bytes.length);
if (this.received.incrementAndGet() == messageCount) {
System.out.println("Server received all " + messageCount + " messages!");
System.out.println("Server received and sent " + this.receivedBytes.get() + " bytes.");
LargeResizeBufferTest.this.serverCheck = LargeResizeBufferTest.this.finalCheckAmount - this.receivedBytes.get();
System.out.println("Server missed " + LargeResizeBufferTest.this.serverCheck + " bytes.");
stopEndPoints();
}
}
});
Client client = new Client(configuration);
addEndPoint(client);
client.listeners()
.add(new Listener.OnMessageReceived<Connection, LargeMessage>() {
AtomicInteger received = new AtomicInteger();
AtomicInteger receivedBytes = new AtomicInteger();
@Override
public
void received(Connection connection, LargeMessage object) {
this.receivedBytes.addAndGet(object.bytes.length);
int count = this.received.getAndIncrement();
// System.out.println("Client received message: " + count);
if (count == messageCount) {
System.out.println("Client received all " + messageCount + " messages!");
System.out.println("Client received and sent " + this.receivedBytes.get() + " bytes.");
LargeResizeBufferTest.this.clientCheck = LargeResizeBufferTest.this.finalCheckAmount - this.receivedBytes.get();
System.out.println("Client missed " + LargeResizeBufferTest.this.clientCheck + " bytes.");
}
}
});
client.connect(5000);
SecureRandom random = new SecureRandom();
System.err.println(" Client sending " + messageCount + " messages");
for (int i = 0; i < messageCount; i++) {
this.finalCheckAmount += OBJ_SIZE; // keep increasing size
byte[] b = new byte[OBJ_SIZE];
random.nextBytes(b);
// set some of the bytes to be all `244`, just so some compression can occur (to test that as well)
for (int j = 0; j < 400; j++) {
b[j] = (byte) 244;
}
// System.err.println("Sending " + b.length + " bytes");
client.send()
.TCP(new LargeMessage(b));
}
System.err.println("Client has queued " + messageCount + " messages.");
waitForThreads();
if (this.clientCheck > 0) {
fail("Client missed " + this.clientCheck + " bytes.");
}
if (this.serverCheck > 0) {
fail("Server missed " + this.serverCheck + " bytes.");
}
}
private
void register(SerializationManager manager) {
manager.register(byte[].class);
manager.register(LargeMessage.class);
}
public static
class LargeMessage {
public byte[] bytes;
public
LargeMessage() {
}
public
LargeMessage(byte[] bytes) {
this.bytes = bytes;
}
}
}

View File

@ -1,197 +0,0 @@
package dorkbox.network.other
import kotlin.math.ceil
/**
*
*/
object Misc {
private fun annotations() {
// internal val classesWithRmiFields = IdentityMap<Class<*>, Array<Field>>()
// // get all classes that have fields with @Rmi field annotation.
// // THESE classes must be customized with our special RmiFieldSerializer serializer so that the @Rmi field is properly handled
//
// // SPECIFICALLY, these fields must also be an IFACE for the field type!
//
// // NOTE: The @Rmi field type will already have to be a registered type with kryo!
// // we can use this information on WHERE to scan for classes.
// val filesToScan = mutableSetOf<File>()
//
// classesToRegister.forEach { registration ->
// val clazz = registration.clazz
//
// // can't do anything if codeSource is null!
// val codeSource = clazz.protectionDomain.codeSource ?: return@forEach
// // file:/Users/home/java/libs/xyz-123.jar
// // file:/projects/classes
// val jarOrClassPath = codeSource.location.toString()
//
// if (jarOrClassPath.endsWith(".jar")) {
// val fileName: String = URLDecoder.decode(jarOrClassPath.substring("file:".length), Charset.defaultCharset())
// filesToScan.add(File(fileName).absoluteFile)
// } else {
// val classPath: String = URLDecoder.decode(jarOrClassPath.substring("file:".length), Charset.defaultCharset())
// filesToScan.add(File(classPath).absoluteFile)
// }
// }
//
// val toTypedArray = filesToScan.toTypedArray()
// if (logger.isTraceEnabled) {
// toTypedArray.forEach {
// logger.trace { "Adding location to annotation scanner: $it"}
// }
// }
//
//
//
// // now scan these jars/directories
// val fieldsWithRmiAnnotation = AnnotationDetector.scanFiles(*toTypedArray)
// .forAnnotations(Rmi::class.java)
// .on(ElementType.FIELD)
// .collect { cursor -> Pair(cursor.type, cursor.field!!) }
//
// // have to make sure that the field type is specified as an interface (and not an implementation)
// fieldsWithRmiAnnotation.forEach { pair ->
// require(pair.second.type.isInterface) { "@Rmi annotated fields must be an interface!" }
// }
//
// if (fieldsWithRmiAnnotation.isNotEmpty()) {
// logger.info { "Verifying scanned classes containing @Rmi field annotations" }
// }
//
// // have to put this in a map, so we can quickly lookup + get the fields later on.
// // NOTE: a single class can have MULTIPLE fields with @Rmi annotations!
// val rmiAnnotationMap = IdentityMap<Class<*>, MutableList<Field>>()
// fieldsWithRmiAnnotation.forEach {
// var fields = rmiAnnotationMap[it.first]
// if (fields == null) {
// fields = mutableListOf()
// }
//
// fields.add(it.second)
// rmiAnnotationMap.put(it.first, fields)
// }
//
// // now make it an array for fast lookup for the [parent class] -> [annotated fields]
// rmiAnnotationMap.forEach {
// classesWithRmiFields.put(it.key, it.value.toTypedArray())
// }
//
// // this will set up the class registration information
// initKryo()
//
// // now everything is REGISTERED, possibly with custom serializers, we have to go back and change them to use our RmiFieldSerializer
// fieldsWithRmiAnnotation.forEach FIELD_SCAN@{ pair ->
// // the parent class must be an IMPL. The reason is that THIS FIELD will be sent as a RMI object, and this can only
// // happen on objects that exist
//
// // NOTE: it IS necessary for the rmi-client to be aware of the @Rmi annotation (because it also has to have the correct serialization)
//
// // also, it is possible for the class that has the @Rmi field to be a NORMAL object (and not an RMI object)
// // this means we found the registration for the @Rmi field annotation
//
// val parentRmiRegistration = classesToRegister.firstOrNull { it is ClassRegistrationForRmi && it.implClass == pair.first}
//
//
// // if we have a parent-class registration, this means we are the rmi-server
// //
// // AND BECAUSE OF THIS
// //
// // we must also have the field type registered as RMI
// if (parentRmiRegistration != null) {
// // rmi-server
//
// // is the field type registered also?
// val fieldRmiRegistration = classesToRegister.firstOrNull { it.clazz == pair.second.type}
// require(fieldRmiRegistration is ClassRegistrationForRmi) { "${pair.second.type} is not registered for RMI! Unable to continue"}
//
// logger.trace { "Found @Rmi field annotation '${pair.second.type}' in class '${pair.first}'" }
// } else {
// // rmi-client
//
// // NOTE: rmi-server MUST have the field IMPL registered (ie: via RegisterRmi)
// // rmi-client will have the serialization updated from the rmi-server during connection handshake
// }
// }
}
/**
* Split array into chunks, max of 256 chunks.
* byte[0] = chunk ID
* byte[1] = total chunks (0-255) (where 0->1, 2->3, 127->127 because this is indexed by a byte)
*/
private fun divideArray(source: ByteArray, chunksize: Int): Array<ByteArray>? {
val fragments = ceil(source.size / chunksize.toDouble()).toInt()
if (fragments > 127) {
// cannot allow more than 127
return null
}
// pre-allocate the memory
val splitArray = Array(fragments) { ByteArray(chunksize + 2) }
var start = 0
for (i in splitArray.indices) {
var length = if (start + chunksize > source.size) {
source.size - start
} else {
chunksize
}
splitArray[i] = ByteArray(length + 2)
splitArray[i][0] = i.toByte() // index
splitArray[i][1] = fragments.toByte() // total number of fragments
System.arraycopy(source, start, splitArray[i], 2, length)
start += chunksize
}
return splitArray
}
}
// fun initClassRegistration(channel: Channel, registration: Registration): Boolean {
// val details = serialization.getKryoRegistrationDetails()
// val length = details.size
// if (length > Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE) {
// // it is too large to send in a single packet
//
// // child arrays have index 0 also as their 'index' and 1 is the total number of fragments
// val fragments = divideArray(details, Serialization.CLASS_REGISTRATION_VALIDATION_FRAGMENT_SIZE)
// if (fragments == null) {
// logger.error("Too many classes have been registered for Serialization. Please report this issue")
// return false
// }
// val allButLast = fragments.size - 1
// for (i in 0 until allButLast) {
// val fragment = fragments[i]
// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey())
// fragmentedRegistration.payload = fragment
//
// // tell the server we are fragmented
// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// fragmentedRegistration.upgraded = true
// channel.writeAndFlush(fragmentedRegistration)
// }
//
// // now tell the server we are done with the fragments
// val fragmentedRegistration = Registration.hello(registration.oneTimePad, config.settingsStore.getPublicKey())
// fragmentedRegistration.payload = fragments[allButLast]
//
// // tell the server we are fragmented
// fragmentedRegistration.upgradeType = UpgradeType.FRAGMENTED
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// fragmentedRegistration.upgraded = true
// channel.writeAndFlush(fragmentedRegistration)
// } else {
// registration.payload = details
//
// // tell the server we are upgraded (it will bounce back telling us to connect)
// registration.upgraded = true
// channel.writeAndFlush(registration)
// }
// return true
// }

View File

@ -1,242 +0,0 @@
/* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* - Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,
* BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
* SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import dorkbox.network.connection.Connection;
import dorkbox.network.connection.Listener;
import dorkbox.network.connection.Listeners;
import dorkbox.network.serialization.Serialization;
import dorkbox.util.exceptions.SecurityException;
public
class MultipleThreadTest extends BaseTest {
private final Object lock = new Object();
private volatile boolean stillRunning = false;
private final Object finalRunLock = new Object();
private volatile boolean finalStillRunning = false;
private final int messageCount = 150;
private final int threadCount = 15;
private final int clientCount = 13;
private final List<Client> clients = new ArrayList<Client>(this.clientCount);
int perClientReceiveTotal = (this.messageCount * this.threadCount);
int serverReceiveTotal = perClientReceiveTotal * this.clientCount;
AtomicInteger sent = new AtomicInteger(0);
AtomicInteger totalClientReceived = new AtomicInteger(0);
AtomicInteger receivedServer = new AtomicInteger(1);
ConcurrentHashMap<Integer, DataClass> sentStringsToClientDebug = new ConcurrentHashMap<Integer, DataClass>();
@Test
public
void multipleThreads() throws SecurityException, IOException {
// our clients should receive messageCount * threadCount * clientCount TOTAL messages
final int totalClientReceivedCountExpected = this.clientCount * this.messageCount * this.threadCount;
final int totalServerReceivedCountExpected = this.clientCount * this.messageCount;
System.err.println("CLIENT RECEIVES: " + totalClientReceivedCountExpected);
System.err.println("SERVER RECEIVES: " + totalServerReceivedCountExpected);
Configuration configuration = new Configuration();
configuration.tcpPort = tcpPort;
configuration.host = host;
configuration.serialization.register(String[].class);
configuration.serialization.register(DataClass.class);
final Server server = new Server(configuration);
server.disableRemoteKeyValidation();
addEndPoint(server);
server.bind(false);
final Listeners listeners = server.listeners();
listeners.add(new Listener.OnConnected<Connection>() {
@Override
public
void connected(final Connection connection) {
System.err.println("Client connected to server.");
// kickoff however many threads we need, and send data to the client.
for (int i = 1; i <= MultipleThreadTest.this.threadCount; i++) {
final int index = i;
new Thread() {
@Override
public
void run() {
for (int i = 1; i <= MultipleThreadTest.this.messageCount; i++) {
int incrementAndGet = MultipleThreadTest.this.sent.getAndIncrement();
DataClass dataClass = new DataClass("Server -> client. Thread #" + index + " message# " + incrementAndGet,
incrementAndGet);
//System.err.println(dataClass.data);
MultipleThreadTest.this.sentStringsToClientDebug.put(incrementAndGet, dataClass);
connection.send()
.TCP(dataClass)
.flush();
}
}
}.start();
}
}
});
listeners.add(new Listener.OnMessageReceived<Connection, DataClass>() {
@Override
public
void received(Connection connection, DataClass object) {
int incrementAndGet = MultipleThreadTest.this.receivedServer.getAndIncrement();
//System.err.println("server #" + incrementAndGet);
if (incrementAndGet % MultipleThreadTest.this.messageCount == 0) {
System.err.println("Server receive DONE for client " + incrementAndGet);
stillRunning = false;
synchronized (MultipleThreadTest.this.lock) {
MultipleThreadTest.this.lock.notifyAll();
}
}
if (incrementAndGet == totalServerReceivedCountExpected) {
System.err.println("Server DONE: " + incrementAndGet);
finalStillRunning = false;
synchronized (MultipleThreadTest.this.finalRunLock) {
MultipleThreadTest.this.finalRunLock.notifyAll();
}
}
}
});
// ----
finalStillRunning = true;
for (int i = 1; i <= this.clientCount; i++) {
final int index = i;
Client client = new Client(configuration);
this.clients.add(client);
addEndPoint(client);
client.listeners()
.add(new Listener.OnMessageReceived<Connection, DataClass>() {
final int clientIndex = index;
final AtomicInteger received = new AtomicInteger(1);
@Override
public
void received(Connection connection, DataClass object) {
totalClientReceived.getAndIncrement();
int clientLocalCounter = this.received.getAndIncrement();
MultipleThreadTest.this.sentStringsToClientDebug.remove(object.index);
//System.err.println(object.data);
// we finished!!
if (clientLocalCounter == perClientReceiveTotal) {
//System.err.println("Client #" + clientIndex + " received " + clientLocalCounter + " Sending back " +
// MultipleThreadTest.this.messageCount + " messages.");
// now spam back messages!
for (int i = 0; i < MultipleThreadTest.this.messageCount; i++) {
connection.send()
.TCP(new DataClass("Client #" + clientIndex + " -> Server message " + i, index));
}
}
}
});
stillRunning = true;
client.connect(5000);
while (stillRunning) {
synchronized (this.lock) {
try {
this.lock.wait(5 * 1000); // 5 secs
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
while (finalStillRunning) {
synchronized (this.finalRunLock) {
try {
this.finalRunLock.wait(5 * 1000); // 5 secs
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
// CLIENT will wait until it's done connecting, but SERVER is async.
// the ONLY way to safely work in the server is with LISTENERS. Everything else can FAIL, because of it's async nature.
if (!this.sentStringsToClientDebug.isEmpty()) {
System.err.println("MISSED DATA: " + this.sentStringsToClientDebug.size());
for (Map.Entry<Integer, DataClass> i : this.sentStringsToClientDebug.entrySet()) {
System.err.println(i.getKey() + " : " + i.getValue().data);
}
}
stopEndPoints();
assertEquals(totalClientReceivedCountExpected, totalClientReceived.get());
// offset by 1 since we start at 1
assertEquals(totalServerReceivedCountExpected, receivedServer.get()-1);
}
public static
class DataClass {
public String data;
public Integer index;
public
DataClass() {
}
public
DataClass(String data, Integer index) {
this.data = data;
this.index = index;
}
}
}

View File

@ -1,326 +0,0 @@
/* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* - Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,
* BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
* SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
import static org.junit.Assert.fail;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import dorkbox.network.connection.Connection;
import dorkbox.network.connection.Listener;
import dorkbox.network.connection.Listeners;
import dorkbox.network.serialization.Serialization;
import dorkbox.util.exceptions.SecurityException;
import dorkbox.util.serialization.SerializationManager;
public
class PingPongLocalTest extends BaseTest {
int tries = 10000;
private volatile String fail;
@Test
public void pingPongLocal() throws SecurityException, IOException {
this.fail = "Data not received.";
final Data dataLOCAL = new Data();
populateData(dataLOCAL);
Configuration configuration = Configuration.localOnly();
register(configuration.serialization);
Server server = new Server(configuration);
addEndPoint(server);
server.bind(false);
final Listeners listeners = server.listeners();
listeners.add(new Listener.OnError<Connection>() {
@Override
public
void error(Connection connection, Throwable throwable) {
PingPongLocalTest.this.fail = "Error during processing. " + throwable;
}
});
listeners.add(new Listener.OnMessageReceived<Connection, Data>() {
@Override
public
void received(Connection connection, Data data) {
connection.id();
if (!data.equals(dataLOCAL)) {
PingPongLocalTest.this.fail = "data is not equal on server.";
throw new RuntimeException("Fail! " + PingPongLocalTest.this.fail);
}
connection.send()
.TCP(data);
}
});
// ----
Client client = new Client(configuration);
addEndPoint(client);
final Listeners listeners1 = client.listeners();
listeners1.add(new Listener.OnConnected<Connection>() {
@Override
public
void connected(Connection connection) {
PingPongLocalTest.this.fail = null;
connection.send()
.TCP(dataLOCAL);
// connection.sendUDP(dataUDP); // TCP and UDP are the same for a local channel.
}
});
listeners1.add(new Listener.OnError<Connection>() {
@Override
public
void error(Connection connection, Throwable throwable) {
PingPongLocalTest.this.fail = "Error during processing. " + throwable;
System.err.println(PingPongLocalTest.this.fail);
}
});
listeners1.add(new Listener.OnMessageReceived<Connection, Data>() {
AtomicInteger check = new AtomicInteger(0);
@Override
public
void received(Connection connection, Data data) {
if (!data.equals(dataLOCAL)) {
PingPongLocalTest.this.fail = "data is not equal on client.";
throw new RuntimeException("Fail! " + PingPongLocalTest.this.fail);
}
if (this.check.getAndIncrement() <= PingPongLocalTest.this.tries) {
connection.send()
.TCP(data);
}
else {
System.err.println("Ran LOCAL " + PingPongLocalTest.this.tries + " times");
stopEndPoints();
}
}
});
client.connect(5000);
waitForThreads();
if (this.fail != null) {
fail(this.fail);
}
}
private void populateData(Data data) {
StringBuilder buffer = new StringBuilder();
for (int i = 0; i < 3000; i++) {
buffer.append('a');
}
data.string = buffer.toString();
data.strings = new String[] {"abcdefghijklmnopqrstuvwxyz0123456789","",null,"!@#$","<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"};
data.ints = new int[] {-1234567,1234567,-1,0,1,Integer.MAX_VALUE,Integer.MIN_VALUE};
data.shorts = new short[] {-12345,12345,-1,0,1,Short.MAX_VALUE,Short.MIN_VALUE};
data.floats = new float[] {0,-0,1,-1,123456,-123456,0.1f,0.2f,-0.3f,(float) Math.PI,Float.MAX_VALUE,
Float.MIN_VALUE};
data.doubles = new double[] {0,-0,1,-1,123456,-123456,0.1d,0.2d,-0.3d,Math.PI,Double.MAX_VALUE,Double.MIN_VALUE};
data.longs = new long[] {0,-0,1,-1,123456,-123456,99999999999l,-99999999999l,Long.MAX_VALUE,Long.MIN_VALUE};
data.bytes = new byte[] {-123,123,-1,0,1,Byte.MAX_VALUE,Byte.MIN_VALUE};
data.chars = new char[] {32345,12345,0,1,63,Character.MAX_VALUE,Character.MIN_VALUE};
data.booleans = new boolean[] {true,false};
data.Ints = new Integer[] {-1234567,1234567,-1,0,1,Integer.MAX_VALUE,Integer.MIN_VALUE};
data.Shorts = new Short[] {-12345,12345,-1,0,1,Short.MAX_VALUE,Short.MIN_VALUE};
data.Floats = new Float[] {0f,-0f,1f,-1f,123456f,-123456f,0.1f,0.2f,-0.3f,(float) Math.PI,Float.MAX_VALUE,
Float.MIN_VALUE};
data.Doubles = new Double[] {0d,-0d,1d,-1d,123456d,-123456d,0.1d,0.2d,-0.3d,Math.PI,Double.MAX_VALUE,
Double.MIN_VALUE};
data.Longs = new Long[] {0l,-0l,1l,-1l,123456l,-123456l,99999999999l,-99999999999l,Long.MAX_VALUE,
Long.MIN_VALUE};
data.Bytes = new Byte[] {-123,123,-1,0,1,Byte.MAX_VALUE,Byte.MIN_VALUE};
data.Chars = new Character[] {32345,12345,0,1,63,Character.MAX_VALUE,Character.MIN_VALUE};
data.Booleans = new Boolean[] {true,false};
}
private void register(SerializationManager manager) {
manager.register(int[].class);
manager.register(short[].class);
manager.register(float[].class);
manager.register(double[].class);
manager.register(long[].class);
manager.register(byte[].class);
manager.register(char[].class);
manager.register(boolean[].class);
manager.register(String[].class);
manager.register(Integer[].class);
manager.register(Short[].class);
manager.register(Float[].class);
manager.register(Double[].class);
manager.register(Long[].class);
manager.register(Byte[].class);
manager.register(Character[].class);
manager.register(Boolean[].class);
manager.register(Data.class);
}
static public class Data {
public String string;
public String[] strings;
public int[] ints;
public short[] shorts;
public float[] floats;
public double[] doubles;
public long[] longs;
public byte[] bytes;
public char[] chars;
public boolean[] booleans;
public Integer[] Ints;
public Short[] Shorts;
public Float[] Floats;
public Double[] Doubles;
public Long[] Longs;
public Byte[] Bytes;
public Character[] Chars;
public Boolean[] Booleans;
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(this.Booleans);
result = prime * result + Arrays.hashCode(this.Bytes);
result = prime * result + Arrays.hashCode(this.Chars);
result = prime * result + Arrays.hashCode(this.Doubles);
result = prime * result + Arrays.hashCode(this.Floats);
result = prime * result + Arrays.hashCode(this.Ints);
result = prime * result + Arrays.hashCode(this.Longs);
result = prime * result + Arrays.hashCode(this.Shorts);
result = prime * result + Arrays.hashCode(this.booleans);
result = prime * result + Arrays.hashCode(this.bytes);
result = prime * result + Arrays.hashCode(this.chars);
result = prime * result + Arrays.hashCode(this.doubles);
result = prime * result + Arrays.hashCode(this.floats);
result = prime * result + Arrays.hashCode(this.ints);
result = prime * result + Arrays.hashCode(this.longs);
result = prime * result + Arrays.hashCode(this.shorts);
result = prime * result + (this.string == null ? 0 : this.string.hashCode());
result = prime * result + Arrays.hashCode(this.strings);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
Data other = (Data) obj;
if (!Arrays.equals(this.Booleans, other.Booleans)) {
return false;
}
if (!Arrays.equals(this.Bytes, other.Bytes)) {
return false;
}
if (!Arrays.equals(this.Chars, other.Chars)) {
return false;
}
if (!Arrays.equals(this.Doubles, other.Doubles)) {
return false;
}
if (!Arrays.equals(this.Floats, other.Floats)) {
return false;
}
if (!Arrays.equals(this.Ints, other.Ints)) {
return false;
}
if (!Arrays.equals(this.Longs, other.Longs)) {
return false;
}
if (!Arrays.equals(this.Shorts, other.Shorts)) {
return false;
}
if (!Arrays.equals(this.booleans, other.booleans)) {
return false;
}
if (!Arrays.equals(this.bytes, other.bytes)) {
return false;
}
if (!Arrays.equals(this.chars, other.chars)) {
return false;
}
if (!Arrays.equals(this.doubles, other.doubles)) {
return false;
}
if (!Arrays.equals(this.floats, other.floats)) {
return false;
}
if (!Arrays.equals(this.ints, other.ints)) {
return false;
}
if (!Arrays.equals(this.longs, other.longs)) {
return false;
}
if (!Arrays.equals(this.shorts, other.shorts)) {
return false;
}
if (this.string == null) {
if (other.string != null) {
return false;
}
} else if (!this.string.equals(other.string)) {
return false;
}
if (!Arrays.equals(this.strings, other.strings)) {
return false;
}
return true;
}
@Override
public String toString() {
return "Data";
}
}
}

View File

@ -1,213 +0,0 @@
/*
* Copyright 2014 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.other
import com.conversantmedia.util.concurrent.MultithreadConcurrentQueue
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.minlog.Log
import dorkbox.network.serialization.ClassRegistration
import dorkbox.network.serialization.ClassRegistration0
import dorkbox.network.serialization.ClassRegistration1
import dorkbox.network.serialization.ClassRegistration2
import dorkbox.network.serialization.ClassRegistration3
import dorkbox.network.serialization.KryoExtra
import dorkbox.util.serialization.SerializationDefaults
import kotlinx.atomicfu.atomic
class PooledSerialization {
companion object {
init {
Log.set(Log.LEVEL_ERROR)
}
}
private var initialized = atomic(false)
private val classesToRegister = mutableListOf<ClassRegistration>()
private var kryoPoolSize = 16
private val kryoInUse = atomic(0)
@Volatile
private var kryoPool = MultithreadConcurrentQueue<KryoExtra>(kryoPoolSize)
/**
* If you customize anything, you will want to register custom types before init() is called!
*/
fun init() {
// NOTE: there are problems if our serializer is THE SAME serializer used by the network stack!
// We are explicitly differet types to prevent that form happening
initialized.value = true
}
private fun initKryo(): KryoExtra {
val kryo = KryoExtra()
SerializationDefaults.register(kryo)
classesToRegister.forEach { registration ->
registration.register(kryo)
}
return kryo
}
/**
* Registers the class using the lowest, next available integer ID and the [default serializer][Kryo.getDefaultSerializer].
* If the class is already registered, the existing entry is updated with the new serializer.
*
*
* Registering a primitive also affects the corresponding primitive wrapper.
*
* Because the ID assigned is affected by the IDs registered before it, the order classes are registered is important when using this
* method.
*
* The order must be the same at deserialization as it was for serialization.
*
* This must happen before the creation of the client/server
*/
fun <T> register(clazz: Class<T>): PooledSerialization {
require(!initialized.value) { "Serialization 'register(class)' cannot happen after initialization!" }
// The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather
// with object types... EVEN IF THERE IS A SERIALIZER
require(!clazz.isInterface) { "Cannot register '${clazz}' with specified ID for serialization. It must be an implementation." }
classesToRegister.add(ClassRegistration3(clazz))
return this
}
/**
* Registers the class using the specified ID. If the ID is already in use by the same type, the old entry is overwritten. If the ID
* is already in use by a different type, an exception is thrown.
*
*
* Registering a primitive also affects the corresponding primitive wrapper.
*
* IDs must be the same at deserialization as they were for serialization.
*
* This must happen before the creation of the client/server
*
* @param id Must be >= 0. Smaller IDs are serialized more efficiently. IDs 0-8 are used by default for primitive types and String, but
* these IDs can be repurposed.
*/
fun <T> register(clazz: Class<T>, id: Int): PooledSerialization {
require(!initialized.value) { "Serialization 'register(Class, int)' cannot happen after initialization!" }
// The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather
// with object types... EVEN IF THERE IS A SERIALIZER
require(!clazz.isInterface) { "Cannot register '${clazz}' with specified ID for serialization. It must be an implementation." }
classesToRegister.add(ClassRegistration1(clazz, id))
return this
}
/**
* Registers the class using the lowest, next available integer ID and the specified serializer. If the class is already registered,
* the existing entry is updated with the new serializer.
*
*
* Registering a primitive also affects the corresponding primitive wrapper.
*
*
* Because the ID assigned is affected by the IDs registered before it, the order classes are registered is important when using this
* method. The order must be the same at deserialization as it was for serialization.
*/
@Synchronized
fun <T> register(clazz: Class<T>, serializer: Serializer<T>): PooledSerialization {
require(!initialized.value) { "Serialization 'register(Class, Serializer)' cannot happen after initialization!" }
// The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather
// with object types... EVEN IF THERE IS A SERIALIZER
require(!clazz.isInterface) { "Cannot register '${clazz.name}' with a serializer. It must be an implementation." }
classesToRegister.add(ClassRegistration0(clazz, serializer))
return this
}
/**
* Registers the class using the specified ID and serializer. If the ID is already in use by the same type, the old entry is
* overwritten. If the ID is already in use by a different type, an exception is thrown.
*
*
* Registering a primitive also affects the corresponding primitive wrapper.
*
*
* IDs must be the same at deserialization as they were for serialization.
*
* @param id Must be >= 0. Smaller IDs are serialized more efficiently. IDs 0-8 are used by default for primitive types and String, but
* these IDs can be repurposed.
*/
@Synchronized
fun <T> register(clazz: Class<T>, serializer: Serializer<T>, id: Int): PooledSerialization {
require(!initialized.value) { "Serialization 'register(Class, Serializer, int)' cannot happen after initialization!" }
// The reason it must be an implementation, is because the reflection serializer DOES NOT WORK with field types, but rather
// with object types... EVEN IF THERE IS A SERIALIZER
require(!clazz.isInterface) { "Cannot register '${clazz.name}'. It must be an implementation." }
classesToRegister.add(ClassRegistration2(clazz, serializer, id))
return this
}
/**
* @return takes a kryo instance from the pool, or creates one if the pool was empty
*/
fun takeKryo(): KryoExtra {
kryoInUse.getAndIncrement()
// ALWAYS get as many as needed. We recycle them (with an auto-growing pool) to prevent too many getting created
return kryoPool.poll() ?: initKryo()
}
/**
* Returns a kryo instance to the pool for re-use later on
*/
fun returnKryo(kryo: KryoExtra) {
val kryoCount = kryoInUse.getAndDecrement()
if (kryoCount > kryoPoolSize) {
// this is CLEARLY a problem, as we have more kryos in use that our pool can support.
// This happens when we send messages REALLY fast.
//
// We fix this by increasing the size of the pool, so kryos aren't thrown away (and create a GC hit)
synchronized(kryoInUse) {
// we have a double check here on purpose. only 1 will work
if (kryoCount > kryoPoolSize) {
val oldPool = kryoPool
val oldSize = kryoPoolSize
val newSize = kryoPoolSize * 2
kryoPoolSize = newSize
kryoPool = MultithreadConcurrentQueue<KryoExtra>(kryoPoolSize)
// take all of the old kryos and put them in the new one
val array = arrayOfNulls<KryoExtra>(oldSize)
val count = oldPool.remove(array)
for (i in 0 until count) {
kryoPool.offer(array[i])
}
}
}
}
kryoPool.offer(kryo)
}
}

View File

@ -1,296 +0,0 @@
/* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* - Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,
* BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
* SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import dorkbox.network.connection.Connection;
import dorkbox.network.connection.Listener;
import dorkbox.network.connection.Listeners;
import dorkbox.util.exceptions.SecurityException;
// NOTE: UDP is unreliable, EVEN ON LOOPBACK! So this can fail with UDP. TCP will never fail.
public
class ReconnectTest extends BaseTest {
private final AtomicInteger receivedCount = new AtomicInteger(0);
private static final Logger logger = LoggerFactory.getLogger(ReconnectTest.class.getSimpleName());
@Test
public
void socketReuseUDP() throws IOException, SecurityException {
socketReuse(false, true);
}
@Test
public
void socketReuseTCP() throws IOException, SecurityException {
socketReuse(true, false);
}
@Test
public
void socketReuseTCPUDP() throws IOException, SecurityException {
socketReuse(true, true);
}
private
void socketReuse(final boolean useTCP, final boolean useUDP) throws SecurityException, IOException {
receivedCount.set(0);
Configuration configuration = new Configuration();
configuration.host = host;
if (useTCP) {
configuration.tcpPort = tcpPort;
}
if (useUDP) {
configuration.udpPort = udpPort;
}
AtomicReference<CountDownLatch> latch = new AtomicReference<CountDownLatch>();
Server server = new Server(configuration);
addEndPoint(server);
final Listeners listeners = server.listeners();
listeners.add(new Listener.OnConnected<Connection>() {
@Override
public
void connected(Connection connection) {
if (useTCP) {
connection.send()
.TCP("-- TCP from server");
}
if (useUDP) {
connection.send()
.UDP("-- UDP from server");
}
}
});
listeners.add(new Listener.OnMessageReceived<Connection, String>() {
@Override
public
void received(Connection connection, String object) {
int incrementAndGet = ReconnectTest.this.receivedCount.incrementAndGet();
logger.error("----- <S " + connection + "> " + incrementAndGet + " : " + object);
latch.get().countDown();
}
});
server.bind(false);
// ----
Client client = new Client(configuration);
addEndPoint(client);
final Listeners listeners1 = client.listeners();
listeners1.add(new Listener.OnConnected<Connection>() {
@Override
public
void connected(Connection connection) {
if (useTCP) {
connection.send()
.TCP("-- TCP from client");
}
if (useUDP) {
connection.send()
.UDP("-- UDP from client");
}
}
});
listeners1.add(new Listener.OnMessageReceived<Connection, String>() {
@Override
public
void received(Connection connection, String object) {
int incrementAndGet = ReconnectTest.this.receivedCount.incrementAndGet();
logger.error("----- <C " + connection + "> " + incrementAndGet + " : " + object);
latch.get().countDown();
}
});
int latchCount = 2;
int count = 100;
int initialCount = 2;
if (useTCP && useUDP) {
initialCount += 2;
latchCount += 2;
}
try {
for (int i = 1; i < count + 1; i++) {
logger.error(".....");
latch.set(new CountDownLatch(latchCount));
try {
client.connect(5000);
} catch (IOException e) {
e.printStackTrace();
}
int retryCount = 20;
int lastRetryCount;
int target = i * initialCount;
boolean failed = false;
synchronized (receivedCount) {
while (this.receivedCount.get() != target) {
lastRetryCount = this.receivedCount.get();
try {
latch.get().await(1, TimeUnit.SECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
// check to see if we changed at all...
if (lastRetryCount == this.receivedCount.get()) {
if (retryCount-- < 0) {
logger.error("Aborting unit test... wrong count!");
if (useUDP) {
// If TCP and UDP both fill the pipe, THERE WILL BE FRAGMENTATION and dropped UDP packets!
// it results in severe UDP packet loss and contention.
//
// http://www.isoc.org/INET97/proceedings/F3/F3_1.HTM
// also, a google search on just "INET97/proceedings/F3/F3_1.HTM" turns up interesting problems.
// Usually it's with ISPs.
logger.error("NOTE: UDP can fail, even on loopback! See: http://www.isoc.org/INET97/proceedings/F3/F3_1.HTM");
}
failed = true;
break;
}
} else {
retryCount = 20;
}
}
}
client.close();
logger.error(".....");
if (failed) {
break;
}
}
int specified = count * initialCount;
int received = this.receivedCount.get();
if (specified != received) {
logger.error("NOTE: UDP can fail, even on loopback! See: http://www.isoc.org/INET97/proceedings/F3/F3_1.HTM");
}
assertEquals(specified, received);
} finally {
stopEndPoints();
waitForThreads(10);
}
}
@Test
public
void localReuse() throws SecurityException, IOException {
receivedCount.set(0);
Server server = new Server();
addEndPoint(server);
server.listeners()
.add(new Listener.OnConnected<Connection>() {
@Override
public
void connected(Connection connection) {
connection.send()
.self("-- LOCAL from server");
}
});
server.listeners()
.add(new Listener.OnMessageReceived<Connection, String>() {
@Override
public
void received(Connection connection, String object) {
int incrementAndGet = ReconnectTest.this.receivedCount.incrementAndGet();
System.out.println("----- <S " + connection + "> " + incrementAndGet + " : " + object);
}
});
// ----
Client client = new Client();
addEndPoint(client);
client.listeners()
.add(new Listener.OnConnected<Connection>() {
@Override
public
void connected(Connection connection) {
connection.send()
.self("-- LOCAL from client");
}
});
client.listeners()
.add(new Listener.OnMessageReceived<Connection, String>() {
@Override
public
void received(Connection connection, String object) {
int incrementAndGet = ReconnectTest.this.receivedCount.incrementAndGet();
System.out.println("----- <C " + connection + "> " + incrementAndGet + " : " + object);
}
});
server.bind(false);
int count = 10;
for (int i = 1; i < count + 1; i++) {
client.connect(5000);
int target = i * 2;
while (this.receivedCount.get() != target) {
System.out.println("----- Waiting...");
try {
Thread.sleep(100);
} catch (InterruptedException ignored) {
}
}
client.close();
}
assertEquals(count * 2, this.receivedCount.get());
stopEndPoints();
waitForThreads(10);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2018 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,3 +13,4 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
rootProject.name = "Network"

File diff suppressed because it is too large Load Diff

View File

@ -19,9 +19,6 @@ package dorkbox.network
import dorkbox.netUtil.IPv4
import dorkbox.netUtil.IPv6
import dorkbox.network.aeron.CoroutineBackoffIdleStrategy
import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.aeron.CoroutineSleepingMillisIdleStrategy
import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.serialization.Serialization
@ -32,12 +29,11 @@ import io.aeron.driver.Configuration
import io.aeron.driver.ThreadingMode
import io.aeron.driver.exceptions.InvalidChannelException
import io.aeron.exceptions.DriverTimeoutException
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import mu.KLogger
import mu.KotlinLogging
import org.agrona.SystemUtil
import org.agrona.concurrent.AgentTerminationException
import org.agrona.concurrent.BackoffIdleStrategy
import org.agrona.concurrent.IdleStrategy
import org.slf4j.Logger
import org.slf4j.helpers.NOPLogger
import java.io.File
import java.net.BindException
@ -45,13 +41,6 @@ import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.*
class ServerConfiguration : dorkbox.network.Configuration() {
companion object {
/**
* Gets the version number.
*/
const val version = "6.4"
}
/**
* The address for the server to listen on. "*" will accept connections from all interfaces, otherwise specify
* the hostname (or IP) to bind to.
@ -72,7 +61,7 @@ class ServerConfiguration : dorkbox.network.Configuration() {
}
/**
* The maximum number of client connection allowed per IP address. IPC is unlimited
* The maximum number of client connection allowed per IP address, Default is unlimited and IPC is always unlimited
*/
var maxConnectionsPerIpAddress = 0
set(value) {
@ -80,6 +69,18 @@ class ServerConfiguration : dorkbox.network.Configuration() {
field = value
}
/**
* If a connection is in a temporal state (in the middle of a reconnect) and a buffered connection is in use -- then how long should we consider
* a new connection from the same client as part of the same "session".
*
* The session timeout cannot be shorter than 60 seconds, and the server will send this configuration to the client
*/
var bufferedConnectionTimeoutSeconds = TimeUnit.MINUTES.toSeconds(2)
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Allows the user to change how endpoint settings and public key information are saved.
*/
@ -89,7 +90,6 @@ class ServerConfiguration : dorkbox.network.Configuration() {
field = value
}
/**
* Validates the current configuration
*/
@ -108,7 +108,7 @@ class ServerConfiguration : dorkbox.network.Configuration() {
require(listenIpAddress.isNotBlank()) { "Blank listen IP address, cannot continue." }
}
override fun initialize(logger: KLogger): dorkbox.network.ServerConfiguration {
override fun initialize(logger: Logger): dorkbox.network.ServerConfiguration {
return super.initialize(logger) as dorkbox.network.ServerConfiguration
}
@ -118,6 +118,7 @@ class ServerConfiguration : dorkbox.network.Configuration() {
config.listenIpAddress = listenIpAddress
config.maxClientCount = maxClientCount
config.maxConnectionsPerIpAddress = maxConnectionsPerIpAddress
config.bufferedConnectionTimeoutSeconds = bufferedConnectionTimeoutSeconds
config.settingsStore = settingsStore
super.copy(config)
@ -133,6 +134,7 @@ class ServerConfiguration : dorkbox.network.Configuration() {
if (listenIpAddress != other.listenIpAddress) return false
if (maxClientCount != other.maxClientCount) return false
if (maxConnectionsPerIpAddress != other.maxConnectionsPerIpAddress) return false
if (bufferedConnectionTimeoutSeconds != other.bufferedConnectionTimeoutSeconds) return false
if (settingsStore != other.settingsStore) return false
return true
@ -143,6 +145,7 @@ class ServerConfiguration : dorkbox.network.Configuration() {
result = 31 * result + listenIpAddress.hashCode()
result = 31 * result + maxClientCount
result = 31 * result + maxConnectionsPerIpAddress
result = 31 * result + bufferedConnectionTimeoutSeconds.hashCode()
result = 31 * result + settingsStore.hashCode()
return result
}
@ -164,6 +167,17 @@ class ClientConfiguration : dorkbox.network.Configuration() {
field = value
}
/**
* The tag name to be assigned to this connection and the server will receive this tag name during the handshake.
* The max length is 32 characters.
*/
var tag: String = ""
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Validates the current configuration. Throws an exception if there are problems.
*/
@ -172,15 +186,16 @@ class ClientConfiguration : dorkbox.network.Configuration() {
super.validate()
// have to do some basic validation of our configuration
if (port != -1) {
// this means it was configured!
require(port > 0) { "Client listen port must be > 0" }
require(port < 65535) { "Client listen port must be < 65535" }
}
require(tag.length <= 32) { "Client tag name length must be <= 32" }
}
override fun initialize(logger: KLogger): dorkbox.network.ClientConfiguration {
override fun initialize(logger: Logger): dorkbox.network.ClientConfiguration {
return super.initialize(logger) as dorkbox.network.ClientConfiguration
}
@ -189,6 +204,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
super.copy(config)
config.port = port
config.tag = tag
return config
}
@ -199,6 +215,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
if (!super.equals(other)) return false
if (port != other.port) return false
if (tag != other.tag) return false
return true
}
@ -206,6 +223,7 @@ class ClientConfiguration : dorkbox.network.Configuration() {
override fun hashCode(): Int {
var result = super.hashCode()
result = 31 * result + port.hashCode()
result = 31 * result + tag.hashCode()
return result
}
}
@ -213,24 +231,24 @@ class ClientConfiguration : dorkbox.network.Configuration() {
abstract class Configuration protected constructor() {
@OptIn(ExperimentalCoroutinesApi::class)
companion object {
internal val NOP_LOGGER = KotlinLogging.logger(NOPLogger.NOP_LOGGER)
/**
* Gets the version number.
*/
const val version = "6.15"
internal val NOP_LOGGER = NOPLogger.NOP_LOGGER
internal const val errorMessage = "Cannot set a property after the configuration context has been created!"
private val appIdRegexString = Regex("a-zA-Z0-9_.-")
private val appIdRegex = Regex("^[$appIdRegexString]+$")
@Volatile
private var alreadyShownTempFsTips = false
internal val networkThreadGroup = ThreadGroup("Network")
internal val aeronThreadFactory = NamedThreadFactory( "Aeron", networkThreadGroup, true)
const val UDP_HANDSHAKE_STREAM_ID: Int = 0x1337cafe // 322423550
const val IPC_HANDSHAKE_STREAM_ID: Int = 0x1337c0de // 322420958
private val defaultMessageCoroutineScope = Dispatchers.Default
private val defaultAeronFilter: (error: Throwable) -> Boolean = { error ->
// we suppress these because they are already handled
when {
@ -253,7 +271,7 @@ abstract class Configuration protected constructor() {
/**
* Depending on the OS, different base locations for the Aeron log directory are preferred.
*/
fun defaultAeronLogLocation(logger: KLogger = NOP_LOGGER): File {
fun defaultAeronLogLocation(logger: Logger = NOP_LOGGER): File {
return when {
OS.isMacOsX -> {
// does the recommended location exist??
@ -263,16 +281,79 @@ abstract class Configuration protected constructor() {
if (suggestedLocation.exists()) {
suggestedLocation
}
else {
if (logger !== NOP_LOGGER) {
if (!alreadyShownTempFsTips) {
alreadyShownTempFsTips = true
logger.info(
"It is recommended to create a RAM drive for best performance. For example\n" + "\$ diskutil erasevolume HFS+ \"DevShm\" `hdiutil attach -nomount ram://\$((2048 * 2048))`"
)
}
}
else if (logger !== NOP_LOGGER) {
// don't ALWAYS create it!
/*
* Note: Since Mac OS does not have a built-in support for /dev/shm, we automatically create a RAM disk for the Aeron directory (aeron.dir).
*
* You can create a RAM disk with the following command:
*
* $ diskutil erasevolume APFS "DISK_NAME" `hdiutil attach -nomount ram://$((SIZE_IN_MB * 2048))`
*
* where:
*
* DISK_NAME should be replaced with a name of your choice.
* SIZE_IN_MB is the size in megabytes for the disk (e.g. 4096 for a 4GB disk).
*
* For example, the following command creates a RAM disk named DevShm which is 8GB in size:
*
* $ diskutil erasevolume APFS "DevShm" `hdiutil attach -nomount ram://$((8 * 1024 * 2048))`
*
* After this command is executed the new disk will be mounted under /Volumes/DevShm.
*/
val sizeInGB = 4
// on macos, we cannot rely on users to actually create this -- so we automatically do it for them.
logger.info("Creating a $sizeInGB GB RAM drive for best performance.")
// hdiutil attach -nobrowse -nomount ram://4194304
val newDevice = dorkbox.executor.Executor()
.command("hdiutil", "attach", "-nomount", "ram://${sizeInGB * 1024 * 2048}")
.destroyOnExit()
.enableRead()
.startBlocking(60, TimeUnit.SECONDS)
.output
.string().trim().also { if (logger.isTraceEnabled) { logger.trace("Created new disk: $it") } }
// diskutil apfs createContainer /dev/disk4
val lines = dorkbox.executor.Executor()
.command("diskutil", "apfs", "createContainer", newDevice)
.destroyOnExit()
.enableRead()
.startBlocking(60, TimeUnit.SECONDS)
.output
.lines().onEach { line -> logger.trace(line) }
val newDiskLine = lines[lines.lastIndex-1]
val disk = newDiskLine.substring(newDiskLine.lastIndexOf(':')+1).trim()
// diskutil apfs addVolume disk5 APFS DevShm -nomount
dorkbox.executor.Executor()
.command("diskutil", "apfs", "addVolume", disk, "APFS", "DevShm", "-nomount")
.destroyOnExit()
.enableRead()
.startBlocking(60, TimeUnit.SECONDS)
.output
.string().also { if (logger.isTraceEnabled) { logger.trace(it) } }
// diskutil mount nobrowse "DevShm"
dorkbox.executor.Executor()
.command("diskutil", "mount", "nobrowse", "DevShm")
.destroyOnExit()
.enableRead()
.startBlocking(60, TimeUnit.SECONDS)
.output
.string().also { if (logger.isTraceEnabled) { logger.trace(it) } }
// touch /Volumes/RAMDisk/.metadata_never_index
File("${suggestedLocation}/.metadata_never_index").createNewFile()
suggestedLocation
}
else {
// we don't always want to create a ram drive!
OS.TEMP_DIR
}
}
@ -287,13 +368,14 @@ abstract class Configuration protected constructor() {
}
}
/**
* Specify the application ID. This is necessary, as it prevents multiple instances of aeron from responding to applications that
* is not theirs. Because of the shared nature of aeron drivers, this is necessary.
*
* This is a human-readable string, and it MUST be configured the same for both the clint/server
*/
var applicationId = ""
var appId = ""
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -370,24 +452,10 @@ abstract class Configuration protected constructor() {
field = value
}
/**
* How long a connection must be disconnected before we cleanup the memory associated with it
*/
var connectionCloseTimeoutInSeconds: Int = 10
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* How often to check if the underlying aeron publication/subscription is connected or not.
*
* Aeron Publications and Subscriptions are, and can be, constantly in flux (because of UDP!).
*
* Too low and it's wasting CPU cycles, too high and there will be some lag when detecting if a connection has been disconnected.
*/
var connectionCheckIntervalNanos = TimeUnit.MILLISECONDS.toNanos(200)
var connectionCloseTimeoutInSeconds: Int = 60
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -401,7 +469,7 @@ abstract class Configuration protected constructor() {
*
* Too low and it's likely to get false-positives, too high and there will be some lag when detecting if a connection has been disconnected.
*/
var connectionExpirationTimoutNanos = TimeUnit.SECONDS.toNanos(2)
var connectionExpirationTimoutNanos = TimeUnit.SECONDS.toNanos(4)
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -416,27 +484,6 @@ abstract class Configuration protected constructor() {
field = value
}
/**
* Changes the default ping timeout, used to test the liveliness of a connection, specifically it's round-trip performance
*/
var pingTimeoutSeconds = 30
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Responsible for publishing messages that arrive via the network.
*
* Normally, events should be dispatched asynchronously across a thread pool, but in certain circumstances you may want to constrain this to a single thread dispatcher or other, custom dispatcher.
*/
var messageDispatch = defaultMessageCoroutineScope
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Allows the user to change how endpoint settings and public key information are saved.
*
@ -481,7 +528,7 @@ abstract class Configuration protected constructor() {
* The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and
* how much CPU should be consumed when no work is being done. There is an inherent tradeoff to consider.
*/
var pollIdleStrategy: CoroutineIdleStrategy = CoroutineBackoffIdleStrategy(maxSpins = 100, maxYields = 10, minParkPeriodMs = 1, maxParkPeriodMs = 100)
var pollIdleStrategy: IdleStrategy = BackoffIdleStrategy()
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -498,12 +545,49 @@ abstract class Configuration protected constructor() {
* The main difference in strategies is how responsive to changes should the idler be when idle for a little bit of time and
* how much CPU should be consumed when no work is being done. There is an inherent tradeoff to consider.
*/
var sendIdleStrategy: CoroutineIdleStrategy = CoroutineSleepingMillisIdleStrategy(sleepPeriodMs = 100)
var sendIdleStrategy: IdleStrategy = BackoffIdleStrategy()
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* The idle strategy used by the Aeron Media Driver to write to the network when in DEDICATED mode. Null will use the aeron defaults
*/
var senderIdleStrategy: IdleStrategy? = null
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* The idle strategy used by the Aeron Media Driver read from the network when in DEDICATED mode. Null will use the aeron defaults
*/
var receiverIdleStrategy: IdleStrategy? = null
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* The idle strategy used by the Aeron Media Driver to read/write to the network when in NETWORK_SHARED mode. Null will use the aeron defaults
*/
var sharedIdleStrategy: IdleStrategy? = null
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* The idle strategy used by the Aeron Media Driver conductor when in DEDICATED mode. Null will use the aeron defaults
*/
var conductorIdleStrategy: IdleStrategy? = null
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* ## A Media Driver, whether being run embedded or not, needs 1-3 threads to perform its operation.
*
@ -563,10 +647,8 @@ abstract class Configuration protected constructor() {
* (> 4KB) messages and for maximizing throughput above everything else. Various checks during publication and subscription/connection
* setup are done to verify a decent relationship with MTU.
*
*
* However, it is good to understand these relationships.
*
*
* The MTU on the Media Driver controls the length of the MTU of data frames. This value is communicated to the Aeron clients during
* registration. So, applications do not have to concern themselves with the MTU value used by the Media Driver and use the same value.
*
@ -574,13 +656,9 @@ abstract class Configuration protected constructor() {
* An MTU value over the interface MTU will cause IP to fragment the datagram. This may increase the likelihood of loss under several
* circumstances. If increasing the MTU over the interface MTU, consider various ways to increase the interface MTU first in preparation.
*
*
* The MTU value indicates the largest message that Aeron will send as a single data frame.
*
*
* MTU length also has implications for socket buffer sizing.
*
*
* Default value is 1408 for internet; for a LAN, 9k is possible with jumbo frames (if the routers/interfaces support it)
*/
var networkMtuSize = Configuration.MTU_LENGTH_DEFAULT
@ -589,6 +667,12 @@ abstract class Configuration protected constructor() {
field = value
}
var ipcMtuSize = Configuration.MAX_UDP_PAYLOAD_LENGTH
set(value) {
require(!contextDefined) { errorMessage }
field = value
}
/**
* Default initial window length for flow control sender to receiver purposes. This assumes a system free of pauses.
*
@ -602,7 +686,7 @@ abstract class Configuration protected constructor() {
* Buffer (10 Gps) = (10 * 1000 * 1000 * 1000 / 8) * 0.0001 = 125000 (Round to 128KB)
* Buffer (1 Gps) = (1 * 1000 * 1000 * 1000 / 8) * 0.0001 = 12500 (Round to 16KB)
*/
var initialWindowLength = SystemUtil.getSizeAsInt(Configuration.INITIAL_WINDOW_LENGTH_PROP_NAME, 16 * 1024)
var initialWindowLength = 16 * 1024
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -621,7 +705,7 @@ abstract class Configuration protected constructor() {
*
* A value of 0 will 'auto-configure' this setting
*/
var sendBufferSize = 1048576
var sendBufferSize = 0
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -638,7 +722,7 @@ abstract class Configuration protected constructor() {
*
* A value of 0 will 'auto-configure' this setting.
*/
var receiveBufferSize = 2097152
var receiveBufferSize = 0
set(value) {
require(!contextDefined) { errorMessage }
field = value
@ -716,12 +800,12 @@ abstract class Configuration protected constructor() {
open fun validate() {
// have to do some basic validation of our configuration
require(applicationId.isNotEmpty()) { "The application ID must be set, as it prevents an listener from responding to differently configured applications. This is a human-readable string, and it MUST be configured the same for both the clint/server!"}
require(appId.isNotEmpty()) { "The application ID must be set, as it prevents an listener from responding to differently configured applications. This is a human-readable string, and it MUST be configured the same for both the clint/server!"}
// The applicationID is used to create the prefix for the aeron directory -- EVEN IF the directory name is specified.
require(applicationId.length < 32) { "The application ID is too long, it must be < 32 characters" }
require(appId.length < 32) { "The application ID is too long, it must be < 32 characters" }
require(isAppIdValid(applicationId)) { "The application ID is not valid. It may only be the following characters: $appIdRegexString" }
require(isAppIdValid(appId)) { "The application ID is not valid. It may only be the following characters: $appIdRegexString" }
// can't disable everything!
require(enableIpc || enableIPv4 || enableIPv6) { "At least one of IPC/IPv4/IPv6 must be enabled!" }
@ -742,17 +826,19 @@ abstract class Configuration protected constructor() {
require(maxStreamSizeInMemoryMB <= 256) { "configuration maxStreamSizeInMemoryMB must be <= 256" } // 256 is arbitrary
require(networkMtuSize > 0) { "configuration networkMtuSize must be > 0" }
require(networkMtuSize < 9 * 1024) { "configuration networkMtuSize must be < ${9 * 1024}" }
require(networkMtuSize < Configuration.MAX_UDP_PAYLOAD_LENGTH) { "configuration networkMtuSize must be < ${Configuration.MAX_UDP_PAYLOAD_LENGTH}" }
require(ipcMtuSize > 0) { "configuration ipcMtuSize must be > 0" }
require(ipcMtuSize <= Configuration.MAX_UDP_PAYLOAD_LENGTH) { "configuration ipcMtuSize must be <= ${Configuration.MAX_UDP_PAYLOAD_LENGTH}" }
require(sendBufferSize > 0) { "configuration socket send buffer must be > 0"}
require(receiveBufferSize > 0) { "configuration socket receive buffer must be > 0"}
require(sendBufferSize >= 0) { "configuration socket send buffer must be >= 0"}
require(receiveBufferSize >= 0) { "configuration socket receive buffer must be >= 0"}
require(ipcTermBufferLength > 65535) { "configuration IPC term buffer must be > 65535"}
require(ipcTermBufferLength < 1_073_741_824) { "configuration IPC term buffer must be < 1,073,741,824"}
require(publicationTermBufferLength > 65535) { "configuration publication term buffer must be > 65535"}
require(publicationTermBufferLength < 1_073_741_824) { "configuration publication term buffer must be < 1,073,741,824"}
}
internal open fun initialize(logger: KLogger): dorkbox.network.Configuration {
internal open fun initialize(logger: Logger): dorkbox.network.Configuration {
// explicitly don't set defaults if we already have the context defined!
if (contextDefined) {
return this
@ -768,76 +854,65 @@ abstract class Configuration protected constructor() {
}
}
/*
* Linux
* Linux normally requires some settings of sysctl values. One is net.core.rmem_max to allow larger SO_RCVBUF and
* net.core.wmem_max to allow larger SO_SNDBUF values to be set.
*
* Windows
* Windows tends to use SO_SNDBUF values that are too small. It is recommended to use values more like 1MB or so.
*
* Mac/Darwin
*
* Mac tends to use SO_SNDBUF values that are too small. It is recommended to use larger values, like 16KB.
*/
if (receiveBufferSize == 0) {
receiveBufferSize = io.aeron.driver.Configuration.SOCKET_RCVBUF_LENGTH_DEFAULT
// when {
// OS.isLinux() ->
// OS.isWindows() ->
// OS.isMacOsX() ->
// }
// val rmem_max = dorkbox.network.other.NetUtil.sysctlGetInt("net.core.rmem_max")
// val wmem_max = dorkbox.network.other.NetUtil.sysctlGetInt("net.core.wmem_max")
}
// /*
// * Linux
// * Linux normally requires some settings of sysctl values. One is net.core.rmem_max to allow larger SO_RCVBUF and
// * net.core.wmem_max to allow larger SO_SNDBUF values to be set.
// *
// * Windows
// * Windows tends to use SO_SNDBUF values that are too small. It is recommended to use values more like 1MB or so.
// *
// * Mac/Darwin
// * Mac tends to use SO_SNDBUF values that are too small. It is recommended to use larger values, like 16KB.
// */
// if (receiveBufferSize == 0) {
// receiveBufferSize = io.aeron.driver.Configuration.SOCKET_RCVBUF_LENGTH_DEFAULT * 4
// // when {
// // OS.isLinux() ->
// // OS.isWindows() ->
// // OS.isMacOsX() ->
// // }
//
// // val rmem_max = dorkbox.network.other.NetUtil.sysctlGetInt("net.core.rmem_max")
// }
//
//
// if (sendBufferSize == 0) {
// sendBufferSize = io.aeron.driver.Configuration.SOCKET_SNDBUF_LENGTH_DEFAULT * 4
// // when {
// // OS.isLinux() ->
// // OS.isWindows() ->
// // OS.isMacOsX() ->
// // }
//
// val wmem_max = dorkbox.netUtil.SocketUtils.sysctlGetInt("net.core.wmem_max")
// }
if (sendBufferSize == 0) {
sendBufferSize = io.aeron.driver.Configuration.SOCKET_SNDBUF_LENGTH_DEFAULT
// when {
// OS.isLinux() ->
// OS.isWindows() ->
// OS.isMacOsX() ->
// }
// val rmem_max = dorkbox.network.other.NetUtil.sysctlGetInt("net.core.rmem_max")
// val wmem_max = dorkbox.network.other.NetUtil.sysctlGetInt("net.core.wmem_max")
}
/*
* Note: Since Mac OS does not have a built-in support for /dev/shm it is advised to create a RAM disk for the Aeron directory (aeron.dir).
*
* You can create a RAM disk with the following command:
*
* $ diskutil erasevolume HFS+ "DISK_NAME" `hdiutil attach -nomount ram://$((2048 * SIZE_IN_MB))`
*
* where:
*
* DISK_NAME should be replaced with a name of your choice.
* SIZE_IN_MB is the size in megabytes for the disk (e.g. 4096 for a 4GB disk).
*
* For example, the following command creates a RAM disk named DevShm which is 2GB in size:
*
* $ diskutil erasevolume HFS+ "DevShm" `hdiutil attach -nomount ram://$((2048 * 2048))`
*
* After this command is executed the new disk will be mounted under /Volumes/DevShm.
*/
var dir = aeronDirectory
if (forceAllowSharedAeronDriver && dir != null) {
logger.warn { "Forcing the Aeron driver to be shared between processes. THIS IS DANGEROUS!" }
} else if (dir != null) {
// we have defined an aeron directory
dir = File(dir.absolutePath + "_$applicationId")
} else {
if (dir != null) {
if (forceAllowSharedAeronDriver) {
logger.warn("Forcing the Aeron driver to be shared between processes. THIS IS DANGEROUS!")
} else if (!dir.absolutePath.endsWith(appId)) {
// we have defined an aeron directory
dir = File(dir.absolutePath + "_$appId")
}
}
else {
val baseFileLocation = defaultAeronLogLocation(logger)
val prefix = if (appId.startsWith("aeron_")) {
""
} else {
"aeron_"
}
val aeronLogDirectory = if (uniqueAeronDirectory) {
// this is incompatible with IPC, and will not be set if IPC is enabled (error will be thrown on validate)
File(baseFileLocation, "aeron_${applicationId}_${mediaDriverIdNoDir()}")
File(baseFileLocation, "$prefix${appId}_${mediaDriverIdNoDir()}")
} else {
File(baseFileLocation, "aeron_$applicationId")
File(baseFileLocation, "$prefix$appId")
}
dir = aeronLogDirectory.absoluteFile
}
@ -857,6 +932,7 @@ abstract class Configuration protected constructor() {
val threadingMode get() = config.threadingMode
val networkMtuSize get() = config.networkMtuSize
val ipcMtuSize get() = config.ipcMtuSize
val initialWindowLength get() = config.initialWindowLength
val sendBufferSize get() = config.sendBufferSize
val receiveBufferSize get() = config.receiveBufferSize
@ -870,6 +946,10 @@ abstract class Configuration protected constructor() {
val ipcTermBufferLength get() = config.ipcTermBufferLength
val publicationTermBufferLength get() = config.publicationTermBufferLength
val conductorIdleStrategy get() = config.conductorIdleStrategy
val sharedIdleStrategy get() = config.sharedIdleStrategy
val receiverIdleStrategy get() = config.receiverIdleStrategy
val senderIdleStrategy get() = config.senderIdleStrategy
val aeronErrorFilter get() = config.aeronErrorFilter
var contextDefined
@ -883,15 +963,7 @@ abstract class Configuration protected constructor() {
*/
@Suppress("DuplicatedCode")
fun validate() {
require(networkMtuSize > 0) { "configuration networkMtuSize must be > 0" }
require(networkMtuSize < 9 * 1024) { "configuration networkMtuSize must be < ${9 * 1024}" }
require(sendBufferSize > 0) { "configuration socket send buffer must be > 0"}
require(receiveBufferSize > 0) { "configuration socket receive buffer must be > 0"}
require(ipcTermBufferLength > 65535) { "configuration IPC term buffer must be > 65535"}
require(ipcTermBufferLength < 1_073_741_824) { "configuration IPC term buffer must be < 1,073,741,824"}
require(publicationTermBufferLength > 65535) { "configuration publication term buffer must be > 65535"}
require(publicationTermBufferLength < 1_073_741_824) { "configuration publication term buffer must be < 1,073,741,824"}
// already validated! do nothing.
}
/**
@ -900,7 +972,7 @@ abstract class Configuration protected constructor() {
*
* This is because configs that are DIFFERENT, but have the same values MUST use the same aeron driver.
*/
val id: Int get() {
fun mediaDriverId(): Int {
return config.mediaDriverId()
}
@ -921,6 +993,7 @@ abstract class Configuration protected constructor() {
if (connectionCloseTimeoutInSeconds != other.connectionCloseTimeoutInSeconds) return false
if (threadingMode != other.threadingMode) return false
if (networkMtuSize != other.networkMtuSize) return false
if (ipcMtuSize != other.ipcMtuSize) return false
if (initialWindowLength != other.initialWindowLength) return false
if (sendBufferSize != other.sendBufferSize) return false
if (receiveBufferSize != other.receiveBufferSize) return false
@ -933,6 +1006,11 @@ abstract class Configuration protected constructor() {
if (publicationTermBufferLength != other.publicationTermBufferLength) return false
if (aeronErrorFilter != other.aeronErrorFilter) return false
if (conductorIdleStrategy != other.conductorIdleStrategy) return false
if (sharedIdleStrategy != other.sharedIdleStrategy) return false
if (receiverIdleStrategy != other.receiverIdleStrategy) return false
if (senderIdleStrategy != other.senderIdleStrategy) return false
return true
}
}
@ -942,10 +1020,16 @@ abstract class Configuration protected constructor() {
if (forceAllowSharedAeronDriver != other.forceAllowSharedAeronDriver) return false
if (threadingMode != other.threadingMode) return false
if (networkMtuSize != other.networkMtuSize) return false
if (ipcMtuSize != other.ipcMtuSize) return false
if (initialWindowLength != other.initialWindowLength) return false
if (sendBufferSize != other.sendBufferSize) return false
if (receiveBufferSize != other.receiveBufferSize) return false
if (conductorIdleStrategy != other.conductorIdleStrategy) return false
if (sharedIdleStrategy != other.sharedIdleStrategy) return false
if (receiverIdleStrategy != other.receiverIdleStrategy) return false
if (senderIdleStrategy != other.senderIdleStrategy) return false
if (aeronDirectory != other.aeronDirectory) return false
if (uniqueAeronDirectory != other.uniqueAeronDirectory) return false
if (uniqueAeronDirectoryID != other.uniqueAeronDirectoryID) return false
@ -959,7 +1043,7 @@ abstract class Configuration protected constructor() {
abstract fun copy(): dorkbox.network.Configuration
protected fun copy(config: dorkbox.network.Configuration) {
config.applicationId = applicationId
config.appId = appId
config.forceAllowSharedAeronDriver = forceAllowSharedAeronDriver
config.enableIPv4 = enableIPv4
config.enableIPv6 = enableIPv6
@ -968,16 +1052,13 @@ abstract class Configuration protected constructor() {
config.udpId = udpId
config.enableRemoteSignatureValidation = enableRemoteSignatureValidation
config.connectionCloseTimeoutInSeconds = connectionCloseTimeoutInSeconds
config.connectionCheckIntervalNanos = connectionCheckIntervalNanos
config.connectionExpirationTimoutNanos = connectionExpirationTimoutNanos
config.isReliable = isReliable
config.pingTimeoutSeconds = pingTimeoutSeconds
config.messageDispatch = messageDispatch
config.settingsStore = settingsStore
config.serialization = serialization
config.maxStreamSizeInMemoryMB = maxStreamSizeInMemoryMB
config.pollIdleStrategy = pollIdleStrategy.clone()
config.sendIdleStrategy = sendIdleStrategy.clone()
config.pollIdleStrategy = pollIdleStrategy
config.sendIdleStrategy = sendIdleStrategy
config.threadingMode = threadingMode
config.aeronDirectory = aeronDirectory
config.uniqueAeronDirectoryID = uniqueAeronDirectoryID
@ -1002,6 +1083,7 @@ abstract class Configuration protected constructor() {
private fun mediaDriverIdNoDir(): Int {
var result = threadingMode.hashCode()
result = 31 * result + networkMtuSize
result = 31 * result + ipcMtuSize
result = 31 * result + initialWindowLength
result = 31 * result + sendBufferSize
result = 31 * result + receiveBufferSize
@ -1022,21 +1104,22 @@ abstract class Configuration protected constructor() {
// some values are defined here. Not necessary to list them twice
if (!mediaDriverEquals(other)) return false
if (applicationId != other.applicationId) return false
if (appId != other.appId) return false
if (enableIPv4 != other.enableIPv4) return false
if (enableIPv6 != other.enableIPv6) return false
if (enableIpc != other.enableIpc) return false
if (ipcId != other.ipcId) return false
if (udpId != other.udpId) return false
if (enableRemoteSignatureValidation != other.enableRemoteSignatureValidation) return false
if (connectionCloseTimeoutInSeconds != other.connectionCloseTimeoutInSeconds) return false
if (connectionCheckIntervalNanos != other.connectionCheckIntervalNanos) return false
if (connectionExpirationTimoutNanos != other.connectionExpirationTimoutNanos) return false
if (isReliable != other.isReliable) return false
if (pingTimeoutSeconds != other.pingTimeoutSeconds) return false
if (settingsStore != other.settingsStore) return false
if (serialization != other.serialization) return false
if (maxStreamSizeInMemoryMB != other.maxStreamSizeInMemoryMB) return false
if (pollIdleStrategy != other.pollIdleStrategy) return false
if (sendIdleStrategy != other.sendIdleStrategy) return false
@ -1051,7 +1134,7 @@ abstract class Configuration protected constructor() {
override fun hashCode(): Int {
var result = mediaDriverId()
result = 31 * result + applicationId.hashCode()
result = 31 * result + appId.hashCode()
result = 31 * result + forceAllowSharedAeronDriver.hashCode()
result = 31 * result + enableIPv4.hashCode()
result = 31 * result + enableIPv6.hashCode()
@ -1059,12 +1142,9 @@ abstract class Configuration protected constructor() {
result = 31 * result + enableRemoteSignatureValidation.hashCode()
result = 31 * result + ipcId
result = 31 * result + udpId
result = 31 * result + pingTimeoutSeconds
result = 31 * result + connectionCloseTimeoutInSeconds
result = 31 * result + connectionCheckIntervalNanos.hashCode()
result = 31 * result + connectionExpirationTimoutNanos.hashCode()
result = 31 * result + isReliable.hashCode()
result = 31 * result + messageDispatch.hashCode()
result = 31 * result + settingsStore.hashCode()
result = 31 * result + serialization.hashCode()
result = 31 * result + maxStreamSizeInMemoryMB

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,24 +15,19 @@
*/
package dorkbox.network
import dorkbox.bytes.toHexString
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronPoller
import dorkbox.network.aeron.EventPoller
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpInfo
import dorkbox.hex.toHexString
import dorkbox.network.aeron.*
import dorkbox.network.connection.*
import dorkbox.network.connection.IpInfo.Companion.IpListenType
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import dorkbox.network.connectionType.ConnectionRule
import dorkbox.network.connection.buffer.BufferManager
import dorkbox.network.exceptions.ServerException
import dorkbox.network.handshake.ServerHandshake
import dorkbox.network.handshake.ServerHandshakePollers
import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.network.rmi.RmiSupportServer
import kotlinx.coroutines.runBlocking
import mu.KotlinLogging
import org.slf4j.LoggerFactory
import java.net.InetAddress
import java.util.concurrent.*
/**
@ -45,82 +40,14 @@ import java.util.concurrent.*
* @param connectionFunc allows for custom connection implementations defined as a unit function
* @param loggerName allows for a custom logger name for this endpoint (for when there are multiple endpoints)
*/
open class Server<CONNECTION : Connection>(
config: ServerConfiguration = ServerConfiguration(),
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
loggerName: String = Server::class.java.simpleName)
: EndPoint<CONNECTION>(config, connectionFunc, loggerName) {
/**
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
* server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections())
*
* To put it bluntly, ONLY have the server do work inside a listener!
*
* @param config these are the specific connection options
* @param loggerName allows for a custom logger name for this endpoint (for when there are multiple endpoints)
* @param connectionFunc allows for custom connection implementations defined as a unit function
*/
constructor(config: ServerConfiguration,
loggerName: String,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION)
: this(config, connectionFunc, loggerName)
/**
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
* server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections())
*
* To put it bluntly, ONLY have the server do work inside of a listener!
*
* @param config these are the specific connection options
* @param connectionFunc allows for custom connection implementations defined as a unit function
*/
constructor(config: ServerConfiguration,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION)
: this(config, connectionFunc, Server::class.java.simpleName)
/**
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
* server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections())
*
* To put it bluntly, ONLY have the server do work inside a listener!
*
* @param config these are the specific connection options
* @param loggerName allows for a custom logger name for this endpoint (for when there are multiple endpoints)
*/
constructor(config: ServerConfiguration,
loggerName: String = Server::class.java.simpleName)
: this(config,
{
@Suppress("UNCHECKED_CAST")
Connection(it) as CONNECTION
},
loggerName)
/**
* The server can only be accessed in an ASYNC manner. This means that the server can only be used in RESPONSE to events. If you access the
* server OUTSIDE of events, you will get inaccurate information from the server (such as getConnections())
*
* To put it bluntly, ONLY have the server do work inside a listener!
*
* @param config these are the specific connection options
*/
constructor(config: ServerConfiguration)
: this(config,
{
@Suppress("UNCHECKED_CAST")
Connection(it) as CONNECTION
},
Server::class.java.simpleName)
open class Server<CONNECTION : Connection>(config: ServerConfiguration = ServerConfiguration(), loggerName: String = Server::class.java.simpleName)
: EndPoint<CONNECTION>(config, loggerName) {
companion object {
/**
* Gets the version number.
*/
const val version = "6.4"
const val version = Configuration.version
/**
* Ensures that an endpoint (using the specified configuration) is NO LONGER running.
@ -131,11 +58,11 @@ open class Server<CONNECTION : Connection>(
*
* @return true if the media driver is STOPPED.
*/
fun ensureStopped(configuration: ServerConfiguration): Boolean = runBlocking {
fun ensureStopped(configuration: ServerConfiguration): Boolean {
val timeout = TimeUnit.SECONDS.toMillis(configuration.connectionCloseTimeoutInSeconds.toLong() * 2)
val logger = KotlinLogging.logger(Server::class.java.simpleName)
AeronDriver.ensureStopped(configuration.copy(), logger, timeout)
val logger = LoggerFactory.getLogger(Server::class.java.simpleName)
return AeronDriver.ensureStopped(configuration.copy(), logger, timeout)
}
/**
@ -145,9 +72,9 @@ open class Server<CONNECTION : Connection>(
*
* @return true if the media driver is active and running
*/
fun isRunning(configuration: ServerConfiguration): Boolean = runBlocking {
val logger = KotlinLogging.logger(Server::class.java.simpleName)
AeronDriver.isRunning(configuration.copy(), logger)
fun isRunning(configuration: ServerConfiguration): Boolean {
val logger = LoggerFactory.getLogger(Server::class.java.simpleName)
return AeronDriver.isRunning(configuration.copy(), logger)
}
init {
@ -161,10 +88,10 @@ open class Server<CONNECTION : Connection>(
*/
val rmiGlobal = RmiSupportServer(logger, rmiGlobalSupport)
/**
* Maintains a thread-safe collection of rules used to define the connection type with this server.
*/
private val connectionRules = CopyOnWriteArrayList<ConnectionRule>()
// /**
// * Maintains a thread-safe collection of rules used to define the connection type with this server.
// */
// private val connectionRules = CopyOnWriteArrayList<ConnectionRule>()
/**
* the IP address information, if available.
@ -175,14 +102,18 @@ open class Server<CONNECTION : Connection>(
internal lateinit var handshake: ServerHandshake<CONNECTION>
/**
* The machine port that the server will listen for connections on
* Different connections (to the same client) can be "buffered", meaning that if they "go down" because of a network glitch -- the data
* being sent is not lost (it is buffered) and then re-sent once the new connection is established. References to the old connection
* will also redirect to the new connection.
*/
@Volatile
var port: Int = 0
private set
internal val bufferedManager: BufferManager<CONNECTION>
private val string0: String by lazy {
"EndPoint [Server: $${storage.publicKey!!.toHexString()}]"
"EndPoint [Server: ${storage.publicKey.toHexString()}]"
}
init {
bufferedManager = BufferManager(config, listenerManager, aeronDriver, config.bufferedConnectionTimeoutSeconds)
}
final override fun newException(message: String, cause: Throwable?): Throwable {
@ -192,139 +123,214 @@ open class Server<CONNECTION : Connection>(
return serverException
}
init {
verifyState()
/**
* Binds the server IPC only, using the previously set AERON configuration
*/
fun bindIpc() {
if (!config.enableIpc) {
logger.warn("IPC explicitly requested, but not enabled. Enabling IPC...")
// we explicitly requested IPC, make sure it's enabled
config.contextDefined = false
config.enableIpc = true
config.contextDefined = true
}
if (config.enableIPv4) { logger.warn("IPv4 is enabled, but only IPC will be used.") }
if (config.enableIPv6) { logger.warn("IPv6 is enabled, but only IPC will be used.") }
internalBind(port1 = 0, port2 = 0, onlyBindIpc = true, runShutdownCheck = true)
}
/**
* Binds the server to AERON configuration
* Binds the server to UDP ports, using the previously set AERON configuration
*
* @param port this is the network port which will be listening for incoming connections
* @param port1 this is the network port which will be listening for incoming connections
* @param port2 this is the network port that the server will use to work around NAT firewalls. By default, this is port1+1, but
* can also be configured independently. This is required, and must be different from port1.
*/
@Suppress("DuplicatedCode")
fun bind(port: Int = 0) = runBlocking {
// NOTE: it is critical to remember that Aeron DOES NOT like running from coroutines!
fun bind(port1: Int, port2: Int = port1+1) {
if (config.enableIPv4 || config.enableIPv6) {
require(port1 != port2) { "port1 cannot be the same as port2" }
require(port1 > 0) { "port1 must be > 0" }
require(port2 > 0) { "port2 must be > 0" }
require(port1 < 65535) { "port1 must be < 65535" }
require(port2 < 65535) { "port2 must be < 65535" }
}
require(port > 0 || config.enableIpc) { "port must be > 0" }
require(port < 65535) { "port must be < 65535" }
internalBind(port1 = port1, port2 = port2, onlyBindIpc = false, runShutdownCheck = true)
}
@Suppress("DuplicatedCode")
private fun internalBind(port1: Int, port2: Int, onlyBindIpc: Boolean, runShutdownCheck: Boolean) {
// the lifecycle of a server is the ENDPOINT (measured via the network event poller)
if (endpointIsRunning.value) {
listenerManager.notifyError(ServerException("Unable to start, the server is already running!"))
return@runBlocking
return
}
if (!waitForClose()) {
if (runShutdownCheck && !waitForEndpointShutdown()) {
listenerManager.notifyError(ServerException("Unable to start the server!"))
return@runBlocking
return
}
try {
startDriver()
verifyState()
initializeState()
} catch (e: Exception) {
}
catch (e: Exception) {
resetOnError()
listenerManager.notifyError(ServerException("Unable to start the server!", e))
return@runBlocking
return
}
this@Server.port = port
this@Server.port1 = port1
this@Server.port2 = port2
config as ServerConfiguration
// we are done with initial configuration, now initialize aeron and the general state of this endpoint
val server = this@Server
handshake = ServerHandshake(config, listenerManager, aeronDriver)
handshake = ServerHandshake(config, listenerManager, aeronDriver, eventDispatch)
val ipcPoller: AeronPoller = if (config.enableIpc) {
val ipcPoller: AeronPoller = if (config.enableIpc || onlyBindIpc) {
ServerHandshakePollers.ipc(server, handshake)
} else {
ServerHandshakePollers.disabled("IPC Disabled")
}
val ipPoller = when (ipInfo.ipType) {
// IPv6 will bind to IPv4 wildcard as well, so don't bind both!
IpListenType.IPWildcard -> ServerHandshakePollers.ip6Wildcard(server, handshake)
IpListenType.IPv4Wildcard -> ServerHandshakePollers.ip4(server, handshake)
IpListenType.IPv6Wildcard -> ServerHandshakePollers.ip6(server, handshake)
IpListenType.IPv4 -> ServerHandshakePollers.ip4(server, handshake)
IpListenType.IPv6 -> ServerHandshakePollers.ip6(server, handshake)
IpListenType.IPC -> ServerHandshakePollers.disabled("IPv4/6 Disabled")
val ipPoller = if (onlyBindIpc) {
ServerHandshakePollers.disabled("IPv4/6 Disabled")
} else {
when (ipInfo.ipType) {
// IPv6 will bind to IPv4 wildcard as well, so don't bind both!
IpListenType.IPWildcard -> ServerHandshakePollers.ip6Wildcard(server, handshake)
IpListenType.IPv4Wildcard -> ServerHandshakePollers.ip4(server, handshake)
IpListenType.IPv6Wildcard -> ServerHandshakePollers.ip6(server, handshake)
IpListenType.IPv4 -> ServerHandshakePollers.ip4(server, handshake)
IpListenType.IPv6 -> ServerHandshakePollers.ip6(server, handshake)
IpListenType.IPC -> ServerHandshakePollers.disabled("IPv4/6 Disabled")
}
}
logger.info { ipcPoller.info }
logger.info { ipPoller.info }
logger.info(ipcPoller.info)
logger.info(ipPoller.info)
// if we shutdown/close before the poller starts, we don't want to block forever
pollerClosedLatch = CountDownLatch(1)
networkEventPoller.submit(
action = {
if (!shutdownEventPoller) {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
action = object : EventActionOperator {
override fun invoke(): Int {
return if (!shutdownEventPoller) {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
// this checks to see if there are NEW clients to handshake with
var pollCount = ipcPoller.poll() + ipPoller.poll()
// this checks to see if there are NEW clients to handshake with
var pollCount = ipcPoller.poll() + ipPoller.poll()
// this manages existing clients (for cleanup + connection polling). This has a concurrent iterator,
// so we can modify this as we go
connections.forEach { connection ->
if (!connection.isClosedViaAeron()) {
// Otherwise, poll the connection for messages
pollCount += connection.poll()
} else {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
logger.debug { "[${connection}] connection expired (cleanup)" }
// this manages existing clients (for cleanup + connection polling). This has a concurrent iterator,
// so we can modify this as we go
connections.forEach { connection ->
if (connection.canPoll()) {
// Otherwise, poll the connection for messages
pollCount += connection.poll()
} else {
// If the connection has either been closed, or has expired, it needs to be cleaned-up/deleted.
if (logger.isDebugEnabled) {
logger.debug("[${connection}] connection expired (cleanup)")
}
// the connection MUST be removed in the same thread that is processing events (it will be removed again in close, and that is expected)
removeConnection(connection)
// the connection MUST be removed in the same thread that is processing events (it will be removed again in close, and that is expected)
removeConnection(connection)
// we already removed the connection, we can call it again without side affects
connection.close()
// we already removed the connection, we can call it again without side effects
connection.close()
}
}
pollCount
} else {
// remove ourselves from processing
EventPoller.REMOVE
}
}
},
onClose = object : EventCloseOperator {
override fun invoke() {
val mustRestartDriverOnError = aeronDriver.internal.mustRestartDriverOnError
logger.debug("Server event dispatch closing...")
ipcPoller.close()
ipPoller.close()
// clear all the handshake info
handshake.clear()
// we only need to run shutdown methods if there was a network outage or D/C
if (!shutdownInProgress.value) {
// this is because we restart automatically on driver errors
this@Server.close(closeEverything = false, sendDisconnectMessage = true, releaseWaitingThreads = !mustRestartDriverOnError)
}
if (mustRestartDriverOnError) {
logger.error("Critical driver error detected, restarting server.")
eventDispatch.CLOSE.launch {
waitForEndpointShutdown()
// also wait for everyone else to shutdown!!
aeronDriver.internal.endPointUsages.forEach {
if (it !== this@Server) {
it.waitForEndpointShutdown()
}
}
// if we restart/reconnect too fast, errors from the previous run will still be present!
aeronDriver.delayLingerTimeout()
val p1 = this@Server.port1
val p2 = this@Server.port2
if (p1 == 0 && p2 == 0) {
internalBind(port1 = 0, port2 = 0, onlyBindIpc = true, runShutdownCheck = false)
} else {
internalBind(port1 = p1, port2 = p2, onlyBindIpc = false, runShutdownCheck = false)
}
}
}
pollCount
} else {
// remove ourselves from processing
EventPoller.REMOVE
// we can now call bind again
endpointIsRunning.lazySet(false)
logger.debug("Closed the Network Event Poller task.")
pollerClosedLatch.countDown()
}
},
onShutdown = {
logger.debug { "Server event dispatch closing..." }
ipcPoller.close()
ipPoller.close()
// clear all the handshake info
handshake.clear()
// we can now call bind again
endpointIsRunning.lazySet(false)
pollerClosedLatch.countDown()
logger.debug { "Closed the Network Event Poller..." }
})
}
/**
* Adds an IP+subnet rule that defines what type of connection this IP+subnet should have.
* - NOTHING : Nothing happens to the in/out bytes
* - COMPRESS: The in/out bytes are compressed with LZ4-fast
* - COMPRESS_AND_ENCRYPT: The in/out bytes are compressed (LZ4-fast) THEN encrypted (AES-256-GCM)
*
* If no rules are defined, then for LOOPBACK, it will always be `COMPRESS` and for everything else it will always be `COMPRESS_AND_ENCRYPT`.
*
* If rules are defined, then everything by default is `COMPRESS_AND_ENCRYPT`.
*
* The compression algorithm is LZ4-fast, so there is a small performance impact for a very large gain
* Compress : 6.210 micros/op; 629.0 MB/s (output: 55.4%)
* Uncompress : 0.641 micros/op; 6097.9 MB/s
*/
fun addConnectionRules(vararg rules: ConnectionRule) {
connectionRules.addAll(listOf(*rules))
}
// /**
// * Adds an IP+subnet rule that defines what type of connection this IP+subnet should have.
// * - NOTHING : Nothing happens to the in/out bytes
// * - COMPRESS: The in/out bytes are compressed with LZ4-fast
// * - COMPRESS_AND_ENCRYPT: The in/out bytes are compressed (LZ4-fast) THEN encrypted (AES-256-GCM)
// *
// * If no rules are defined, then for LOOPBACK, it will always be `COMPRESS` and for everything else it will always be `COMPRESS_AND_ENCRYPT`.
// *
// * If rules are defined, then everything by default is `COMPRESS_AND_ENCRYPT`.
// *
// * The compression algorithm is LZ4-fast, so there is a small performance impact for a very large gain
// * Compress : 6.210 micros/op; 629.0 MB/s (output: 55.4%)
// * Uncompress : 0.641 micros/op; 6097.9 MB/s
// */
// fun addConnectionRules(vararg rules: ConnectionRule) {
// connectionRules.addAll(listOf(*rules))
// }
/**
* Adds an IP+subnet rule that defines if that IP+subnet is allowed/denied connectivity to this server.
@ -335,11 +341,11 @@ open class Server<CONNECTION : Connection>(
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*
* @param ipFilterRule the IpFilterRule to determine if this connection will be allowed to connect
*/
fun filter(ipFilterRule: IpFilterRule) {
runBlocking {
listenerManager.filter(ipFilterRule)
}
listenerManager.filter(ipFilterRule)
}
/**
@ -359,11 +365,33 @@ open class Server<CONNECTION : Connection>(
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*
* @param function clientAddress: UDP connection address
* tagName: the connection tag name
*/
fun filter(function: CONNECTION.() -> Boolean) {
runBlocking {
listenerManager.filter(function)
}
fun filter(function: (clientAddress: InetAddress, tagName: String) -> Boolean) {
listenerManager.filter(function)
}
/**
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if buffered messages
* for a connection should be enabled
*
* By default, if there are no rules, then all connections will have buffered messages enabled
* If there are rules - then ONLY connections for the rule that returns true will have buffered messages enabled (all else are disabled)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the buffered messages for a connection are enabled.
* If the function returns FALSE, then the buffered messages for a connection is disabled.
*
* If ANY rule that is applied returns true, then the buffered messages for a connection are enabled
*
* @param function clientAddress: not-null when UDP connection, null when IPC connection
* tagName: the connection tag name
*/
fun enableBufferedMessages(function: (clientAddress: InetAddress?, tagName: String) -> Boolean) {
listenerManager.enableBufferedMessages(function)
}
/**
@ -391,18 +419,14 @@ open class Server<CONNECTION : Connection>(
* @param closeEverything if true, all parts of the server will be closed (listeners, driver, event polling, etc)
*/
fun close(closeEverything: Boolean = true) {
runBlocking {
close(closeEverything = closeEverything, initiatedByClientClose = false, initiatedByShutdown = false)
}
bufferedManager.close()
close(closeEverything = closeEverything, sendDisconnectMessage = true, releaseWaitingThreads = true)
}
override fun toString(): String {
return string0
}
/**
* Enable
*/
fun <R> use(block: (Server<CONNECTION>) -> R): R {
return try {
block(this)

View File

@ -17,8 +17,11 @@
package dorkbox.network.aeron
import dorkbox.network.Configuration
import dorkbox.network.exceptions.AeronDriverException
import dorkbox.util.Sys
import io.aeron.driver.MediaDriver
import io.aeron.exceptions.DriverTimeoutException
import org.slf4j.Logger
import java.io.Closeable
import java.io.File
import java.util.concurrent.*
@ -29,7 +32,7 @@ import java.util.concurrent.*
* @throws IllegalStateException if the configuration has already been used to create a context
* @throws IllegalArgumentException if the aeron media driver directory cannot be setup
*/
internal class AeronContext(config: Configuration.MediaDriverConfig, aeronErrorHandler: (Throwable) -> Unit) : Closeable {
internal class AeronContext(config: Configuration.MediaDriverConfig, logger: Logger, aeronErrorHandler: (Throwable) -> Unit) : Closeable {
companion object {
private fun create(config: Configuration.MediaDriverConfig, aeronErrorHandler: (Throwable) -> Unit): MediaDriver.Context {
// LOW-LATENCY SETTINGS
@ -52,15 +55,16 @@ internal class AeronContext(config: Configuration.MediaDriverConfig, aeronErrorH
val mediaDriverContext = MediaDriver.Context()
.termBufferSparseFile(false) // files occupy the same space virtually AND physically!
.useWindowsHighResTimer(true)
// we assign our OWN ID! so we reserve everything.
.publicationReservedSessionIdLow(AeronDriver.RESERVED_SESSION_ID_LOW)
.publicationReservedSessionIdHigh(AeronDriver.RESERVED_SESSION_ID_HIGH)
.threadingMode(config.threadingMode)
.mtuLength(config.networkMtuSize)
.ipcMtuLength(config.ipcMtuSize)
.initialWindowLength(config.initialWindowLength)
.socketSndbufLength(config.sendBufferSize)
.socketRcvbufLength(config.receiveBufferSize)
.conductorThreadFactory(threadFactory)
.receiverThreadFactory(threadFactory)
@ -68,6 +72,28 @@ internal class AeronContext(config: Configuration.MediaDriverConfig, aeronErrorH
.sharedNetworkThreadFactory(threadFactory)
.sharedThreadFactory(threadFactory)
if (config.sendBufferSize > 0) {
mediaDriverContext.socketSndbufLength(config.sendBufferSize)
}
if (config.receiveBufferSize > 0) {
mediaDriverContext.socketRcvbufLength(config.receiveBufferSize)
}
if (config.conductorIdleStrategy != null) {
mediaDriverContext.conductorIdleStrategy(config.conductorIdleStrategy)
}
if (config.sharedIdleStrategy != null) {
mediaDriverContext.sharedIdleStrategy(config.sharedIdleStrategy)
}
if (config.receiverIdleStrategy != null) {
mediaDriverContext.receiverIdleStrategy(config.receiverIdleStrategy)
}
if (config.senderIdleStrategy != null) {
mediaDriverContext.senderIdleStrategy(config.senderIdleStrategy)
}
mediaDriverContext.aeronDirectoryName(config.aeronDirectory!!.path)
if (config.ipcTermBufferLength > 0) {
@ -125,7 +151,11 @@ internal class AeronContext(config: Configuration.MediaDriverConfig, aeronErrorH
private fun isRunning(context: MediaDriver.Context): Boolean {
// if the media driver is running, it will be a quick connection. Usually 100ms or so
return context.isDriverActive(context.driverTimeoutMs()) { }
return try {
context.isDriverActive(context.driverTimeoutMs()) { }
} catch (e: Exception) {
false
}
}
init {
@ -144,14 +174,16 @@ internal class AeronContext(config: Configuration.MediaDriverConfig, aeronErrorH
// sometimes when starting up, if a PREVIOUS run was corrupted (during startup, for example)
// we ONLY do this during the initial startup check because it will delete the directory, and we don't always want to do this.
//
var isRunning = try {
val isRunning = try {
context.isDriverActive(driverTimeout) { }
} catch (e: DriverTimeoutException) {
// we have to delete the directory, since it was corrupted, and we try again.
if (aeronDir.deleteRecursively()) {
if (!config.forceAllowSharedAeronDriver && aeronDir.deleteRecursively()) {
context.isDriverActive(driverTimeout) { }
} else if (config.forceAllowSharedAeronDriver) {
// we are expecting a shared directory. SOMETHING is screwed up!
throw AeronDriverException("Aeron was expected to be running, and the current location is corrupted. Not doing anything!", e)
} else {
// unable to delete the directory
throw e
@ -167,29 +199,20 @@ internal class AeronContext(config: Configuration.MediaDriverConfig, aeronErrorH
// if we are not CURRENTLY running, then we should ALSO delete it when we are done!
context.dirDeleteOnShutdown()
} else {
} else if (!config.forceAllowSharedAeronDriver) {
// maybe it's a mistake because we restarted too quickly! A brief pause to fix this!
// wait for it to close!
val timeoutInNanos = TimeUnit.SECONDS.toMillis(config.connectionCloseTimeoutInSeconds.toLong())
val closeTimeoutTime = System.nanoTime()
while (isRunning(context) && System.nanoTime() - closeTimeoutTime < timeoutInNanos) {
Thread.sleep(timeoutInNanos)
val timeoutInNs = TimeUnit.SECONDS.toNanos(config.connectionCloseTimeoutInSeconds.toLong()) + context.publicationLingerTimeoutNs()
val timeoutInMs = TimeUnit.NANOSECONDS.toMillis(timeoutInNs)
logger.warn("Aeron is currently running, waiting ${Sys.getTimePrettyFull(timeoutInNs)} for it to close.")
// wait for it to close! wait longer.
val startTime = System.nanoTime()
while (isRunning(context) && System.nanoTime() - startTime < timeoutInNs) {
Thread.sleep(timeoutInMs)
}
isRunning = try {
context.isDriverActive(driverTimeout) { }
} catch (e: DriverTimeoutException) {
// we have to delete the directory, since it was corrupted, and we try again.
if (aeronDir.deleteRecursively()) {
context.isDriverActive(driverTimeout) { }
} else {
// unable to delete the directory
throw e
}
}
require(!isRunning || config.forceAllowSharedAeronDriver) { "Aeron is currently running, and this is the first instance created by this JVM. " +
require(!isRunning(context)) { "Aeron is currently running, and this is the first instance created by this JVM. " +
"You must use `config.forceAllowSharedAeronDriver` to be able to re-use a shared aeron process at: $aeronDir" }
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,31 +21,39 @@ package dorkbox.network.aeron
import dorkbox.collections.IntMap
import dorkbox.netUtil.IPv6
import dorkbox.network.Configuration
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.ListenerManager.Companion.cleanAllStackTrace
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTraceInternal
import dorkbox.network.exceptions.AllocationException
import dorkbox.network.handshake.RandomId65kAllocator
import dorkbox.network.serialization.AeronOutput
import dorkbox.util.Sys
import io.aeron.*
import io.aeron.driver.reports.LossReportReader
import io.aeron.driver.reports.LossReportUtil
import io.aeron.samples.SamplesUtil
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KLogger
import mu.KotlinLogging
import org.agrona.DirectBuffer
import org.agrona.IoUtil
import org.agrona.SemanticVersion
import io.aeron.logbuffer.BufferClaim
import io.aeron.protocol.DataHeaderFlyweight
import kotlinx.atomicfu.AtomicBoolean
import org.agrona.*
import org.agrona.concurrent.AtomicBuffer
import org.agrona.concurrent.IdleStrategy
import org.agrona.concurrent.UnsafeBuffer
import org.agrona.concurrent.errors.ErrorLogReader
import org.agrona.concurrent.ringbuffer.RingBufferDescriptor
import org.agrona.concurrent.status.CountersReader
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.File
import java.io.IOException
import java.io.RandomAccessFile
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import java.util.concurrent.locks.*
import kotlin.concurrent.read
import kotlin.concurrent.write
fun ChannelUriStringBuilder.endpoint(isIpv4: Boolean, addressString: String, port: Int): ChannelUriStringBuilder {
this.endpoint(AeronDriver.address(isIpv4, addressString, port))
@ -59,11 +67,7 @@ fun ChannelUriStringBuilder.endpoint(isIpv4: Boolean, addressString: String, por
/**
* Class for managing the Aeron+Media drivers
*/
class AeronDriver private constructor(config: Configuration, val logger: KLogger, val endPoint: EndPoint<*>?) {
constructor(config: Configuration, logger: KLogger) : this(config, logger, null)
constructor(endPoint: EndPoint<*>) : this(endPoint.config, endPoint.logger, endPoint)
class AeronDriver(config: Configuration, val logger: Logger, val endPoint: EndPoint<*>?) {
companion object {
/**
@ -87,24 +91,42 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
// prevents multiple instances, within the same JVM, from starting at the exact same time.
private val lock = Mutex()
private val lock = ReentrantReadWriteLock()
// have to keep track of configurations and drivers, as we do not want to start the same media driver configuration multiple times (this causes problems!)
internal val driverConfigurations = IntMap<AeronDriverInternal>(4)
fun new(endPoint: EndPoint<*>): AeronDriver {
var driver: AeronDriver?
lock.write {
driver = AeronDriver(endPoint.config, endPoint.logger, endPoint)
}
return driver!!
}
fun withLock(action: () -> Unit) {
lock.write {
action()
}
}
/**
* Ensures that an endpoint (using the specified configuration) is NO LONGER running.
*
* @return true if the media driver is STOPPED.
*/
suspend fun ensureStopped(configuration: Configuration, logger: KLogger, timeout: Long): Boolean {
fun ensureStopped(configuration: Configuration, logger: Logger, timeout: Long): Boolean {
if (!isLoaded(configuration.copy(), logger)) {
return true
}
val stopped = AeronDriver(configuration, logger, null).use {
it.ensureStopped(timeout, 500)
var stopped = false
lock.write {
stopped = AeronDriver(configuration, logger, null).use {
it.ensureStopped(timeout, 500)
}
}
// hacky, but necessary for multiple checks
@ -118,15 +140,15 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if the media driver is loaded.
*/
suspend fun isLoaded(configuration: Configuration, logger: KLogger): Boolean {
fun isLoaded(configuration: Configuration, logger: Logger): Boolean {
// not EVERYTHING is used for the media driver. For ** REUSING ** the media driver, only care about those specific settings
val mediaDriverConfig = getDriverConfig(configuration, logger)
// assign the driver for this configuration. THIS IS GLOBAL for a JVM, because for a specific configuration, aeron only needs to be initialized ONCE.
// we have INSTANCE of the "wrapper" AeronDriver, because we want to be able to have references to the logger when doing things,
// however - the code that actually does stuff is a "singleton" in regard to an aeron configuration
return lock.withLock {
driverConfigurations.get(mediaDriverConfig.id) != null
return lock.read {
driverConfigurations[mediaDriverConfig.mediaDriverId()] != null
}
}
@ -135,9 +157,12 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if the media driver is active and running
*/
suspend fun isRunning(configuration: Configuration, logger: KLogger): Boolean {
val running = AeronDriver(configuration, logger).use {
it.isRunning()
fun isRunning(configuration: Configuration, logger: Logger): Boolean {
var running = false
lock.read {
running = AeronDriver(configuration, logger, null).use {
it.isRunning()
}
}
return running
@ -146,25 +171,17 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
/**
* @return true if all JVM tracked Aeron drivers are closed, false otherwise
*/
suspend fun areAllInstancesClosed(logger: Logger): Boolean {
val logger1 = KotlinLogging.logger(logger)
return areAllInstancesClosed(logger1)
}
/**
* @return true if all JVM tracked Aeron drivers are closed, false otherwise
*/
suspend fun areAllInstancesClosed(logger: KLogger = KotlinLogging.logger(AeronDriver::class.java.simpleName)): Boolean {
return lock.withLock {
fun areAllInstancesClosed(logger: Logger = LoggerFactory.getLogger(AeronDriver::class.java.simpleName)): Boolean {
return lock.read {
val traceEnabled = logger.isTraceEnabled
driverConfigurations.forEach { entry ->
val driver = entry.value
val closed = if (traceEnabled) driver.isInUse(logger) else driver.isRunning()
val closed = if (traceEnabled) driver.isInUse(null, logger) else driver.isRunning()
if (closed) {
logger.error { "Aeron Driver [${driver.driverId}]: still running during check (${driver.aeronDirectory})" }
return@withLock false
logger.error( "Aeron Driver [${driver.driverId}]: still running during check (${driver.aeronDirectory})")
return@read false
}
}
@ -172,9 +189,9 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
// this is already checked if we are in trace mode.
driverConfigurations.forEach { entry ->
val driver = entry.value
if (driver.isInUse(logger)) {
logger.error { "Aeron Driver [${driver.driverId}]: still in use during check (${driver.aeronDirectory})" }
return@withLock false
if (driver.isInUse(null, logger)) {
logger.error("Aeron Driver [${driver.driverId}]: still in use during check (${driver.aeronDirectory})")
return@read false
}
}
}
@ -183,10 +200,34 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
}
}
/**
* @return the error code text for the specified number
*/
internal fun errorCodeName(result: Long): String {
return when (result) {
// The publication is not connected to a subscriber, this can be an intermittent state as subscribers come and go.
Publication.NOT_CONNECTED -> "Not connected"
// The offer failed due to back pressure from the subscribers preventing further transmission.
Publication.BACK_PRESSURED -> "Back pressured"
// The action is an operation such as log rotation which is likely to have succeeded by the next retry attempt.
Publication.ADMIN_ACTION -> "Administrative action"
// The Publication has been closed and should no longer be used.
Publication.CLOSED -> "Publication is closed"
// If this happens then the publication should be closed and a new one added. To make it less likely to happen then increase the term buffer length.
Publication.MAX_POSITION_EXCEEDED -> "Maximum term position exceeded"
else -> throw IllegalStateException("Unknown error code: $result")
}
}
private fun aeronCounters(aeronLocation: File): CountersReader? {
val resolve = aeronLocation.resolve("cnc.dat")
return if (resolve.exists()) {
val cncByteBuffer = SamplesUtil.mapExistingFileReadOnly(resolve)
val cncByteBuffer = mapExistingFileReadOnly(resolve)
val cncMetaDataBuffer: DirectBuffer = CncFileDescriptor.createMetaDataBuffer(cncByteBuffer)
CountersReader(
@ -228,7 +269,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
* @return the number of errors for the Aeron driver
*/
fun driverErrors(aeronLocation: File, errorAction: (observationCount: Int, firstObservationTimestamp: Long, lastObservationTimestamp: Long, encodedException: String) -> Unit): Int {
val errorMmap = SamplesUtil.mapExistingFileReadOnly(aeronLocation.resolve("cnc.dat"))
val errorMmap = mapExistingFileReadOnly(aeronLocation.resolve("cnc.dat"))
try {
val buffer: AtomicBuffer = CommonContext.errorLogBuffer(errorMmap)
@ -257,7 +298,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
val lossReportFile = aeronLocation.resolve(LossReportUtil.LOSS_REPORT_FILE_NAME)
return if (lossReportFile.exists()) {
val mappedByteBuffer = SamplesUtil.mapExistingFileReadOnly(lossReportFile)
val mappedByteBuffer = mapExistingFileReadOnly(lossReportFile)
val buffer: AtomicBuffer = UnsafeBuffer(mappedByteBuffer)
LossReportReader.read(buffer, lossStats)
@ -270,7 +311,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
* @return the internal heartbeat of the Aeron driver in the specified aeron directory
*/
fun driverHeartbeatMs(aeronLocation: File): Long {
val cncByteBuffer = SamplesUtil.mapExistingFileReadOnly(aeronLocation.resolve("cnc.dat"))
val cncByteBuffer = mapExistingFileReadOnly(aeronLocation.resolve("cnc.dat"))
val cncMetaDataBuffer: DirectBuffer = CncFileDescriptor.createMetaDataBuffer(cncByteBuffer)
val toDriverBuffer = CncFileDescriptor.createToDriverBuffer(cncByteBuffer, cncMetaDataBuffer)
@ -283,7 +324,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
* @return the internal version of the Aeron driver in the specified aeron directory
*/
fun driverVersion(aeronLocation: File): String {
val cncByteBuffer = SamplesUtil.mapExistingFileReadOnly(aeronLocation.resolve("cnc.dat"))
val cncByteBuffer = mapExistingFileReadOnly(aeronLocation.resolve("cnc.dat"))
val cncMetaDataBuffer: DirectBuffer = CncFileDescriptor.createMetaDataBuffer(cncByteBuffer)
val cncVersion = cncMetaDataBuffer.getInt(CncFileDescriptor.cncVersionOffset(0))
@ -380,7 +421,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
internal fun getDriverConfig(config: Configuration, logger: KLogger): Configuration.MediaDriverConfig {
internal fun getDriverConfig(config: Configuration, logger: Logger): Configuration.MediaDriverConfig {
val mediaDriverConfig = Configuration.MediaDriverConfig(config)
// this happens more than once! (this is ok)
@ -388,16 +429,40 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
mediaDriverConfig.validate()
require(!config.contextDefined) { "Aeron configuration has already been initialized, unable to reuse this configuration!" }
require(!config.contextDefined) { "Aeron configuration [${config.mediaDriverId()}] has already been initialized, unable to reuse this configuration!" }
// cannot make any more changes to the configuration!
config.initialize(logger)
// technically possible, but practically unlikely because of the different values calculated
require(mediaDriverConfig.id != 0) { "There has been a severe error when calculating the media configuration ID. Aborting" }
require(mediaDriverConfig.mediaDriverId() != 0) { "There has been a severe error when calculating the media configuration ID. Aborting" }
return mediaDriverConfig
}
/**
* Map an existing file as a read only buffer.
*
* @param location of file to map.
* @return the mapped file.
*/
fun mapExistingFileReadOnly(location: File): MappedByteBuffer? {
if (!location.exists()) {
val msg = "file not found: " + location.absolutePath
throw IllegalStateException(msg)
}
var mappedByteBuffer: MappedByteBuffer? = null
try {
RandomAccessFile(location, "r").use { file ->
file.channel.use { channel ->
mappedByteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, 0, channel.size())
}
}
} catch (ex: IOException) {
LangUtil.rethrowUnchecked(ex)
}
return mappedByteBuffer
}
}
@ -411,38 +476,38 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
// assign the driver for this configuration. THIS IS GLOBAL for a JVM, because for a specific configuration, aeron only needs to be initialized ONCE.
// we have INSTANCE of the "wrapper" AeronDriver, because we want to be able to have references to the logger when doing things,
// however - the code that actually does stuff is a "singleton" in regard to an aeron configuration
internal = runBlocking {
lock.withLock {
val driverId = mediaDriverConfig.id
val driverId = mediaDriverConfig.mediaDriverId()
var driver = driverConfigurations.get(driverId)
if (driver == null) {
driver = AeronDriverInternal(endPoint, mediaDriverConfig)
logger.debug("Aeron Driver [$driverId]: Initializing...")
val aeronDriver = driverConfigurations.get(driverId)
if (aeronDriver == null) {
val driver = AeronDriverInternal(endPoint, mediaDriverConfig, logger)
driverConfigurations.put(driverId, driver)
driverConfigurations.put(driverId, driver)
// register a logger so that we are notified when there is an error in Aeron
driver.addError {
logger.error(this) { "Aeron Driver [$driverId]: error!" }
}
if (logEverything) {
logger.debug { "Aeron Driver [$driverId]: Creating at '${driver.aeronDirectory}'" }
}
} else {
if (logEverything) {
logger.debug { "Aeron Driver [$driverId]: Reusing driver" }
}
// assign our endpoint to the driver
driver.addEndpoint(endPoint)
}
driver
// register a logger so that we are notified when there is an error in Aeron
driver.addError {
logger.error("Aeron Driver [$driverId]: error!", this)
}
if (logEverything && logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Creating at '${driver.aeronDirectory}'")
}
internal = driver
} else {
if (logEverything && logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Reusing driver")
}
// assign our endpoint to the driver
aeronDriver.addEndpoint(endPoint)
internal = aeronDriver
}
}
/**
* This does TWO things
* - start the media driver if not already running
@ -450,7 +515,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if we are successfully connected to the aeron client
*/
suspend fun start()= lock.withLock {
fun start() = lock.write {
internal.start(logger)
}
@ -459,14 +524,13 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
* ESPECIALLY if it is with the same streamID
*
* The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
*
* this check is in the "reconnect" logic
*/
suspend fun waitForConnection(
fun waitForConnection(
shutdown: AtomicBoolean,
publication: Publication,
handshakeTimeoutNs: Long,
logInfo: String,
onErrorHandler: suspend (Throwable) -> Exception
onErrorHandler: (Throwable) -> Exception
) {
if (publication.isConnected) {
return
@ -478,13 +542,64 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
if (publication.isConnected) {
return
}
if (shutdown.value) {
break
}
delay(200L)
Thread.sleep(200L)
}
close(publication, logInfo)
var closeException: Exception? = null
try {
// we might not be able to close this connection.
close(publication, logInfo)
}
catch (e: Exception) {
closeException = e
}
val exception = onErrorHandler(Exception("Aeron Driver [${internal.driverId}]: Publication timed out in ${Sys.getTimePrettyFull(handshakeTimeoutNs)} while waiting for connection state: ${publication.channel()} streamId=${publication.streamId()}"))
val exception = onErrorHandler(Exception("Aeron Driver [${internal.driverId}]: Publication timed out in ${Sys.getTimePrettyFull(handshakeTimeoutNs)} while waiting for connection state: ${publication.channel()} streamId=${publication.streamId()}", closeException))
exception.cleanAllStackTrace()
throw exception
}
/**
* For subscriptions, in the client we want to guarantee that the remote server has connected BACK to us!
*/
fun waitForConnection(
shutdown: AtomicBoolean,
subscription: Subscription,
handshakeTimeoutNs: Long,
logInfo: String,
onErrorHandler: (Throwable) -> Exception
) {
if (subscription.isConnected) {
return
}
val startTime = System.nanoTime()
while (System.nanoTime() - startTime < handshakeTimeoutNs) {
if (subscription.isConnected && subscription.imageCount() > 0) {
return
}
if (shutdown.value) {
break
}
Thread.sleep(200L)
}
var closeException: Exception? = null
try {
// we might not be able to close this connection.
close(subscription, logInfo)
}
catch (e: Exception) {
closeException = e
}
val exception = onErrorHandler(Exception("Aeron Driver [${internal.driverId}]: Subscription timed out in ${Sys.getTimePrettyFull(handshakeTimeoutNs)} while waiting for connection state: ${subscription.channel()} streamId=${subscription.streamId()}", closeException))
exception.cleanAllStackTrace()
throw exception
}
@ -496,7 +611,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* The publication returned is thread-safe.
*/
suspend fun addPublication(publicationUri: ChannelUriStringBuilder, streamId: Int, logInfo: String, isIpc: Boolean): Publication {
fun addPublication(publicationUri: ChannelUriStringBuilder, streamId: Int, logInfo: String, isIpc: Boolean): Publication {
return internal.addPublication(logger, publicationUri, streamId, logInfo, isIpc)
}
@ -507,7 +622,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* This is not a thread-safe publication!
*/
suspend fun addExclusivePublication(publicationUri: ChannelUriStringBuilder, streamId: Int, logInfo: String, isIpc: Boolean): Publication {
fun addExclusivePublication(publicationUri: ChannelUriStringBuilder, streamId: Int, logInfo: String, isIpc: Boolean): Publication {
return internal.addExclusivePublication(logger, publicationUri, streamId, logInfo, isIpc)
}
@ -520,7 +635,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
* {@link Aeron.Context#availableImageHandler(AvailableImageHandler)} and
* {@link Aeron.Context#unavailableImageHandler(UnavailableImageHandler)} from the {@link Aeron.Context}.
*/
suspend fun addSubscription(subscriptionUri: ChannelUriStringBuilder, streamId: Int, logInfo: String, isIpc: Boolean): Subscription {
fun addSubscription(subscriptionUri: ChannelUriStringBuilder, streamId: Int, logInfo: String, isIpc: Boolean): Subscription {
return internal.addSubscription(logger, subscriptionUri, streamId, logInfo, isIpc)
}
@ -532,7 +647,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* This can throw exceptions!
*/
suspend fun close(publication: Publication, logInfo: String) {
fun close(publication: Publication, logInfo: String) {
internal.close(publication, logger, logInfo)
}
@ -541,7 +656,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* This can throw exceptions!
*/
suspend fun close(subscription: Subscription, logInfo: String) {
fun close(subscription: Subscription, logInfo: String) {
internal.close(subscription, logger, logInfo)
}
@ -551,7 +666,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if the media driver is STOPPED.
*/
suspend fun ensureStopped(timeoutMS: Long, intervalTimeoutMS: Long): Boolean =
fun ensureStopped(timeoutMS: Long, intervalTimeoutMS: Long): Boolean =
internal.ensureStopped(timeoutMS, intervalTimeoutMS, logger)
/**
@ -571,9 +686,9 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if the media driver was explicitly closed
*/
suspend fun closed() = internal.closed()
fun closed() = internal.closed()
suspend fun isInUse(): Boolean = internal.isInUse(logger)
fun isInUse(endPoint: EndPoint<*>?): Boolean = internal.isInUse(endPoint, logger)
/**
* @return the aeron media driver log file for a specific publication.
@ -666,7 +781,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
/**
* Make sure that we DO NOT approach the Aeron linger timeout!
*/
suspend fun delayLingerTimeout(multiplier: Number) = internal.delayLingerTimeout(multiplier.toDouble())
fun delayLingerTimeout(multiplier: Number = 1) = internal.delayLingerTimeout(multiplier.toDouble())
/**
* A safer way to try to close the media driver if in the ENTIRE JVM, our process is the only one using aeron with it's specific configuration
@ -676,8 +791,8 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if the driver was successfully stopped.
*/
suspend fun closeIfSingle(): Boolean = lock.withLock {
if (!isInUse()) {
fun closeIfSingle(): Boolean = lock.write {
if (!isInUse(endPoint)) {
if (logEverything) {
internal.close(endPoint, logger)
} else {
@ -700,7 +815,7 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
*
* @return true if the driver was successfully stopped.
*/
suspend fun close(): Boolean = lock.withLock {
fun close(): Boolean = lock.write {
if (logEverything) {
internal.close(endPoint, logger)
} else {
@ -708,11 +823,261 @@ class AeronDriver private constructor(config: Configuration, val logger: KLogger
}
}
suspend fun <R> use(block: suspend (AeronDriver) -> R): R {
fun <R> use(block: (AeronDriver) -> R): R {
return try {
block(this)
} finally {
close()
}
}
/**
* NOTE: This cannot be on a coroutine, because our kryo instances are NOT threadsafe!
*
* the actual bits that send data on the network.
*
* There is a maximum length allowed for messages which is the min of 1/8th a term length or 16MB.
* Messages larger than this should chunked using an application level chunking protocol. Chunking has better recovery
* properties from failure and streams with mechanical sympathy.
*
* This can be overridden if you want to customize exactly how data is sent on the network
*
* @param publication the connection specific publication
* @param internalBuffer the internal buffer that will be copied to the Aeron network driver
* @param offset the offset in the internal buffer at which to start copying bytes
* @param objectSize the number of bytes to copy (starting at the offset)
* @param connection the connection object
*
* @return true if the message was successfully sent by aeron, false otherwise. Exceptions are caught and NOT rethrown!
*/
internal fun <CONNECTION: Connection> send(
publication: Publication,
internalBuffer: MutableDirectBuffer,
bufferClaim: BufferClaim,
offset: Int,
objectSize: Int,
sendIdleStrategy: IdleStrategy,
connection: Connection,
abortEarly: Boolean,
listenerManager: ListenerManager<CONNECTION>
): Boolean {
var result: Long
while (true) {
// The maximum claimable length is given by the maxPayloadLength() function, which is the MTU length less header (with defaults this is 1,376 bytes).
result = publication.tryClaim(objectSize, bufferClaim)
if (result >= 0) {
// success!
try {
// both .offer and .putBytes add bytes to the underlying termBuffer -- HOWEVER, putBytes is faster as there are no
// extra checks performed BECAUSE we have to do our own data fragmentation management.
// It doesn't make sense to use `.offer`, which ALSO has its own fragmentation handling (which is extra overhead for us)
bufferClaim.buffer().putBytes(DataHeaderFlyweight.HEADER_LENGTH, internalBuffer, offset, objectSize)
return true
} catch (e: Exception) {
logger.error("Error adding data to aeron buffer.", e)
return false
} finally {
// must commit() or abort() before the unblock timeout (default 15 seconds) occurs.
bufferClaim.commit()
}
}
if (internal.mustRestartDriverOnError) {
logger.error("Critical error, not able to send data.")
// there were critical errors. Don't even try anything! we will reconnect automatically (on the client) when it shuts-down (the connection is closed immediately when an error of this type is encountered
// aeron will likely report this is as "BACK PRESSURE"
return false
}
/**
* Since the publication is not connected, we weren't able to send data to the remote endpoint.
*/
val endPoint = endPoint!!
if (result == Publication.NOT_CONNECTED) {
if (abortEarly) {
val exception = endPoint.newException(
"[${publication.sessionId()}] Unable to send message. (Connection in non-connected state, aborted attempt! ${errorCodeName(result)})"
)
listenerManager.notifyError(exception)
return false
}
else if (publication.isConnected) {
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Error sending message. (Connection in non-connected state longer than linger timeout. ${errorCodeName(result)})"
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
listenerManager.notifyError(exception)
return false
}
else {
// by default, we BUFFER data on a connection -- so the message will be placed into a queue to be re-sent once the connection comes back
// no extra actions required by us.
// Returning a "false" here makes sure that the session manager picks-up this message to e-broadcast (eventually) on the updated connection
return false
}
}
/**
* The publication is not connected to a subscriber, this can be an intermittent state as subscribers come and go.
* val NOT_CONNECTED: Long = -1
*
* The offer failed due to back pressure from the subscribers preventing further transmission.
* val BACK_PRESSURED: Long = -2
*
* The offer failed due to an administration action and should be retried.
* The action is an operation such as log rotation which is likely to have succeeded by the next retry attempt.
* val ADMIN_ACTION: Long = -3
*/
if (result >= Publication.ADMIN_ACTION) {
// we should retry, BUT we want to block ANYONE ELSE trying to write at the same time!
sendIdleStrategy.idle()
continue
}
if (result == Publication.CLOSED && connection.isClosed()) {
// this can happen when we use RMI to close a connection. RMI will (in most cases) ALWAYS send a response when it's
// done executing. If the connection is *closed* first (because an RMI method closed it), then we will not be able to
// send the message.
return false
}
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Error sending message. (${errorCodeName(result)})"
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
listenerManager.notifyError(exception)
return false
}
}
/**
* NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine!
* CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
* Server -> will be network polling thread
* Client -> will be thread that calls `connect()`
*
* @return true if the message was successfully sent by aeron
*/
internal fun <CONNECTION: Connection> send(
publication: Publication,
buffer: AeronOutput,
logInfo: String,
listenerManager: ListenerManager<CONNECTION>,
handshakeSendIdleStrategy: IdleStrategy
): Boolean {
val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer
var result: Long
while (true) {
result = publication.offer(internalBuffer, 0, objectSize)
if (result >= 0) {
// success!
return true
}
if (internal.mustRestartDriverOnError) {
// there were critical errors. Don't even try anything! we will reconnect automatically (on the client) when it shuts-down (the connection is closed immediately when an error of this type is encountered
// aeron will likely report this is as "BACK PRESSURE"
return false
}
/**
* Since the publication is not connected, we weren't able to send data to the remote endpoint.
*
* According to Aeron Docs, Pubs and Subs can "come and go", whatever that means. We just want to make sure that we
* don't "loop forever" if a publication is ACTUALLY closed, like on purpose.
*/
val endPoint = endPoint!!
if (result == Publication.NOT_CONNECTED) {
if (publication.isConnected) {
// more critical error sending the message. we shouldn't retry or anything.
// this exception will be a ClientException or a ServerException
val exception = endPoint.newException(
"[$logInfo] Error sending message. (Connection in non-connected state longer than linger timeout. ${errorCodeName(result)})",
null
)
exception.cleanStackTraceInternal()
listenerManager.notifyError(exception)
throw exception
}
else {
// publication was actually closed, so no bother throwing an error
return false
}
}
/**
* The publication is not connected to a subscriber, this can be an intermittent state as subscribers come and go.
* val NOT_CONNECTED: Long = -1
*
* The offer failed due to back pressure from the subscribers preventing further transmission.
* val BACK_PRESSURED: Long = -2
*
* The offer failed due to an administration action and should be retried.
* The action is an operation such as log rotation which is likely to have succeeded by the next retry attempt.
* val ADMIN_ACTION: Long = -3
*/
if (result >= Publication.ADMIN_ACTION) {
// we should retry.
handshakeSendIdleStrategy.idle()
continue
}
if (result == Publication.CLOSED) {
// this can happen when we use RMI to close a connection. RMI will (in most cases) ALWAYS send a response when it's
// done executing. If the connection is *closed* first (because an RMI method closed it), then we will not be able to
// send the message.
return false
}
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Error sending message. (${errorCodeName(result)})"
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
listenerManager.notifyError(exception)
return false
}
}
fun newIfClosed(): AeronDriver {
endPoint!!
var driver: AeronDriver? = null
withLock {
driver = if (closed()) {
// Only starts the media driver if we are NOT already running!
try {
AeronDriver(endPoint.config, endPoint.logger, endPoint)
} catch (e: Exception) {
throw endPoint.newException("Error initializing aeron driver", e)
}
} else {
this
}
}
return driver!!
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package dorkbox.network.aeron
import dorkbox.collections.ConcurrentIterator
import dorkbox.collections.LockFreeHashSet
import dorkbox.network.Configuration
import dorkbox.network.connection.EndPoint
@ -29,39 +30,43 @@ import io.aeron.*
import io.aeron.driver.MediaDriver
import io.aeron.status.ChannelEndpointStatus
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KLogger
import mu.KotlinLogging
import org.agrona.DirectBuffer
import org.agrona.concurrent.BackoffIdleStrategy
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.File
import java.io.IOException
import java.net.BindException
import java.net.SocketException
import java.util.concurrent.*
import java.util.concurrent.locks.*
import kotlin.concurrent.write
internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: Configuration.MediaDriverConfig) {
internal class AeronDriverInternal(endPoint: EndPoint<*>?, config: Configuration.MediaDriverConfig, logger: Logger) {
companion object {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
private const val AERON_PUBLICATION_LINGER_TIMEOUT = 5_000L // in MS
private val driverLogger = KotlinLogging.logger(AeronDriver::class.java.simpleName)
private const val AERON_PUB_SUB_TIMEOUT = 50L // in MS
private val driverLogger = LoggerFactory.getLogger(AeronDriver::class.java.simpleName)
private val onErrorGlobalList = atomic(Array<Throwable.() -> Unit>(0) { { } })
private val onErrorGlobalMutex = Mutex()
private val onErrorGlobalLock = ReentrantReadWriteLock()
/**
* Called when there is an Aeron error
*/
suspend fun onError(function: Throwable.() -> Unit) {
onErrorGlobalMutex.withLock {
fun onError(function: Throwable.() -> Unit) {
onErrorGlobalLock.write {
// we have to follow the single-writer principle!
onErrorGlobalList.lazySet(ListenerManager.add(function, onErrorGlobalList.value))
}
}
private suspend fun removeOnError(function: Throwable.() -> Unit) {
onErrorGlobalMutex.withLock {
private fun removeOnError(function: Throwable.() -> Unit) {
onErrorGlobalLock.write {
// we have to follow the single-writer principle!
onErrorGlobalList.lazySet(ListenerManager.remove(function, onErrorGlobalList.value))
}
@ -76,7 +81,6 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
fun notifyError(exception: Throwable) {
onErrorGlobalList.value.forEach {
try {
driverLogger.error(exception) { "Aeron error!" }
it(exception)
} catch (t: Throwable) {
// NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
@ -92,16 +96,16 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
}
}
val driverId = config.id
val driverId = config.mediaDriverId()
private val endPointUsages = mutableListOf<EndPoint<*>>()
internal val endPointUsages = ConcurrentIterator<EndPoint<*>>()
@Volatile
private var aeron: Aeron? = null
private var mediaDriver: MediaDriver? = null
private val onErrorLocalList = mutableListOf<Throwable.() -> Unit>()
private val onErrorLocalMutex = Mutex()
private val onErrorLocalLock = ReentrantReadWriteLock()
private val context: AeronContext
private val aeronErrorHandler: (Throwable) -> Unit
@ -112,11 +116,21 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
private val registeredPublicationsTrace: LockFreeHashSet<Long> = LockFreeHashSet()
private val registeredSubscriptionsTrace: LockFreeHashSet<Long> = LockFreeHashSet()
private val stateMutex = Mutex()
private val stateLock = ReentrantReadWriteLock()
/**
* Checks to see if there are any critical network errors (for example, a VPN connection getting disconnected while running)
*/
@Volatile
internal var mustRestartDriverOnError = false
@Volatile
private var closedTime = 0L
@Volatile
private var closed = false
suspend fun closed(): Boolean = stateMutex.withLock {
fun closed(): Boolean {
return closed
}
@ -129,7 +143,68 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// configure the aeron error handler
val filter = config.aeronErrorFilter
aeronErrorHandler = { error ->
if (filter(error)) {
// NOTE: this is an error callback for MANY things, MOST of them are ASYNC! This means that a messages can successfully be ADDED
// to aeron, but NOT successfully sent over the network.
// this is bad! We must close this connection. THIS WILL BE CALLED AS FAST AS THE CPU CAN RUN (because of how aeron works).
if (!mustRestartDriverOnError) {
var restartNetwork = false
// if the network interface is removed (for example, a VPN connection).
if (error is io.aeron.exceptions.ChannelEndpointException ||
error.cause is BindException ||
error.cause is SocketException ||
error.cause is IOException) {
restartNetwork = true
if (error.message?.startsWith("ERROR - channel error - Network is unreachable") == true) {
val exception = AeronDriverException("Aeron Driver [$driverId]: Network is disconnected or unreachable.")
exception.cleanAllStackTrace()
notifyError(exception)
} else if (error.message?.startsWith("WARN - failed to send") == true) {
val exception = AeronDriverException("Aeron Driver [$driverId]: Network socket error, can't send data.")
exception.cleanAllStackTrace()
notifyError(exception)
}
else if (error.message == "Can't assign requested address") {
val exception = AeronDriverException("Aeron Driver [$driverId]: Network socket error, can't assign requested address.")
exception.cleanAllStackTrace()
notifyError(exception)
} else {
error.cleanStackTrace()
// send this out to the listener-manager so we can be notified of global errors
notifyError(AeronDriverException("Aeron Driver [$driverId]: Unexpected error!", error.cause))
}
}
else if (error is io.aeron.exceptions.AeronException) {
if (error.message?.startsWith("ERROR - unexpected close of heartbeat timestamp counter:") == true) {
restartNetwork = true
val exception = AeronDriverException("Aeron Driver [$driverId]: HEARTBEAT error, can't continue.")
exception.cleanAllStackTrace()
notifyError(exception)
}
}
if (restartNetwork) {
notifyError(AeronDriverException("Critical network error internal to the Aeron Driver, restarting network!").cleanAllStackTrace())
// this must be set before anything else happens
mustRestartDriverOnError = true
// close will make sure to run on a different thread
endPointUsages.forEach {
// we cannot send the DC message because the network layer has issues!
it.close(closeEverything = false, sendDisconnectMessage = false, releaseWaitingThreads = false)
}
}
}
// if we are restarting the network, ignore all future messages
if (!mustRestartDriverOnError && filter(error)) {
error.cleanStackTrace()
// send this out to the listener-manager so we can be notified of global errors
notifyError(AeronDriverException(error))
@ -138,7 +213,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// @throws IllegalStateException if the configuration has already been used to create a context
// @throws IllegalArgumentException if the aeron media driver directory cannot be setup
context = AeronContext(config, aeronErrorHandler)
context = AeronContext(config, logger, aeronErrorHandler)
addEndpoint(endPoint)
}
@ -153,19 +228,22 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
}
suspend fun addError(function: Throwable.() -> Unit) {
fun addError(function: Throwable.() -> Unit) {
// always add this to the global one
onError(function)
// this is so we can track all the added error listeners (and removed them when we close, since the DRIVER has a global list)
onErrorLocalMutex.withLock {
onErrorLocalLock.write {
onErrorLocalList.add(function)
}
}
private suspend fun removeErrors() = onErrorLocalMutex.withLock {
onErrorLocalList.forEach {
removeOnError(it)
private fun removeErrors() {
onErrorLocalLock.write {
mustRestartDriverOnError = false
onErrorLocalList.forEach {
removeOnError(it)
}
}
}
@ -176,12 +254,14 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
*
* @return true if we are successfully connected to the aeron client
*/
suspend fun start(logger: KLogger): Boolean = stateMutex.withLock {
fun start(logger: Logger): Boolean = stateLock.write {
require(!closed) { "Aeron Driver [$driverId]: Cannot start a driver that was closed. A new driver + context must be created" }
val isLoaded = mediaDriver != null && aeron != null && aeron?.isClosed == false
if (isLoaded) {
logger.debug { "Aeron Driver [$driverId]: Already running... Not starting again." }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Already running... Not starting again.")
}
return true
}
@ -193,8 +273,10 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// wait for a bit, because we are running, but we ALSO issued a START, and expect it to start.
// SOMETIMES aeron is in the middle of shutting down, and this prevents us from trying to connect to
// that instance
logger.debug { "Aeron Driver [$driverId]: Already running. Double checking status..." }
delay(context.driverTimeout / 2)
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Already running. Double checking status...")
}
Thread.sleep(context.driverTimeout / 2)
running = isRunning()
}
@ -204,20 +286,22 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
while (count-- > 0) {
try {
mediaDriver = MediaDriver.launch(context.context)
logger.debug { "Aeron Driver [$driverId]: Successfully started" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Successfully started")
}
break
} catch (e: Exception) {
logger.warn(e) { "Aeron Driver [$driverId]: Unable to start at ${context.directory}. Retrying $count more times..." }
delay(context.driverTimeout)
logger.warn("Aeron Driver [$driverId]: Unable to start at ${context.directory}. Retrying $count more times...", e)
Thread.sleep(context.driverTimeout)
}
}
} else {
logger.debug { "Aeron Driver [$driverId]: Not starting. It was already running." }
} else if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Not starting. It was already running.")
}
// if we were unable to load the aeron driver, don't continue.
if (!running && mediaDriver == null) {
logger.error { "Aeron Driver [$driverId]: Not running and unable to start at ${context.directory}." }
logger.error("Aeron Driver [$driverId]: Not running and unable to start at ${context.directory}.")
return false
}
}
@ -243,7 +327,9 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// this might succeed if we can connect to the media driver
aeron = Aeron.connect(aeronDriverContext)
logger.debug { "Aeron Driver [$driverId]: Connected to '${context.directory}'" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Connected to '${context.directory}'")
}
return true
}
@ -257,13 +343,13 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
* The publication returned is threadsafe.
*/
@Suppress("DEPRECATION")
suspend fun addPublication(
logger: KLogger,
fun addPublication(
logger: Logger,
publicationUri: ChannelUriStringBuilder,
streamId: Int,
logInfo: String,
isIpc: Boolean
): Publication = stateMutex.withLock {
): Publication = stateLock.write {
val uri = publicationUri.build()
@ -283,7 +369,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
val aeron1 = aeron
if (aeron1 == null || aeron1.isClosed) {
logger.error { "Aeron Driver [$driverId]: Aeron is closed, error creating publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=$streamId" }
logger.error("Aeron Driver [$driverId]: Aeron is closed, error creating publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=$streamId")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding a publication to aeron")
ex.cleanAllStackTrace()
@ -293,7 +379,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
val publication: ConcurrentPublication? = try {
aeron1.addPublication(uri, streamId)
} catch (e: Exception) {
logger.error(e) { "Aeron Driver [$driverId]: Error creating publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=$streamId" }
logger.error("Aeron Driver [$driverId]: Error creating publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=$streamId", e)
// this happens if the aeron media driver cannot actually establish connection... OR IF IT IS TOO FAST BETWEEN ADD AND REMOVE FOR THE SAME SESSION/STREAM ID!
e.cleanAllStackTrace()
@ -303,7 +389,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
}
if (publication == null) {
logger.error { "Aeron Driver [$driverId]: Error creating publication (is null) [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=$streamId" }
logger.error("Aeron Driver [$driverId]: Error creating publication (is null) [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=$streamId")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding a publication")
@ -313,16 +399,27 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
var hasDelay = false
while (publication.channelStatus() != ChannelEndpointStatus.ACTIVE || (!isIpc && publication.localSocketAddresses().isEmpty())) {
if (publication.channelStatus() == ChannelEndpointStatus.ERRORED) {
logger.error("Aeron Driver [$driverId]: Error creating publication (has errors) $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding an publication")
ex.cleanAllStackTrace()
throw ex
}
if (!hasDelay) {
hasDelay = true
logger.debug { "Aeron Driver [$driverId]: Delaying creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Delaying creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
}
}
// the publication has not ACTUALLY been created yet!
delay(50)
Thread.sleep(AERON_PUB_SUB_TIMEOUT)
}
if (hasDelay) {
logger.debug { "Aeron Driver [$driverId]: Delayed creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
if (hasDelay && logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Delayed creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
}
@ -331,7 +428,9 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
registeredPublicationsTrace.add(publication.registrationId())
}
logger.trace { "Aeron Driver [$driverId]: Creating publication [$logInfo] :: regId=${publication.registrationId()}, sessionId=${publication.sessionId()}, streamId=${publication.streamId()}, channel=${publication.channel()}" }
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: Creating publication [$logInfo] :: regId=${publication.registrationId()}, sessionId=${publication.sessionId()}, streamId=${publication.streamId()}, channel=${publication.channel()}")
}
return publication
}
@ -343,12 +442,12 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
* This is not a thread-safe publication!
*/
@Suppress("DEPRECATION")
suspend fun addExclusivePublication(
logger: KLogger,
fun addExclusivePublication(
logger: Logger,
publicationUri: ChannelUriStringBuilder,
streamId: Int,
logInfo: String,
isIpc: Boolean): Publication = stateMutex.withLock {
isIpc: Boolean): Publication = stateLock.write {
val uri = publicationUri.build()
@ -366,7 +465,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
val aeron1 = aeron
if (aeron1 == null || aeron1.isClosed) {
logger.error { "Aeron Driver [$driverId]: Aeron is closed, error creating ex-publication $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
logger.error("Aeron Driver [$driverId]: Aeron is closed, error creating ex-publication $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding an ex-publication to aeron")
@ -379,7 +478,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
val publication: ExclusivePublication? = try {
aeron1.addExclusivePublication(uri, streamId)
} catch (e: Exception) {
logger.error(e) { "Aeron Driver [$driverId]: Error creating ex-publication $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
logger.error("Aeron Driver [$driverId]: Error creating ex-publication $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}", e)
// this happens if the aeron media driver cannot actually establish connection... OR IF IT IS TOO FAST BETWEEN ADD AND REMOVE FOR THE SAME SESSION/STREAM ID!
e.cleanAllStackTrace()
@ -389,7 +488,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
}
if (publication == null) {
logger.error { "Aeron Driver [$driverId]: Error creating ex-publication (is null) $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
logger.error("Aeron Driver [$driverId]: Error creating ex-publication (is null) $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding an ex-publication")
@ -399,16 +498,28 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
var hasDelay = false
while (publication.channelStatus() != ChannelEndpointStatus.ACTIVE || (!isIpc && publication.localSocketAddresses().isEmpty())) {
if (publication.channelStatus() == ChannelEndpointStatus.ERRORED) {
logger.error("Aeron Driver [$driverId]: Error creating ex-publication (has errors) $logInfo :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding an ex-publication")
ex.cleanAllStackTrace()
throw ex
}
if (!hasDelay) {
hasDelay = true
logger.debug { "Aeron Driver [$driverId]: Delaying creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Delaying creation of ex-publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
}
}
// the publication has not ACTUALLY been created yet!
delay(50)
Thread.sleep(AERON_PUB_SUB_TIMEOUT)
}
if (hasDelay) {
logger.debug { "Aeron Driver [$driverId]: Delayed creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}" }
logger.debug("Aeron Driver [$driverId]: Delayed creation of publication [$logInfo] :: sessionId=${publicationUri.sessionId()}, streamId=${streamId}")
}
registeredPublications.getAndIncrement()
@ -416,7 +527,9 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
registeredPublicationsTrace.add(publication.registrationId())
}
logger.trace { "Aeron Driver [$driverId]: Creating ex-publication $logInfo :: regId=${publication.registrationId()}, sessionId=${publication.sessionId()}, streamId=${publication.streamId()}, channel=${publication.channel()}" }
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: Creating ex-publication $logInfo :: regId=${publication.registrationId()}, sessionId=${publication.sessionId()}, streamId=${publication.streamId()}, channel=${publication.channel()}")
}
return publication
}
@ -430,12 +543,12 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
* {@link Aeron.Context#unavailableImageHandler(UnavailableImageHandler)} from the {@link Aeron.Context}.
*/
@Suppress("DEPRECATION")
suspend fun addSubscription(
logger: KLogger,
fun addSubscription(
logger: Logger,
subscriptionUri: ChannelUriStringBuilder,
streamId: Int,
logInfo: String,
isIpc: Boolean): Subscription = stateMutex.withLock {
isIpc: Boolean): Subscription = stateLock.write {
val uri = subscriptionUri.build()
@ -457,18 +570,17 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
val aeron1 = aeron
if (aeron1 == null || aeron1.isClosed) {
logger.error { "Aeron Driver [$driverId]: Aeron is closed, error creating subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}" }
logger.error("Aeron Driver [$driverId]: Aeron is closed, error creating subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding a subscription to aeron")
ex.cleanStackTraceInternal()
throw ex
}
val subscription = try {
aeron1.addSubscription(uri, streamId)
} catch (e: Exception) {
logger.error(e) { "Aeron Driver [$driverId]: Error creating subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}" }
logger.error("Aeron Driver [$driverId]: Error creating subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}")
e.cleanAllStackTrace()
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding a subscription", e) // maybe not retry? or not clientRetry?
@ -477,7 +589,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
}
if (subscription == null) {
logger.error { "Aeron Driver [$driverId]: Error creating subscription (is null) [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}" }
logger.error("Aeron Driver [$driverId]: Error creating subscription (is null) [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding a subscription")
@ -487,16 +599,27 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
var hasDelay = false
while (subscription.channelStatus() != ChannelEndpointStatus.ACTIVE || (!isIpc && subscription.localSocketAddresses().isEmpty())) {
if (subscription.channelStatus() == ChannelEndpointStatus.ERRORED) {
logger.error("Aeron Driver [$driverId]: Error creating subscription (has errors) $logInfo :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}")
// there was an error connecting to the aeron client or media driver.
val ex = ClientRetryException("Aeron Driver [$driverId]: Error adding an subscription")
ex.cleanAllStackTrace()
throw ex
}
if (!hasDelay) {
hasDelay = true
logger.debug { "Aeron Driver [$driverId]: Delaying creation of subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Delaying creation of subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}")
}
}
// the subscription has not ACTUALLY been created yet!
delay(50)
Thread.sleep(AERON_PUB_SUB_TIMEOUT)
}
if (hasDelay) {
logger.debug { "Aeron Driver [$driverId]: Delayed creation of subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}" }
if (hasDelay && logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Delayed creation of subscription [$logInfo] :: sessionId=${subscriptionUri.sessionId()}, streamId=${streamId}")
}
registeredSubscriptions.getAndIncrement()
@ -504,14 +627,16 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
registeredSubscriptionsTrace.add(subscription.registrationId())
}
logger.trace { "Aeron Driver [$driverId]: Creating subscription [$logInfo] :: regId=${subscription.registrationId()}, sessionId=${subscriptionUri.sessionId()}, streamId=${subscription.streamId()}, channel=${subscription.channel()}" }
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: Creating subscription [$logInfo] :: regId=${subscription.registrationId()}, sessionId=${subscriptionUri.sessionId()}, streamId=${subscription.streamId()}, channel=${subscription.channel()}")
}
return subscription
}
/**
* Guarantee that the publication is closed AND the backing file is removed
*/
suspend fun close(publication: Publication, logger: KLogger, logInfo: String) = stateMutex.withLock {
fun close(publication: Publication, logger: Logger, logInfo: String) = stateLock.write {
val name = if (publication is ConcurrentPublication) {
"publication"
} else {
@ -520,13 +645,14 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
val registrationId = publication.registrationId()
logger.trace { "Aeron Driver [$driverId]: Closing $name file [$logInfo] :: regId=$registrationId, sessionId=${publication.sessionId()}, streamId=${publication.streamId()}" }
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: Closing $name file [$logInfo] :: regId=$registrationId, sessionId=${publication.sessionId()}, streamId=${publication.streamId()}")
}
val aeron1 = aeron
if (aeron1 == null || aeron1.isClosed) {
val e = Exception("Aeron Driver [$driverId]: Error closing $name [$logInfo] :: sessionId=${publication.sessionId()}, streamId=${publication.streamId()}")
e.cleanStackTraceInternal()
throw e
}
@ -534,20 +660,20 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// This can throw exceptions!
publication.close()
} catch (e: Exception) {
logger.error(e) { "Aeron Driver [$driverId]: Unable to close [$logInfo] $name $publication" }
logger.error("Aeron Driver [$driverId]: Unable to close [$logInfo] $name $publication", e)
}
if (publication is ConcurrentPublication) {
// aeron is async. close() doesn't immediately close, it just submits the close command!
// THIS CAN TAKE A WHILE TO ACTUALLY CLOSE!
while (publication.isConnected || publication.channelStatus() == ChannelEndpointStatus.ACTIVE || aeron1.getPublication(registrationId) != null) {
delay(50)
Thread.sleep(AERON_PUB_SUB_TIMEOUT)
}
} else {
// aeron is async. close() doesn't immediately close, it just submits the close command!
// THIS CAN TAKE A WHILE TO ACTUALLY CLOSE!
while (publication.isConnected || publication.channelStatus() == ChannelEndpointStatus.ACTIVE || aeron1.getExclusivePublication(registrationId) != null) {
delay(50)
Thread.sleep(AERON_PUB_SUB_TIMEOUT)
}
}
@ -562,13 +688,14 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
/**
* Guarantee that the publication is closed AND the backing file is removed
*/
suspend fun close(subscription: Subscription, logger: KLogger, logInfo: String) {
logger.trace { "Aeron Driver [$driverId]: Closing subscription [$logInfo] :: regId=${subscription.registrationId()}, sessionId=${subscription.images().firstOrNull()?.sessionId()}, streamId=${subscription.streamId()}" }
fun close(subscription: Subscription, logger: Logger, logInfo: String) = stateLock.write {
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: Closing subscription [$logInfo] :: regId=${subscription.registrationId()}, sessionId=${subscription.images().firstOrNull()?.sessionId()}, streamId=${subscription.streamId()}")
}
val aeron1 = aeron
if (aeron1 == null || aeron1.isClosed) {
val e = Exception("Aeron Driver [$driverId]: Error closing publication [$logInfo] :: sessionId=${subscription.images().firstOrNull()?.sessionId()}, streamId=${subscription.streamId()}")
e.cleanStackTraceInternal()
val e = Exception("Aeron Driver [$driverId]: Error closing subscription [$logInfo] :: sessionId=${subscription.images().firstOrNull()?.sessionId()}, streamId=${subscription.streamId()}")
throw e
}
@ -576,13 +703,16 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// This can throw exceptions!
subscription.close()
} catch (e: Exception) {
logger.error(e) { "Aeron Driver [$driverId]: Unable to close [$logInfo] subscription $subscription" }
logger.error("Aeron Driver [$driverId]: Unable to close [$logInfo] subscription $subscription")
}
// aeron is async. close() doesn't immediately close, it just submits the close command!
// THIS CAN TAKE A WHILE TO ACTUALLY CLOSE!
while (subscription.isConnected || subscription.channelStatus() == ChannelEndpointStatus.ACTIVE || subscription.images().isNotEmpty()) {
delay(50)
Thread.sleep(AERON_PUB_SUB_TIMEOUT)
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: Still closing sub!")
}
}
// deleting log files is generally not recommended in a production environment as it can result in data loss and potential disruption of the messaging system!!
@ -598,7 +728,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
*
* @return true if the media driver is STOPPED.
*/
suspend fun ensureStopped(timeoutMS: Long, intervalTimeoutMS: Long, logger: KLogger): Boolean {
fun ensureStopped(timeoutMS: Long, intervalTimeoutMS: Long, logger: Logger): Boolean {
if (closed) {
return true
}
@ -611,9 +741,11 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
// only emit the log info once. It's rather spammy otherwise!
if (!didLog) {
didLog = true
logger.debug { "Aeron Driver [$driverId]: Still running (${aeronDirectory}). Waiting for it to stop..." }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Still running (${aeronDirectory}). Waiting for it to stop...")
}
}
delay(intervalTimeoutMS)
Thread.sleep(intervalTimeoutMS)
}
return !isRunning()
@ -636,21 +768,22 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
return context.isRunning()
}
suspend fun isInUse(logger: KLogger): Boolean {
fun isInUse(endPoint: EndPoint<*>?, logger: Logger): Boolean {
// as many "sort-cuts" as we can for checking if the current Aeron Driver/client is still in use
if (!isRunning()) {
logger.trace { "Aeron Driver [$driverId]: not running" }
if (logger.isTraceEnabled) {
logger.trace("Aeron Driver [$driverId]: not running")
}
return false
}
val driverId = config.id
if (registeredPublications.value > 0) {
if (logger.isTraceEnabled) {
val elements = registeredPublicationsTrace.elements
val joined = elements.joinToString()
logger.debug { "Aeron Driver [$driverId]: has [$joined] publications (${registeredPublications.value} total)" }
} else {
logger.debug { "Aeron Driver [$driverId]: has publications (${registeredPublications.value} total)" }
logger.trace("Aeron Driver [$driverId]: has publications: [$joined] (${registeredPublications.value} total)")
} else if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: has publications (${registeredPublications.value} total)")
}
return true
}
@ -659,18 +792,26 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
if (logger.isTraceEnabled) {
val elements = registeredSubscriptionsTrace.elements
val joined = elements.joinToString()
logger.debug { "Aeron Driver [$driverId]: has [$joined] subscriptions (${registeredSubscriptions.value} total)" }
} else {
logger.debug { "Aeron Driver [$driverId]: has subscriptions (${registeredSubscriptions.value} total)" }
logger.trace("Aeron Driver [$driverId]: has subscriptions: [$joined] (${registeredSubscriptions.value} total)")
} else if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: has subscriptions (${registeredSubscriptions.value} total)")
}
return true
}
if (endPointUsages.isNotEmpty()) {
logger.debug { "Aeron Driver [$driverId]: still referenced by ${endPointUsages.size} endpoints" }
if (endPointUsages.size() > 1 && !endPointUsages.contains(endPoint)) {
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: still referenced by ${endPointUsages.size()} endpoints")
}
return true
}
// ignore the extra driver checks, because in SOME situations, when trying to reconnect upon an error, the
// driver gets into a bad state. When this happens, we cannot rely on the driver stat info!
if (mustRestartDriverOnError) {
return false
}
// check to see if we ALREADY have loaded this location.
// null or empty snapshot means that this location is currently unused
// >0 can also happen because the location is old. It's not running, but still has info because it hasn't been cleaned up yet
@ -679,7 +820,9 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
var count = 3
while (count > 0 && currentUsage > 0) {
logger.debug { "Aeron Driver [$driverId]: in use, double checking status" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: usage is: $currentUsage, double checking status")
}
delayLingerTimeout()
currentUsage = driverBacklog()?.snapshot()?.size ?: 0
count--
@ -692,14 +835,16 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
count = 3
while (count > 0 && currentUsage > 0) {
logger.debug { "Aeron Driver [$driverId]: in use, double checking status (long)" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: usage is: $currentUsage, double checking status (long)")
}
delayDriverTimeout()
currentUsage = driverBacklog()?.snapshot()?.size ?: 0
count--
}
if (currentUsage > 0) {
logger.debug { "Aeron Driver [$driverId]: usage is: $currentUsage" }
if (currentUsage > 0 && logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: usage is: $currentUsage")
}
return currentUsage > 0
@ -713,27 +858,44 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
*
* @return true if the driver was successfully stopped.
*/
suspend fun close(endPoint: EndPoint<*>?, logger: KLogger): Boolean = stateMutex.withLock {
val driverId = config.id
fun close(endPoint: EndPoint<*>?, logger: Logger): Boolean = stateLock.write {
if (endPoint != null) {
endPointUsages.remove(endPoint)
}
logger.trace { "Aeron Driver [$driverId]: Requested close... (${endPointUsages.size} endpoints still in use)" }
if (isInUse(logger)) {
logger.debug { "Aeron Driver [$driverId]: in use, not shutting down this instance." }
return@withLock false
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Requested close... (${endPointUsages.size()} endpoints still in use)")
}
val removed = AeronDriver.driverConfigurations.remove(driverId)
// ignore the extra driver checks, because in SOME situations, when trying to reconnect upon an error, the
if (isInUse(endPoint, logger)) {
if (mustRestartDriverOnError) {
// driver gets into a bad state. When this happens, we have to ignore "are we already in use" checks, BECAUSE the driver is now corrupted and unusable!
}
else {
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: in use, not shutting down this instance.")
}
// reset our contextDefine value, so that this configuration can safely be reused
endPoint?.config?.contextDefined = false
return@write false
}
}
val removed = AeronDriver.driverConfigurations[driverId]
if (removed == null) {
logger.debug { "Aeron Driver [$driverId]: already closed. Ignoring close request." }
return@withLock false
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: already closed. Ignoring close request.")
}
// reset our contextDefine value, so that this configuration can safely be reused
endPoint?.config?.contextDefined = false
return@write false
}
logger.debug { "Aeron Driver [$driverId]: Closing..." }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Closing...")
}
// we have to assign context BEFORE we close, because the `getter` for context will create it if necessary
val aeronContext = context
@ -745,7 +907,7 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
if (endPoint != null) {
endPoint.listenerManager.notifyError(AeronDriverException("Aeron Driver [$driverId]: Error stopping", e))
} else {
logger.error(e) { "Aeron Driver [$driverId]: Error stopping" }
logger.error("Aeron Driver [$driverId]: Error stopping", e)
}
}
@ -753,48 +915,52 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
if (mediaDriver == null) {
logger.debug { "Aeron Driver [$driverId]: No driver started, not Stopping." }
return@withLock false
}
logger.debug { "Aeron Driver [$driverId]: Stopping driver at '${driverDirectory}'..." }
if (!isRunning()) {
// not running
logger.debug { "Aeron Driver [$driverId]: is not running at '${driverDirectory}' for this context. Not Stopping." }
return@withLock false
}
// if we are the ones that started the media driver, then we must be the ones to close it
try {
mediaDriver!!.close()
} catch (e: Exception) {
if (endPoint != null) {
endPoint.listenerManager.notifyError(AeronDriverException("Aeron Driver [$driverId]: Error closing", e))
} else {
logger.error(e) { "Aeron Driver [$driverId]: Error closing" }
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: No driver started, not stopping driver or context.")
}
} else {
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Stopping driver at '${driverDirectory}'...")
}
// if we are the ones that started the media driver, then we must be the ones to close it
try {
mediaDriver!!.close()
} catch (e: Exception) {
if (endPoint != null) {
endPoint.listenerManager.notifyError(AeronDriverException("Aeron Driver [$driverId]: Error closing", e))
} else {
logger.error("Aeron Driver [$driverId]: Error closing", e)
}
}
mediaDriver = null
}
mediaDriver = null
// it can actually close faster, if everything is ideal.
val timeout = (aeronContext.driverTimeout + AERON_PUBLICATION_LINGER_TIMEOUT) / 4
// it can actually close faster, if everything is ideal.
if (isRunning()) {
// on close, we want to wait for the driver to timeout before considering it "closed". Connections can still LINGER (see below)
// on close, the publication CAN linger (in case a client goes away, and then comes back)
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
delay(timeout)
}
try {
if (isRunning()) {
// on close, we want to wait for the driver to timeout before considering it "closed". Connections can still LINGER (see below)
// on close, the publication CAN linger (in case a client goes away, and then comes back)
// AERON_PUBLICATION_LINGER_TIMEOUT, 5s by default (this can also be set as a URI param)
Thread.sleep(timeout)
}
// wait for the media driver to actually stop
var count = 10
while (--count >= 0 && isRunning()) {
logger.warn { "Aeron Driver [$driverId]: still running at '${driverDirectory}'. Waiting for it to stop. Trying to close $count more times." }
delay(timeout)
// wait for the media driver to actually stop
var count = 10
while (--count >= 0 && isRunning()) {
logger.warn("Aeron Driver [$driverId]: still running at '${driverDirectory}'. Waiting for it to stop. Trying to close $count more times.")
Thread.sleep(timeout)
}
}
catch (e: Exception) {
if (!mustRestartDriverOnError) {
logger.error("Error while checking isRunning() state.", e)
}
}
// make sure the context is also closed, but ONLY if we are the last one
@ -809,34 +975,40 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
if (endPoint != null) {
endPoint.listenerManager.notifyError(AeronDriverException("Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory"))
} else {
logger.error { "Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory" }
logger.error("Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory")
}
}
} catch (e: Exception) {
if (endPoint != null) {
endPoint.listenerManager.notifyError(AeronDriverException("Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory", e))
} else {
logger.error(e) { "Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory" }
logger.error("Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory", e)
}
}
// check to make sure it's actually deleted
if (driverDirectory.isDirectory) {
if (endPoint != null) {
endPoint.listenerManager.notifyError(AeronDriverException("Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory"))
} else {
logger.error { "Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory" }
logger.error("Aeron Driver [$driverId]: Error deleting Aeron directory at: $driverDirectory")
}
}
logger.debug { "Aeron Driver [$driverId]: Closed the media driver at '${driverDirectory}'" }
// reset our contextDefine value, so that this configuration can safely be reused
config.contextDefined = false
endPoint?.config?.contextDefined = false
// actually remove it, since we've passed all the checks to guarantee it's closed...
AeronDriver.driverConfigurations.remove(driverId)
if (logger.isDebugEnabled) {
logger.debug("Aeron Driver [$driverId]: Closed the media driver at '${driverDirectory}'")
}
closed = true
closedTime = System.nanoTime()
return true
return@write true
}
/**
@ -873,7 +1045,9 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
*/
fun deleteLogFile(image: Image) {
val file = getMediaDriverFile(image)
driverLogger.debug { "Deleting log file: $image" }
if (driverLogger.isDebugEnabled) {
driverLogger.debug("Deleting log file: $image")
}
file.delete()
}
@ -953,15 +1127,27 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
/**
* Make sure that we DO NOT approach the Aeron linger timeout!
*/
suspend fun delayDriverTimeout(multiplier: Number = 1) {
delay((driverTimeout() * multiplier.toDouble()).toLong())
fun delayDriverTimeout(multiplier: Number = 1) {
Thread.sleep((driverTimeout() * multiplier.toDouble()).toLong())
}
/**
* Make sure that we DO NOT approach the Aeron linger timeout!
* Make sure that we DO NOT approach the Aeron linger timeout! If we have already passed it, do nothing.
*/
suspend fun delayLingerTimeout(multiplier: Number = 1) {
delay(driverTimeout().coerceAtLeast(TimeUnit.NANOSECONDS.toSeconds((lingerNs() * multiplier.toDouble()).toLong())) )
fun delayLingerTimeout(multiplier: Number = 1) {
val lingerTimeoutNs = (lingerNs() * multiplier.toDouble()).toLong()
val driverTimeoutSec = driverTimeout().coerceAtLeast(TimeUnit.NANOSECONDS.toSeconds(lingerTimeoutNs))
val driverTimeoutNs = TimeUnit.SECONDS.toNanos(driverTimeoutSec)
val elapsedNs = System.nanoTime() - closedTime
if (elapsedNs >= driverTimeoutNs) {
// timeout already expired, do nothing.
return
}
// not always the full duration, but the duration since the close event
val adjustedTimeoutSec = TimeUnit.NANOSECONDS.toSeconds(driverTimeoutNs - elapsedNs)
Thread.sleep(adjustedTimeoutSec)
}
override fun equals(other: Any?): Boolean {
@ -979,6 +1165,4 @@ internal class AeronDriverInternal(endPoint: EndPoint<*>?, private val config: C
override fun toString(): String {
return "Aeron Driver [${driverId}]"
}
}

View File

@ -18,6 +18,6 @@ package dorkbox.network.aeron
internal interface AeronPoller {
fun poll(): Int
suspend fun close()
fun close()
val info: String
}

View File

@ -1,381 +0,0 @@
/*
* Copyright 2014-2020 Real Logic Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron
import kotlinx.coroutines.delay
import kotlinx.coroutines.yield
import org.agrona.concurrent.BackoffIdleStrategy
import org.agrona.hints.ThreadHints
abstract class BackoffIdleStrategyPrePad {
val p000: Byte = 0
val p001: Byte = 0
val p002: Byte = 0
val p003: Byte = 0
val p004: Byte = 0
val p005: Byte = 0
val p006: Byte = 0
val p007: Byte = 0
val p008: Byte = 0
val p009: Byte = 0
val p010: Byte = 0
val p011: Byte = 0
val p012: Byte = 0
val p013: Byte = 0
val p014: Byte = 0
val p015: Byte = 0
val p016: Byte = 0
val p017: Byte = 0
val p018: Byte = 0
val p019: Byte = 0
val p020: Byte = 0
val p021: Byte = 0
val p022: Byte = 0
val p023: Byte = 0
val p024: Byte = 0
val p025: Byte = 0
val p026: Byte = 0
val p027: Byte = 0
val p028: Byte = 0
val p029: Byte = 0
val p030: Byte = 0
val p031: Byte = 0
val p032: Byte = 0
val p033: Byte = 0
val p034: Byte = 0
val p035: Byte = 0
val p036: Byte = 0
val p037: Byte = 0
val p038: Byte = 0
val p039: Byte = 0
val p040: Byte = 0
val p041: Byte = 0
val p042: Byte = 0
val p043: Byte = 0
val p044: Byte = 0
val p045: Byte = 0
val p046: Byte = 0
val p047: Byte = 0
val p048: Byte = 0
val p049: Byte = 0
val p050: Byte = 0
val p051: Byte = 0
val p052: Byte = 0
val p053: Byte = 0
val p054: Byte = 0
val p055: Byte = 0
val p056: Byte = 0
val p057: Byte = 0
val p058: Byte = 0
val p059: Byte = 0
val p060: Byte = 0
val p061: Byte = 0
val p062: Byte = 0
val p063: Byte = 0
}
abstract class BackoffIdleStrategyData(
protected val maxSpins: Long, protected val maxYields: Long, protected val minParkPeriodMs: Long, protected val maxParkPeriodMs: Long) : BackoffIdleStrategyPrePad() {
protected var state = 0 // NOT_IDLE
protected var spins: Long = 0
protected var yields: Long = 0
protected var parkPeriodMs: Long = 0
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is BackoffIdleStrategyData) return false
if (maxSpins != other.maxSpins) return false
if (maxYields != other.maxYields) return false
if (minParkPeriodMs != other.minParkPeriodMs) return false
if (maxParkPeriodMs != other.maxParkPeriodMs) return false
if (state != other.state) return false
if (spins != other.spins) return false
if (yields != other.yields) return false
if (parkPeriodMs != other.parkPeriodMs) return false
return true
}
override fun hashCode(): Int {
var result = maxSpins.hashCode()
result = 31 * result + maxYields.hashCode()
result = 31 * result + minParkPeriodMs.hashCode()
result = 31 * result + maxParkPeriodMs.hashCode()
result = 31 * result + state
result = 31 * result + spins.hashCode()
result = 31 * result + yields.hashCode()
result = 31 * result + parkPeriodMs.hashCode()
return result
}
}
/**
* Idling strategy for threads when they have no work to do.
* <p>
* Spin for maxSpins, then
* [Coroutine.yield] for maxYields, then
* [Coroutine.delay] on an exponential backoff to maxParkPeriodMs
*/
@Suppress("unused")
class CoroutineBackoffIdleStrategy : BackoffIdleStrategyData, CoroutineIdleStrategy {
val p064: Byte = 0
val p065: Byte = 0
val p066: Byte = 0
val p067: Byte = 0
val p068: Byte = 0
val p069: Byte = 0
val p070: Byte = 0
val p071: Byte = 0
val p072: Byte = 0
val p073: Byte = 0
val p074: Byte = 0
val p075: Byte = 0
val p076: Byte = 0
val p077: Byte = 0
val p078: Byte = 0
val p079: Byte = 0
val p080: Byte = 0
val p081: Byte = 0
val p082: Byte = 0
val p083: Byte = 0
val p084: Byte = 0
val p085: Byte = 0
val p086: Byte = 0
val p087: Byte = 0
val p088: Byte = 0
val p089: Byte = 0
val p090: Byte = 0
val p091: Byte = 0
val p092: Byte = 0
val p093: Byte = 0
val p094: Byte = 0
val p095: Byte = 0
val p096: Byte = 0
val p097: Byte = 0
val p098: Byte = 0
val p099: Byte = 0
val p100: Byte = 0
val p101: Byte = 0
val p102: Byte = 0
val p103: Byte = 0
val p104: Byte = 0
val p105: Byte = 0
val p106: Byte = 0
val p107: Byte = 0
val p108: Byte = 0
val p109: Byte = 0
val p110: Byte = 0
val p111: Byte = 0
val p112: Byte = 0
val p113: Byte = 0
val p114: Byte = 0
val p115: Byte = 0
val p116: Byte = 0
val p117: Byte = 0
val p118: Byte = 0
val p119: Byte = 0
val p120: Byte = 0
val p121: Byte = 0
val p122: Byte = 0
val p123: Byte = 0
val p124: Byte = 0
val p125: Byte = 0
val p126: Byte = 0
val p127: Byte = 0
companion object {
private const val NOT_IDLE = 0
private const val SPINNING = 1
private const val YIELDING = 2
private const val PARKING = 3
/**
* Name to be returned from [.alias].
*/
const val ALIAS = "backoff"
/**
* Default number of times the strategy will spin without work before going to next state.
*/
const val DEFAULT_MAX_SPINS = 10L
/**
* Default number of times the strategy will yield without work before going to next state.
*/
const val DEFAULT_MAX_YIELDS = 5L
/**
* Default interval the strategy will park the thread on entering the park state in milliseconds.
*/
const val DEFAULT_MIN_PARK_PERIOD_MS = 1L
/**
* Default interval the strategy will park the thread will expand interval to as a max in milliseconds.
*/
const val DEFAULT_MAX_PARK_PERIOD_MS = 1000L
}
/**
* Default constructor using [.DEFAULT_MAX_SPINS], [.DEFAULT_MAX_YIELDS], [.DEFAULT_MIN_PARK_PERIOD_MS], and [.DEFAULT_MAX_PARK_PERIOD_MS].
*/
constructor() : super(DEFAULT_MAX_SPINS, DEFAULT_MAX_YIELDS, DEFAULT_MIN_PARK_PERIOD_MS, DEFAULT_MAX_PARK_PERIOD_MS)
/**
* Create a set of state tracking idle behavior
* <p>
* @param maxSpins to perform before moving to [Coroutine.yield]
* @param maxYields to perform before moving to [Coroutine.delay]
* @param minParkPeriodMs to use when initiating parking
* @param maxParkPeriodMs to use for end duration when parking
*/
constructor(maxSpins: Long, maxYields: Long, minParkPeriodMs: Long, maxParkPeriodMs: Long)
: super(maxSpins, maxYields, minParkPeriodMs, maxParkPeriodMs) {
}
/**
* Perform current idle action (e.g. nothing/yield/sleep). This method signature expects users to call into it on
* every work 'cycle'. The implementations may use the indication "workCount &gt; 0" to reset internal backoff
* state. This method works well with 'work' APIs which follow the following rules:
* <ul>
* <li>'work' returns a value larger than 0 when some work has been done</li>
* <li>'work' returns 0 when no work has been done</li>
* <li>'work' may return error codes which are less than 0, but which amount to no work has been done</li>
* </ul>
* <p>
* Callers are expected to follow this pattern:
*
* <pre>
* <code>
* while (isRunning)
* {
* idleStrategy.idle(doWork());
* }
* </code>
* </pre>
*
* @param workCount performed in last duty cycle.
*/
override suspend fun idle(workCount: Int) {
if (workCount > 0) {
reset()
} else {
idle()
}
}
/**
* Perform current idle action (e.g. nothing/yield/sleep). To be used in conjunction with
* {@link IdleStrategy#reset()} to clear internal state when idle period is over (or before it begins).
* Callers are expected to follow this pattern:
*
* <pre>
* <code>
* while (isRunning)
* {
* if (!hasWork())
* {
* idleStrategy.reset();
* while (!hasWork())
* {
* if (!isRunning)
* {
* return;
* }
* idleStrategy.idle();
* }
* }
* doWork();
* }
* </code>
* </pre>
*/
override suspend fun idle() {
when (state) {
NOT_IDLE -> {
state = SPINNING
spins++
}
SPINNING -> {
ThreadHints.onSpinWait()
if (++spins > maxSpins) {
state = YIELDING
yields = 0
}
}
YIELDING -> if (++yields > maxYields) {
state = PARKING
parkPeriodMs = minParkPeriodMs
} else {
yield()
}
PARKING -> {
delay(parkPeriodMs)
// double the delay until we get to MAX
parkPeriodMs = (parkPeriodMs shl 1).coerceAtMost(maxParkPeriodMs)
}
}
}
/**
* Reset the internal state in preparation for entering an idle state again.
*/
override fun reset() {
spins = 0
yields = 0
parkPeriodMs = minParkPeriodMs
state = NOT_IDLE
}
/**
* Simple name by which the strategy can be identified.
*
* @return simple name by which the strategy can be identified.
*/
override fun alias(): String {
return ALIAS
}
/**
* Creates a clone of this IdleStrategy
*/
override fun clone(): CoroutineBackoffIdleStrategy {
return CoroutineBackoffIdleStrategy(maxSpins = maxSpins, maxYields = maxYields, minParkPeriodMs = minParkPeriodMs, maxParkPeriodMs = maxParkPeriodMs)
}
/**
* Creates a clone of this IdleStrategy
*/
override fun cloneToNormal(): BackoffIdleStrategy {
return BackoffIdleStrategy(maxSpins, maxYields, minParkPeriodMs, maxParkPeriodMs)
}
override fun toString(): String {
return "BackoffIdleStrategy{" +
"alias=" + ALIAS +
", maxSpins=" + maxSpins +
", maxYields=" + maxYields +
", minParkPeriodMs=" + minParkPeriodMs +
", maxParkPeriodMs=" + maxParkPeriodMs +
'}'
}
}

View File

@ -1,117 +0,0 @@
/*
* Copyright 2014-2020 Real Logic Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron
import org.agrona.concurrent.IdleStrategy
/**
* Idle strategy for use by threads when they do not have work to do.
*
*
* **Note regarding implementor state**
*
*
* Some implementations are known to be stateful, please note that you cannot safely assume implementations to be
* stateless. Where implementations are stateful it is recommended that implementation state is padded to avoid false
* sharing.
*
*
* **Note regarding potential for TTSP(Time To Safe Point) issues**
*
*
* If the caller spins in a 'counted' loop, and the implementation does not include a a safepoint poll this may cause a
* TTSP (Time To SafePoint) problem. If this is the case for your application you can solve it by preventing the idle
* method from being inlined by using a Hotspot compiler command as a JVM argument e.g:
* `-XX:CompileCommand=dontinline,org.agrona.concurrent.NoOpIdleStrategy::idle`
*/
interface CoroutineIdleStrategy {
/**
* Perform current idle action (e.g. nothing/yield/sleep). This method signature expects users to call into it on
* every work 'cycle'. The implementations may use the indication "workCount &gt; 0" to reset internal backoff
* state. This method works well with 'work' APIs which follow the following rules:
* <ul>
* <li>'work' returns a value larger than 0 when some work has been done</li>
* <li>'work' returns 0 when no work has been done</li>
* <li>'work' may return error codes which are less than 0, but which amount to no work has been done</li>
* </ul>
* <p>
* Callers are expected to follow this pattern:
*
* <pre>
* <code>
* while (isRunning)
* {
* idleStrategy.idle(doWork());
* }
* </code>
* </pre>
*
* @param workCount performed in last duty cycle.
*/
suspend fun idle(workCount: Int)
/**
* Perform current idle action (e.g. nothing/yield/sleep). To be used in conjunction with
* {@link IdleStrategy#reset()} to clear internal state when idle period is over (or before it begins).
* Callers are expected to follow this pattern:
*
* <pre>
* <code>
* while (isRunning)
* {
* if (!hasWork())
* {
* idleStrategy.reset();
* while (!hasWork())
* {
* if (!isRunning)
* {
* return;
* }
* idleStrategy.idle();
* }
* }
* doWork();
* }
* </code>
* </pre>
*/
suspend fun idle()
/**
* Reset the internal state in preparation for entering an idle state again.
*/
fun reset()
/**
* Simple name by which the strategy can be identified.
*
* @return simple name by which the strategy can be identified.
*/
fun alias(): String {
return ""
}
/**
* Creates a clone of this IdleStrategy
*/
fun clone(): CoroutineIdleStrategy
/**
* Creates a clone of this IdleStrategy
*/
fun cloneToNormal(): IdleStrategy
}

View File

@ -1,121 +0,0 @@
/*
* Copyright 2014-2020 Real Logic Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron
import kotlinx.coroutines.delay
import org.agrona.concurrent.SleepingMillisIdleStrategy
/**
* When idle this strategy is to sleep for a specified period time in milliseconds.
*
*
* This class uses [Coroutine.delay] to idle.
*/
class CoroutineSleepingMillisIdleStrategy : CoroutineIdleStrategy {
companion object {
/**
* Name to be returned from [.alias].
*/
const val ALIAS = "sleep-ms"
/**
* Default sleep period when the default constructor is used.
*/
const val DEFAULT_SLEEP_PERIOD_MS = 1L
}
private val sleepPeriodMs: Long
/**
* Default constructor that uses [.DEFAULT_SLEEP_PERIOD_MS].
*/
constructor() {
sleepPeriodMs = DEFAULT_SLEEP_PERIOD_MS
}
/**
* Constructed a new strategy that will sleep for a given period when idle.
*
* @param sleepPeriodMs period in milliseconds for which the strategy will sleep when work count is 0.
*/
constructor(sleepPeriodMs: Long) {
this.sleepPeriodMs = sleepPeriodMs
}
/**
* {@inheritDoc}
*/
override suspend fun idle(workCount: Int) {
if (workCount > 0) {
return
}
delay(sleepPeriodMs)
}
/**
* {@inheritDoc}
*/
override suspend fun idle() {
delay(sleepPeriodMs)
}
/**
* {@inheritDoc}
*/
override fun reset() {}
/**
* {@inheritDoc}
*/
override fun alias(): String {
return ALIAS
}
/**
* Creates a clone of this IdleStrategy
*/
override fun clone(): CoroutineSleepingMillisIdleStrategy {
return CoroutineSleepingMillisIdleStrategy(sleepPeriodMs = sleepPeriodMs)
}
/**
* Creates a clone of this IdleStrategy
*/
override fun cloneToNormal(): SleepingMillisIdleStrategy {
return SleepingMillisIdleStrategy(sleepPeriodMs)
}
override fun toString(): String {
return "SleepingMillisIdleStrategy{" +
"alias=" + ALIAS +
", sleepPeriodMs=" + sleepPeriodMs +
'}'
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is CoroutineSleepingMillisIdleStrategy) return false
if (sleepPeriodMs != other.sleepPeriodMs) return false
return true
}
override fun hashCode(): Int {
return sleepPeriodMs.hashCode()
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,13 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi.messages
/**
* @param rmiId which rmi object was deleted
*/
data class ConnectionObjectDeleteResponse(val rmiId: Int) : RmiMessage {
override fun toString(): String {
return "ConnectionObjectDeleteResponse(id: $rmiId)"
}
package dorkbox.network.aeron
internal interface EventActionOperator {
operator fun invoke(): Int
}

View File

@ -0,0 +1,21 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron
internal interface EventCloseOperator {
operator fun invoke()
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,15 +21,13 @@ import dorkbox.collections.ConcurrentIterator
import dorkbox.network.Configuration
import dorkbox.network.connection.EndPoint
import dorkbox.util.NamedThreadFactory
import dorkbox.util.sync.CountDownLatch
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KLogger
import mu.KotlinLogging
import org.agrona.concurrent.IdleStrategy
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.concurrent.*
import java.util.concurrent.locks.*
import kotlin.concurrent.write
/**
* there are threading issues if there are client(s) and server's within the same JVM, where we have thread starvation
@ -38,69 +36,69 @@ import java.util.concurrent.*
* this among ALL clients within the same JVM so that we can support multiple clients/servers
*/
internal class EventPoller {
private class EventAction(val onAction: EventActionOperator, val onClose: EventCloseOperator)
companion object {
internal const val REMOVE = -1
val eventLogger = KotlinLogging.logger(EventPoller::class.java.simpleName)
private class EventAction(val onAction: suspend ()->Int, val onClose: suspend ()->Unit)
val eventLogger = LoggerFactory.getLogger(EventPoller::class.java.simpleName)
private val pollDispatcher = Executors.newSingleThreadExecutor(
private val pollExecutor = Executors.newSingleThreadExecutor(
NamedThreadFactory("Poll Dispatcher", Configuration.networkThreadGroup, true)
).asCoroutineDispatcher()
)
}
private var configured = false
private lateinit var dispatchScope: CoroutineScope
private lateinit var pollStrategy: CoroutineIdleStrategy
private lateinit var clonedStrategy: IdleStrategy
private lateinit var pollStrategy: IdleStrategy
private var running = true
private var mutex = Mutex()
@Volatile
private var running = false
private var lock = ReentrantReadWriteLock()
// this is thread safe
private val pollEvents = ConcurrentIterator<EventAction>()
private val submitEvents = atomic(0)
private val configureEventsEndpoints = mutableSetOf<ByteArrayWrapper>()
@Volatile
private var delayClose = false
@Volatile
private var shutdownLatch = CountDownLatch(0)
@Volatile
private var threadId = Thread.currentThread().id
private var threadId = 0L
fun inDispatch(): Boolean {
fun isDispatch(): Boolean {
// this only works because we are a single thread dispatch
return threadId == Thread.currentThread().id
}
fun configure(logger: KLogger, config: Configuration, endPoint: EndPoint<*>) = runBlocking {
mutex.withLock {
logger.debug { "Initializing the Network Event Poller..." }
fun configure(logger: Logger, config: Configuration, endPoint: EndPoint<*>) {
lock.write {
if (logger.isDebugEnabled) {
logger.debug("Initializing the Network Event Poller...")
}
configureEventsEndpoints.add(ByteArrayWrapper.wrap(endPoint.storage.publicKey))
if (!configured) {
logger.trace { "Configuring the Network Event Poller..." }
if (logger.isTraceEnabled) {
logger.trace("Configuring the Network Event Poller...")
}
delayClose = false
running = true
configured = true
shutdownLatch = CountDownLatch(1)
pollStrategy = config.pollIdleStrategy
clonedStrategy = config.pollIdleStrategy.cloneToNormal()
dispatchScope = CoroutineScope(pollDispatcher + SupervisorJob())
require(pollDispatcher.isActive) { "Unable to start the event dispatch in the terminated state!" }
dispatchScope.launch {
val pollIdleStrategy = clonedStrategy
pollExecutor.submit {
val pollIdleStrategy = pollStrategy
var pollCount = 0
threadId = Thread.currentThread().id
threadId = Thread.currentThread().id // only ever 1 thread!!!
pollIdleStrategy.reset()
while (running) {
pollEvents.forEachRemovable {
@ -116,31 +114,22 @@ internal class EventPoller {
// remove our event, it is no longer valid
pollEvents.remove(this)
it.onClose() // shutting down
// check to see if we requested a shutdown
if (delayClose) {
doClose()
}
} else if (poll > 0) {
pollCount += poll
}
} catch (e: Exception) {
eventLogger.error(e) { "Unexpected error during Network Event Polling! Aborting event dispatch for it!" }
eventLogger.error("Unexpected error during Network Event Polling! Aborting event dispatch for it!", e)
// remove our event, it is no longer valid
pollEvents.remove(this)
it.onClose() // shutting down
// check to see if we requested a shutdown
if (delayClose) {
doClose()
}
}
}
pollIdleStrategy.idle(pollCount)
}
// now we have to REMOVE all poll events -- so that their remove logic will run.
pollEvents.forEachRemovable {
// remove our event, it is no longer valid
@ -151,7 +140,9 @@ internal class EventPoller {
shutdownLatch.countDown()
}
} else {
require(pollStrategy == config.pollIdleStrategy) {
// we don't want to use .equals, because that also compares STATE, which for us is going to be different because we are cloned!
// toString has the right info to compare types/config accurately
require(pollStrategy.toString() == config.pollIdleStrategy.toString()) {
"The network event poll strategy is different between the multiple instances of network clients/servers. There **WILL BE** thread starvation, so this behavior is forbidden!"
}
}
@ -161,21 +152,25 @@ internal class EventPoller {
/**
* Will cause the executing thread to wait until the event has been started
*/
suspend fun submit(action: suspend () -> Int, onShutdown: suspend () -> Unit) = mutex.withLock {
fun submit(action: EventActionOperator, onClose: EventCloseOperator) = lock.write {
submitEvents.getAndIncrement()
// this forces the current thread to WAIT until the network poll system has started
val pollStartupLatch = CountDownLatch(1)
pollEvents.add(EventAction(action, onShutdown))
pollEvents.add(EventAction(action, onClose))
pollEvents.add(EventAction(
{
pollStartupLatch.countDown()
object : EventActionOperator {
override fun invoke(): Int {
pollStartupLatch.countDown()
// remove ourselves
REMOVE
},
{}
// remove ourselves
return REMOVE
}
}
, object : EventCloseOperator {
override fun invoke() {}
}
))
pollStartupLatch.await()
@ -188,9 +183,17 @@ internal class EventPoller {
/**
* Waits for all events to finish running
*/
suspend fun close(logger: KLogger, endPoint: EndPoint<*>) {
mutex.withLock {
logger.debug { "Requesting close for the Network Event Poller..." }
fun close(logger: Logger, endPoint: EndPoint<*>) {
// make sure that we close on the CLOSE dispatcher if we run on the poll dispatcher!
if (isDispatch()) {
endPoint.eventDispatch.CLOSE.launch {
close(logger, endPoint)
}
return
}
lock.write {
logger.debug("Requesting close for the Network Event Poller...")
// ONLY if there are no more poll-events do we ACTUALLY shut down.
// when an endpoint closes its polling, it will automatically be removed from this datastructure.
@ -205,33 +208,34 @@ internal class EventPoller {
if (running && sEvents == 0 && cEvents == 0) {
when (pEvents) {
0 -> {
logger.debug { "Closing the Network Event Poller..." }
doClose()
}
1 -> {
// this means we are trying to close on our poll event, and obviously it won't work.
logger.debug { "Delayed closing the Network Event Poller..." }
delayClose = true
logger.debug("Closing the Network Event Poller...")
doClose(logger)
}
else -> {
logger.debug { "Not closing the Network Event Poller... (isRunning=$running submitEvents=$sEvents configureEvents=${cEvents} pollEvents=$pEvents)" }
if (logger.isDebugEnabled) {
logger.debug("Not closing the Network Event Poller... (isRunning=$running submitEvents=$sEvents configureEvents=${cEvents} pollEvents=$pEvents)")
}
}
}
} else {
logger.debug { "Not closing the Network Event Poller... (isRunning=$running submitEvents=$sEvents configureEvents=${cEvents} pollEvents=$pEvents)" }
} else if (logger.isDebugEnabled) {
logger.debug("Not closing the Network Event Poller... (isRunning=$running submitEvents=$sEvents configureEvents=${cEvents} pollEvents=$pEvents)")
}
}
}
private suspend fun doClose() {
private fun doClose(logger: Logger) {
val wasRunning = running
running = false
shutdownLatch.await()
while (!shutdownLatch.await(500, TimeUnit.MILLISECONDS)) {
logger.error("Waiting for Network Event Poller to close. It should not take this long")
}
configured = false
if (wasRunning) {
dispatchScope.cancel("Closed event dispatch")
pollExecutor.awaitTermination(200, TimeUnit.MILLISECONDS)
}
logger.debug("Closed Network Event Poller: wasRunning=$wasRunning")
}
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.aeron;

View File

@ -16,32 +16,32 @@
package dorkbox.network.connection
import dorkbox.network.Client
import dorkbox.network.Server
import dorkbox.network.aeron.AeronDriver.Companion.sessionIdAllocator
import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.SerializationException
import dorkbox.network.exceptions.ServerException
import dorkbox.network.exceptions.TransmitException
import dorkbox.network.connection.buffer.BufferedMessages
import dorkbox.network.connection.buffer.BufferedSession
import dorkbox.network.ping.Ping
import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.rmi.messages.MethodResponse
import dorkbox.network.serialization.KryoExtra
import io.aeron.FragmentAssembler
import io.aeron.Image
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header
import io.aeron.protocol.DataHeaderFlyweight
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.coroutines.delay
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.agrona.DirectBuffer
import org.agrona.concurrent.IdleStrategy
import java.util.concurrent.*
import javax.crypto.SecretKey
/**
* This connection is established once the registration information is validated, and the various connect/filter checks have passed
* This connection is established once the registration information is validated, and the various connect/filter checks have passed.
*
* Connections are also BUFFERED, meaning that if the connection between a client-server goes down because of a network glitch, then the
* data being sent is not lost (it is buffered) and then re-sent once a new connection has the same UUID within the timout period.
*
* References to the old connection will also redirect to the new connection.
*/
open class Connection(connectionParameters: ConnectionParams<*>) {
private var messageHandler: FragmentAssembler
private val messageHandler: FragmentHandler
/**
* The specific connection details for this connection!
@ -58,19 +58,18 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
internal val subscription = info.sub
internal val publication = info.pub
private lateinit var image: Image
/**
* When publishing data, we cannot have concurrent publications for a single connection (per Aeron publication)
*/
private val writeMutex = Mutex()
// only accessed on a single thread!
private val connectionExpirationTimoutNanos = endPoint.config.connectionExpirationTimoutNanos
// the timeout starts from when the connection is first created, so that we don't get "instant" timeouts when the server rejects a connection
private var connectionTimeoutTimeNanos = System.nanoTime()
/**
* There can be concurrent writes to the network stack, at most 1 per connection. Each connection has its own logic on the remote endpoint,
* and can have its own back-pressure.
*/
private val sendIdleStrategy: IdleStrategy
private val writeKryo: KryoExtra<Connection>
internal val sendIdleStrategy = endPoint.config.sendIdleStrategy
/**
* This is the client UUID. This is useful determine if the same client is connecting multiple times to a server (instead of only using IP address)
@ -88,6 +87,11 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
info.sessionIdSub
}
/**
* The tag name for a connection permits an INCOMING client to define a custom string. The max length is 32
*/
val tag = info.tagName
/**
* The remote address, as a string. Will be null for IPC connections
*/
@ -113,7 +117,27 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
*/
val isNetwork = !isIpc
/**
* used when the connection is buffered
*/
private val bufferedSession: BufferedSession
/**
* used to determine if this connection will have buffered messages enabled or not.
*/
internal val enableBufferedMessages = connectionParameters.enableBufferedMessages
/**
* The largest size a SINGLE message via AERON can be. Because the maximum size we can send in a "single fragment" is the
* publication.maxPayloadLength() function (which is the MTU length less header). We could depend on Aeron for fragment reassembly,
* but that has a (very low) maximum reassembly size -- so we have our own mechanism for object fragmentation/assembly, which
* is (in reality) only limited by available ram.
*/
internal val maxMessageSize = if (isNetwork) {
endPoint.config.networkMtuSize - DataHeaderFlyweight.HEADER_LENGTH
} else {
endPoint.config.ipcMtuSize - DataHeaderFlyweight.HEADER_LENGTH
}
private val listenerManager = atomic<ListenerManager<Connection>?>(null)
@ -121,40 +145,34 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
private val isClosed = atomic(false)
// only accessed on a single thread!
private var connectionLastCheckTimeNanos = 0L
private var connectionTimeoutTimeNanos = 0L
// always offset by the linger amount, since we cannot act faster than the linger timeout for adding/removing publications
private val connectionCheckIntervalNanos = endPoint.config.connectionCheckIntervalNanos + endPoint.aeronDriver.lingerNs()
private val connectionExpirationTimoutNanos = endPoint.config.connectionExpirationTimoutNanos + endPoint.aeronDriver.lingerNs()
// while on the CLIENT, if the SERVER's ecc key has changed, the client will abort and show an error.
private val remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED
// The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 8 (external counter) + 4 (GCM counter)
// The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this
// counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small)
// private val aes_gcm_iv = atomic(0)
internal val remoteKeyChanged = connectionParameters.publicKeyValidation == PublicKeyValidationState.TAMPERED
/**
* Methods supporting Remote Method Invocation and Objects
*/
val rmi: RmiSupportConnection<out Connection>
// we customize the toString() value for this connection, and it's just better to cache it's value (since it's a modestly complex string)
// we customize the toString() value for this connection, and it's just better to cache its value (since it's a modestly complex string)
private val toString0: String
/**
* @return the AES key
*/
internal val cryptoKey: SecretKey = connectionParameters.cryptoKey
// The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 4 (external counter) + 4 (GCM counter)
// The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this
// counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small)
internal val aes_gcm_iv = atomic(0)
// Used to track that this connection WILL be closed, but has not yet been closed.
@Volatile
internal var closeRequested = false
init {
@Suppress("UNCHECKED_CAST")
writeKryo = endPoint.serialization.initKryo() as KryoExtra<Connection>
sendIdleStrategy = endPoint.config.sendIdleStrategy.cloneToNormal()
// NOTE: subscriptions (ie: reading from buffers, etc) are not thread safe! Because it is ambiguous HOW EXACTLY they are unsafe,
// we exclusively read from the DirectBuffer on a single thread.
@ -162,12 +180,17 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
// publication of any state to other threads and not be:
// - long running
// - re-entrant with the client
messageHandler = FragmentAssembler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// this is processed on the thread that calls "poll". Subscriptions are NOT multi-thread safe!
messageHandler = FragmentHandler { buffer: DirectBuffer, offset: Int, length: Int, header: Header ->
// Subscriptions are NOT multi-thread safe, so only processed on the thread that calls .poll()!
endPoint.dataReceive(buffer, offset, length, header, this@Connection)
}
bufferedSession = when (endPoint) {
is Server -> endPoint.bufferedManager.onConnect(this)
is Client -> endPoint.bufferedManager!!.onConnect(this)
else -> throw RuntimeException("Unable to determine type, aborting!")
}
@Suppress("LeakingThis")
rmi = endPoint.rmiConnectionSupport.getNewRmiSupport(this)
@ -176,99 +199,88 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
}
/**
* @return true if the remote public key changed. This can be useful if specific actions are necessary when the key has changed.
* When this is called, we should always have a subscription image!
*/
fun hasRemoteKeyChanged(): Boolean {
return remoteKeyChanged
internal fun setImage() {
var triggered = false
while (subscription.hasNoImages()) {
triggered = true
Thread.sleep(50)
}
if (triggered) {
logger.error("Delay while configuring subscription!")
}
image = subscription.imageAtIndex(0)
}
// /**
// * This is the per-message sequence number.
// *
// * The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 4 (external counter) + 4 (GCM counter)
// * The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this
// * counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small)
// */
// fun nextGcmSequence(): Long {
// return aes_gcm_iv.getAndIncrement()
// }
//
// /**
// * @return the AES key. key=32 byte, iv=12 bytes (AES-GCM implementation).
// */
// fun cryptoKey(): SecretKey {
//// return channelWrapper.cryptoKey()
// }
/**
* Polls the AERON media driver subscription channel for incoming messages
*/
internal fun poll(): Int {
// NOTE: regarding fragment limit size. Repeated calls to '.poll' will reassemble a fragment.
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
return subscription.poll(messageHandler, 1)
return image.poll(messageHandler, 1)
}
/**
* Safely sends objects to a destination, if `abortEarly` is true, there are no retries if sending the message fails.
*
* NOTE: this is dispatched to the IO context!! (since network calls are IO/blocking calls)
*
* @return true if the message was successfully sent, false otherwise. Exceptions are caught and NOT rethrown!
*/
internal suspend fun send(message: Any, abortEarly: Boolean): Boolean {
// we use a mutex because we do NOT want different threads/coroutines to be able to send data over the SAME connections at the SAME time.
// NOTE: additionally we want to propagate back-pressure to the calling coroutines, PER CONNECTION!
val success = writeMutex.withLock {
// we reset the sending timeout strategy when a message was successfully sent.
sendIdleStrategy.reset()
try {
// The handshake sessionId IS NOT globally unique
logger.trace { "[$toString0] send: ${message.javaClass.simpleName} : $message" }
val write = endPoint.write(writeKryo, message, publication, sendIdleStrategy, this@Connection, abortEarly)
write
} catch (e: Throwable) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager(logger)
}
val listenerManager = listenerManager.value!!
if (message is MethodResponse && message.result is Exception) {
val result = message.result as Exception
val newException = SerializationException("Error serializing message ${message.javaClass.simpleName}: '$message'", result)
listenerManager.notifyError(this@Connection, newException)
} else if (message is ClientException || message is ServerException) {
val newException = TransmitException("Error with message ${message.javaClass.simpleName}: '$message'", e)
listenerManager.notifyError(this@Connection, newException)
} else {
val newException = TransmitException("Error sending message ${message.javaClass.simpleName}: '$message'", e)
listenerManager.notifyError(this@Connection, newException)
}
false
internal fun send(message: Any, abortEarly: Boolean): Boolean {
if (logger.isTraceEnabled) {
// The handshake sessionId IS NOT globally unique
// don't automatically create the lambda when trace is disabled! Because this uses 'outside' scoped info, it's a new lambda each time!
if (logger.isTraceEnabled) {
logger.trace("[$toString0] send: ${message.javaClass.simpleName} : $message")
}
}
val success = endPoint.write(message, publication, sendIdleStrategy, this@Connection, maxMessageSize, abortEarly)
return success
return if (!success && message !is DisconnectMessage) {
// queue up the messages, because we couldn't write them for whatever reason!
// NEVER QUEUE THE DISCONNECT MESSAGE!
bufferedSession.queueMessage(this@Connection, message, abortEarly)
} else {
success
}
}
private fun sendNoBuffer(message: Any): Boolean {
if (logger.isTraceEnabled) {
// The handshake sessionId IS NOT globally unique
// don't automatically create the lambda when trace is disabled! Because this uses 'outside' scoped info, it's a new lambda each time!
if (logger.isTraceEnabled) {
logger.trace("[$toString0] send: ${message.javaClass.simpleName} : $message")
}
}
return endPoint.write(message, publication, sendIdleStrategy, this@Connection, maxMessageSize, false)
}
/**
* Safely sends objects to a destination.
*
* NOTE: this is dispatched to the IO context!! (since network calls are IO/blocking calls)
* @return true if the message was successfully sent, false otherwise. Exceptions are caught and NOT rethrown!
*/
fun send(message: Any): Boolean {
return send(message, false)
}
/**
* Safely sends objects to a destination, where the callback is notified once the remote endpoint has received the message.
*
* This is to guarantee happens-before, and using this will depend upon APP+NETWORK latency, and is (by design) not as performant as
* sending a regular message!
*
* @return true if the message was successfully sent, false otherwise. Exceptions are caught and NOT rethrown!
*/
suspend fun send(message: Any): Boolean {
return send(message, false)
fun send(message: Any, onSuccessCallback: Connection.() -> Unit): Boolean {
return sendSync(message, onSuccessCallback)
}
/**
@ -276,8 +288,19 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
*
* @return true if the message was successfully sent by aeron
*/
suspend fun ping(pingTimeoutSeconds: Int = endPoint.config.pingTimeoutSeconds, function: suspend Ping.() -> Unit = {}): Boolean {
return endPoint.ping(this, pingTimeoutSeconds, function)
fun ping(function: Ping.() -> Unit = {}): Boolean {
return sendPing(function)
}
/**
* This is the per-message sequence number.
*
* The IV for AES-GCM must be 12 bytes, since it's 4 (salt) + 4 (external counter) + 4 (GCM counter)
* The 12 bytes IV is created during connection registration, and during the AES-GCM crypto, we override the last 8 with this
* counter, which is also transmitted as an optimized int. (which is why it starts at 0, so the transmitted bytes are small)
*/
internal fun nextGcmSequence(): Int {
return aes_gcm_iv.getAndIncrement()
}
/**
@ -290,10 +313,10 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
* (via connection.addListener), meaning that ONLY that listener attached to
* the connection is notified on that event (ie, admin type listeners)
*/
suspend fun onDisconnect(function: suspend Connection.() -> Unit) {
fun onDisconnect(function: Connection.() -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager(logger)
origManager ?: ListenerManager(logger, endPoint.eventDispatch)
}
listenerManager.value!!.onDisconnect(function)
@ -302,10 +325,10 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
/**
* Adds a function that will be called only for this connection, when a client/server receives a message
*/
suspend fun <MESSAGE> onMessage(function: suspend Connection.(MESSAGE) -> Unit) {
fun <MESSAGE> onMessage(function: Connection.(MESSAGE) -> Unit) {
// make sure we atomically create the listener manager, if necessary
listenerManager.getAndUpdate { origManager ->
origManager ?: ListenerManager(logger)
origManager ?: ListenerManager(logger, endPoint.eventDispatch)
}
listenerManager.value!!.onMessage(function)
@ -316,54 +339,24 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
*
* This is ALWAYS called on a new dispatch
*/
internal suspend fun notifyOnMessage(message: Any): Boolean {
internal fun notifyOnMessage(message: Any): Boolean {
return listenerManager.value?.notifyOnMessage(this, message) ?: false
}
/**
* We must account for network blips. The blips will be recovered by aeron, but we want to make sure that we are actually
* disconnected for a set period of time before we start the close process for a connection
*
* @return `true` if this connection has been closed via aeron
*/
fun isClosedViaAeron(): Boolean {
if (isClosed.value) {
// if we are manually closed, then don't check aeron timeouts!
return true
internal fun sendBufferedMessages() {
if (enableBufferedMessages) {
val bufferedMessage = BufferedMessages()
val numberDrained = bufferedSession.pendingMessagesQueue.drainTo(bufferedMessage.messages)
if (numberDrained > 0) {
// now send all buffered/pending messages
if (logger.isDebugEnabled) {
logger.debug("Sending buffered messages: ${bufferedSession.pendingMessagesQueue.size}")
}
sendNoBuffer(bufferedMessage)
}
}
// we ONLY want to actually, legit check, 1 time every XXX ms.
val now = System.nanoTime()
if (now - connectionLastCheckTimeNanos < connectionCheckIntervalNanos) {
// we haven't waited long enough for another check. always return false (true means we are closed)
return false
}
connectionLastCheckTimeNanos = now
// as long as we are connected, we reset the state, so that if there is a network blip, we want to make sure that it is
// a network blip for a while, instead of just once or twice. (which can happen)
if (subscription.isConnected && publication.isConnected) {
// reset connection timeout
connectionTimeoutTimeNanos = 0L
// we are still connected (true means we are closed)
return false
}
//
// aeron is not connected
//
if (connectionTimeoutTimeNanos == 0L) {
connectionTimeoutTimeNanos = now
}
// make sure that our "isConnected" state lasts LONGER than the expiry timeout!
// 1) connections take a little bit of time from polling -> connecting (because of how we poll connections before 'connecting' them).
// 2) network blips happen. Aeron will recover, and we want to make sure that WE don't instantly DC
return now - connectionTimeoutTimeNanos >= connectionExpirationTimoutNanos
}
/**
@ -373,66 +366,171 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
return isClosed.value
}
/**
* Is this a "dirty" disconnect, meaning that it has timed out, but not been explicitly closed
*/
internal fun isDirtyClose(): Boolean {
return !closeRequested && !isClosed() && isClosedWithTimeout()
}
/**
* Is this connection considered still safe for polling (or rather, has it been closed in an unusual way?)
*/
internal fun canPoll(): Boolean {
return !closeRequested && !isClosed() && !isClosedWithTimeout()
}
/**
* We must account for network blips. The blips will be recovered by aeron, but we want to make sure that we are actually
* disconnected for a set period of time before we start the close process for a connection
*
* @return `true` if this connection has been closed via aeron
*/
internal fun isClosedWithTimeout(): Boolean {
// we ONLY want to actually, legit check, 1 time every XXX ms.
val now = System.nanoTime()
// as long as we are connected, we reset the state, so that if there is a network blip, we want to make sure that it is
// a network blip for a while, instead of just once or twice. (which WILL happen)
if (subscription.isConnected && publication.isConnected) {
// reset connection timeout
connectionTimeoutTimeNanos = now
// we are still connected (true means we are closed)
return false
}
// make sure that our "isConnected" state lasts LONGER than the expiry timeout!
// 1) connections take a little bit of time from polling -> connecting (because of how we poll connections before 'connecting' them).
// 2) network blips happen. Aeron will recover, and we want to make sure that WE don't instantly DC
return now - connectionTimeoutTimeNanos >= connectionExpirationTimoutNanos
}
/**
* Closes the connection, and removes all connection specific listeners
*/
suspend fun close() {
fun close() {
close(sendDisconnectMessage = true,
closeEverything = true)
}
/**
* Closes the connection, and removes all connection specific listeners
*/
internal fun close(sendDisconnectMessage: Boolean, closeEverything: Boolean) {
// there are 2 ways to call close.
// MANUALLY
// When a connection is disconnected via a timeout/expire.
// the compareAndSet is used to make sure that if we call close() MANUALLY, (and later) when the auto-cleanup/disconnect is called -- it doesn't
// try to do it again.
closeRequested = true
// make sure that EVERYTHING before "close()" runs before we do
EventDispatcher.launchSequentially(EventDispatcher.CLOSE) {
closeImmediately()
// make sure that EVERYTHING before "close()" runs before we do.
// If there are multiple clients/servers sharing the same NetworkPoller -- then they will wait on each other!
val close = endPoint.eventDispatch.CLOSE
if (!close.isDispatch()) {
close.launch {
close(sendDisconnectMessage = sendDisconnectMessage, closeEverything = closeEverything)
}
return
}
closeImmediately(sendDisconnectMessage = sendDisconnectMessage, closeEverything = closeEverything)
}
// connection.close() -> this
// endpoint.close() -> connection.close() -> this
internal suspend fun closeImmediately() {
internal fun closeImmediately(sendDisconnectMessage: Boolean, closeEverything: Boolean) {
// the server 'handshake' connection info is cleaned up with the disconnect via timeout/expire.
if (!isClosed.compareAndSet(expect = false, update = true)) {
logger.debug("[$toString0] connection ignoring close request.")
return
}
logger.debug {"[$toString0] connection closing"}
if (logger.isDebugEnabled) {
logger.debug("[$toString0] connection closing. sendDisconnectMessage=$sendDisconnectMessage, closeEverything=$closeEverything")
}
// make sure to save off the RMI objects for session management
if (!closeEverything) {
when (endPoint) {
is Server -> endPoint.bufferedManager.onDisconnect(this)
is Client -> endPoint.bufferedManager!!.onDisconnect(this)
else -> throw RuntimeException("Unable to determine type, aborting!")
}
}
if (!closeEverything) {
when (endPoint) {
is Server -> endPoint.bufferedManager.onDisconnect(this)
is Client -> endPoint.bufferedManager!!.onDisconnect(this)
else -> throw RuntimeException("Unable to determine type, aborting!")
}
}
// on close, we want to make sure this file is DELETED!
endPoint.aeronDriver.close(subscription, toString0)
try {
// we might not be able to close this connection!!
endPoint.aeronDriver.close(subscription, toString0)
}
catch (e: Exception) {
endPoint.listenerManager.notifyError(e)
}
// notify the remote endPoint that we are closing
// we send this AFTER we close our subscription (so that no more messages will be received, when the remote end ping-pong's this message back)
if (publication.isConnected) {
// sometimes the remote end has already disconnected, THERE WILL BE ERRORS if this happens (but they are ok)
send(DisconnectMessage.INSTANCE, true)
}
if (sendDisconnectMessage) {
if (publication.isConnected) {
if (logger.isDebugEnabled) {
logger.debug("Sending disconnect message to ${endPoint.otherTypeName}")
}
val timeoutInNanos = TimeUnit.SECONDS.toNanos(endPoint.config.connectionCloseTimeoutInSeconds.toLong())
val closeTimeoutTime = System.nanoTime()
// sometimes the remote end has already disconnected, THERE WILL BE ERRORS if this happens (but they are ok)
if (closeEverything) {
send(DisconnectMessage.CLOSE_EVERYTHING, true)
} else {
send(DisconnectMessage.CLOSE_SIMPLE, true)
}
// we do not want to close until AFTER all publications have been sent. Calling this WITHOUT waiting will instantly stop everything
// we want a timeout-check, otherwise this will run forever
while (writeMutex.isLocked && System.nanoTime() - closeTimeoutTime < timeoutInNanos) {
delay(50)
// wait for .5 seconds to (help) make sure that the messages are sent before shutdown! This is not guaranteed!
if (logger.isDebugEnabled) {
logger.debug("Waiting for disconnect message to send")
}
Thread.sleep(500L)
} else {
if (logger.isDebugEnabled) {
logger.debug("Publication is not connected with ${endPoint.otherTypeName}, not sending disconnect message.")
}
}
}
// on close, we want to make sure this file is DELETED!
endPoint.aeronDriver.close(publication, toString0)
try {
// we might not be able to close this connection.
endPoint.aeronDriver.close(publication, toString0)
}
catch (e: Exception) {
endPoint.listenerManager.notifyError(e)
}
// NOTE: any waiting RMI messages that are in-flight will terminate when they time-out (and then do nothing)
// NOTE: notifyDisconnect() is called inside closeAction()!!
// if there are errors within the driver, we do not want to notify disconnect, as we will automatically reconnect.
endPoint.listenerManager.notifyDisconnect(this)
endPoint.removeConnection(this)
endPoint.listenerManager.notifyDisconnect(this)
val connection = this
endPoint.isServer {
if (endPoint.isServer()) {
// clean up the resources associated with this connection when it's closed
logger.debug { "[${connection}] freeing resources" }
if (logger.isDebugEnabled) {
logger.debug("[${connection}] freeing resources")
}
sessionIdAllocator.free(info.sessionIdPub)
sessionIdAllocator.free(info.sessionIdSub)
@ -441,11 +539,13 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
if (remoteAddress != null) {
// unique for UDP endpoints
handshake.connectionsPerIpCounts.decrementSlow(remoteAddress)
(endPoint as Server).handshake.connectionsPerIpCounts.decrementSlow(remoteAddress)
}
}
logger.debug {"[$toString0] connection closed"}
if (logger.isDebugEnabled) {
logger.debug("[$toString0] connection closed")
}
}
@ -478,4 +578,94 @@ open class Connection(connectionParameters: ConnectionParams<*>) {
val other1 = other as Connection
return id == other1.id
}
internal fun receiveSendSync(sendSync: SendSync) {
if (sendSync.message != null) {
// this is on the "remote end".
sendSync.message = null
if (!send(sendSync)) {
logger.error("Error returning send-sync: $sendSync")
}
} else {
// this is on the "local end" when the response comes back
val responseId = sendSync.id
// process the ping message so that our ping callback does something
// this will be null if the ping took longer than XXX seconds and was cancelled
val result = EndPoint.responseManager.removeWaiterCallback<Connection.() -> Unit>(responseId, logger)
if (result != null) {
result(this)
} else {
logger.error("Unable to receive send-sync, there was no waiting response for $sendSync ($responseId)")
}
}
}
/**
* Safely sends objects to a destination, the callback is notified once the remote endpoint has received the message.
*
* This is to guarantee happens-before, and using this will depend upon APP+NETWORK latency, and is (by design) not as performant as
* sending a regular message!
*
* @return true if the message was successfully sent, false otherwise. Exceptions are caught and NOT rethrown!
*/
private fun sendSync(message: Any, onSuccessCallback: Connection.() -> Unit): Boolean {
val id = EndPoint.responseManager.prepWithCallback(logger, onSuccessCallback)
val sendSync = SendSync()
sendSync.message = message
sendSync.id = id
// if there is no sync response EVER, it means that the connection is in a critically BAD state!
// eventually, all the ping/sync replies (or, in our case, the replies that have timed out) will
// become recycled.
// Is it a memory-leak? No, because the memory will **EVENTUALLY** get freed.
return send(sendSync, false)
}
internal fun receivePing(ping: Ping) {
if (ping.pongTime == 0L) {
// this is on the "remote end".
ping.pongTime = System.currentTimeMillis()
if (!send(ping)) {
logger.error("Error returning ping: $ping")
}
} else {
// this is on the "local end" when the response comes back
ping.finishedTime = System.currentTimeMillis()
val responseId = ping.packedId
// process the ping message so that our ping callback does something
// this will be null if the ping took longer than XXX seconds and was cancelled
val result = EndPoint.responseManager.removeWaiterCallback<Ping.() -> Unit>(responseId, logger)
if (result != null) {
result(ping)
} else {
logger.error("Unable to receive ping, there was no waiting response for $ping ($responseId)")
}
}
}
private fun sendPing(function: Ping.() -> Unit): Boolean {
val id = EndPoint.responseManager.prepWithCallback(logger, function)
val ping = Ping()
ping.packedId = id
ping.pingTime = System.currentTimeMillis()
// if there is no ping response EVER, it means that the connection is in a critically BAD state!
// eventually, all the ping replies (or, in our case, the RMI replies that have timed out) will
// become recycled.
// Is it a memory-leak? No, because the memory will **EVENTUALLY** get freed.
return send(ping)
}
}

View File

@ -16,10 +16,13 @@
package dorkbox.network.connection
import dorkbox.network.handshake.PubSub
import javax.crypto.spec.SecretKeySpec
data class ConnectionParams<CONNECTION : Connection>(
val publicKey: ByteArray,
val endPoint: EndPoint<CONNECTION>,
val connectionInfo: PubSub,
val publicKeyValidation: PublicKeyValidationState
val publicKeyValidation: PublicKeyValidationState,
val enableBufferedMessages: Boolean,
val cryptoKey: SecretKeySpec
)

View File

@ -16,13 +16,13 @@
package dorkbox.network.connection
import dorkbox.bytes.Hash
import dorkbox.bytes.toHexString
import dorkbox.hex.toHexString
import dorkbox.network.handshake.ClientConnectionInfo
import dorkbox.network.serialization.AeronInput
import dorkbox.network.serialization.AeronOutput
import dorkbox.network.serialization.SettingsStore
import dorkbox.util.entropy.Entropy
import mu.KLogger
import org.slf4j.Logger
import java.math.BigInteger
import java.net.InetAddress
import java.security.KeyFactory
@ -42,30 +42,36 @@ import javax.crypto.spec.SecretKeySpec
/**
* Management for all the crypto stuff used
*/
internal class CryptoManagement(val logger: KLogger,
internal class CryptoManagement(val logger: Logger,
private val settingsStore: SettingsStore,
type: Class<*>,
private val enableRemoteSignatureValidation: Boolean) {
companion object {
private val X25519 = "X25519"
const val curve25519 = "curve25519"
const val GCM_IV_LENGTH_BYTES = 12 // 12 bytes for a 96-bit IV
const val GCM_TAG_LENGTH_BITS = 128
const val AES_ALGORITHM = "AES/GCM/NoPadding"
val NOCRYPT = SecretKeySpec(ByteArray(1), "NOCRYPT")
val secureRandom = SecureRandom()
}
private val X25519 = "X25519"
private val X25519KeySpec = NamedParameterSpec(X25519)
private val keyFactory = KeyFactory.getInstance(X25519) // key size is 32 bytes (256 bits)
private val keyAgreement = KeyAgreement.getInstance("XDH")
private val aesCipher = Cipher.getInstance("AES/GCM/NoPadding")
private val aesCipher = Cipher.getInstance(AES_ALGORITHM)
companion object {
const val curve25519 = "curve25519"
const val GCM_IV_LENGTH_BYTES = 12
const val GCM_TAG_LENGTH_BITS = 128
val secureRandom = SecureRandom()
}
val privateKey: XECPrivateKey
val publicKey: XECPublicKey
// These are both 32 bytes long (256 bits)
val privateKeyBytes: ByteArray
val publicKeyBytes: ByteArray
@ -117,7 +123,7 @@ internal class CryptoManagement(val logger: KLogger,
this.publicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(publicKeyBytes))) as XECPublicKey
this.privateKey = keyFactory.generatePrivate(XECPrivateKeySpec(X25519KeySpec, privateKeyBytes)) as XECPrivateKey
this.privateKeyBytes = privateKeyBytes!!
this.privateKeyBytes = privateKeyBytes
this.publicKeyBytes = publicKeyBytes
}
@ -170,12 +176,14 @@ internal class CryptoManagement(val logger: KLogger,
return PublicKeyValidationState.VALID
}
private fun makeInfo(serverPublicKeyBytes: ByteArray): ClientConnectionInfo {
private fun makeInfo(serverPublicKeyBytes: ByteArray, secretKey: SecretKeySpec): ClientConnectionInfo {
val sessionIdPub = cryptInput.readInt()
val sessionIdSub = cryptInput.readInt()
val streamIdPub = cryptInput.readInt()
val streamIdSub = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val sessionTimeout = cryptInput.readLong()
val bufferedMessages = cryptInput.readBoolean()
val regDetails = cryptInput.readBytes(regDetailsSize)
// now save data off
@ -185,15 +193,22 @@ internal class CryptoManagement(val logger: KLogger,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
publicKey = serverPublicKeyBytes,
kryoRegistrationDetails = regDetails)
sessionTimeout = sessionTimeout,
bufferedMessages = bufferedMessages,
kryoRegistrationDetails = regDetails,
secretKey = secretKey)
}
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the server, mutually exclusive calls to decrypt)
fun nocrypt(sessionIdPub: Int,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
kryoRegDetails: ByteArray): ByteArray {
fun nocrypt(
sessionIdPub: Int,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
sessionTimeout: Long,
bufferedMessages: Boolean,
kryoRegDetails: ByteArray
): ByteArray {
return try {
// now create the byte array that holds all our data
@ -203,6 +218,8 @@ internal class CryptoManagement(val logger: KLogger,
cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBoolean(bufferedMessages)
cryptOutput.writeBytes(kryoRegDetails)
cryptOutput.toBytes()
@ -219,7 +236,7 @@ internal class CryptoManagement(val logger: KLogger,
// this message is NOT-ENCRYPTED!
cryptInput.buffer = registrationData
makeInfo(serverPublicKeyBytes)
makeInfo(serverPublicKeyBytes, NOCRYPT)
} catch (e: Exception) {
logger.error("Error during IPC decrypt!", e)
null
@ -229,7 +246,7 @@ internal class CryptoManagement(val logger: KLogger,
/**
* Generate the AES key based on ECDH
*/
private fun generateAesKey(remotePublicKeyBytes: ByteArray, bytesA: ByteArray, bytesB: ByteArray): SecretKeySpec {
internal fun generateAesKey(remotePublicKeyBytes: ByteArray, bytesA: ByteArray, bytesB: ByteArray): SecretKeySpec {
val clientPublicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(remotePublicKeyBytes)))
keyAgreement.init(privateKey)
keyAgreement.doPhase(clientPublicKey, true)
@ -246,19 +263,22 @@ internal class CryptoManagement(val logger: KLogger,
}
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the server, mutually exclusive calls to decrypt)
fun encrypt(clientPublicKeyBytes: ByteArray,
sessionIdPub: Int,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
kryoRegDetails: ByteArray): ByteArray {
fun encrypt(
cryptoSecretKey: SecretKeySpec,
sessionIdPub: Int,
sessionIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
sessionTimeout: Long,
bufferedMessages: Boolean,
kryoRegDetails: ByteArray
): ByteArray {
try {
val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes)
secureRandom.nextBytes(iv)
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, iv)
aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec)
aesCipher.init(Cipher.ENCRYPT_MODE, cryptoSecretKey, gcmParameterSpec)
// now create the byte array that holds all our data
cryptOutput.reset()
@ -267,6 +287,8 @@ internal class CryptoManagement(val logger: KLogger,
cryptOutput.writeInt(streamIdPub)
cryptOutput.writeInt(streamIdSub)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeLong(sessionTimeout)
cryptOutput.writeBoolean(bufferedMessages)
cryptOutput.writeBytes(kryoRegDetails)
return iv + aesCipher.doFinal(cryptOutput.toBytes())
@ -278,7 +300,7 @@ internal class CryptoManagement(val logger: KLogger,
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt)
fun decrypt(registrationData: ByteArray, serverPublicKeyBytes: ByteArray): ClientConnectionInfo? {
try {
return try {
val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes)
// now decrypt the data
@ -287,11 +309,11 @@ internal class CryptoManagement(val logger: KLogger,
cryptInput.buffer = aesCipher.doFinal(registrationData, GCM_IV_LENGTH_BYTES, registrationData.size - GCM_IV_LENGTH_BYTES)
return makeInfo(serverPublicKeyBytes)
makeInfo(serverPublicKeyBytes, secretKeySpec)
} catch (e: Exception) {
logger.error("Error during AES decrypt!", e)
return null
null
}
}

View File

@ -16,8 +16,9 @@
package dorkbox.network.connection
class DisconnectMessage {
class DisconnectMessage(val closeEverything: Boolean) {
companion object {
val INSTANCE = DisconnectMessage()
val CLOSE_SIMPLE = DisconnectMessage(false)
val CLOSE_EVERYTHING = DisconnectMessage(true)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -19,158 +19,161 @@ package dorkbox.network.connection
import dorkbox.network.Configuration
import dorkbox.util.NamedThreadFactory
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.launch
import mu.KotlinLogging
import org.slf4j.LoggerFactory
import java.util.concurrent.*
/**
* This MUST be run on multiple coroutines! There are deadlock issues if it is only one.
*
* This class LITERALLY forces a coroutine dispatcher to be exclusively on a single thread.
* Event logic throughout the network MUST be run on multiple threads! There are deadlock issues if it is only one, or if the client + server
* share an event dispatcher (multiple network restarts were required to check this)
*
* WARNING: The logic in this class will ONLY work in this class, as it relies on this specific behavior. Do not use it elsewhere!
*/
enum class EventDispatcher {
// NOTE: CLOSE must be last!
HANDSHAKE, CONNECT, DISCONNECT, RESPONSE_MANAGER, ERROR, CLOSE;
internal class EventDispatcher(val type: String) {
enum class EDType {
// CLOSE must be last!
HANDSHAKE, CONNECT, ERROR, CLOSE
}
internal class ED(private val dispatcher: EventDispatcher, private val type: EDType) {
fun launch(function: () -> Unit) {
dispatcher.launch(type, function)
}
fun isDispatch(): Boolean {
return dispatcher.isDispatch(type)
}
fun shutdownAndWait(timeout: Long, timeoutUnit: TimeUnit) {
dispatcher.shutdownAndWait(type, timeout, timeoutUnit)
}
}
companion object {
private val DEBUG_EVENTS = false
private val traceId = atomic(0)
private val logger = KotlinLogging.logger(EventDispatcher::class.java.simpleName)
private val threadIds = values().map { atomic(0L) }.toTypedArray()
private val executors = values().map { event ->
// It CANNOT be the default dispatch because there will be thread starvation
// NOTE: THIS CANNOT CHANGE!! IT WILL BREAK EVERYTHING IF IT CHANGES!
Executors.newSingleThreadExecutor(
NamedThreadFactory("Event Dispatcher-${event.name}",
Configuration.networkThreadGroup, Thread.NORM_PRIORITY, true) { thread ->
// when a new thread is created, assign it to the array
threadIds[event.ordinal].lazySet(thread.id)
}
)
}
private val eventData = executors.map { executor ->
CoroutineScope(executor.asCoroutineDispatcher() + SupervisorJob())
}
private val typedEntries: Array<EDType>
init {
executors.forEachIndexed { _, executor ->
executor.submit {
// this is to create a new thread only, so that the thread ID can be assigned
}
}
}
/**
* Checks if the current execution thread is running inside one of the event dispatchers listed.
*
* No values specified means we check ALL events
*/
fun isDispatch(): Boolean {
return isCurrentEvent(*values())
}
/**
* Checks if the current execution thread is running inside one of the event dispatchers listed.
*
* No values specified means we check ALL events
*/
fun isCurrentEvent(vararg events: EventDispatcher = values()): Boolean {
val threadId = Thread.currentThread().id
events.forEach { event ->
if (threadIds[event.ordinal].value == threadId) {
return true
}
}
return false
}
/**
* Checks if the current execution thread is NOT running inside one of the event dispatchers listed.
*
* No values specified means we check ALL events
*/
fun isNotCurrentEvent(vararg events: EventDispatcher = values()): Boolean {
val currentDispatch = getCurrentEvent() ?: return false
return events.contains(currentDispatch)
}
/**
* @return which event dispatch thread we are running in, if any
*/
fun getCurrentEvent(): EventDispatcher? {
val threadId = Thread.currentThread().id
values().forEach { event ->
if (threadIds[event.ordinal].value == threadId) {
return event
}
}
return null
}
/**
* Each event type runs inside its own coroutine dispatcher.
*
* We want EACH event type to run in its own dispatcher... on its OWN thread, in order to prevent deadlocks
* This is because there are blocking dependencies: DISCONNECT -> CONNECT.
*
* If an event is RE-ENTRANT, then it will immediately execute!
*/
private fun launch(event: EventDispatcher, function: suspend () -> Unit): Job {
val eventId = event.ordinal
return if (DEBUG_EVENTS) {
val id = traceId.getAndIncrement()
eventData[eventId].launch(block = {
logger.debug { "Starting $event : $id" }
function()
logger.debug { "Finished $event : $id" }
})
} else {
eventData[eventId].launch {
function()
}
}
}
suspend fun launchSequentially(endEvent: EventDispatcher, function: suspend () -> Unit) {
// If one of our callbacks requested a shutdown, we wait until all callbacks have run ... THEN shutdown
val event = getCurrentEvent()
val index = event?.ordinal ?: -1
// This will loop through until it runs on the CLOSE EventDispatcher
if (index < endEvent.ordinal) {
// If this runs inside EVENT.CONNECT/DISCONNECT/ETC, we must ***WAIT*** until all listeners have been called!
// this problem is solved by running AGAIN after we have finished running whatever event dispatcher we are currently on
// MORE SPECIFICALLY, we must run at the end of our current one, but repeatedly until CLOSE
EventDispatcher.launch(values()[index+1]) {
launchSequentially(endEvent, function)
}
} else {
function()
}
typedEntries = EDType.entries.toTypedArray()
}
}
fun launch(function: suspend () -> Unit): Job {
return launch(this, function)
private val logger = LoggerFactory.getLogger("$type Dispatch")
private val threadIds = EDType.entries.map { atomic(0L) }.toTypedArray()
private val executors = EDType.entries.map { event ->
// It CANNOT be the default dispatch because there will be thread starvation
// NOTE: THIS CANNOT CHANGE!! IT WILL BREAK EVERYTHING IF IT CHANGES!
Executors.newSingleThreadExecutor(
NamedThreadFactory(
namePrefix = "$type-${event.name}",
group = Configuration.networkThreadGroup,
threadPriority = Thread.NORM_PRIORITY,
daemon = true
) { thread ->
// when a new thread is created, assign it to the array
threadIds[event.ordinal].lazySet(thread.id)
}
)
}.toTypedArray()
val HANDSHAKE: ED
val CONNECT: ED
val ERROR: ED
val CLOSE: ED
init {
executors.forEachIndexed { _, executor ->
executor.submit {
// this is to create a new thread only, so that the thread ID can be assigned
}
}
HANDSHAKE = ED(this, EDType.HANDSHAKE)
CONNECT = ED(this, EDType.CONNECT)
ERROR = ED(this, EDType.ERROR)
CLOSE = ED(this, EDType.CLOSE)
}
/**
* Shuts-down each event dispatcher executor, and waits for it to gracefully shutdown. Once shutdown, it cannot be restarted.
*
* @param timeout how long to wait
* @param timeoutUnit what the unit count is
*/
fun shutdownAndWait(timeout: Long, timeoutUnit: TimeUnit) {
require(timeout > 0) { logger.error("The EventDispatcher shutdown timeout must be > 0!") }
HANDSHAKE.shutdownAndWait(timeout, timeoutUnit)
CONNECT.shutdownAndWait(timeout, timeoutUnit)
ERROR.shutdownAndWait(timeout, timeoutUnit)
CLOSE.shutdownAndWait(timeout, timeoutUnit)
}
/**
* Checks if the current execution thread is running inside one of the event dispatchers.
*/
fun isDispatch(): Boolean {
val threadId = Thread.currentThread().id
typedEntries.forEach { event ->
if (threadIds[event.ordinal].value == threadId) {
return true
}
}
return false
}
/**
* Checks if the current execution thread is running inside one of the event dispatchers.
*/
private fun isDispatch(type: EDType): Boolean {
val threadId = Thread.currentThread().id
return threadIds[type.ordinal].value == threadId
}
/**
* shuts-down the current execution thread and waits for it complete.
*/
private fun shutdownAndWait(type: EDType, timeout: Long, timeoutUnit: TimeUnit) {
executors[type.ordinal].shutdown()
executors[type.ordinal].awaitTermination(timeout, timeoutUnit)
}
/**
* Each event type runs inside its own thread executor.
*
* We want EACH event type to run in its own executor... on its OWN thread, in order to prevent deadlocks
* This is because there are blocking dependencies: DISCONNECT -> CONNECT.
*
* If an event is RE-ENTRANT, then it will immediately execute!
*/
private fun launch(event: EDType, function: () -> Unit) {
val eventId = event.ordinal
try {
if (DEBUG_EVENTS) {
val id = traceId.getAndIncrement()
executors[eventId].submit {
if (logger.isDebugEnabled) {
logger.debug("Starting $event : $id")
}
function()
if (logger.isDebugEnabled) {
logger.debug("Finished $event : $id")
}
}
} else {
executors[eventId].submit(function)
}
} catch (e: Exception) {
logger.error("Error during event dispatch!", e)
}
}
}

View File

@ -189,7 +189,7 @@ internal class IpInfo(config: ServerConfiguration) {
}
else -> {
ipType = IPC
listenAddressString = "IPC"
listenAddressString = EndPoint.IPC_NAME
formattedListenAddressString = listenAddressString
}
}
@ -228,48 +228,4 @@ internal class IpInfo(config: ServerConfiguration) {
formattedListenAddressString
}
}
// localhost/loopback IP might not always be 127.0.0.1 or ::1
// We want to listen on BOTH IPv4 and IPv6 (config option lets us configure this)
// val listenIPv4Address: InetAddress? =
// if (canUseIPv4) {
// formatCommonAddress(config.listenIpAddress, true) { null } // if it's not a valid IP, the lambda will return null
// }
// else {
// null
// }
// val listenIPv6Address: InetAddress? =
// if (canUseIPv6) {
// EndPoint.formatCommonAddress(config.listenIpAddress, false) { null } // if it's not a valid IP, the lambda will return null
// }
// else {
// null
// }
//
// val listenAddressString: String by lazy {
// if (listenIPv6Address == IPv6.WILDCARD) {
// IPv6.WILDCARD_STRING
// } else {
// IPv4.WILDCARD_STRING
// }
// }
//
// val listenAddressPrettyString: String by lazy {
// if (listenIPv4Address == null) {
// "IPC"
// }
// else {
// "IPC"
// }
// val listenAddressString = IP.toString(listenAddress!!)
//
// val prettyAddressString = when (listenAddress) {
// IPv4.WILDCARD -> listenAddressString
// IPv6.WILDCARD -> IPv4.WILDCARD.hostAddress + "/" + listenAddressString
// else -> listenAddressString
// }
// }
}

View File

@ -15,20 +15,21 @@
*/
package dorkbox.network.connection
import dorkbox.classUtil.ClassHelper
import dorkbox.classUtil.ClassHierarchy
import dorkbox.collections.IdentityMap
import dorkbox.network.ipFilter.IpFilterRule
import dorkbox.os.OS
import dorkbox.util.classes.ClassHelper
import dorkbox.util.classes.ClassHierarchy
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import mu.KLogger
import net.jodah.typetools.TypeResolver
import org.slf4j.Logger
import java.net.InetAddress
import java.util.concurrent.locks.*
import kotlin.concurrent.write
/**
* Manages all of the different connect/disconnect/etc listeners
*/
internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogger) {
internal class ListenerManager<CONNECTION: Connection>(private val logger: Logger, val eventDispatch: EventDispatcher) {
companion object {
/**
* Specifies the load-factor for the IdentityMap used to manage keeping track of the number of connections + listeners
@ -40,13 +41,13 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* Neither of these are useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
*/
fun Throwable.cleanStackTrace(adjustedStartOfStack: Int = 0) {
fun Throwable.cleanStackTrace(adjustedStartOfStack: Int = 0): Throwable {
// we never care about coroutine stacks, so filter then to start with.
val origStackTrace = this.stackTrace
val size = origStackTrace.size
if (size == 0) {
return
return this
}
val stackTrace = origStackTrace.filterNot {
@ -92,6 +93,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
// keep just one, since it's a stack frame INSIDE our network library, and we need that!
this.stackTrace = stackTrace.copyOfRange(0, 1)
}
return this
}
/**
@ -129,19 +132,20 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* We only want the error message, because we do something based on it (and the full stack trace is meaningless)
*/
fun Throwable.cleanAllStackTrace() {
fun Throwable.cleanAllStackTrace(): Throwable{
val stackTrace = this.stackTrace
val size = stackTrace.size
if (size == 0) {
return
return this
}
// throw everything out
this.stackTrace = stackTrace.copyOfRange(0, 1)
return this
}
internal inline fun <reified T> add(thing: T, array: Array<T>): Array<T> {
internal inline fun <reified T: Any> add(thing: T, array: Array<T>): Array<T> {
val currentLength: Int = array.size
// add the new subscription to the END of the array
@ -152,40 +156,45 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
return newMessageArray
}
internal inline fun <reified T> remove(thing: T, array: Array<T>): Array<T> {
internal inline fun <reified T: Any> remove(thing: T, array: Array<T>): Array<T> {
// remove the subscription form the array
// THIS IS IDENTITY CHECKS, NOT EQUALITY
return array.filter { it !== thing }.toTypedArray()
}
}
// initialize emtpy arrays
@Volatile
private var onConnectFilterList = Array<suspend (CONNECTION.() -> Boolean)>(0) { { true } }
private val onConnectFilterMutex = Mutex()
private var onConnectFilterList = Array<((InetAddress, String) -> Boolean)>(0) { { _, _ -> true } }
private val onConnectFilterLock = ReentrantReadWriteLock()
@Volatile
private var onInitList = Array<suspend (CONNECTION.() -> Unit)>(0) { { } }
private val onInitMutex = Mutex()
private var onConnectBufferedMessageFilterList = Array<((InetAddress?, String) -> Boolean)>(0) { { _, _ -> true } }
private val onConnectBufferedMessageFilterLock = ReentrantReadWriteLock()
@Volatile
private var onConnectList = Array<suspend (CONNECTION.() -> Unit)>(0) { { } }
private val onConnectMutex = Mutex()
private var onInitList = Array<(CONNECTION.() -> Unit)>(0) { { } }
private val onInitLock = ReentrantReadWriteLock()
@Volatile
private var onDisconnectList = Array<suspend CONNECTION.() -> Unit>(0) { { } }
private val onDisconnectMutex = Mutex()
private var onConnectList = Array<(CONNECTION.() -> Unit)>(0) { { } }
private val onConnectLock = ReentrantReadWriteLock()
@Volatile
private var onErrorList = Array<suspend CONNECTION.(Throwable) -> Unit>(0) { { } }
private val onErrorMutex = Mutex()
private var onDisconnectList = Array<CONNECTION.() -> Unit>(0) { { } }
private val onDisconnectLock = ReentrantReadWriteLock()
@Volatile
private var onErrorGlobalList = Array<suspend Throwable.() -> Unit>(0) { { } }
private val onErrorGlobalMutex = Mutex()
private var onErrorList = Array<CONNECTION.(Throwable) -> Unit>(0) { { } }
private val onErrorLock = ReentrantReadWriteLock()
@Volatile
private var onMessageMap = IdentityMap<Class<*>, Array<suspend CONNECTION.(Any) -> Unit>>(32, LOAD_FACTOR)
private val onMessageMutex = Mutex()
private var onErrorGlobalList = Array<Throwable.() -> Unit>(0) { { } }
private val onErrorGlobalLock = ReentrantReadWriteLock()
@Volatile
private var onMessageMap = IdentityMap<Class<*>, Array<CONNECTION.(Any) -> Unit>>(32, LOAD_FACTOR)
private val onMessageLock = ReentrantReadWriteLock()
// used to keep a cache of class hierarchy for distributing messages
private val classHierarchyCache = ClassHierarchy(LOAD_FACTOR)
@ -196,10 +205,10 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
* If there are no rules added, then all connections are allowed
* If there are rules added, then a rule MUST be matched to be allowed
*/
suspend fun filter(ipFilterRule: IpFilterRule) {
filter {
// IPC will not filter, so this is OK to coerce to not-null
ipFilterRule.matches(remoteAddress!!)
fun filter(ipFilterRule: IpFilterRule) {
filter { clientAddress, _ ->
// IPC will not filter
ipFilterRule.matches(clientAddress)
}
}
@ -208,29 +217,62 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if a connection
* should be allowed
*
* By default, if there are no filter rules, then all connections are allowed to connect
* If there are filter rules - then ONLY connections for the filter that returns true are allowed to connect (all else are denied)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the connection will continue to connect.
* If the function returns FALSE, then the other end of the connection will
* receive a connection error
*
* For a server, this function will be called for ALL clients.
*
* If ANY filter rule that is applied returns true, then the connection is permitted
*
* This function will be called for **only** network clients (IPC client are excluded)
*
* @param function clientAddress: UDP connection address
* tagName: the connection tag name
*/
suspend fun filter(function: suspend CONNECTION.() -> Boolean) {
onConnectFilterMutex.withLock {
fun filter(function: (clientAddress: InetAddress, tagName: String) -> Boolean) {
onConnectFilterLock.write {
// we have to follow the single-writer principle!
onConnectFilterList = add(function, onConnectFilterList)
}
}
/**
* Adds a function that will be called BEFORE a client/server "connects" with each other, and used to determine if buffered messages
* for a connection should be enabled
*
* By default, if there are no rules, then all connections will have buffered messages enabled
* If there are rules - then ONLY connections for the rule that returns true will have buffered messages enabled (all else are disabled)
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* If the function returns TRUE, then the buffered messages for a connection are enabled.
* If the function returns FALSE, then the buffered messages for a connection is disabled.
*
* If ANY rule that is applied returns true, then the buffered messages for a connection are enabled
*
* @param function clientAddress: not-null when UDP connection, null when IPC connection
* tagName: the connection tag name
*/
fun enableBufferedMessages(function: (clientAddress: InetAddress?, tagName: String) -> Boolean) {
onConnectBufferedMessageFilterLock.write {
// we have to follow the single-writer principle!
onConnectBufferedMessageFilterList = add(function, onConnectBufferedMessageFilterList)
}
}
/**
* Adds a function that will be called when a client/server connection is FIRST initialized, but before it's
* connected to the remote endpoint
*
* For a server, this function will be called for ALL client connections.
*/
suspend fun onInit(function: suspend CONNECTION.() -> Unit) {
onInitMutex.withLock {
fun onInit(function: CONNECTION.() -> Unit) {
onInitLock.write {
// we have to follow the single-writer principle!
onInitList = add(function, onInitList)
}
@ -240,8 +282,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
* Adds a function that will be called when a client/server connection first establishes a connection with the remote end.
* 'onInit()' callbacks will execute for both the client and server before `onConnect()` will execute will "connects" with each other
*/
suspend fun onConnect(function: suspend CONNECTION.() -> Unit) {
onConnectMutex.withLock {
fun onConnect(function: CONNECTION.() -> Unit) {
onConnectLock.write {
// we have to follow the single-writer principle!
onConnectList = add(function, onConnectList)
}
@ -252,8 +294,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* Do not try to send messages! The connection will already be closed, resulting in an error if you attempt to do so.
*/
suspend fun onDisconnect(function: suspend CONNECTION.() -> Unit) {
onDisconnectMutex.withLock {
fun onDisconnect(function: CONNECTION.() -> Unit) {
onDisconnectLock.write {
// we have to follow the single-writer principle!
onDisconnectList = add(function, onDisconnectList)
}
@ -264,8 +306,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* The error is also sent to an error log before this method is called.
*/
suspend fun onError(function: suspend CONNECTION.(Throwable) -> Unit) {
onErrorMutex.withLock {
fun onError(function: CONNECTION.(Throwable) -> Unit) {
onErrorLock.write {
// we have to follow the single-writer principle!
onErrorList = add(function, onErrorList)
}
@ -276,8 +318,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* The error is also sent to an error log before this method is called.
*/
suspend fun onError(function: suspend Throwable.() -> Unit) {
onErrorGlobalMutex.withLock {
fun onError(function: Throwable.() -> Unit) {
onErrorGlobalLock.write {
// we have to follow the single-writer principle!
onErrorGlobalList = add(function, onErrorGlobalList)
}
@ -288,8 +330,8 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* This method should not block for long periods as other network activity will not be processed until it returns.
*/
suspend fun <MESSAGE> onMessage(function: suspend CONNECTION.(MESSAGE) -> Unit) {
onMessageMutex.withLock {
fun <MESSAGE> onMessage(function: CONNECTION.(MESSAGE) -> Unit) {
onMessageLock.write {
// we have to follow the single-writer principle!
// this is the connection generic parameter for the listener, works for lambda expressions as well
@ -310,26 +352,26 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
}
if (success) {
// NOTE: https://github.com/Kotlin/kotlinx.atomicfu
// https://github.com/Kotlin/kotlinx.atomicfu
// this is EXPLICITLY listed as a "Don't" via the documentation. The ****ONLY**** reason this is actually OK is because
// we are following the "single-writer principle", so only ONE THREAD can modify this at a time.
val tempMap = onMessageMap
@Suppress("UNCHECKED_CAST")
val func = function as suspend (CONNECTION, Any) -> Unit
val func = function as (CONNECTION, Any) -> Unit
val newMessageArray: Array<suspend (CONNECTION, Any) -> Unit>
val onMessageArray: Array<suspend (CONNECTION, Any) -> Unit>? = tempMap.get(messageClass)
val newMessageArray: Array<(CONNECTION, Any) -> Unit>
val onMessageArray: Array<(CONNECTION, Any) -> Unit>? = tempMap[messageClass]
if (onMessageArray != null) {
newMessageArray = add(function, onMessageArray)
} else {
@Suppress("RemoveExplicitTypeArguments")
newMessageArray = Array<suspend (CONNECTION, Any) -> Unit>(1) { { _, _ -> } }
newMessageArray = Array<(CONNECTION, Any) -> Unit>(1) { { _, _ -> } }
newMessageArray[0] = func
}
tempMap.put(messageClass, newMessageArray)
tempMap.put(messageClass!!, newMessageArray)
onMessageMap = tempMap
} else {
throw IllegalArgumentException("Unable to add incompatible types! Detected connection/message classes: $connectionClass, $messageClass")
@ -342,25 +384,18 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* NOTE: This is run directly on the thread that calls it!
* This is run directly on the thread that calls it!
*
* @return true if the connection will be allowed to connect. False if we should terminate this connection
* @return true if the client address is allowed to connect. False if we should terminate this connection
*/
suspend fun notifyFilter(connection: CONNECTION): Boolean {
// remote address will NOT be null at this stage, but best to verify.
val remoteAddress = connection.remoteAddress
if (remoteAddress == null) {
logger.error("Connection ${connection.id}: Unable to attempt connection stages when no remote address is present")
return false
}
fun notifyFilter(clientAddress: InetAddress, clientTagName: String): Boolean {
// by default, there is a SINGLE rule that will always exist, and will always ACCEPT ALL connections.
// This is so the array types can be setup (the compiler needs SOMETHING there)
val list = onConnectFilterList
// if there is a rule, a connection must match for it to connect
list.forEach {
if (it.invoke(connection)) {
if (it.invoke(clientAddress, clientTagName)) {
return true
}
}
@ -371,19 +406,46 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
return list.isEmpty()
}
/**
* Invoked just after a connection is created, but before it is connected.
*
* It is the responsibility of the custom filter to write the error, if there is one
*
* This is run directly on the thread that calls it!
*
* @return true if the connection will have buffered messages enabled. False if buffered messages for this connection should be disabled.
*/
fun notifyEnableBufferedMessages(clientAddress: InetAddress?, clientTagName: String): Boolean {
// by default, there is a SINGLE rule that will always exist, and will always PERMIT buffered messages.
// This is so the array types can be setup (the compiler needs SOMETHING there)
val list = onConnectBufferedMessageFilterList
// if there is a rule, a connection must match for it to enable buffered messages
list.forEach {
if (it.invoke(clientAddress, clientTagName)) {
return true
}
}
// default if nothing matches
// NO RULES ADDED -> ALLOW Buffered Messages
// RULES ADDED -> DISABLE Buffered Messages
return list.isEmpty()
}
/**
* Invoked when a connection is first initialized, but BEFORE it's connected to the remote address.
*
* NOTE: This is run directly on the thread that calls it! Things that happen in event are TIME-CRITICAL, and must happen before connect happens.
* Because of this guarantee, init is immediately executed where connect is on a separate thread
*/
suspend fun notifyInit(connection: CONNECTION) {
fun notifyInit(connection: CONNECTION) {
val list = onInitList
list.forEach {
try {
it(connection)
} catch (t: Throwable) {
// NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
// when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
t.cleanStackTrace()
logger.error("Connection ${connection.id} error", t)
}
@ -393,17 +455,17 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
/**
* Invoked when a connection is connected to a remote address.
*
* NOTE: This is run on the EventDispatch!
* This is run on the EventDispatch!
*/
fun notifyConnect(connection: CONNECTION) {
val list = onConnectList
if (list.isNotEmpty()) {
EventDispatcher.CONNECT.launch {
connection.endPoint.eventDispatch.CONNECT.launch {
list.forEach {
try {
it(connection)
} catch (t: Throwable) {
// NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
// when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
t.cleanStackTrace()
logger.error("Connection ${connection.id} error", t)
}
@ -415,9 +477,9 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
/**
* Invoked when a connection is disconnected to a remote address.
*
* NOTE: This is exclusively called from a connection, when that connection is closed!
* This is exclusively called from a connection, when that connection is closed!
*
* NOTE: This is run on the EventDispatch!
* This is run on the EventDispatch!
*/
fun notifyDisconnect(connection: Connection) {
connection.notifyDisconnect()
@ -432,12 +494,12 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
fun directNotifyDisconnect(connection: CONNECTION) {
val list = onDisconnectList
if (list.isNotEmpty()) {
EventDispatcher.DISCONNECT.launch {
connection.endPoint.eventDispatch.CLOSE.launch {
list.forEach {
try {
it(connection)
} catch (t: Throwable) {
// NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
// when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
t.cleanStackTrace()
logger.error("Connection ${connection.id} error", t)
}
@ -452,17 +514,17 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* The error is also sent to an error log before notifying callbacks
*
* NOTE: This is run on the EventDispatch!
* This is run on the EventDispatch!
*/
fun notifyError(connection: CONNECTION, exception: Throwable) {
val list = onErrorList
if (list.isNotEmpty()) {
EventDispatcher.ERROR.launch {
connection.endPoint.eventDispatch.ERROR.launch {
list.forEach {
try {
it(connection, exception)
} catch (t: Throwable) {
// NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
// when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
t.cleanStackTrace()
logger.error("Connection ${connection.id} error", t)
}
@ -481,12 +543,12 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
fun notifyError(exception: Throwable) {
val list = onErrorGlobalList
if (list.isNotEmpty()) {
EventDispatcher.ERROR.launch {
eventDispatch.ERROR.launch {
list.forEach {
try {
it(exception)
} catch (t: Throwable) {
// NOTE: when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
// when we remove stuff, we ONLY want to remove the "tail" of the stacktrace, not ALL parts of the stacktrace
t.cleanStackTrace()
logger.error("Global error", t)
}
@ -502,7 +564,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
*
* @return true if there were listeners assigned for this message type
*/
suspend fun notifyOnMessage(connection: CONNECTION, message: Any): Boolean {
fun notifyOnMessage(connection: CONNECTION, message: Any): Boolean {
val messageClass: Class<*> = message.javaClass
// have to save the types + hierarchy (note: duplicates are OK, since they will just be overwritten)
@ -525,7 +587,7 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
val tempMap = onMessageMap
var hasListeners = false
hierarchy.forEach { clazz ->
val onMessageArray: Array<suspend (CONNECTION, Any) -> Unit>? = tempMap.get(clazz)
val onMessageArray: Array<(CONNECTION, Any) -> Unit>? = tempMap[clazz]
if (onMessageArray != null) {
hasListeners = true
@ -545,30 +607,33 @@ internal class ListenerManager<CONNECTION: Connection>(private val logger: KLogg
/**
* This will remove all listeners that have been registered!
*/
suspend fun close() {
fun close() {
// we have to follow the single-writer principle!
logger.debug { "Closing the listener manager" }
logger.debug("Closing the listener manager")
onConnectFilterMutex.withLock {
onConnectFilterList = Array(0) { { true } }
onConnectFilterLock.write {
onConnectFilterList = Array(0) { { _, _ -> true } }
}
onInitMutex.withLock {
onConnectBufferedMessageFilterLock.write {
onConnectBufferedMessageFilterList = Array(0) { { _, _ -> true } }
}
onInitLock.write {
onInitList = Array(0) { { } }
}
onConnectMutex.withLock {
onConnectLock.write {
onConnectList = Array(0) { { } }
}
onDisconnectMutex.withLock {
onDisconnectLock.write {
onDisconnectList = Array(0) { { } }
}
onErrorMutex.withLock {
onErrorLock.write {
onErrorList = Array(0) { { } }
}
onErrorGlobalMutex.withLock {
onErrorGlobalLock.write {
onErrorGlobalList = Array(0) { { } }
}
onMessageMutex.withLock {
onMessageMap = IdentityMap<Class<*>, Array<suspend CONNECTION.(Any) -> Unit>>(32, LOAD_FACTOR)
onMessageLock.write {
onMessageMap = IdentityMap(32, LOAD_FACTOR)
}
}
}

View File

@ -0,0 +1,22 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
class Paired<CONNECTION : Connection> {
lateinit var connection: CONNECTION
lateinit var message: Any
}

View File

@ -0,0 +1,46 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection
import dorkbox.network.rmi.RmiUtils
class SendSync {
var message: Any? = null
// used to notify the remote endpoint that the message has been processed
var id: Int = 0
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is SendSync) return false
if (message != other.message) return false
if (id != other.id) return false
return true
}
override fun hashCode(): Int {
var result = message?.hashCode() ?: 0
result = 31 * result + id
return result
}
override fun toString(): String {
return "SendSync ${RmiUtils.unpackUnsignedRight(id)} (message=$message)"
}
}

View File

@ -0,0 +1,130 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.buffer
import dorkbox.bytes.ByteArrayWrapper
import dorkbox.collections.LockFreeHashMap
import dorkbox.hex.toHexString
import dorkbox.network.Configuration
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.util.Sys
import net.jodah.expiringmap.ExpirationPolicy
import net.jodah.expiringmap.ExpiringMap
import org.slf4j.LoggerFactory
import java.util.concurrent.*
internal open class BufferManager<CONNECTION: Connection>(
config: Configuration,
listenerManager: ListenerManager<CONNECTION>,
aeronDriver: AeronDriver,
sessionTimeout: Long
) {
companion object {
private val logger = LoggerFactory.getLogger(BufferManager::class.java.simpleName)
}
private val sessions = LockFreeHashMap<ByteArrayWrapper, BufferedSession>()
private val expiringSessions: ExpiringMap<ByteArrayWrapper, BufferedSession>
init {
require(sessionTimeout >= 60) { "The buffered connection timeout 'bufferedConnectionTimeoutSeconds' must be greater than 60 seconds!" }
// ignore 0
val check = TimeUnit.SECONDS.toNanos(sessionTimeout)
val lingerNs = aeronDriver.lingerNs()
val required = TimeUnit.SECONDS.toNanos(config.connectionCloseTimeoutInSeconds.toLong())
require(check == 0L || check > required + lingerNs) {
"The session timeout (${Sys.getTimePretty(check)}) must be longer than the connection close timeout (${Sys.getTimePretty(required)}) + the aeron driver linger timeout (${Sys.getTimePretty(lingerNs)})!"
}
// connections are extremely difficult to diagnose when the connection timeout is short
val timeUnit = if (EndPoint.DEBUG_CONNECTIONS) { TimeUnit.HOURS } else { TimeUnit.SECONDS }
expiringSessions = ExpiringMap.builder()
.expiration(sessionTimeout, timeUnit)
.expirationPolicy(ExpirationPolicy.CREATED)
.expirationListener<ByteArrayWrapper, BufferedSession> { publicKeyWrapped, sessionConnection ->
// this blocks until it fully runs (which is ok. this is fast)
logger.debug("Connection session expired for: ${publicKeyWrapped.bytes.toHexString()}")
// this SESSION has expired, so we should call the onDisconnect for the underlying connection, in order to clean it up.
listenerManager.notifyDisconnect(sessionConnection.connection)
}
.build()
}
/**
* this must be called when a new connection is created
*
* @return true if this is a new session, false if it is an existing session
*/
fun onConnect(connection: Connection): BufferedSession {
val publicKeyWrapped = ByteArrayWrapper.wrap(connection.uuid)
return synchronized(sessions) {
// always check if we are expiring first...
val expiring = expiringSessions.remove(publicKeyWrapped)
if (expiring != null) {
expiring.connection = connection
expiring
} else {
val existing = sessions[publicKeyWrapped]
if (existing != null) {
// we must always set this session value!!
existing.connection = connection
existing
} else {
val newSession = BufferedSession(connection)
sessions[publicKeyWrapped] = newSession
// we must always set this when the connection is created, and it must be inside the sync block!
newSession
}
}
}
}
/**
* Always called when a connection is disconnected from the network
*/
fun onDisconnect(connection: Connection) {
try {
val publicKeyWrapped = ByteArrayWrapper.wrap(connection.uuid)
synchronized(sessions) {
val sess = sessions.remove(publicKeyWrapped)
// we want to expire this session after XYZ time
expiringSessions[publicKeyWrapped] = sess
}
}
catch (e: Exception) {
logger.error("Unable to run session expire logic!", e)
}
}
fun close() {
synchronized(sessions) {
sessions.clear()
expiringSessions.clear()
}
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.buffer
class BufferedMessages {
var messages = arrayListOf<Any>()
}

View File

@ -0,0 +1,34 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.buffer
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
internal class BufferedSerializer: Serializer<BufferedMessages>() {
override fun write(kryo: Kryo, output: Output, messages: BufferedMessages) {
kryo.writeClassAndObject(output, messages.messages)
}
override fun read(kryo: Kryo, input: Input, type: Class<out BufferedMessages>): BufferedMessages {
val messages = BufferedMessages()
messages.messages = kryo.readClassAndObject(input) as ArrayList<Any>
return messages
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.buffer
import dorkbox.network.connection.Connection
import java.util.concurrent.*
open class BufferedSession(@Volatile var connection: Connection) {
/**
* Only used when configured. Will re-send all missing messages to a connection when a connection re-connects.
*/
val pendingMessagesQueue: LinkedTransferQueue<Any> = LinkedTransferQueue()
fun queueMessage(connection: Connection, message: Any, abortEarly: Boolean): Boolean {
if (this.connection != connection) {
connection.logger.trace("[{}] message received on old connection, resending", connection)
// we received a message on an OLD connection (which is no longer connected ---- BUT we have a NEW connection that is connected)
// this can happen on RMI object that are old
val success = this.connection.send(message, abortEarly)
if (success) {
connection.logger.trace("[{}] successfully resent message", connection)
return true
}
}
if (!connection.enableBufferedMessages) {
// nothing, since we emit logs during connection initialization that pending messages are DISABLED
return false
}
if (!abortEarly) {
// this was a "normal" send (instead of the disconnect message).
pendingMessagesQueue.put(message)
connection.logger.trace("[{}] queueing message", connection)
}
else if (connection.endPoint.aeronDriver.internal.mustRestartDriverOnError) {
// the only way we get errors, is if the connection is bad OR if we are sending so fast that the connection cannot keep up.
// don't restart/reconnect -- there was an internal network error
pendingMessagesQueue.put(message)
connection.logger.trace("[{}] queueing message", connection)
}
else if (!connection.isClosedWithTimeout()) {
// there was an issue - the connection should automatically reconnect
pendingMessagesQueue.put(message)
connection.logger.trace("[{}] queueing message", connection)
}
return false
}
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.buffer;

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection;

View File

@ -17,5 +17,18 @@
package dorkbox.network.connection.streaming
import dorkbox.network.serialization.AeronOutput
import kotlinx.atomicfu.atomic
class AeronWriter: StreamingWriter, AeronOutput()
class AeronWriter(val size: Int): StreamingWriter, AeronOutput(size) {
private val written = atomic(0)
override fun writeBytes(startPosition: Int, bytes: ByteArray) {
position = startPosition
writeBytes(bytes)
written.getAndAdd(bytes.size)
}
override fun isFinished(): Boolean {
return written.value == size
}
}

View File

@ -16,11 +16,47 @@
package dorkbox.network.connection.streaming
import kotlinx.atomicfu.atomic
import java.io.File
import java.io.FileOutputStream
import java.io.RandomAccessFile
class FileWriter(file: File) : StreamingWriter, FileOutputStream(file) {
override fun writeBytes(bytes: ByteArray) {
class FileWriter(val size: Int, val file: File) : StreamingWriter, RandomAccessFile(file, "rw") {
private val written = atomic(0)
init {
// reserve space on disk!
val saveSize = size.coerceAtMost(4096)
var bytes = ByteArray(saveSize)
this.write(bytes)
if (saveSize < size) {
var remainingBytes = size - saveSize
while (remainingBytes > 0) {
if (saveSize > remainingBytes) {
bytes = ByteArray(remainingBytes)
}
this.write(bytes)
remainingBytes = (remainingBytes - saveSize).coerceAtLeast(0)
}
}
}
override fun writeBytes(startPosition: Int, bytes: ByteArray) {
// the OS will synchronize writes to disk
this.seek(startPosition.toLong())
write(bytes)
written.addAndGet(bytes.size)
}
override fun isFinished(): Boolean {
return written.value == size
}
fun finishAndClose() {
fd.sync()
close()
}
}

View File

@ -17,6 +17,7 @@
package dorkbox.network.connection.streaming
data class StreamingControl(val state: StreamingState,
val isFile: Boolean,
val streamId: Int,
val totalSize: Long = 0L
): StreamingMessage

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.streaming
import com.esotericsoftware.kryo.Kryo
@ -20,30 +21,21 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
class StreamingControlSerializer: Serializer<StreamingControl>() {
internal class StreamingControlSerializer: Serializer<StreamingControl>() {
override fun write(kryo: Kryo, output: Output, data: StreamingControl) {
output.writeByte(data.state.ordinal)
output.writeBoolean(data.isFile)
output.writeVarInt(data.streamId, true)
output.writeVarLong(data.totalSize, true)
}
override fun read(kryo: Kryo, input: Input, type: Class<out StreamingControl>): StreamingControl {
val stateOrdinal = input.readByte().toInt()
val state = StreamingState.values().first { it.ordinal == stateOrdinal }
val isFile = input.readBoolean()
val state = StreamingState.entries.first { it.ordinal == stateOrdinal }
val streamId = input.readVarInt(true)
val totalSize = input.readVarLong(true)
return StreamingControl(state, streamId, totalSize)
}
}
class StreamingDataSerializer: Serializer<StreamingData>() {
override fun write(kryo: Kryo, output: Output, data: StreamingData) {
output.writeVarInt(data.streamId, true)
}
override fun read(kryo: Kryo, input: Input, type: Class<out StreamingData>): StreamingData {
val streamId = input.readVarInt(true)
return StreamingData(streamId)
return StreamingControl(state, isFile, streamId, totalSize)
}
}

View File

@ -1,9 +1,27 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.streaming
import dorkbox.bytes.xxHash32
class StreamingData(val streamId: Int) : StreamingMessage {
// These are set just after we receive the message, and before we process it
@Transient var payload: ByteArray? = null
var payload: ByteArray? = null
var startPosition: Int = 0
override fun equals(other: Any?): Boolean {
if (this === other) return true
@ -17,16 +35,19 @@ class StreamingData(val streamId: Int) : StreamingMessage {
if (!payload.contentEquals(other.payload)) return false
} else if (other.payload != null) return false
if (startPosition != other.startPosition) return false
return true
}
override fun hashCode(): Int {
var result = streamId.hashCode()
result = 31 * result + (payload?.contentHashCode() ?: 0)
result = 31 * result + (startPosition)
return result
}
override fun toString(): String {
return "StreamingData(streamId=$streamId)"
return "StreamingData(streamId=$streamId position=${startPosition}, xxHash=${payload?.xxHash32()})"
}
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.streaming
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
internal class StreamingDataSerializer: Serializer<StreamingData>() {
override fun write(kryo: Kryo, output: Output, data: StreamingData) {
output.writeVarInt(data.streamId, true)
// we re-use this data when streaming data to the remote endpoint, so we don't write out the payload here, we do it in another place
}
override fun read(kryo: Kryo, input: Input, type: Class<out StreamingData>): StreamingData {
val streamId = input.readVarInt(true)
val streamingData = StreamingData(streamId)
// we want to read out the start-position AND payload. It is not written by the serializer, but by the streaming manager
val startPosition = input.readVarInt(true)
val payloadSize = input.readVarInt(true)
streamingData.startPosition = startPosition
streamingData.payload = input.readBytes(payloadSize)
return streamingData
}
}

View File

@ -19,31 +19,29 @@
package dorkbox.network.connection.streaming
import com.esotericsoftware.kryo.io.Input
import dorkbox.bytes.OptimizeUtilsByteArray
import dorkbox.bytes.OptimizeUtilsByteBuf
import dorkbox.collections.LockFreeHashMap
import dorkbox.collections.LockFreeLongMap
import dorkbox.network.Configuration
import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager.Companion.cleanAllStackTrace
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import dorkbox.network.exceptions.StreamingException
import dorkbox.network.serialization.AeronInput
import dorkbox.network.serialization.AeronOutput
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.KryoWriter
import dorkbox.os.OS
import dorkbox.util.Sys
import io.aeron.Publication
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import mu.KLogger
import org.agrona.ExpandableDirectByteBuffer
import org.agrona.MutableDirectBuffer
import org.agrona.concurrent.IdleStrategy
import org.agrona.concurrent.UnsafeBuffer
import org.slf4j.Logger
import java.io.File
import java.io.FileInputStream
internal class StreamingManager<CONNECTION : Connection>(
private val logger: KLogger, private val messageDispatch: CoroutineScope, val config: Configuration
) {
internal class StreamingManager<CONNECTION : Connection>(private val logger: Logger, val config: Configuration) {
companion object {
private const val KILOBYTE = 1024
@ -51,7 +49,7 @@ internal class StreamingManager<CONNECTION : Connection>(
private const val GIGABYTE = 1024 * MEGABYTE
private const val TERABYTE = 1024L * GIGABYTE
@Suppress("UNUSED_CHANGED_VALUE")
@Suppress("UNUSED_CHANGED_VALUE", "SameParameterValue")
private fun writeVarInt(internalBuffer: MutableDirectBuffer, position: Int, value: Int, optimizePositive: Boolean): Int {
var p = position
var newValue = value
@ -91,15 +89,33 @@ internal class StreamingManager<CONNECTION : Connection>(
}
private val streamingDataTarget = LockFreeHashMap<Long, StreamingControl>()
private val streamingDataInMemory = LockFreeHashMap<Long, StreamingWriter>()
private val streamingDataTarget = LockFreeLongMap<StreamingControl>()
private val streamingDataInMemory = LockFreeLongMap<StreamingWriter>()
/**
* What is the max stream size that can exist in memory when deciding if data chunks are in memory or on temo-file on disk
* What is the max stream size that can exist in memory when deciding if data blocks are in memory or temp-file on disk
*/
private val maxStreamSizeInMemoryInBytes = config.maxStreamSizeInMemoryMB * MEGABYTE
fun getFile(connection: CONNECTION, endPoint: EndPoint<CONNECTION>, messageStreamId: Int): File {
// NOTE: the stream session ID is a combination of the connection ID + random ID (on the receiving side),
// otherwise clients can abuse it and corrupt OTHER clients data!!
val streamId = (connection.id.toLong() shl 4) or messageStreamId.toLong()
val output = streamingDataInMemory[streamId]
return if (output is FileWriter) {
streamingDataInMemory.remove(streamId)
output.file
} else {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "Error while reading file output, stream $streamId was of the wrong type!"
// either client or server. No other choices. We create an exception, because it's more useful!
throw endPoint.newException(errorMessage)
}
}
/**
* NOTE: MUST BE ON THE AERON THREAD!
@ -108,7 +124,6 @@ internal class StreamingManager<CONNECTION : Connection>(
*/
fun processControlMessage(
message: StreamingControl,
kryo: KryoExtra<CONNECTION>,
endPoint: EndPoint<CONNECTION>,
connection: CONNECTION
) {
@ -118,40 +133,70 @@ internal class StreamingManager<CONNECTION : Connection>(
when (message.state) {
StreamingState.START -> {
// message.totalSize > maxInMemory, then write to a temp file INSTEAD
if (message.totalSize > maxStreamSizeInMemoryInBytes) {
val fileName = "${config.applicationId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
// message.totalSize > maxInMemory OR if we are a file, then write to a temp file INSTEAD
if (message.isFile || message.totalSize > maxStreamSizeInMemoryInBytes) {
var fileName = "${config.appId}_${streamId}_${connection.id}.tmp"
var tempFileLocation = OS.TEMP_DIR.resolve(fileName)
while (tempFileLocation.canRead()) {
fileName = "${config.appId}_${streamId}_${connection.id}_${CryptoManagement.secureRandom.nextInt()}.tmp"
tempFileLocation = OS.TEMP_DIR.resolve(fileName)
}
tempFileLocation.deleteOnExit()
val prettySize = Sys.getSizePretty(message.totalSize)
endPoint.logger.info { "Saving $prettySize of streaming data [${streamId}] to: $tempFileLocation" }
streamingDataInMemory[streamId] = FileWriter(tempFileLocation)
if (endPoint.logger.isInfoEnabled) {
endPoint.logger.info("Saving $prettySize of streaming data [${streamId}] to: $tempFileLocation")
}
streamingDataInMemory[streamId] = FileWriter(message.totalSize.toInt(), tempFileLocation)
} else {
endPoint.logger.info { "Saving streaming data [${streamId}] in memory" }
streamingDataInMemory[streamId] = AeronWriter()
if (endPoint.logger.isTraceEnabled) {
endPoint.logger.trace("Saving streaming data [${streamId}] in memory")
}
// .toInt is safe because we know the total size is < than maxStreamSizeInMemoryInBytes
streamingDataInMemory[streamId] = AeronWriter(message.totalSize.toInt())
}
// this must be last
streamingDataTarget[streamId] = message
}
StreamingState.FINISHED -> {
// NOTE: cannot be on a coroutine before kryo usage!
if (message.isFile) {
// we do not do anything with this file yet! The serializer has to return this instance!
val output = streamingDataInMemory[streamId]
if (output is FileWriter) {
output.finishAndClose()
// we don't need to do anything else (no de-serialization into an object) because we are already our target object
return
} else {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "Error while processing streaming content, stream $streamId was supposed to be a FileWriter."
// either client or server. No other choices. We create an exception, because it's more useful!
throw endPoint.newException(errorMessage)
}
}
// get the data out and send messages!
val output = streamingDataInMemory.remove(streamId)
val input = when (output) {
is AeronWriter -> {
// the position can be wrong, especially if there are multiple threads setting the data
output.setPosition(output.size)
AeronInput(output.internalBuffer)
}
is FileWriter -> {
output.flush()
output.close()
// if we are too large to fit in memory while streaming, we store it on disk.
output.finishAndClose()
val fileName = "${config.applicationId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
val fileInputStream = FileInputStream(tempFileLocation)
val fileInputStream = FileInputStream(output.file)
Input(fileInputStream)
}
else -> {
@ -160,29 +205,31 @@ internal class StreamingManager<CONNECTION : Connection>(
}
val streamedMessage = if (input != null) {
try {
kryo.read(input)
} catch (e: Exception) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "Error deserializing message from received streaming content, stream $streamId"
val kryo = endPoint.serialization.takeRead()
try {
kryo.read(connection, input)
} catch (e: Exception) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "Error deserializing message from received streaming content, stream $streamId"
// either client or server. No other choices. We create an exception, because it's more useful!
throw endPoint.newException(errorMessage, e)
} finally {
if (output is FileWriter) {
val fileName = "${config.applicationId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
tempFileLocation.delete()
// either client or server. No other choices. We create an exception, because it's more useful!
throw endPoint.newException(errorMessage, e)
} finally {
endPoint.serialization.putRead(kryo)
if (output is FileWriter) {
val fileName = "${config.appId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
tempFileLocation.delete()
}
}
} else {
null
}
} else {
null
}
if (streamedMessage == null) {
if (output is FileWriter) {
val fileName = "${config.applicationId}_${streamId}_${connection.id}.tmp"
val fileName = "${config.appId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
tempFileLocation.delete()
}
@ -196,29 +243,13 @@ internal class StreamingManager<CONNECTION : Connection>(
}
// NOTE: This MUST be on a new co-routine
messageDispatch.launch {
val listenerManager = endPoint.listenerManager
try {
var hasListeners = listenerManager.notifyOnMessage(connection, streamedMessage)
// each connection registers, and is polled INDEPENDENTLY for messages.
hasListeners = hasListeners or connection.notifyOnMessage(streamedMessage)
if (!hasListeners) {
logger.error("No streamed message callbacks found for ${streamedMessage::class.java.name}")
}
} catch (e: Exception) {
val newException = StreamingException("Error processing message ${streamedMessage::class.java.name}", e)
listenerManager.notifyError(connection, newException)
}
}
// this can be a regular message or an RMI message. Redispatch!
endPoint.processMessageFromChannel(connection, streamedMessage)
}
StreamingState.FAILED -> {
val output = streamingDataInMemory.remove(streamId)
if (output is FileWriter) {
val fileName = "${config.applicationId}_${streamId}_${connection.id}.tmp"
val fileName = "${config.appId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
tempFileLocation.delete()
}
@ -239,7 +270,7 @@ internal class StreamingManager<CONNECTION : Connection>(
StreamingState.UNKNOWN -> {
val output = streamingDataInMemory.remove(streamId)
if (output is FileWriter) {
val fileName = "${config.applicationId}_${streamId}_${connection.id}.tmp"
val fileName = "${config.appId}_${streamId}_${connection.id}.tmp"
val tempFileLocation = OS.TEMP_DIR.resolve(fileName)
tempFileLocation.delete()
}
@ -260,20 +291,20 @@ internal class StreamingManager<CONNECTION : Connection>(
}
/**
* NOTE: MUST BE ON THE AERON THREAD!
* NOTE: MUST BE ON THE AERON THREAD BECAUSE THIS MUST BE SINGLE THREADED!!!
*
* Reassemble/figure out the internal message pieces
*
* NOTE sending a huge file can prevent other other network traffic from arriving until it's done!
* NOTE sending a huge file can cause other network traffic delays!
*/
fun processDataMessage(message: StreamingData, endPoint: EndPoint<CONNECTION>, connection: CONNECTION) {
// the receiving data will ALWAYS come sequentially, but there might be OTHER streaming data received meanwhile.
// NOTE: the stream session ID is a combination of the connection ID + random ID (on the receiving side)
val streamId = (connection.id.toLong() shl 4) or message.streamId.toLong()
val controlMessage = streamingDataTarget[streamId]
if (controlMessage != null) {
streamingDataInMemory[streamId]!!.writeBytes(message.payload!!)
val dataWriter = streamingDataInMemory[streamId]
if (dataWriter != null) {
dataWriter.writeBytes(message.startPosition, message.payload!!)
} else {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
@ -294,13 +325,13 @@ internal class StreamingManager<CONNECTION : Connection>(
streamSessionId: Int,
publication: Publication,
endPoint: EndPoint<CONNECTION>,
kryoExtra: KryoExtra<Connection>,
sendIdleStrategy: IdleStrategy,
connection: Connection
connection: CONNECTION,
kryo: KryoWriter<CONNECTION>
) {
val failMessage = StreamingControl(StreamingState.FAILED, streamSessionId)
val failMessage = StreamingControl(StreamingState.FAILED, false, streamSessionId)
val failSent = endPoint.writeUnsafe(kryoExtra, failMessage, publication, sendIdleStrategy, connection)
val failSent = endPoint.writeUnsafe(failMessage, publication, sendIdleStrategy, connection, kryo)
if (!failSent) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
@ -328,39 +359,35 @@ internal class StreamingManager<CONNECTION : Connection>(
* We don't write max possible length per message, we write out MTU (payload) length (so aeron doesn't fragment the message).
* The max possible length is WAY, WAY more than the max payload length.
*
* @param internalBuffer this is the ORIGINAL object data that is to be "chunked" and sent across the wire
* @return true if ALL the message chunks were successfully sent by aeron, false otherwise. Exceptions are caught and rethrown!
* @param originalBuffer this is the ORIGINAL object data that is to be blocks sent across the wire
*
* @return true if ALL the message blocks were successfully sent by aeron, false otherwise. Exceptions are caught and rethrown!
*/
fun send(
publication: Publication,
internalBuffer: MutableDirectBuffer,
originalBuffer: MutableDirectBuffer,
maxMessageSize: Int,
objectSize: Int,
endPoint: EndPoint<CONNECTION>,
kryo: KryoExtra<Connection>,
kryo: KryoWriter<CONNECTION>,
sendIdleStrategy: IdleStrategy,
connection: Connection
connection: CONNECTION
): Boolean {
// this buffer is the exact size as our internal buffer, so it is unnecessary to have multiple kryo instances
val originalBuffer = ExpandableDirectByteBuffer(objectSize) // this can grow, so it's fine to lock it to this size!
// we have to save out our internal buffer, so we can reuse the kryo instance!
originalBuffer.putBytes(0, internalBuffer, 0, objectSize)
// NOTE: our max object size for IN-MEMORY messages is an INT. For file transfer it's a LONG (so everything here is cast to a long)
var remainingPayload = objectSize
var payloadSent = 0
// NOTE: the stream session ID is a combination of the connection ID + random ID (on the receiving side)
val streamSessionId = CryptoManagement.secureRandom.nextInt()
// tell the other side how much data we are sending
val startMessage = StreamingControl(StreamingState.START, streamSessionId, objectSize.toLong())
val startMessage = StreamingControl(StreamingState.START, false, streamSessionId, remainingPayload.toLong())
val startSent = endPoint.writeUnsafe(kryo, startMessage, publication, sendIdleStrategy, connection)
val startSent = endPoint.writeUnsafe(startMessage, publication, sendIdleStrategy, connection, kryo)
if (!startSent) {
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Error starting streaming content."
val errorMessage = "[${publication.sessionId()}] Error starting streaming content (could not send data)."
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
@ -372,56 +399,63 @@ internal class StreamingManager<CONNECTION : Connection>(
}
// we do the FIRST chunk super-weird, because of the way we copy data around (we inject headers,
// we do the FIRST block super-weird, because of the way we copy data around (we inject headers,
// so the first message is SUPER tiny and is a COPY, the rest are no-copy.
// This is REUSED to prevent garbage collection issues.
val chunkData = StreamingData(streamSessionId)
// payload size is for a PRODUCER, and not SUBSCRIBER, so we have to include this amount every time.
// MINOR fragmentation by aeron is OK, since that will greatly speed up data transfer rates!
// the maxPayloadLength MUST ABSOLUTELY be less that the max size + header!
var sizeOfPayload = publication.maxMessageLength() - 200
var sizeOfBlockData = maxMessageSize
val header: ByteArray
val headerSize: Int
try {
val objectBuffer = kryo.write(connection, chunkData)
// This is REUSED to prevent garbage collection issues.
val blockData = StreamingData(streamSessionId)
val objectBuffer = kryo.write(connection, blockData)
headerSize = objectBuffer.position()
header = ByteArray(headerSize)
// we have to account for the header + the MAX optimized int size
sizeOfPayload -= (headerSize + 5)
// we have to account for the header + the MAX optimized int size (position and data-length)
val dataSize = headerSize + 5 + 5
sizeOfBlockData -= dataSize
// this size might be a LITTLE too big, but that's ok, since we only make this specific buffer once.
val chunkBuffer = AeronOutput(headerSize + sizeOfPayload)
val blockBuffer = AeronOutput(dataSize)
// copy out our header info
objectBuffer.internalBuffer.getBytes(0, header, 0, headerSize)
// write out our header
chunkBuffer.writeBytes(header)
blockBuffer.writeBytes(header)
// write out the payload size using optimized data structures.
val varIntSize = chunkBuffer.writeVarInt(sizeOfPayload, true)
// write out the start-position (of the payload). First start-position is always 0
val positionIntSize = blockBuffer.writeVarInt(0, true)
// write out the payload size
val payloadIntSize = blockBuffer.writeVarInt(sizeOfBlockData, true)
// write out the payload. Our resulting data written out is the ACTUAL MTU of aeron.
originalBuffer.getBytes(0, chunkBuffer.internalBuffer, headerSize + varIntSize, sizeOfPayload)
originalBuffer.getBytes(0, blockBuffer.internalBuffer, headerSize + positionIntSize + payloadIntSize, sizeOfBlockData)
remainingPayload -= sizeOfPayload
payloadSent += sizeOfPayload
remainingPayload -= sizeOfBlockData
payloadSent += sizeOfBlockData
val success = endPoint.dataSend(
publication,
chunkBuffer.internalBuffer,
0,
headerSize + varIntSize + sizeOfPayload,
sendIdleStrategy,
connection,
false
// we reuse/recycle objects, so the payload size is not EXACTLY what is specified
val reusedPayloadSize = headerSize + positionIntSize + payloadIntSize + sizeOfBlockData
val success = endPoint.aeronDriver.send(
publication = publication,
internalBuffer = blockBuffer.internalBuffer,
bufferClaim = kryo.bufferClaim,
offset = 0,
objectSize = reusedPayloadSize,
sendIdleStrategy = sendIdleStrategy,
connection = connection,
abortEarly = false,
listenerManager = endPoint.listenerManager
)
if (!success) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
@ -436,16 +470,16 @@ internal class StreamingManager<CONNECTION : Connection>(
throw exception
}
} catch (e: Exception) {
sendFailMessageAndThrow(e, streamSessionId, publication, endPoint, kryo, sendIdleStrategy, connection)
sendFailMessageAndThrow(e, streamSessionId, publication, endPoint, sendIdleStrategy, connection, kryo)
return false // doesn't actually get here because exceptions are thrown, but this makes the IDE happy.
}
// now send the chunks as fast as possible. Aeron will have us back-off if we send too quickly
// now send the block as fast as possible. Aeron will have us back-off if we send too quickly
while (remainingPayload > 0) {
val amountToSend = if (remainingPayload < sizeOfPayload) {
val amountToSend = if (remainingPayload < sizeOfBlockData) {
remainingPayload
} else {
sizeOfPayload
sizeOfBlockData
}
remainingPayload -= amountToSend
@ -458,32 +492,283 @@ internal class StreamingManager<CONNECTION : Connection>(
// fortunately, the way that serialization works, we can safely ADD data to the tail and then appropriately read it off
// on the receiving end without worry.
/// TODO: Compression/encryption??
try {
val varIntSize = OptimizeUtilsByteBuf.intLength(sizeOfPayload, true)
val writeIndex = payloadSent - headerSize - varIntSize
val positionIntSize = OptimizeUtilsByteBuf.intLength(payloadSent, true)
val payloadIntSize = OptimizeUtilsByteBuf.intLength(amountToSend, true)
val writeIndex = payloadSent - headerSize - positionIntSize - payloadIntSize
// write out our header data (this will OVERWRITE previous data!)
originalBuffer.putBytes(writeIndex, header)
// write out the payload size using optimized data structures.
writeVarInt(originalBuffer, writeIndex + headerSize, sizeOfPayload, true)
// write out the payload start position
writeVarInt(originalBuffer, writeIndex + headerSize, payloadSent, true)
// write out the payload size
writeVarInt(originalBuffer, writeIndex + headerSize + positionIntSize, amountToSend, true)
// we reuse/recycle objects, so the payload size is not EXACTLY what is specified
val reusedPayloadSize = headerSize + payloadIntSize + positionIntSize + amountToSend
// write out the payload
endPoint.dataSend(
publication,
originalBuffer,
writeIndex,
headerSize + varIntSize + amountToSend,
sendIdleStrategy,
connection,
false
val success = endPoint.aeronDriver.send(
publication = publication,
internalBuffer = originalBuffer,
bufferClaim = kryo.bufferClaim,
offset = writeIndex,
objectSize = reusedPayloadSize,
sendIdleStrategy = sendIdleStrategy,
connection = connection,
abortEarly = false,
listenerManager = endPoint.listenerManager
)
if (!success) {
// critical errors have an exception. Normal "the connection is closed" do not.
return false
}
payloadSent += amountToSend
} catch (e: Exception) {
val failMessage = StreamingControl(StreamingState.FAILED, false, streamSessionId)
val failSent = endPoint.writeUnsafe(failMessage, publication, sendIdleStrategy, connection, kryo)
if (!failSent) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Abnormal failure with exception while streaming content."
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage, e)
exception.cleanAllStackTrace()
throw exception
} else {
// send it up!
throw e
}
}
}
// send the last block of data
val finishedMessage = StreamingControl(StreamingState.FINISHED, false, streamSessionId, payloadSent.toLong())
return endPoint.writeUnsafe(finishedMessage, publication, sendIdleStrategy, connection, kryo)
}
/**
* This is called ONLY when a message is too large to send across the network in a single message (large data messages should
* be split into smaller ones anyways!)
*
* NOTE: this **MUST** stay on the same co-routine that calls "send". This cannot be re-dispatched onto a different coroutine!
*
* We don't write max possible length per message, we write out MTU (payload) length (so aeron doesn't fragment the message).
* The max possible length is WAY, WAY more than the max payload length.
*
* @param streamSessionId the stream session ID is a combination of the connection ID + random ID (on the receiving side)
*
* @return true if ALL the message blocks were successfully sent by aeron, false otherwise. Exceptions are caught and rethrown!
*/
@Suppress("SameParameterValue")
fun sendFile(
file: File,
publication: Publication,
endPoint: EndPoint<CONNECTION>,
kryo: KryoWriter<CONNECTION>,
sendIdleStrategy: IdleStrategy,
connection: CONNECTION,
streamSessionId: Int
): Boolean {
val maxMessageSize = connection.maxMessageSize.toLong()
val fileInputStream = file.inputStream()
// if the message is a file, we xfer the file AS a file, and leave it as a temp file (with a file reference to it) on the remote endpoint
// the temp file will be unique.
// NOTE: our max object size for IN-MEMORY messages is an INT. For file transfer it's a LONG (so everything here is cast to a long)
var remainingPayload = file.length()
var payloadSent = 0
// tell the other side how much data we are sending
val startMessage = StreamingControl(StreamingState.START, true, streamSessionId, remainingPayload)
val startSent = endPoint.writeUnsafe(startMessage, publication, sendIdleStrategy, connection, kryo)
if (!startSent) {
fileInputStream.close()
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Error starting streaming file."
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
throw exception
}
// we do the FIRST block super-weird, because of the way we copy data around (we inject headers),
// so the first message is SUPER tiny and is a COPY, the rest are no-copy.
// payload size is for a PRODUCER, and not SUBSCRIBER, so we have to include this amount every time.
// we don't know which is larger, the max message size or the file size!
var sizeOfBlockData = maxMessageSize.coerceAtMost(remainingPayload).toInt()
val headerSize: Int
val buffer: ByteArray
val blockBuffer: UnsafeBuffer
try {
// This is REUSED to prevent garbage collection issues.
val blockData = StreamingData(streamSessionId)
val objectBuffer = kryo.write(connection, blockData)
headerSize = objectBuffer.position()
// we have to account for the header + the MAX optimized int size (position and data-length)
val dataSize = headerSize + 5 + 5
sizeOfBlockData -= dataSize
// this size might be a LITTLE too big, but that's ok, since we only make this specific buffer once.
buffer = ByteArray(sizeOfBlockData + dataSize)
blockBuffer = UnsafeBuffer(buffer)
// copy out our header info (this skips the header object)
objectBuffer.internalBuffer.getBytes(0, buffer, 0, headerSize)
// write out the start-position (of the payload). First start-position is always 0
val positionIntSize = OptimizeUtilsByteArray.writeInt(buffer, 0, true, headerSize)
// write out the payload size
val payloadIntSize = OptimizeUtilsByteArray.writeInt(buffer, sizeOfBlockData, true, headerSize + positionIntSize)
// write out the payload. Our resulting data written out is the ACTUAL MTU of aeron.
val readBytes = fileInputStream.read(buffer, headerSize + positionIntSize + payloadIntSize, sizeOfBlockData)
if (readBytes != sizeOfBlockData) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Abnormal failure while streaming file (read bytes was wrong! ${readBytes} - ${sizeOfBlockData}."
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
throw exception
}
remainingPayload -= sizeOfBlockData
payloadSent += sizeOfBlockData
// we reuse/recycle objects, so the payload size is not EXACTLY what is specified
val reusedPayloadSize = headerSize + positionIntSize + payloadIntSize + sizeOfBlockData
val success = endPoint.aeronDriver.send(
publication = publication,
internalBuffer = blockBuffer,
bufferClaim = kryo.bufferClaim,
offset = 0,
objectSize = reusedPayloadSize,
sendIdleStrategy = sendIdleStrategy,
connection = connection,
abortEarly = false,
listenerManager = endPoint.listenerManager
)
if (!success) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Abnormal failure while streaming file."
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
throw exception
}
} catch (e: Exception) {
fileInputStream.close()
sendFailMessageAndThrow(e, streamSessionId, publication, endPoint, sendIdleStrategy, connection, kryo)
return false // doesn't actually get here because exceptions are thrown, but this makes the IDE happy.
}
val aeronDriver = endPoint.aeronDriver
val listenerManager = endPoint.listenerManager
// now send the block as fast as possible. Aeron will have us back-off if we send too quickly
while (remainingPayload > 0) {
val amountToSend = if (remainingPayload < sizeOfBlockData) {
remainingPayload.toInt()
} else {
sizeOfBlockData
}
remainingPayload -= amountToSend
// to properly do this, we have to be careful with the underlying protocol, in order to avoid copying the buffer multiple times.
// the data that will be sent is object data + buffer data. We are sending the SAME parent buffer, just at different spots and
// with different headers -- so we don't copy out the data repeatedly
// fortunately, the way that serialization works, we can safely ADD data to the tail and then appropriately read it off
// on the receiving end without worry.
/// TODO: Compression/encryption??
try {
// write out the payload start position
val positionIntSize = OptimizeUtilsByteArray.writeInt(buffer, payloadSent, true, headerSize)
// write out the payload size
val payloadIntSize = OptimizeUtilsByteArray.writeInt(buffer, amountToSend, true, headerSize + positionIntSize)
// write out the payload. Our resulting data written out is the ACTUAL MTU of aeron.
val readBytes = fileInputStream.read(buffer, headerSize + positionIntSize + payloadIntSize, amountToSend)
if (readBytes != amountToSend) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
val errorMessage = "[${publication.sessionId()}] Abnormal failure while streaming file (read bytes was wrong! ${readBytes} - ${amountToSend}."
// either client or server. No other choices. We create an exception, because it's more useful!
val exception = endPoint.newException(errorMessage)
// +3 more because we do not need to see the "internals" for sending messages. The important part of the stack trace is
// where we see who is calling "send()"
exception.cleanStackTrace(3)
throw exception
}
// we reuse/recycle objects, so the payload size is not EXACTLY what is specified
val reusedPayloadSize = headerSize + positionIntSize + payloadIntSize + amountToSend
// write out the payload
aeronDriver.send(
publication = publication,
internalBuffer = blockBuffer,
bufferClaim = kryo.bufferClaim,
offset = 0, // 0 because we are not reading the entire file at once
objectSize = reusedPayloadSize,
sendIdleStrategy = sendIdleStrategy,
connection = connection,
abortEarly = false,
listenerManager = listenerManager
)
payloadSent += amountToSend
} catch (e: Exception) {
val failMessage = StreamingControl(StreamingState.FAILED, streamSessionId)
fileInputStream.close()
val failSent = endPoint.writeUnsafe(kryo, failMessage, publication, sendIdleStrategy, connection)
val failMessage = StreamingControl(StreamingState.FAILED, false, streamSessionId)
val failSent = endPoint.writeUnsafe(failMessage, publication, sendIdleStrategy, connection, kryo)
if (!failSent) {
// something SUPER wrong!
// more critical error sending the message. we shouldn't retry or anything.
@ -503,9 +788,11 @@ internal class StreamingManager<CONNECTION : Connection>(
}
}
// send the last chunk of data
val finishedMessage = StreamingControl(StreamingState.FINISHED, streamSessionId, payloadSent.toLong())
fileInputStream.close()
return endPoint.writeUnsafe(kryo, finishedMessage, publication, sendIdleStrategy, connection)
// send the last block of data
val finishedMessage = StreamingControl(StreamingState.FINISHED, true, streamSessionId, payloadSent.toLong())
return endPoint.writeUnsafe(finishedMessage, publication, sendIdleStrategy, connection, kryo)
}
}

View File

@ -17,5 +17,6 @@
package dorkbox.network.connection.streaming
interface StreamingWriter {
fun writeBytes(bytes: ByteArray)
fun writeBytes(startPosition: Int, bytes: ByteArray)
fun isFinished(): Boolean
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connection.streaming;

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.connectionType;

View File

@ -1,39 +0,0 @@
/*
* Copyright 2020 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.coroutines;
import kotlin.coroutines.Continuation;
import kotlin.jvm.functions.Function1;
/**
* Class to access suspending invocation of methods from kotlin...
*
* ULTIMATELY, this is all java bytecode, and the bytecode signature here matches what kotlin expects. The generics type information is
* discarded at compile time.
*/
public
class SuspendFunctionTrampoline {
/**
* trampoline so we can access suspend functions correctly using reflection
*/
@SuppressWarnings("unchecked")
public static
Object invoke(final Continuation<?> continuation, final Object suspendFunction) throws Throwable {
Function1<? super Continuation<? super Object>, ?> suspendFunction1 = (Function1<? super Continuation<? super Object>, ?>) suspendFunction;
return suspendFunction1.invoke((Continuation<? super Object>) continuation);
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.exceptions
/**
* The type of exceptions raised for send-sync errors
*/
open class SendSyncException : Exception {
/**
* Create an exception.
*
* @param message The message
*/
constructor(message: String) : super(message)
/**
* Create an exception.
*
* @param cause The cause
*/
constructor(cause: Throwable) : super(cause)
/**
* Create an exception.
*
* @param message The message
* @param cause The cause
*/
constructor(message: String, cause: Throwable?) : super(message, cause)
}

View File

@ -0,0 +1,21 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.exceptions
class TimeoutException: Exception() {
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.exceptions;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,9 +21,11 @@ import dorkbox.network.aeron.AeronDriver.Companion.getLocalAddressString
import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.aeron.controlEndpoint
import dorkbox.network.aeron.endpoint
import dorkbox.network.connection.EndPoint
import dorkbox.network.exceptions.ClientRetryException
import dorkbox.network.exceptions.ClientTimedOutException
import io.aeron.CommonContext
import kotlinx.atomicfu.AtomicBoolean
import java.net.Inet4Address
import java.net.InetAddress
@ -39,11 +41,14 @@ import java.net.InetAddress
internal class ClientConnectionDriver(val connectionInfo: PubSub) {
companion object {
suspend fun build(
fun build(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
handshakeConnection: ClientHandshakeDriver,
connectionInfo: ClientConnectionInfo
connectionInfo: ClientConnectionInfo,
port2Server: Int, // this is the port2 value from the server
tagName: String
): ClientConnectionDriver {
val handshakePubSub = handshakeConnection.pubSub
val reliable = handshakePubSub.reliable
@ -65,6 +70,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
logInfo = "CONNECTION-IPC"
pubSub = buildIPC(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
sessionIdPub = sessionIdPub,
@ -72,6 +78,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -88,6 +95,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
}
pubSub = buildUDP(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
sessionIdPub = sessionIdPub,
@ -98,7 +106,9 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
port2Server = port2Server,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -107,7 +117,8 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
}
@Throws(ClientTimedOutException::class)
private suspend fun buildIPC(
private fun buildIPC(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
sessionIdPub: Int,
@ -115,6 +126,7 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
streamIdPub: Int,
streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -128,11 +140,11 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// can throw an exception! We catch it in the calling class
val publication = aeronDriver.addExclusivePublication(publicationUri, streamIdPub, logInfo, true)
val publication = aeronDriver.addPublication(publicationUri, streamIdPub, logInfo, true)
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo publication cannot connect with server!", cause)
}
@ -141,14 +153,32 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
val subscriptionUri = uri(CommonContext.IPC_MEDIA, sessionIdSub, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable)
// wait for the REMOTE end to also connect to us!
aeronDriver.waitForConnection(shutdown, subscription, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
}
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
@Throws(ClientTimedOutException::class)
private suspend fun buildUDP(
private fun buildUDP(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
sessionIdPub: Int,
@ -159,7 +189,9 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
remoteAddressString: String,
portPub: Int,
portSub: Int,
port2Server: Int, // this is the port2 value from the server
reliable: Boolean,
tagName: String,
logInfo: String,
): PubSub {
val isRemoteIpv4 = remoteAddress is Inet4Address
@ -176,11 +208,11 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// publication of any state to other threads and not be long running or re-entrant with the client.
// can throw an exception! We catch it in the calling class
val publication = aeronDriver.addExclusivePublication(publicationUri, streamIdPub, logInfo, false)
val publication = aeronDriver.addPublication(publicationUri, streamIdPub, logInfo, false)
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo publication cannot connect with server $remoteAddressString", cause)
}
@ -191,20 +223,32 @@ internal class ClientConnectionDriver(val connectionInfo: PubSub) {
// A control endpoint for the subscriptions will cause a periodic service management "heartbeat" to be sent to the
// remote endpoint publication, which permits the remote publication to send us data, thereby getting us around NAT
val subscriptionUri = uri(CommonContext.UDP_MEDIA, sessionIdSub, reliable)
.endpoint(isRemoteIpv4, localAddressString, 0) // 0 for MDC!
.controlEndpoint(isRemoteIpv4, remoteAddressString, portSub)
.endpoint(isRemoteIpv4, localAddressString, portSub)
.controlEndpoint(isRemoteIpv4, remoteAddressString, port2Server)
.controlMode(CommonContext.MDC_CONTROL_MODE_DYNAMIC)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
// wait for the REMOTE end to also connect to us!
aeronDriver.waitForConnection(shutdown, subscription, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo subscription cannot connect with server!", cause)
}
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
}

View File

@ -15,10 +15,16 @@
*/
package dorkbox.network.handshake
internal class ClientConnectionInfo(val sessionIdPub: Int = 0,
val sessionIdSub: Int = 0,
val streamIdPub: Int,
val streamIdSub: Int = 0,
val publicKey: ByteArray = ByteArray(0),
val kryoRegistrationDetails: ByteArray) {
}
import javax.crypto.spec.SecretKeySpec
internal class ClientConnectionInfo(
val sessionIdPub: Int = 0,
val sessionIdSub: Int = 0,
val streamIdPub: Int,
val streamIdSub: Int = 0,
val publicKey: ByteArray = ByteArray(0),
val sessionTimeout: Long,
val bufferedMessages: Boolean,
val kryoRegistrationDetails: ByteArray,
val secretKey: SecretKeySpec
)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package dorkbox.network.handshake
import dorkbox.network.Client
import dorkbox.network.connection.Connection
import dorkbox.network.connection.CryptoManagement
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager.Companion.cleanAllStackTrace
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTraceInternal
import dorkbox.network.exceptions.*
@ -26,13 +27,12 @@ import io.aeron.FragmentAssembler
import io.aeron.Image
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header
import kotlinx.coroutines.delay
import mu.KLogger
import org.agrona.DirectBuffer
import org.slf4j.Logger
internal class ClientHandshake<CONNECTION: Connection>(
private val client: Client<CONNECTION>,
private val logger: KLogger
private val logger: Logger
) {
// @Volatile is used BECAUSE suspension of coroutines can continue on a DIFFERENT thread. We want to make sure that thread visibility is
@ -92,8 +92,8 @@ internal class ClientHandshake<CONNECTION: Connection>(
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) {
throw ClientRejectedException("[$logInfo] Connection not allowed! unrecognized message: $msg") .apply { cleanAllStackTrace() }
} else {
logger.trace { "[$logInfo] (${msg.connectKey}) received HS: $msg" }
} else if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${msg.connectKey}) received HS: $msg")
}
msg
} catch (e: Exception) {
@ -177,7 +177,12 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed
suspend fun hello(handshakeConnection: ClientHandshakeDriver, handshakeTimeoutNs: Long) : ClientConnectionInfo {
fun hello(
tagName: String,
endPoint: EndPoint<CONNECTION>,
handshakeConnection: ClientHandshakeDriver,
handshakeTimeoutNs: Long
) : ClientConnectionInfo {
val pubSub = handshakeConnection.pubSub
// is our pub still connected??
@ -194,12 +199,13 @@ internal class ClientHandshake<CONNECTION: Connection>(
handshaker.writeMessage(pubSub.pub, handshakeConnection.details,
HandshakeMessage.helloFromClient(
connectKey = connectKey,
publicKey = client.storage.publicKey!!,
publicKey = client.storage.publicKey,
streamIdSub = pubSub.streamIdSub,
portSub = pubSub.portSub
portSub = pubSub.portSub,
tagName = tagName
))
} catch (e: Exception) {
handshakeConnection.close()
handshakeConnection.close(endPoint)
throw TransmitException("$handshakeConnection Handshake message error!", e)
}
@ -211,23 +217,23 @@ internal class ClientHandshake<CONNECTION: Connection>(
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
pubSub.sub.poll(handler, 1)
if (failedException != null || connectionHelloInfo != null) {
if (endPoint.isShutdown() || failedException != null || connectionHelloInfo != null) {
break
}
delay(100)
Thread.sleep(100)
}
val failedEx = failedException
if (failedEx != null) {
handshakeConnection.close()
handshakeConnection.close(endPoint)
failedEx.cleanStackTraceInternal()
throw failedEx
}
if (connectionHelloInfo == null) {
handshakeConnection.close()
handshakeConnection.close(endPoint)
val exception = ClientTimedOutException("$handshakeConnection Waiting for registration response from server for more than ${Sys.getTimePrettyFull(handshakeTimeoutNs)}")
throw exception
@ -238,11 +244,12 @@ internal class ClientHandshake<CONNECTION: Connection>(
// called from the connect thread
// when exceptions are thrown, the handshake pub/sub will be closed
suspend fun done(
fun done(
endPoint: EndPoint<CONNECTION>,
handshakeConnection: ClientHandshakeDriver,
clientConnection: ClientConnectionDriver,
handshakeTimeoutNs: Long,
aeronLogInfo: String
logInfo: String
) {
val pubSub = clientConnection.connectionInfo
val handshakePubSub = handshakeConnection.pubSub
@ -254,15 +261,14 @@ internal class ClientHandshake<CONNECTION: Connection>(
// Send the done message to the server.
try {
handshaker.writeMessage(handshakeConnection.pubSub.pub, aeronLogInfo,
handshaker.writeMessage(handshakeConnection.pubSub.pub, logInfo,
HandshakeMessage.doneFromClient(
connectKey = connectKey,
sessionIdSub = handshakePubSub.sessionIdSub,
streamIdSub = handshakePubSub.streamIdSub,
portSub = handshakePubSub.portSub
streamIdSub = handshakePubSub.streamIdSub
))
} catch (e: Exception) {
handshakeConnection.close()
handshakeConnection.close(endPoint)
throw TransmitException("$handshakeConnection Handshake message error!", e)
}
@ -277,7 +283,7 @@ internal class ClientHandshake<CONNECTION: Connection>(
// `.poll(handler, 4)` == `.poll(handler, 2)` + `.poll(handler, 2)`
handshakePubSub.sub.poll(handler, 1)
if (failedException != null || connectionDone) {
if (endPoint.isShutdown() || failedException != null || connectionDone) {
break
}
@ -288,19 +294,19 @@ internal class ClientHandshake<CONNECTION: Connection>(
startTime = System.nanoTime()
}
delay(100)
Thread.sleep(100)
}
val failedEx = failedException
if (failedEx != null) {
handshakeConnection.close()
handshakeConnection.close(endPoint)
throw failedEx
}
if (!connectionDone) {
// since this failed, close everything
handshakeConnection.close()
handshakeConnection.close(endPoint)
val exception = ClientTimedOutException("Timed out waiting for registration response from server: ${Sys.getTimePrettyFull(handshakeTimeoutNs)}")
throw exception

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,7 +16,6 @@
package dorkbox.network.handshake
import dorkbox.network.Configuration
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.getLocalAddressString
import dorkbox.network.aeron.AeronDriver.Companion.streamIdAllocator
@ -34,7 +33,8 @@ import dorkbox.network.exceptions.ClientTimedOutException
import dorkbox.util.Sys
import io.aeron.CommonContext
import io.aeron.Subscription
import mu.KLogger
import kotlinx.atomicfu.AtomicBoolean
import org.slf4j.Logger
import java.net.Inet4Address
import java.net.InetAddress
import java.util.*
@ -48,33 +48,36 @@ import java.util.*
* @throws ClientTimedOutException if we cannot connect to the server in the designated time
*/
internal class ClientHandshakeDriver(
private val aeronDriver: AeronDriver,
val aeronDriver: AeronDriver,
val pubSub: PubSub,
private val logInfo: String,
val details: String
) {
companion object {
suspend fun build(
config: Configuration,
fun build(
endpoint: EndPoint<*>,
aeronDriver: AeronDriver,
autoChangeToIpc: Boolean,
remoteAddress: InetAddress?,
remoteAddressString: String,
remotePort: Int,
port: Int,
remotePort1: Int,
remotePort2: Int,
clientListenPort: Int,
handshakeTimeoutNs: Long,
connectionTimoutInNs: Long,
reliable: Boolean,
logger: KLogger
tagName: String,
logger: Logger
): ClientHandshakeDriver {
logger.trace { "Starting client handshake" }
logger.trace("Starting client handshake")
var isUsingIPC = false
if (autoChangeToIpc) {
if (remoteAddress == null) {
logger.info { "IPC enabled" }
logger.info("IPC enabled")
} else {
logger.warn { "IPC for loopback enabled and aeron is already running. Auto-changing network connection from '$remoteAddressString' -> IPC" }
logger.warn("IPC for loopback enabled and aeron is already running. Auto-changing network connection from '$remoteAddressString' -> IPC")
}
isUsingIPC = true
}
@ -96,24 +99,38 @@ internal class ClientHandshakeDriver(
var pubSub: PubSub? = null
val timeoutInfo = if (connectionTimoutInNs > 0L) {
"[Handshake: ${Sys.getTimePrettyFull(handshakeTimeoutNs)}, Max connection attempt: ${Sys.getTimePrettyFull(connectionTimoutInNs)}]"
} else {
"[Handshake: ${Sys.getTimePrettyFull(handshakeTimeoutNs)}, Max connection attempt: Unlimited]"
}
val config = endpoint.config
val shutdown = endpoint.shutdown
if (isUsingIPC) {
streamIdPub = config.ipcId
logInfo = "HANDSHAKE-IPC"
details = logInfo
logger.info("Client connecting via IPC. $timeoutInfo")
try {
pubSub = buildIPC(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
sessionIdPub = sessionIdPub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
} catch (exception: Exception) {
logger.error(exception) { "Error initializing IPC connection" }
logger.error("Error initializing IPC connection", exception)
// MAYBE the server doesn't have IPC enabled? If no, we need to connect via network instead
isUsingIPC = false
@ -147,17 +164,26 @@ internal class ClientHandshakeDriver(
streamIdPub = config.udpId
if (remoteAddress is Inet4Address) {
logger.info("Client connecting to IPv4 $remoteAddressString. $timeoutInfo")
} else {
logger.info("Client connecting to IPv6 $remoteAddressString. $timeoutInfo")
}
pubSub = buildUDP(
shutdown = shutdown,
aeronDriver = aeronDriver,
handshakeTimeoutNs = handshakeTimeoutNs,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = remotePort,
portSub = port,
portPub = remotePort1,
portSub = clientListenPort,
port2Server = remotePort2,
sessionIdPub = sessionIdPub,
streamIdPub = streamIdPub,
reliable = reliable,
streamIdSub = streamIdSub,
tagName = tagName,
logInfo = logInfo
)
@ -180,13 +206,16 @@ internal class ClientHandshakeDriver(
}
@Throws(ClientTimedOutException::class)
private suspend fun buildIPC(
private fun buildIPC(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
sessionIdPub: Int,
streamIdPub: Int, streamIdSub: Int,
streamIdPub: Int,
streamIdSub: Int,
reliable: Boolean,
logInfo: String
tagName: String,
logInfo: String,
): PubSub {
// Create a publication at the given address and port, using the given stream ID.
// Note: The Aeron.addPublication method will block until the Media Driver acknowledges the request or a timeout occurs.
@ -200,11 +229,11 @@ internal class ClientHandshakeDriver(
// this check is in the "reconnect" logic
// can throw an exception! We catch it in the calling class
val publication = aeronDriver.addExclusivePublication(publicationUri, streamIdPub, logInfo, true)
val publication = aeronDriver.addPublication(publicationUri, streamIdPub, logInfo, true)
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ClientTimedOutException("$logInfo publication cannot connect with server in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
@ -212,24 +241,37 @@ internal class ClientHandshakeDriver(
val subscriptionUri = uriHandshake(CommonContext.IPC_MEDIA, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, 0,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
@Throws(ClientTimedOutException::class)
private suspend fun buildUDP(
private fun buildUDP(
shutdown: AtomicBoolean,
aeronDriver: AeronDriver,
handshakeTimeoutNs: Long,
remoteAddress: InetAddress,
remoteAddressString: String,
portPub: Int,
portPub: Int, // this is the port1 value from the server
portSub: Int,
port2Server: Int, // this is the port2 value from the server
sessionIdPub: Int,
streamIdPub: Int,
reliable: Boolean,
streamIdSub: Int,
tagName: String,
logInfo: String,
): PubSub {
@Suppress("NAME_SHADOWING")
@ -254,11 +296,11 @@ internal class ClientHandshakeDriver(
// can throw an exception! We catch it in the calling class
val publication = aeronDriver.addExclusivePublication(publicationUri, streamIdPub, logInfo, false)
val publication = aeronDriver.addPublication(publicationUri, streamIdPub, logInfo, false)
// can throw an exception! We catch it in the calling class
// we actually have to wait for it to connect before we continue
aeronDriver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
aeronDriver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
streamIdAllocator.free(streamIdSub) // we don't continue, so close this as well
ClientTimedOutException("$logInfo publication cannot connect with server in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
@ -277,8 +319,8 @@ internal class ClientHandshakeDriver(
// A control endpoint for the subscriptions will cause a periodic service management "heartbeat" to be sent to the
// remote endpoint publication, which permits the remote publication to send us data, thereby getting us around NAT
val subscriptionUri = uriHandshake(CommonContext.UDP_MEDIA, reliable)
.endpoint(isRemoteIpv4, localAddressString, 0) // 0 for MDC!
.controlEndpoint(isRemoteIpv4, remoteAddressString, portSub)
.endpoint(isRemoteIpv4, localAddressString, portSub)
.controlEndpoint(isRemoteIpv4, remoteAddressString, port2Server)
.controlMode(CommonContext.MDC_CONTROL_MODE_DYNAMIC)
subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
@ -300,8 +342,8 @@ internal class ClientHandshakeDriver(
// A control endpoint for the subscriptions will cause a periodic service management "heartbeat" to be sent to the
// remote endpoint publication, which permits the remote publication to send us data, thereby getting us around NAT
val subscriptionUri = uriHandshake(CommonContext.UDP_MEDIA, reliable)
.endpoint(isRemoteIpv4, localAddressString, 0) // 0 for MDC!
.controlEndpoint(isRemoteIpv4, remoteAddressString, portSub)
.endpoint(isRemoteIpv4, localAddressString, portSub)
.controlEndpoint(isRemoteIpv4, remoteAddressString, port2Server)
.controlMode(CommonContext.MDC_CONTROL_MODE_DYNAMIC)
subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
@ -317,16 +359,24 @@ internal class ClientHandshakeDriver(
throw ex
}
return PubSub(publication, subscription,
sessionIdPub, 0,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = 0,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
suspend fun close() {
fun close(endpoint: EndPoint<*>) {
// only the subs are allocated on the client!
// sessionIdAllocator.free(pubSub.sessionIdPub)
// sessionIdAllocator.free(sessionIdSub)
@ -334,7 +384,19 @@ internal class ClientHandshakeDriver(
streamIdAllocator.free(pubSub.streamIdSub)
// on close, we want to make sure this file is DELETED!
aeronDriver.close(pubSub.sub, logInfo)
aeronDriver.close(pubSub.pub, logInfo)
// we might not be able to close these connections.
try {
aeronDriver.close(pubSub.sub, logInfo)
}
catch (e: Exception) {
endpoint.listenerManager.notifyError(e)
}
try {
aeronDriver.close(pubSub.pub, logInfo)
}
catch (e: Exception) {
endpoint.listenerManager.notifyError(e)
}
}
}

View File

@ -1,3 +1,19 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.handshake
import org.agrona.collections.Object2IntHashMap
@ -9,18 +25,22 @@ import java.net.InetAddress
internal class ConnectionCounts {
private val connectionsPerIpCounts = Object2IntHashMap<InetAddress>(-1)
@Synchronized
fun get(inetAddress: InetAddress): Int {
return connectionsPerIpCounts.getOrPut(inetAddress) { 0 }
}
@Synchronized
fun increment(inetAddress: InetAddress, currentCount: Int) {
connectionsPerIpCounts[inetAddress] = currentCount + 1
}
@Synchronized
fun decrement(inetAddress: InetAddress, currentCount: Int) {
connectionsPerIpCounts[inetAddress] = currentCount - 1
}
@Synchronized
fun decrementSlow(inetAddress: InetAddress) {
if (connectionsPerIpCounts.containsKey(inetAddress)) {
val defaultVal = connectionsPerIpCounts.getValue(inetAddress)
@ -28,10 +48,12 @@ internal class ConnectionCounts {
}
}
@Synchronized
fun isEmpty(): Boolean {
return connectionsPerIpCounts.isEmpty()
}
@Synchronized
override fun toString(): String {
return connectionsPerIpCounts.entries.map { it.key }.joinToString()
}

View File

@ -32,6 +32,9 @@ internal class HandshakeMessage private constructor() {
// -1 means there is an error
var state = INVALID
// used to name a connection (via the client)
var tag: String = ""
var errorMessage: String? = null
var port = 0
@ -51,7 +54,7 @@ internal class HandshakeMessage private constructor() {
const val DONE = 3
const val DONE_ACK = 4
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int): HandshakeMessage {
fun helloFromClient(connectKey: Long, publicKey: ByteArray, streamIdSub: Int, portSub: Int, tagName: String): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = HELLO
hello.connectKey = connectKey // this is 'bounced back' by the server, so the client knows if it's the correct connection message
@ -59,6 +62,7 @@ internal class HandshakeMessage private constructor() {
hello.sessionId = 0 // not used by the server, since it connects in a different way!
hello.streamId = streamIdSub
hello.port = portSub
hello.tag = tagName
return hello
}
@ -76,13 +80,12 @@ internal class HandshakeMessage private constructor() {
return hello
}
fun doneFromClient(connectKey: Long, sessionIdSub: Int, streamIdSub: Int, portSub: Int): HandshakeMessage {
fun doneFromClient(connectKey: Long, sessionIdSub: Int, streamIdSub: Int): HandshakeMessage {
val hello = HandshakeMessage()
hello.state = DONE
hello.connectKey = connectKey // THIS MUST NEVER CHANGE! (the server/client expect this)
hello.sessionId = sessionIdSub
hello.streamId = streamIdSub
hello.port = portSub
return hello
}
@ -136,6 +139,6 @@ internal class HandshakeMessage private constructor() {
""
}
return "HandshakeMessage($stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
return "HandshakeMessage($tag :: $stateStr$errorMsg sessionId=$sessionId, streamId=$streamId, port=$port${connectInfo})"
}
}

View File

@ -18,38 +18,44 @@ package dorkbox.network.handshake
import dorkbox.network.Configuration
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.CoroutineIdleStrategy
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTraceInternal
import dorkbox.network.exceptions.ClientException
import dorkbox.network.exceptions.ServerException
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.KryoReader
import dorkbox.network.serialization.KryoWriter
import dorkbox.network.serialization.Serialization
import io.aeron.Publication
import mu.KLogger
import io.aeron.logbuffer.FrameDescriptor
import org.agrona.DirectBuffer
import org.agrona.concurrent.IdleStrategy
import org.slf4j.Logger
internal class Handshaker<CONNECTION : Connection>(
private val logger: KLogger,
private val logger: Logger,
config: Configuration,
serialization: Serialization<CONNECTION>,
private val listenerManager: ListenerManager<CONNECTION>,
aeronDriver: AeronDriver,
val aeronDriver: AeronDriver,
val newException: (String, Throwable?) -> Throwable
) {
private val handshakeReadKryo: KryoExtra<CONNECTION>
private val handshakeWriteKryo: KryoExtra<CONNECTION>
private val handshakeSendIdleStrategy: CoroutineIdleStrategy
private val writeTimeoutNS = (aeronDriver.lingerNs() * 1.2).toLong() // close enough. Just needs to be slightly longer
private val handshakeReadKryo: KryoReader<CONNECTION>
private val handshakeWriteKryo: KryoWriter<CONNECTION>
private val handshakeSendIdleStrategy: IdleStrategy
init {
handshakeReadKryo = serialization.newHandshakeKryo()
handshakeWriteKryo = serialization.newHandshakeKryo()
handshakeSendIdleStrategy = config.sendIdleStrategy.clone()
val maxMessageSize = FrameDescriptor.computeMaxMessageLength(config.publicationTermBufferLength)
// All registration MUST happen in-order of when the register(*) method was called, otherwise there are problems.
handshakeReadKryo = KryoReader(maxMessageSize)
handshakeWriteKryo = KryoWriter(maxMessageSize)
serialization.newHandshakeKryo(handshakeReadKryo)
serialization.newHandshakeKryo(handshakeWriteKryo)
handshakeSendIdleStrategy = config.sendIdleStrategy
}
/**
@ -61,94 +67,33 @@ internal class Handshaker<CONNECTION : Connection>(
* @return true if the message was successfully sent by aeron
*/
@Suppress("DuplicatedCode")
internal suspend fun writeMessage(publication: Publication, aeronLogInfo: String, message: HandshakeMessage) {
internal fun writeMessage(publication: Publication, logInfo: String, message: HandshakeMessage): Boolean {
// The handshake sessionId IS NOT globally unique
logger.trace { "[$aeronLogInfo] (${message.connectKey}) send HS: $message" }
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${message.connectKey}) send HS: $message")
}
try {
val buffer = handshakeWriteKryo.write(message)
val objectSize = buffer.position()
val internalBuffer = buffer.internalBuffer
var timeoutInNanos = 0L
var startTime = 0L
var result: Long
while (true) {
result = publication.offer(internalBuffer, 0, objectSize)
if (result >= 0) {
// success!
return
}
/**
* Since the publication is not connected, we weren't able to send data to the remote endpoint.
*
* According to Aeron Docs, Pubs and Subs can "come and go", whatever that means. We just want to make sure that we
* don't "loop forever" if a publication is ACTUALLY closed, like on purpose.
*/
if (result == Publication.NOT_CONNECTED) {
if (timeoutInNanos == 0L) {
timeoutInNanos = writeTimeoutNS
startTime = System.nanoTime()
}
if (System.nanoTime() - startTime < timeoutInNanos) {
// we should retry.
handshakeSendIdleStrategy.idle()
continue
} else if (publication.isConnected) {
// more critical error sending the message. we shouldn't retry or anything.
// this exception will be a ClientException or a ServerException
val exception = newException(
"[$aeronLogInfo] Error sending message. (Connection in non-connected state longer than linger timeout. ${
EndPoint.errorCodeName(result)
})",
null
)
exception.cleanStackTraceInternal()
listenerManager.notifyError(exception)
throw exception
}
else {
// publication was actually closed, so no bother throwing an error
return
}
}
/**
* The publication is not connected to a subscriber, this can be an intermittent state as subscribers come and go.
* val NOT_CONNECTED: Long = -1
*
* The offer failed due to back pressure from the subscribers preventing further transmission.
* val BACK_PRESSURED: Long = -2
*
* The offer failed due to an administration action and should be retried.
* The action is an operation such as log rotation which is likely to have succeeded by the next retry attempt.
* val ADMIN_ACTION: Long = -3
*/
if (result >= Publication.ADMIN_ACTION) {
// we should retry.
handshakeSendIdleStrategy.idle()
continue
}
// more critical error sending the message. we shouldn't retry or anything.
// this exception will be a ClientException or a ServerException
val exception = newException("[$aeronLogInfo] Error sending handshake message. $message (${EndPoint.errorCodeName(result)})", null)
exception.cleanStackTraceInternal()
listenerManager.notifyError(exception)
throw exception
}
return aeronDriver.send(publication, buffer, logInfo, listenerManager, handshakeSendIdleStrategy)
} catch (e: Exception) {
if (e is ClientException || e is ServerException) {
// if the driver is closed due to a network disconnect or a remote-client termination, we also must close the connection.
if (aeronDriver.internal.mustRestartDriverOnError) {
// we had a HARD network crash/disconnect, we close the driver and then reconnect automatically
//NOTE: notifyDisconnect IS NOT CALLED!
}
else if (e is ClientException || e is ServerException) {
throw e
} else {
val exception = newException("[$aeronLogInfo] Error serializing handshake message $message", e)
}
else {
val exception = newException("[$logInfo] Error serializing handshake message $message", e)
exception.cleanStackTrace(2) // 2 because we do not want to see the stack for the abstract `newException`
listenerManager.notifyError(exception)
throw exception
}
return false
} finally {
handshakeSendIdleStrategy.reset()
}

View File

@ -1,120 +0,0 @@
/*
* Copyright 2020 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.handshake
import org.agrona.collections.IntArrayList
/**
* An allocator for port numbers.
*
* The allocator accepts a base number `p` and a maximum count `n | n > 0`, and will allocate
* up to `n` numbers, in a random order, in the range `[p, p + n - 1`.
*
* @param basePort The base port
* @param numberOfPortsToAllocate The maximum number of ports that will be allocated
*
* @throws IllegalArgumentException If the port range is not valid
*/
class PortAllocator(basePort: Int, numberOfPortsToAllocate: Int) {
private val minPort: Int
private val maxPort: Int
private val portShuffleReset: Int
private var portShuffleCount: Int
private val freePorts: IntArrayList
init {
if (basePort !in 1..65535) {
throw IllegalArgumentException("Base port $basePort must be in the range [1, 65535]")
}
minPort = basePort
maxPort = Math.max(basePort+1, basePort + (numberOfPortsToAllocate - 1))
if (maxPort !in (basePort + 1)..65535) {
throw IllegalArgumentException("Uppermost port $maxPort must be in the range [$basePort, 65535]")
}
// every time we add 25% of ports back (via 'free'), reshuffle the ports
portShuffleReset = numberOfPortsToAllocate/4
portShuffleCount = portShuffleReset
freePorts = IntArrayList()
for (port in basePort..maxPort) {
freePorts.addInt(port)
}
freePorts.shuffle()
}
/**
* Allocate `count` number of ports.
*
* @param count The number of ports that will be allocated
*
* @return An array of allocated ports
*
* @throws PortAllocationException If there are fewer than `count` ports available to allocate
*/
fun allocate(count: Int): IntArray {
if (freePorts.size < count) {
throw IllegalArgumentException("Too few ports available to allocate $count ports")
}
// reshuffle the ports once we need to re-allocate a new port
if (portShuffleCount <= 0) {
portShuffleCount = portShuffleReset
freePorts.shuffle()
}
val result = IntArray(count)
for (index in 0 until count) {
val lastValue = freePorts.size - 1
val removed = freePorts.removeAt(lastValue)
result[index] = removed
}
return result
}
/**
* Frees the given ports. Has no effect if the given port is outside of the range considered by the allocator.
*
* @param ports The array of ports to free
*/
fun free(ports: IntArray) {
ports.forEach {
free(it)
}
}
/**
* Free a given port.
* <p>
* Has no effect if the given port is outside of the range considered by the allocator.
*
* @param port The port
*/
fun free(port: Int) {
if (port in minPort..maxPort) {
// add at the end (so we don't have unnecessary array resizes)
freePorts.addInt(freePorts.size, port)
portShuffleCount--
}
}
}

View File

@ -16,6 +16,7 @@
package dorkbox.network.handshake
import dorkbox.network.connection.EndPoint
import io.aeron.Publication
import io.aeron.Subscription
import java.net.Inet4Address
@ -29,31 +30,42 @@ data class PubSub(
val streamIdPub: Int,
val streamIdSub: Int,
val reliable: Boolean,
val remoteAddress: InetAddress? = null,
val remoteAddressString: String = "IPC",
val portPub: Int = 0,
val portSub: Int = 0
val remoteAddress: InetAddress?,
val remoteAddressString: String,
val portPub: Int,
val portSub: Int,
val tagName: String // will either be "", or will be "[tag_name]"
) {
val isIpc get() = remoteAddress == null
fun getLogInfo(debugEnabled: Boolean): String {
fun getLogInfo(extraDetails: Boolean): String {
return if (isIpc) {
if (debugEnabled) {
"IPC sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}"
val prefix = if (tagName.isNotEmpty()) {
EndPoint.IPC_NAME + " ($tagName)"
} else {
"IPC [${sessionIdPub}|${sessionIdSub}|${streamIdPub}|${streamIdSub}]"
EndPoint.IPC_NAME
}
if (extraDetails) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
prefix
}
} else {
val prefix = if (remoteAddress is Inet4Address) {
var prefix = if (remoteAddress is Inet4Address) {
"IPv4 $remoteAddressString"
} else {
"IPv6 $remoteAddressString"
}
if (debugEnabled) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, port: p=${portPub} s=${portSub}"
if (tagName.isNotEmpty()) {
prefix += " ($tagName)"
}
if (extraDetails) {
"$prefix sessionID: p=${sessionIdPub} s=${sessionIdSub}, streamID: p=${streamIdPub} s=${streamIdSub}, port: p=${portPub} s=${portSub}, reg: p=${pub.registrationId()} s=${sub.registrationId()}"
} else {
"$prefix [${sessionIdPub}|${sessionIdSub}|${streamIdPub}|${streamIdSub}|${portPub}|${portSub}]"
prefix
}
}
}

View File

@ -20,7 +20,7 @@ import dorkbox.network.exceptions.AllocationException
import dorkbox.objectPool.ObjectPool
import dorkbox.objectPool.Pool
import kotlinx.atomicfu.atomic
import mu.KotlinLogging
import org.slf4j.LoggerFactory
/**
* An allocator for random IDs, the maximum number of IDs is an unsigned short (65535).
@ -37,7 +37,7 @@ class RandomId65kAllocator(private val min: Int, max: Int) {
constructor(size: Int): this(1, size + 1)
companion object {
private val logger = KotlinLogging.logger("RandomId65k")
private val logger = LoggerFactory.getLogger("RandomId65k")
}
@ -53,6 +53,7 @@ class RandomId65kAllocator(private val min: Int, max: Int) {
maxAssignments = (max - min).coerceIn(1, max65k)
// create a shuffled list of ID's. This operation is ONLY performed ONE TIME per endpoint!
// Boxing the Ints here is OK, because they are boxed in the cache as well (so it doesn't matter).
val ids = ArrayList<Int>(maxAssignments)
for (id in min until min + maxAssignments) {
ids.add(id)
@ -78,7 +79,9 @@ class RandomId65kAllocator(private val min: Int, max: Int) {
val count = assigned.incrementAndGet()
val id = cache.take()
logger.trace { "Allocating $id (total $count)" }
if (logger.isTraceEnabled) {
logger.trace("Allocating $id (total $count)")
}
return id
}
@ -92,7 +95,9 @@ class RandomId65kAllocator(private val min: Int, max: Int) {
if (assigned < 0) {
throw AllocationException("Unequal allocate/free method calls attempting to free [$id] (too many 'free' calls).")
}
logger.trace { "Freeing $id" }
if (logger.isTraceEnabled) {
logger.trace("Freeing $id")
}
cache.put(id)
}

View File

@ -18,6 +18,7 @@ package dorkbox.network.handshake
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.uri
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpInfo
import io.aeron.CommonContext
import java.net.Inet4Address
@ -32,7 +33,7 @@ import java.net.InetAddress
*/
internal class ServerConnectionDriver(val pubSub: PubSub) {
companion object {
suspend fun build(isIpc: Boolean,
fun build(isIpc: Boolean,
aeronDriver: AeronDriver,
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
@ -40,8 +41,9 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
ipInfo: IpInfo,
remoteAddress: InetAddress?,
remoteAddressString: String,
portPub: Int, portSub: Int,
portPubMdc: Int, portPub: Int, portSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String): ServerConnectionDriver {
val pubSub: PubSub
@ -54,6 +56,7 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
} else {
@ -66,9 +69,11 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
streamIdSub = streamIdSub,
remoteAddress = remoteAddress!!,
remoteAddressString = remoteAddressString,
portPubMdc = portPubMdc,
portPub = portPub,
portSub = portSub,
reliable = reliable,
tagName = tagName,
logInfo = logInfo
)
}
@ -76,11 +81,12 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
return ServerConnectionDriver(pubSub)
}
private suspend fun buildIPC(
private fun buildIPC(
aeronDriver: AeronDriver,
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -91,26 +97,39 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
val publication = aeronDriver.addExclusivePublication(publicationUri, streamIdPub, logInfo, true)
val publication = aeronDriver.addPublication(publicationUri, streamIdPub, logInfo, true)
// Create a subscription at the given address and port, using the given stream ID.
val subscriptionUri = uri(CommonContext.IPC_MEDIA, sessionIdSub, reliable)
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, true)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = null,
remoteAddressString = EndPoint.IPC_NAME,
portPub = 0,
portSub = 0,
tagName = tagName
)
}
private suspend fun buildUdp(
private fun buildUdp(
aeronDriver: AeronDriver,
ipInfo: IpInfo,
sessionIdPub: Int, sessionIdSub: Int,
streamIdPub: Int, streamIdSub: Int,
remoteAddress: InetAddress, remoteAddressString: String,
portPub: Int, portSub: Int,
portPubMdc: Int, // this is the MDC port - used to dynamically discover the portPub value (but we manually save this info)
portPub: Int,
portSub: Int,
reliable: Boolean,
tagName: String,
logInfo: String
): PubSub {
// on close, the publication CAN linger (in case a client goes away, and then comes back)
@ -122,16 +141,17 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
// create a new publication for the connection (since the handshake ALWAYS closes the current publication)
// we explicitly have the publisher "connect to itself", because we are using MDC to work around NAT
// A control endpoint for the subscriptions will cause a periodic service management "heartbeat" to be sent to the
// remote endpoint publication, which permits the remote publication to send us data, thereby getting us around NAT
val publicationUri = uri(CommonContext.UDP_MEDIA, sessionIdPub, reliable)
.controlEndpoint(ipInfo.getAeronPubAddress(isRemoteIpv4) + ":" + portPub) // this is the port of the client subscription!
.controlMode(CommonContext.MDC_CONTROL_MODE_DYNAMIC)
.controlEndpoint(ipInfo.getAeronPubAddress(isRemoteIpv4) + ":" + portPubMdc) // this is the control port! (listens to status messages and NAK from client)
// NOTE: Handlers are called on the client conductor thread. The client conductor thread expects handlers to do safe
// publication of any state to other threads and not be long running or re-entrant with the client.
val publication = aeronDriver.addExclusivePublication(publicationUri, streamIdPub, logInfo, false)
val publication = aeronDriver.addPublication(publicationUri, streamIdPub, logInfo, false)
// if we are IPv6 WILDCARD -- then our subscription must ALSO be IPv6, even if our connection is via IPv4
@ -142,12 +162,20 @@ internal class ServerConnectionDriver(val pubSub: PubSub) {
val subscription = aeronDriver.addSubscription(subscriptionUri, streamIdSub, logInfo, false)
return PubSub(publication, subscription,
sessionIdPub, sessionIdSub,
streamIdPub, streamIdSub,
reliable,
remoteAddress, remoteAddressString,
portPub, portSub)
return PubSub(
pub = publication,
sub = subscription,
sessionIdPub = sessionIdPub,
sessionIdSub = sessionIdSub,
streamIdPub = streamIdPub,
streamIdSub = streamIdSub,
reliable = reliable,
remoteAddress = remoteAddress,
remoteAddressString = remoteAddressString,
portPub = portPub,
portSub = portSub,
tagName = tagName
)
}
}
}

View File

@ -25,14 +25,13 @@ import dorkbox.network.exceptions.AllocationException
import dorkbox.network.exceptions.ServerHandshakeException
import dorkbox.network.exceptions.ServerTimedoutException
import dorkbox.network.exceptions.TransmitException
import dorkbox.util.sync.CountDownLatch
import io.aeron.Publication
import kotlinx.coroutines.runBlocking
import mu.KLogger
import net.jodah.expiringmap.ExpirationPolicy
import net.jodah.expiringmap.ExpiringMap
import org.slf4j.Logger
import java.net.Inet4Address
import java.net.InetAddress
import java.util.*
import java.util.concurrent.*
@ -45,11 +44,10 @@ import java.util.concurrent.*
internal class ServerHandshake<CONNECTION : Connection>(
private val config: ServerConfiguration,
private val listenerManager: ListenerManager<CONNECTION>,
val aeronDriver: AeronDriver
private val aeronDriver: AeronDriver,
private val eventDispatch: EventDispatcher
) {
// note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close
private val pendingConnections = ExpiringMap.builder()
.apply {
@ -63,10 +61,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
.expirationListener<Long, CONNECTION> { clientConnectKey, connection ->
// this blocks until it fully runs (which is ok. this is fast)
listenerManager.notifyError(ServerTimedoutException("[${clientConnectKey} Connection (${connection.id}) Timed out waiting for registration response from client"))
runBlocking {
connection.close()
}
connection.close()
}
.build<Long, CONNECTION>()
@ -80,7 +75,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
init {
// we MUST include the publication linger timeout, otherwise we might encounter problems that are NOT REALLY problems
var handshakeTimeoutNs = aeronDriver.publicationConnectionTimeoutNs() + aeronDriver.lingerNs()
var handshakeTimeoutNs = TimeUnit.SECONDS.toNanos(config.connectionCloseTimeoutInSeconds.toLong()) + aeronDriver.publicationConnectionTimeoutNs() + aeronDriver.lingerNs()
if (EndPoint.DEBUG_CONNECTIONS) {
// connections are extremely difficult to diagnose when the connection timeout is short
@ -94,13 +89,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
* @return true if we should continue parsing the incoming message, false if we should abort (as we are DONE processing data)
*/
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD. ONLY RESPONSES ARE ON ACTION DISPATCH!
suspend fun validateMessageTypeAndDoPending(
fun validateMessageTypeAndDoPending(
server: Server<CONNECTION>,
handshaker: Handshaker<CONNECTION>,
handshakePublication: Publication,
message: HandshakeMessage,
aeronLogInfo: String,
logger: KLogger
logInfo: String,
logger: Logger
): Boolean {
// check to see if this sessionId is ALREADY in use by another connection!
@ -117,7 +112,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
try {
handshaker.writeMessage(handshakePublication,
aeronLogInfo,
logInfo,
HandshakeMessage.retry("Handshake already in progress for sessionID!"))
} catch (e: Error) {
listenerManager.notifyError(ServerHandshakeException("[$existingConnection] Handshake error", e))
@ -128,30 +123,44 @@ internal class ServerHandshake<CONNECTION : Connection>(
// check to see if this is a pending connection
if (message.state == HandshakeMessage.DONE) {
val existingConnection = pendingConnections.remove(message.connectKey)
if (existingConnection == null) {
val newConnection = pendingConnections.remove(message.connectKey)
if (newConnection == null) {
listenerManager.notifyError(ServerHandshakeException("[?????] (${message.connectKey}) Error! Pending connection from client was null, and cannot complete handshake!"))
return true
}
// Server is the "source", client mirrors the server
logger.debug { "[${existingConnection}] (${message.connectKey}) Connection done with handshake." }
val connectionType = if (newConnection.enableBufferedMessages) {
"Buffered connection"
} else {
"Connection"
}
// before we finish creating the connection, we initialize it (in case there needs to be logic that happens-before `onConnect` calls occur
listenerManager.notifyInit(existingConnection)
// Server is the "source", client mirrors the server
if (logger.isTraceEnabled) {
logger.trace("[${newConnection}] (${message.connectKey}) $connectionType (${newConnection.id}) done with handshake.")
} else if (logger.isDebugEnabled) {
logger.debug("[${newConnection}] $connectionType (${newConnection.id}) done with handshake.")
}
newConnection.setImage()
// before we finish creating the connection, we initialize it (in case there needs to be logic that happens-before `onConnect` calls
listenerManager.notifyInit(newConnection)
// this enables the connection to start polling for messages
server.addConnection(existingConnection)
server.addConnection(newConnection)
// now tell the client we are done
try {
handshaker.writeMessage(handshakePublication,
aeronLogInfo,
logInfo,
HandshakeMessage.doneToClient(message.connectKey))
listenerManager.notifyConnect(existingConnection)
listenerManager.notifyConnect(newConnection)
newConnection.sendBufferedMessages()
} catch (e: Exception) {
listenerManager.notifyError(existingConnection, TransmitException("[$existingConnection] Handshake error", e))
listenerManager.notifyError(newConnection, TransmitException("[$newConnection] Handshake error", e))
}
return false
@ -164,25 +173,25 @@ internal class ServerHandshake<CONNECTION : Connection>(
* @return true if we should continue parsing the incoming message, false if we should abort
*/
// note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
private suspend fun validateUdpConnectionInfo(
private fun validateUdpConnectionInfo(
server: Server<CONNECTION>,
handshaker: Handshaker<CONNECTION>,
handshakePublication: Publication,
config: ServerConfiguration,
clientAddress: InetAddress,
aeronLogInfo: String
logInfo: String
): Boolean {
try {
// VALIDATE:: Check to see if there are already too many clients connected.
if (server.connections.size() >= config.maxClientCount) {
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Server is full. Max allowed is ${config.maxClientCount}"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Server is full"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -190,29 +199,29 @@ internal class ServerHandshake<CONNECTION : Connection>(
// VALIDATE:: we are now connected to the client and are going to create a new connection.
val currentCountForIp = connectionsPerIpCounts.get(clientAddress)
if (currentCountForIp >= config.maxConnectionsPerIpAddress) {
if (config.maxConnectionsPerIpAddress in 1..currentCountForIp) {
// decrement it now, since we aren't going to permit this connection (take the extra decrement hit on failure, instead of always)
connectionsPerIpCounts.decrement(clientAddress, currentCountForIp)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Too many connections for IP address. Max allowed is ${config.maxConnectionsPerIpAddress}"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Too many connections for IP address. Max allowed is ${config.maxConnectionsPerIpAddress}"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Too many connections for IP address"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
connectionsPerIpCounts.increment(clientAddress, currentCountForIp)
} catch (e: Exception) {
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Handshake error, Could not validate client message", e))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Handshake error, Could not validate client message", e))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Invalid connection"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
}
@ -221,21 +230,28 @@ internal class ServerHandshake<CONNECTION : Connection>(
/**
* NOTE: This must not be called on the main thread because it is blocking!
*
* @return true if the connection was SUCCESS. False if the handshake poller should immediately close the publication
*/
suspend fun processIpcHandshakeMessageServer(
fun processIpcHandshakeMessageServer(
server: Server<CONNECTION>,
handshaker: Handshaker<CONNECTION>,
aeronDriver: AeronDriver,
handshakePublication: Publication,
publicKey: ByteArray,
message: HandshakeMessage,
aeronLogInfo: String,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
logger: KLogger
logInfo: String,
logger: Logger
): Boolean {
val serialization = config.serialization
val clientTagName = message.tag
if (clientTagName.length > 32) {
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
/////
/////
///// DONE WITH VALIDATION
@ -248,13 +264,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
try {
connectionSessionIdPub = sessionIdAllocator.allocate()
} catch (e: AllocationException) {
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a session pub ID for the client connection!", e))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a session pub ID for the client connection!", e))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -266,13 +282,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
// have to unwind actions!
sessionIdAllocator.free(connectionSessionIdPub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a session sub ID for the client connection!", e))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a session sub ID for the client connection!", e))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -286,13 +302,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdAllocator.free(connectionSessionIdPub)
sessionIdAllocator.free(connectionSessionIdSub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a stream publication ID for the client connection!", e))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a stream publication ID for the client connection!", e))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -306,13 +322,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdAllocator.free(connectionSessionIdSub)
streamIdAllocator.free(connectionStreamIdPub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a stream subscription ID for the client connection!", e))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a stream subscription ID for the client connection!", e))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -320,14 +336,16 @@ internal class ServerHandshake<CONNECTION : Connection>(
// create a new connection. The session ID is encrypted.
var connection: CONNECTION? = null
var newConnection: CONNECTION? = null
try {
// Create a pub/sub at the given address and port, using the given stream ID.
// NOTE: This must not be called on the main thread because it is blocking!
val newConnectionDriver = ServerConnectionDriver.build(
aeronDriver = aeronDriver,
ipInfo = server.ipInfo,
isIpc = true,
logInfo = "IPC",
tagName = clientTagName,
logInfo = EndPoint.IPC_NAME,
remoteAddress = null,
remoteAddressString = "",
@ -335,21 +353,41 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
portPubMdc = 0,
portPub = 0,
portSub = 0,
reliable = true
)
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug { "Creating new connection to $logInfo" }
val enableBufferedMessagesForConnection = listenerManager.notifyEnableBufferedMessages(null, clientTagName)
val connectionType = if (enableBufferedMessagesForConnection) {
"buffered connection"
} else {
logger.info { "Creating new connection to $logInfo" }
"connection"
}
val connectionTypeCaps = connectionType.replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() }
connection = connectionFunc(ConnectionParams(publicKey, server, newConnectionDriver.pubSub, PublicKeyValidationState.VALID))
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug("Creating new $connectionType to $logInfo")
} else {
logger.info("Creating new $connectionType to $logInfo")
}
newConnection = server.newConnection(ConnectionParams(
publicKey = publicKey,
endPoint = server,
connectionInfo = newConnectionDriver.pubSub,
publicKeyValidation = PublicKeyValidationState.VALID,
enableBufferedMessages = enableBufferedMessagesForConnection,
cryptoKey = CryptoManagement.NOCRYPT // we don't use encryption for IPC connections
))
server.bufferedManager.onConnect(newConnection)
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
// NOTE: all IPC client connections are, by default, always allowed to connect, because they are running on the same machine
@ -368,21 +406,28 @@ internal class ServerHandshake<CONNECTION : Connection>(
// now create the encrypted payload, using no crypto
successMessage.registrationData = server.crypto.nocrypt(
connectionSessionIdPub,
connectionSessionIdSub,
connectionStreamIdPub,
connectionStreamIdSub,
serialization.getKryoRegistrationDetails())
sessionIdPub = connectionSessionIdPub,
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
bufferedMessages = enableBufferedMessagesForConnection,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = connection
pendingConnections[message.connectKey] = newConnection
logger.debug { "[$aeronLogInfo] (${message.connectKey}) Connection (${connection.id}) responding to handshake hello." }
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${message.connectKey}) $connectionType (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connectionTypeCaps (${newConnection.id}) responding to handshake hello.")
}
// this tells the client all the info to connect.
handshaker.writeMessage(handshakePublication, aeronLogInfo, successMessage) // exception is already caught!
handshaker.writeMessage(handshakePublication, logInfo, successMessage) // exception is already caught!
} catch (e: Exception) {
// have to unwind actions!
sessionIdAllocator.free(connectionSessionIdPub)
@ -390,7 +435,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
streamIdAllocator.free(connectionStreamIdSub)
streamIdAllocator.free(connectionStreamIdPub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] (${message.connectKey}) Connection (${connection?.id}) handshake crashed! Message $message", e))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] (${message.connectKey}) Connection (${newConnection?.id}) handshake crashed! Message $message", e))
return false
}
@ -399,22 +444,24 @@ internal class ServerHandshake<CONNECTION : Connection>(
}
/**
* note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
* NOTE: This must not be called on the main thread because it is blocking!
*
* @return true if the connection was SUCCESS. False if the handshake poller should immediately close the publication
*/
suspend fun processUdpHandshakeMessageServer(
fun processUdpHandshakeMessageServer(
server: Server<CONNECTION>,
handshaker: Handshaker<CONNECTION>,
handshakePublication: Publication,
publicKey: ByteArray,
clientAddress: InetAddress,
clientAddressString: String,
portSub: Int,
portPub: Int,
mdcPortPub: Int,
isReliable: Boolean,
message: HandshakeMessage,
aeronLogInfo: String,
connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
logger: KLogger
logInfo: String,
logger: Logger
): Boolean {
val serialization = config.serialization
@ -425,18 +472,40 @@ internal class ServerHandshake<CONNECTION : Connection>(
// VALIDATE:: check to see if the remote connection's public key has changed!
validateRemoteAddress = server.crypto.validateRemoteAddress(clientAddress, clientAddressString, clientPublicKeyBytes)
if (validateRemoteAddress == PublicKeyValidationState.INVALID) {
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Public key mismatch."))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Public key mismatch."))
return false
}
clientPublicKeyBytes!!
val isSelfMachine = clientAddress.isLoopbackAddress || clientAddress == EndPoint.lanAddress
if (!isSelfMachine &&
!validateUdpConnectionInfo(server, handshaker, handshakePublication, config, clientAddress, aeronLogInfo)) {
!validateUdpConnectionInfo(server, handshaker, handshakePublication, config, clientAddress, logInfo)) {
// we do not want to limit the loopback addresses!
return false
}
val clientTagName = message.tag
if (clientTagName.length > 32) {
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Invalid tag name."))
return false
}
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(clientAddress, clientTagName)
if (!permitConnection) {
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
/////
/////
@ -453,13 +522,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
// have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a session ID for the client connection!"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a session ID for the client connection!"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -473,13 +542,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
connectionsPerIpCounts.decrementSlow(clientAddress)
sessionIdAllocator.free(connectionSessionIdPub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a session ID for the client connection!"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a session ID for the client connection!"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -494,13 +563,13 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdAllocator.free(connectionSessionIdPub)
sessionIdAllocator.free(connectionSessionIdSub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a stream ID for the client connection!"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a stream ID for the client connection!"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
@ -515,23 +584,18 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdAllocator.free(connectionSessionIdSub)
streamIdAllocator.free(connectionStreamIdPub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection not allowed! Unable to allocate a stream ID for the client connection!"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] Connection not allowed! Unable to allocate a stream ID for the client connection!"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
handshaker.writeMessage(handshakePublication, logInfo,
HandshakeMessage.error("Connection error!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
listenerManager.notifyError(TransmitException("[$logInfo] Handshake error", e))
}
return false
}
// the pub/sub do not necessarily have to be the same. They can be ANY port
val portPub = message.port
val portSub = server.port
val logType = if (clientAddress is Inet4Address) {
"IPv4"
} else {
@ -539,9 +603,10 @@ internal class ServerHandshake<CONNECTION : Connection>(
}
// create a new connection. The session ID is encrypted.
var connection: CONNECTION? = null
var newConnection: CONNECTION? = null
try {
// Create a pub/sub at the given address and port, using the given stream ID.
// NOTE: This must not be called on the main thread because it is blocking!
val newConnectionDriver = ServerConnectionDriver.build(
ipInfo = server.ipInfo,
aeronDriver = aeronDriver,
@ -554,43 +619,48 @@ internal class ServerHandshake<CONNECTION : Connection>(
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
portPubMdc = mdcPortPub,
portPub = portPub,
portSub = portSub,
tagName = clientTagName,
reliable = isReliable
)
val cryptoSecretKey = server.crypto.generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, server.crypto.publicKeyBytes)
val enableBufferedMessagesForConnection = listenerManager.notifyEnableBufferedMessages(clientAddress, clientTagName)
val connectionType = if (enableBufferedMessagesForConnection) {
"buffered connection"
} else {
"connection"
}
val connectionTypeCaps = connectionType.replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() }
val logInfo = newConnectionDriver.pubSub.getLogInfo(logger.isDebugEnabled)
if (logger.isDebugEnabled) {
logger.debug { "Creating new connection to $logInfo" }
logger.debug("Creating new $connectionType to $logInfo")
} else {
logger.info { "Creating new connection to $logInfo" }
logger.info("Creating new $connectionType to $logInfo")
}
connection = connectionFunc(ConnectionParams(publicKey, server, newConnectionDriver.pubSub, validateRemoteAddress))
newConnection = server.newConnection(ConnectionParams(
publicKey = publicKey,
endPoint = server,
connectionInfo = newConnectionDriver.pubSub,
publicKeyValidation = validateRemoteAddress,
enableBufferedMessages = enableBufferedMessagesForConnection,
cryptoKey = cryptoSecretKey
))
// VALIDATE:: are we allowed to connect to this server (now that we have the initial server information)
val permitConnection = listenerManager.notifyFilter(connection)
if (!permitConnection) {
// this will also unwind/free allocations
connection.close()
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] Connection was not permitted!"))
try {
handshaker.writeMessage(handshakePublication, aeronLogInfo,
HandshakeMessage.error("Connection was not permitted!"))
} catch (e: Exception) {
listenerManager.notifyError(TransmitException("[$aeronLogInfo] Handshake error", e))
}
return false
}
server.bufferedManager.onConnect(newConnection)
///////////////
/// HANDSHAKE
///////////////
// The one-time pad is used to encrypt the session ID, so that ONLY the correct client knows what it is!
val successMessage = HandshakeMessage.helloAckToClient(message.connectKey)
@ -599,23 +669,29 @@ internal class ServerHandshake<CONNECTION : Connection>(
// now create the encrypted payload, using ECDH
successMessage.registrationData = server.crypto.encrypt(
clientPublicKeyBytes = clientPublicKeyBytes!!,
cryptoSecretKey = cryptoSecretKey,
sessionIdPub = connectionSessionIdPub,
sessionIdSub = connectionSessionIdSub,
streamIdPub = connectionStreamIdPub,
streamIdSub = connectionStreamIdSub,
sessionTimeout = config.bufferedConnectionTimeoutSeconds,
bufferedMessages = enableBufferedMessagesForConnection,
kryoRegDetails = serialization.getKryoRegistrationDetails()
)
successMessage.publicKey = server.crypto.publicKeyBytes
// before we notify connect, we have to wait for the client to tell us that they can receive data
pendingConnections[message.connectKey] = connection
pendingConnections[message.connectKey] = newConnection
logger.debug { "[$aeronLogInfo] (${message.connectKey}) Connection (${connection.id}) responding to handshake hello." }
if (logger.isTraceEnabled) {
logger.trace("[$logInfo] $connectionTypeCaps (${newConnection.id}) responding to handshake hello.")
} else if (logger.isDebugEnabled) {
logger.debug("[$logInfo] $connectionTypeCaps (${newConnection.id}) responding to handshake hello.")
}
// this tells the client all the info to connect.
handshaker.writeMessage(handshakePublication, aeronLogInfo, successMessage) // exception is already caught
handshaker.writeMessage(handshakePublication, logInfo, successMessage) // exception is already caught
} catch (e: Exception) {
// have to unwind actions!
connectionsPerIpCounts.decrementSlow(clientAddress)
@ -624,7 +700,7 @@ internal class ServerHandshake<CONNECTION : Connection>(
streamIdAllocator.free(connectionStreamIdPub)
streamIdAllocator.free(connectionStreamIdSub)
listenerManager.notifyError(ServerHandshakeException("[$aeronLogInfo] (${message.connectKey}) Connection (${connection?.id}) handshake crashed! Message $message"))
listenerManager.notifyError(ServerHandshakeException("[$logInfo] (${message.connectKey}) Connection (${newConnection?.id}) handshake crashed! Message $message", e))
return false
}
@ -650,18 +726,18 @@ internal class ServerHandshake<CONNECTION : Connection>(
*
* note: CANNOT be called in action dispatch. ALWAYS ON SAME THREAD
*/
suspend fun clear() {
fun clear() {
val connections = pendingConnections
val latch = CountDownLatch(connections.size)
EventDispatcher.launchSequentially(EventDispatcher.CLOSE) {
eventDispatch.CLOSE.launch {
connections.forEach { (_, v) ->
v.close()
latch.countDown()
}
}
latch.await(config.connectionCloseTimeoutInSeconds.toLong() * connections.size)
latch.await(config.connectionCloseTimeoutInSeconds.toLong() * connections.size, TimeUnit.MILLISECONDS)
connections.clear()
}
}

View File

@ -18,6 +18,7 @@ package dorkbox.network.handshake
import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.uriHandshake
import dorkbox.network.connection.EndPoint
import dorkbox.network.connection.IpInfo
import io.aeron.ChannelUriStringBuilder
import io.aeron.CommonContext
@ -34,7 +35,7 @@ internal class ServerHandshakeDriver(
private val logInfo: String)
{
companion object {
suspend fun build(
fun build(
aeronDriver: AeronDriver,
isIpc: Boolean,
ipInfo: IpInfo,
@ -62,7 +63,18 @@ internal class ServerHandshakeDriver(
}
}
suspend fun close() {
fun close(endPoint: EndPoint<*>) {
try {
// we might not be able to close this connection.
aeronDriver.close(subscription, logInfo)
}
catch (e: Exception) {
endPoint.listenerManager.notifyError(e)
}
}
fun unsafeClose() {
// we might not be able to close this connection.
aeronDriver.close(subscription, logInfo)
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2023 dorkbox, llc
* Copyright 2024 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,8 +26,6 @@ import dorkbox.network.aeron.AeronDriver
import dorkbox.network.aeron.AeronDriver.Companion.uriHandshake
import dorkbox.network.aeron.AeronPoller
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ConnectionParams
import dorkbox.network.connection.EventDispatcher
import dorkbox.network.connection.IpInfo
import dorkbox.network.exceptions.ServerException
import dorkbox.network.exceptions.ServerHandshakeException
@ -40,11 +38,10 @@ import io.aeron.Image
import io.aeron.Publication
import io.aeron.logbuffer.FragmentHandler
import io.aeron.logbuffer.Header
import kotlinx.coroutines.runBlocking
import mu.KLogger
import net.jodah.expiringmap.ExpirationPolicy
import net.jodah.expiringmap.ExpiringMap
import org.agrona.DirectBuffer
import org.slf4j.Logger
import java.net.Inet4Address
import java.util.concurrent.*
@ -52,22 +49,23 @@ internal object ServerHandshakePollers {
fun disabled(serverInfo: String): AeronPoller {
return object : AeronPoller {
override fun poll(): Int { return 0 }
override suspend fun close() {}
override fun close() {}
override val info = serverInfo
}
}
class IpcProc<CONNECTION : Connection>(
val logger: KLogger,
val logger: Logger,
val server: Server<CONNECTION>,
val driver: AeronDriver,
val handshake: ServerHandshake<CONNECTION>,
val connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION
val handshake: ServerHandshake<CONNECTION>
): FragmentHandler {
private val isReliable = server.config.isReliable
private val handshaker = server.handshaker
private val handshakeTimeoutNs = handshake.handshakeTimeoutNs
private val shutdownInProgress = server.shutdownInProgress
private val shutdown = server.shutdown
// note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close
private val publications = ExpiringMap.builder()
@ -76,9 +74,14 @@ internal object ServerHandshakePollers {
}
.expirationPolicy(ExpirationPolicy.CREATED)
.expirationListener<Long, Publication> { connectKey, publication ->
runBlocking {
try {
// we might not be able to close this connection.
driver.close(publication, "Server IPC Handshake ($connectKey)")
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
.build<Long, Publication>()
@ -93,6 +96,12 @@ internal object ServerHandshakePollers {
val logInfo = "$sessionId/$streamId : IPC" // Server is the "source", client mirrors the server
if (shutdownInProgress.value) {
driver.deleteLogFile(image)
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] server is shutting down. Aborting new connection attempts."))
return
}
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length)
@ -100,8 +109,8 @@ internal object ServerHandshakePollers {
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) {
throw ServerHandshakeException("[$logInfo] Connection not allowed! unrecognized message: $msg")
} else {
logger.trace { "[$logInfo] (${msg.connectKey}) received HS: $msg" }
} else if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${msg.connectKey}) received HS: $msg")
}
msg
} catch (e: Exception) {
@ -114,19 +123,20 @@ internal object ServerHandshakePollers {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
return
}
// we have read all the data, now dispatch it.
EventDispatcher.HANDSHAKE.launch {
// NOTE: This MUST to happen in separates thread so that we can take as long as we need when creating publications and handshaking,
// because under load -- this will REGULARLY timeout! Under no circumstance can this happen in the main processing thread!!
server.eventDispatch.HANDSHAKE.launch {
// we have read all the data, now dispatch it.
// HandshakeMessage.HELLO
// HandshakeMessage.DONE
val messageState = message.state
val connectKey = message.connectKey
if (messageState == HandshakeMessage.HELLO) {
// we create a NEW publication for the handshake, which connects directly to the client handshake subscription
@ -134,8 +144,9 @@ internal object ServerHandshakePollers {
// this will always connect to the CLIENT handshake subscription!
val publication = try {
driver.addExclusivePublication(publicationUri, message.streamId, logInfo, true)
} catch (e: Exception) {
driver.addPublication(publicationUri, message.streamId, logInfo, true)
}
catch (e: Exception) {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
@ -146,10 +157,11 @@ internal object ServerHandshakePollers {
try {
// we actually have to wait for it to connect before we continue
driver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
driver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ServerTimedoutException("$logInfo publication cannot connect with client in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
} catch (e: Exception) {
}
catch (e: Exception) {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
@ -167,22 +179,36 @@ internal object ServerHandshakePollers {
handshakePublication = publication,
publicKey = message.publicKey!!,
message = message,
aeronLogInfo = logInfo,
connectionFunc = connectionFunc,
logInfo = logInfo,
logger = logger
)
if (success) {
publications[connectKey] = publication
} else {
driver.close(publication, logInfo)
}
} catch (e: Exception) {
else {
try {
// we might not be able to close this connection.
driver.close(publication, logInfo)
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
}
catch (e: Exception) {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
driver.close(publication, logInfo)
try {
// we might not be able to close this connection.
driver.close(publication, logInfo)
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] Error processing IPC handshake", e))
}
} else {
@ -204,7 +230,7 @@ internal object ServerHandshakePollers {
handshaker = handshaker,
handshakePublication = publication,
message = message,
aeronLogInfo = logInfo,
logInfo = logInfo,
logger = logger
)
} catch (e: Exception) {
@ -215,26 +241,37 @@ internal object ServerHandshakePollers {
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
driver.close(publication, logInfo)
try {
// we might not be able to close this connection.
driver.close(publication, logInfo)
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
}
}
suspend fun close() {
fun close() {
publications.forEach { (connectKey, publication) ->
AeronDriver.sessionIdAllocator.free(publication.sessionId())
driver.close(publication, "Server Handshake ($connectKey)")
try {
// we might not be able to close this connection.
driver.close(publication, "Server Handshake ($connectKey)")
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
publications.clear()
}
}
class UdpProc<CONNECTION : Connection>(
val logger: KLogger,
val logger: Logger,
val server: Server<CONNECTION>,
val driver: AeronDriver,
val handshake: ServerHandshake<CONNECTION>,
val connectionFunc: (connectionParameters: ConnectionParams<CONNECTION>) -> CONNECTION,
val isReliable: Boolean
): FragmentHandler {
companion object {
@ -246,6 +283,12 @@ internal object ServerHandshakePollers {
private val ipInfo = server.ipInfo
private val handshaker = server.handshaker
private val handshakeTimeoutNs = handshake.handshakeTimeoutNs
private val shutdownInProgress = server.shutdownInProgress
private val shutdown = server.shutdown
private val serverPortSub = server.port1
// MDC 'dynamic control mode' means that the server will to listen for status messages and NAK (from the client) on a port.
private val mdcPortPub = server.port2
// note: the expire time here is a LITTLE longer than the expire time in the client, this way we can adjust for network lag if it's close
private val publications = ExpiringMap.builder()
@ -255,9 +298,13 @@ internal object ServerHandshakePollers {
}
.expirationPolicy(ExpirationPolicy.CREATED)
.expirationListener<Long, Publication> { connectKey, publication ->
runBlocking {
try {
// we might not be able to close this connection.
driver.close(publication, "Server UDP Handshake ($connectKey)")
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
.build<Long, Publication>()
@ -314,6 +361,12 @@ internal object ServerHandshakePollers {
val logInfo = "$sessionId/$streamId:$clientAddressString"
if (shutdownInProgress.value) {
driver.deleteLogFile(image)
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] server is shutting down. Aborting new connection attempts."))
return
}
// ugh, this is verbose -- but necessary
val message = try {
val msg = handshaker.readMessage(buffer, offset, length)
@ -321,8 +374,8 @@ internal object ServerHandshakePollers {
// VALIDATE:: a Registration object is the only acceptable message during the connection phase
if (msg !is HandshakeMessage) {
throw ServerHandshakeException("[$logInfo] Connection not allowed! unrecognized message: $msg")
} else {
logger.trace { "[$logInfo] (${msg.connectKey}) received HS: $msg" }
} else if (logger.isTraceEnabled) {
logger.trace("[$logInfo] (${msg.connectKey}) received HS: $msg")
}
msg
} catch (e: Exception) {
@ -334,13 +387,12 @@ internal object ServerHandshakePollers {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
return
}
EventDispatcher.HANDSHAKE.launch {
// NOTE: This MUST to happen in separates thread so that we can take as long as we need when creating publications and handshaking,
// because under load -- this will REGULARLY timeout! Under no circumstance can this happen in the main processing thread!!
server.eventDispatch.HANDSHAKE.launch {
// HandshakeMessage.HELLO
// HandshakeMessage.DONE
val messageState = message.state
@ -349,16 +401,16 @@ internal object ServerHandshakePollers {
if (messageState == HandshakeMessage.HELLO) {
// we create a NEW publication for the handshake, which connects directly to the client handshake subscription
// A control endpoint for the subscriptions will cause a periodic service management "heartbeat" to be sent to the
// remote endpoint publication, which permits the remote publication to send us data, thereby getting us around NAT
// we explicitly have the publisher "connect to itself", because we are using MDC to work around NAT.
// It will "auto-connect" to the correct client port (negotiated by the MDC client subscription negotiating on the
// control port of the server)
val publicationUri = uriHandshake(CommonContext.UDP_MEDIA, isReliable)
.controlEndpoint(ipInfo.getAeronPubAddress(isRemoteIpv4) + ":" + message.port)
.controlMode(CommonContext.MDC_CONTROL_MODE_DYNAMIC)
.controlEndpoint(ipInfo.getAeronPubAddress(isRemoteIpv4) + ":" + mdcPortPub)
// this will always connect to the CLIENT handshake subscription!
val publication = try {
driver.addExclusivePublication(publicationUri, message.streamId, logInfo, false)
driver.addPublication(publicationUri, message.streamId, logInfo, false)
} catch (e: Exception) {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
@ -369,8 +421,9 @@ internal object ServerHandshakePollers {
}
try {
// we actually have to wait for it to connect before we continue
driver.waitForConnection(publication, handshakeTimeoutNs, logInfo) { cause ->
// we actually have to wait for it to connect before we continue.
//
driver.waitForConnection(shutdown, publication, handshakeTimeoutNs, logInfo) { cause ->
ServerTimedoutException("$logInfo publication cannot connect with client in ${Sys.getTimePrettyFull(handshakeTimeoutNs)}", cause)
}
} catch (e: Exception) {
@ -382,7 +435,6 @@ internal object ServerHandshakePollers {
return@launch
}
try {
val success = handshake.processUdpHandshakeMessageServer(
server = server,
@ -391,10 +443,12 @@ internal object ServerHandshakePollers {
publicKey = message.publicKey!!,
clientAddress = clientAddress,
clientAddressString = clientAddressString,
portPub = message.port,
portSub = serverPortSub,
mdcPortPub = mdcPortPub,
isReliable = isReliable,
message = message,
aeronLogInfo = logInfo,
connectionFunc = connectionFunc,
logInfo = logInfo,
logger = logger
)
@ -405,14 +459,27 @@ internal object ServerHandshakePollers {
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
driver.close(publication, logInfo)
try {
// we might not be able to close this connection.
driver.close(publication, logInfo)
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
} catch (e: Exception) {
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
driver.close(publication, logInfo)
try {
// we might not be able to close this connection.
driver.close(publication, logInfo)
}
catch (e: Exception) {
driver.close(publication, logInfo)
}
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] Error processing IPC handshake", e))
}
} else {
@ -435,34 +502,46 @@ internal object ServerHandshakePollers {
handshaker = handshaker,
handshakePublication = publication,
message = message,
aeronLogInfo = logInfo,
logInfo = logInfo,
logger = logger
)
} catch (e: Exception) {
server.listenerManager.notifyError(ServerHandshakeException("[$logInfo] Error processing IPC handshake", e))
}
try {
// we might not be able to close this connection.
driver.close(publication, logInfo)
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
// we should immediately remove the logbuffer for this! Aeron will **EVENTUALLY** remove the logbuffer, but if errors
// and connections occur too quickly (within the cleanup/linger period), we can run out of memory!
driver.deleteLogFile(image)
driver.close(publication, logInfo)
}
}
}
suspend fun close() {
fun close() {
publications.forEach { (connectKey, publication) ->
AeronDriver.sessionIdAllocator.free(publication.sessionId())
driver.close(publication, "Server Handshake ($connectKey)")
try {
// we might not be able to close this connection.
driver.close(publication, "Server Handshake ($connectKey)")
}
catch (e: Exception) {
server.listenerManager.notifyError(e)
}
}
publications.clear()
}
}
suspend fun <CONNECTION : Connection> ipc(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
fun <CONNECTION : Connection> ipc(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
val logger = server.logger
val connectionFunc = server.connectionFunc
val config = server.config as ServerConfiguration
val poller = try {
@ -483,18 +562,23 @@ internal object ServerHandshakePollers {
// - re-entrant with the client
val subscription = driver.subscription
val delegate = IpcProc(logger, server, server.aeronDriver, handshake, connectionFunc)
val delegate = IpcProc(logger, server, server.aeronDriver, handshake)
val handler = FragmentAssembler(delegate)
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override suspend fun close() {
override fun close() {
delegate.close()
handler.clear()
driver.close()
logger.info { "Closed IPC poller" }
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPC poller")
}
override val info = "IPC ${driver.info}"
@ -509,9 +593,8 @@ internal object ServerHandshakePollers {
suspend fun <CONNECTION : Connection> ip4(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
fun <CONNECTION : Connection> ip4(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
val logger = server.logger
val connectionFunc = server.connectionFunc
val config = server.config
val isReliable = config.isReliable
@ -520,7 +603,7 @@ internal object ServerHandshakePollers {
aeronDriver = server.aeronDriver,
isIpc = false,
ipInfo = server.ipInfo,
port = server.port,
port = server.port1,
streamIdSub = config.udpId,
sessionIdSub = 9,
logInfo = "HANDSHAKE-IPv4"
@ -533,18 +616,23 @@ internal object ServerHandshakePollers {
// - re-entrant with the client
val subscription = driver.subscription
val delegate = UdpProc(logger, server, server.aeronDriver, handshake, connectionFunc, isReliable)
val delegate = UdpProc(logger, server, server.aeronDriver, handshake, isReliable)
val handler = FragmentAssembler(delegate)
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override suspend fun close() {
override fun close() {
delegate.close()
handler.clear()
driver.close()
logger.info { "Closed IPv4 poller" }
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPv4 poller")
}
override val info = "IPv4 ${driver.info}"
@ -557,9 +645,8 @@ internal object ServerHandshakePollers {
return poller
}
suspend fun <CONNECTION : Connection> ip6(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
fun <CONNECTION : Connection> ip6(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
val logger = server.logger
val connectionFunc = server.connectionFunc
val config = server.config
val isReliable = config.isReliable
@ -568,7 +655,7 @@ internal object ServerHandshakePollers {
aeronDriver = server.aeronDriver,
isIpc = false,
ipInfo = server.ipInfo,
port = server.port,
port = server.port1,
streamIdSub = config.udpId,
sessionIdSub = 0,
logInfo = "HANDSHAKE-IPv6"
@ -581,18 +668,23 @@ internal object ServerHandshakePollers {
// - re-entrant with the client
val subscription = driver.subscription
val delegate = UdpProc(logger, server, server.aeronDriver, handshake, connectionFunc, isReliable)
val delegate = UdpProc(logger, server, server.aeronDriver, handshake, isReliable)
val handler = FragmentAssembler(delegate)
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override suspend fun close() {
override fun close() {
delegate.close()
handler.clear()
driver.close()
logger.info { "Closed IPv4 poller" }
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPv4 poller")
}
override val info = "IPv6 ${driver.info}"
@ -605,10 +697,9 @@ internal object ServerHandshakePollers {
return poller
}
suspend fun <CONNECTION : Connection> ip6Wildcard(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
fun <CONNECTION : Connection> ip6Wildcard(server: Server<CONNECTION>, handshake: ServerHandshake<CONNECTION>): AeronPoller {
val logger = server.logger
val connectionFunc = server.connectionFunc
val config = server.config
val isReliable = config.isReliable
@ -617,7 +708,7 @@ internal object ServerHandshakePollers {
aeronDriver = server.aeronDriver,
isIpc = false,
ipInfo = server.ipInfo,
port = server.port,
port = server.port1,
streamIdSub = config.udpId,
sessionIdSub = 0,
logInfo = "HANDSHAKE-IPv4+6"
@ -630,18 +721,23 @@ internal object ServerHandshakePollers {
// - re-entrant with the client
val subscription = driver.subscription
val delegate = UdpProc(logger, server, server.aeronDriver, handshake, connectionFunc, isReliable)
val delegate = UdpProc(logger, server, server.aeronDriver, handshake, isReliable)
val handler = FragmentAssembler(delegate)
override fun poll(): Int {
return subscription.poll(handler, 1)
}
override suspend fun close() {
override fun close() {
delegate.close()
handler.clear()
driver.close()
logger.info { "Closed IPv4+6 poller" }
try {
driver.unsafeClose()
}
catch (ignored: Exception) {
// we are already shutting down, ignore
}
logger.info("Closed IPv4+6 poller")
}
override val info = "IPv4+6 ${driver.info}"

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.handshake;

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.ipFilter;

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network;

View File

@ -1,82 +0,0 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.ping
import dorkbox.network.connection.Connection
import dorkbox.network.rmi.ResponseManager
import mu.KLogger
import java.util.concurrent.*
/**
* How to handle ping messages
*/
internal class PingManager<CONNECTION : Connection> {
suspend fun manage(connection: CONNECTION, responseManager: ResponseManager, ping: Ping, logger: KLogger) {
if (ping.pongTime == 0L) {
// this is on the server.
ping.pongTime = System.currentTimeMillis()
if (!connection.send(ping)) {
logger.error { "Error returning ping: $ping" }
}
} else {
// this is on the client
ping.finishedTime = System.currentTimeMillis()
val rmiId = ping.packedId
// process the ping message so that our ping callback does something
// this will be null if the ping took longer than XXX seconds and was cancelled
val result = responseManager.getWaiterCallback<suspend Ping.() -> Unit>(rmiId, logger)
if (result != null) {
result(ping)
} else {
logger.error { "Unable to receive ping, there was no waiting response for $ping ($rmiId)" }
}
}
}
/**
* Sends a "ping" packet to measure **ROUND TRIP** time to the remote connection.
*
* @return true if the message was successfully sent by aeron
*/
internal suspend fun ping(
connection: Connection,
pingTimeoutSeconds: Int,
responseManager: ResponseManager,
logger: KLogger,
function: suspend Ping.() -> Unit
): Boolean {
val id = responseManager.prepWithCallback(logger, function)
val ping = Ping()
ping.packedId = id
ping.pingTime = System.currentTimeMillis()
// ALWAYS cancel the ping after XXX seconds
responseManager.cancelRequest(TimeUnit.SECONDS.toMillis(pingTimeoutSeconds.toLong()), id, logger) {
// kill the callback, since we are now "cancelled". If there is a race here (and the response comes at the exact same time)
// we don't care since either it will be null or it won't (if it's not null, it will run the callback)
result = null
}
return connection.send(ping)
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,7 +20,7 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
class PingSerializer: Serializer<Ping>() {
internal class PingSerializer: Serializer<Ping>() {
override fun write(kryo: Kryo, output: Output, ping: Ping) {
output.writeInt(ping.packedId)
output.writeLong(ping.pingTime)

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.ping;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -12,25 +12,6 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided with the distribution.
* - Neither the name of Esoteric Software nor the names of its contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,
* BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
* SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package dorkbox.network.rmi

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,5 +23,5 @@ interface RemoteObjectCallback<Iface> {
/**
* @param remoteObject the remote object (as a proxy object) or null if there was an error creating the RMI object
*/
suspend fun created(remoteObject: Iface)
fun created(remoteObject: Iface)
}

View File

@ -16,8 +16,8 @@
package dorkbox.network.rmi
import dorkbox.collections.LockFreeIntBiMap
import mu.KLogger
import org.agrona.collections.IntArrayList
import org.slf4j.Logger
import java.util.concurrent.locks.*
import kotlin.concurrent.write
@ -59,7 +59,7 @@ import kotlin.concurrent.write
*
* @author Nathan Robinson
*/
internal class RemoteObjectStorage(val logger: KLogger) {
class RemoteObjectStorage(val logger: Logger) {
companion object {
const val INVALID_RMI = 0
@ -77,109 +77,41 @@ internal class RemoteObjectStorage(val logger: KLogger) {
// 2) specifically request a number
// To solve this, we use 3 data structures, because it's also possible to RETURN no-longer needed object ID's (like when a connection closes)
private var objectIdCounter: Int = 1
private val reservedObjectIds = IntArrayList(1, INVALID_RMI)
private val objectIds = IntArrayList(16, INVALID_RMI)
init {
(0..8).forEach { _ ->
objectIds.addInt(objectIdCounter++)
}
}
private fun validate(objectId: Int) {
require(objectId > 0) { "The ID must be greater than 0" }
require(objectId <= 65535) { "The ID must be less than 65,535" }
}
/**
* @return the next ID or 0 (INVALID_RMI, if it's invalid)
*/
private fun unsafeNextId(): Int {
val id = if (objectIds.size > 0) {
objectIds.removeAt(objectIds.size - 1)
} else {
objectIdCounter++
}
if (objectIdCounter > 65535) {
// basically, it's a short (but collections are a LOT easier to deal with if it's an int)
val msg = "Max ID size is 65535, because of how we pack the bytes when sending RMI messages. FATAL ERROR! (too many objects)"
logger.error(msg)
return INVALID_RMI
}
return id
}
/**
* @return the next possible RMI object ID. Either one that is next available, or 0 (INVALID_RMI) if it was invalid
*/
fun nextId(): Int {
idLock.write {
var idToReturn = unsafeNextId()
while (reservedObjectIds.contains(idToReturn)) {
idToReturn = unsafeNextId()
}
return idToReturn
}
}
/**
* Reserves an ID so that other requests for ID's will never return this ID. The number must be > 0 and < 65535
*
* Reservations are permanent and it will ALWAYS be reserved! You cannot "un-reserve" an ID.
*
* If you care about memory and performance, use the ID from "nextId()" instead.
*
* @return false if this ID was not able to be reserved
*/
fun reserveId(id: Int): Boolean {
validate(id)
idLock.write {
val contains = objectIds.remove(id)
if (contains) {
// this id is available for us to use (and was temporarily used before)
return true
}
if (reservedObjectIds.contains(id)) {
// this id is ALREADY used by something else
return false
}
if (objectIdCounter < id) {
// this id is ALREADY used by something else
return false
}
if (objectIdCounter == id) {
// we are available via the counter, so make sure the counter increments
val id = if (objectIds.size > 0) {
objectIds.removeAt(objectIds.size - 1)
} else {
objectIdCounter++
// we still want to mark this as reserved, so fall through
}
// this means that the counter is LARGER than the id (maybe even a LOT larger)
// we just stuff this requested number in a small array and check it whenever we get a new number
reservedObjectIds.add(id)
return true
if (objectIdCounter > 65535) {
// basically, it's a short (but collections are a LOT easier to deal with if it's an int)
val msg = "Max ID size is 65535, because of how we pack the bytes when sending RMI messages. FATAL ERROR! (too many objects)"
logger.error(msg)
return INVALID_RMI
}
return id
}
}
/**
* @return an ID to be used again. Reserved IDs will not be allowed to be returned
*/
fun returnId(id: Int) {
idLock.write {
if (reservedObjectIds.contains(id)) {
logger.error {
"Do not return a reserved ID ($id). Once an ID is reserved, it is permanent."
}
return
}
val shortCheck: Int = (id + 1)
if (shortCheck == objectIdCounter) {
objectIdCounter--
@ -190,9 +122,6 @@ internal class RemoteObjectStorage(val logger: KLogger) {
}
}
/**
* Automatically registers an object with the next available ID to allow a remote connection to access this object via the returned ID
*
@ -202,10 +131,10 @@ internal class RemoteObjectStorage(val logger: KLogger) {
// this will return INVALID_RMI if there are too many in the ObjectSpace
val nextObjectId = nextId()
if (nextObjectId != INVALID_RMI) {
objectMap.put(nextObjectId, `object`)
objectMap[nextObjectId] = `object`
logger.trace {
"Remote object <proxy:$nextObjectId> registered with .toString() = '${`object`}'"
if (logger.isTraceEnabled) {
logger.trace("Remote object <proxy:$nextObjectId> registered with .toString() = '${`object`}'")
}
}
@ -222,10 +151,10 @@ internal class RemoteObjectStorage(val logger: KLogger) {
fun register(`object`: Any, objectId: Int): Boolean {
validate(objectId)
objectMap.put(objectId, `object`)
objectMap[objectId] = `object`
logger.trace {
"Remote object <proxy:$objectId> registered with .toString() = '${`object`}'"
if (logger.isTraceEnabled) {
logger.trace("Remote object <proxy:$objectId> registered with .toString() = '${`object`}'")
}
return true
@ -241,10 +170,13 @@ internal class RemoteObjectStorage(val logger: KLogger) {
val rmiObject = objectMap.remove(objectId) as T?
returnId(objectId)
logger.trace {
"Object <proxy #${objectId}> removed"
if (logger.isTraceEnabled) {
if (rmiObject is RemoteObject<*>) {
logger.trace("Object <proxy #${objectId}> removed")
} else {
logger.trace("Object <proxy-impl #${objectId}> removed")
}
}
@Suppress("UNCHECKED_CAST")
return rmiObject
}
@ -259,8 +191,8 @@ internal class RemoteObjectStorage(val logger: KLogger) {
} else {
returnId(objectId)
logger.trace {
"Object '${remoteObject}' (ID: ${objectId}) removed from RMI system."
if (logger.isTraceEnabled) {
logger.trace("Object '${remoteObject}' (ID: ${objectId}) removed from RMI system.")
}
}
}
@ -277,12 +209,42 @@ internal class RemoteObjectStorage(val logger: KLogger) {
/**
* @return the ID registered for the specified object, or INVALID_RMI if not found.
*/
fun <T> getId(remoteObject: T): Int {
fun <T: Any> getId(remoteObject: T): Int {
// Find an ID with the object.
return objectMap.inverse()[remoteObject]
}
fun close() {
/**
* @return all the saved objects along with their RMI ID. This is so we can restore these later on
*/
fun getAll(): List<Pair<Int, Any>> {
return objectMap.entries.map { it -> Pair(it.key, it.value) }.toList()
}
/**
* @return all the saved RMI implementation objects along with their RMI ID. This is so we can restore these later on
*/
fun restoreAll(implObjects: List<Pair<Int, Any>>) {
idLock.write {
// this is a bit slow, but we have to re-inject objects. THIS happens before the connection is initialized, so we know
// these RMI ids are available
implObjects.forEach {
objectMap.remove(it.first)
}
objectIdCounter += implObjects.size
}
// now we have to put our items back into the backing map.
implObjects.forEach {
objectMap[it.first] = it.second
}
}
fun clear() {
objectMap.clear()
}
}

View File

@ -15,16 +15,11 @@
*/
package dorkbox.network.rmi
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EventDispatcher
import dorkbox.objectPool.ObjectPool
import dorkbox.objectPool.SuspendingPool
import dorkbox.objectPool.Pool
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.delay
import mu.KLogger
import mu.KotlinLogging
import org.slf4j.Logger
import java.util.concurrent.locks.*
import kotlin.concurrent.read
import kotlin.concurrent.write
/**
@ -36,27 +31,26 @@ import kotlin.concurrent.write
* - these are just looped around in a ring buffer.
* - these are stored here as int, however these are REALLY shorts and are int-packed when transferring data on the wire
*
* (By default, for RMI...)
* (By default, for RMI/Ping/SendSync...)
* - 0 is reserved for INVALID
* - 1 is reserved for ASYNC (the response will never be sent back, and we don't wait for it)
*
*/
internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int = 2) {
companion object {
val TIMEOUT_EXCEPTION = Exception().apply { stackTrace = arrayOf<StackTraceElement>() }
private val logger: KLogger = KotlinLogging.logger(ResponseManager::class.java.simpleName)
val TIMEOUT_EXCEPTION = TimeoutException().apply { stackTrace = arrayOf<StackTraceElement>() }
}
private val rmiWaitersInUse = atomic(0)
private val waiterCache: SuspendingPool<ResponseWaiter>
private val responseWaitersInUse = atomic(0)
private val waiterCache: Pool<ResponseWaiter>
private val pendingLock = ReentrantReadWriteLock()
private val pending = arrayOfNulls<Any?>(maxValuesInCache+1) // +1 because it's possible to have the value 65535 in the cache
init {
require(maxValuesInCache <= 65535) { "The maximum size for the values in the response manager is 65535"}
require(maxValuesInCache > minimumValue) { "< $minimumValue (0 and 1 for RMI) are reserved"}
require(minimumValue > 0) { "The minimum value $minimumValue must be > 0"}
require(maxValuesInCache > minimumValue) { "< $minimumValue (0 and 1 for RMI/Ping/SendSync) are reserved"}
require(minimumValue > 1) { "The minimum value $minimumValue must be > 1"}
// create a shuffled list of ID's. This operation is ONLY performed ONE TIME per endpoint!
val ids = mutableListOf<ResponseWaiter>()
@ -69,17 +63,20 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
ids.shuffle()
// populate the array of randomly assigned ID's + waiters.
waiterCache = ObjectPool.suspending(ids)
waiterCache = ObjectPool.blocking(ids)
}
/**
* Called when we receive the answer for our initial request. If no response data, then the pending rmi data entry is deleted
*
* resume any pending remote object method invocations (if they are not async, or not manually waiting)
*
* NOTE: async RMI will never call this (because async doesn't return a response)
*/
suspend fun notifyWaiter(id: Int, result: Any?, logger: KLogger) {
logger.trace { "[RM] notify: $id" }
fun notifyWaiter(id: Int, result: Any?, logger: Logger) {
if (logger.isTraceEnabled) {
logger.trace("[RM] notify: [$id]")
}
val previous = pendingLock.write {
val previous = pending[id]
@ -89,7 +86,9 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
// if NULL, since either we don't exist (because we were async), or it was cancelled
if (previous is ResponseWaiter) {
logger.trace { "[RM] valid-cancel: $id" }
if (logger.isTraceEnabled) {
logger.trace("[RM] valid-notify: [$id]")
}
// this means we were NOT timed out! (we cannot be timed out here)
previous.doNotify()
@ -101,8 +100,10 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
*
* This is ONLY called when we want to get the data out of the stored entry, because we are operating ASYNC. (pure RMI async is different)
*/
suspend fun <T> getWaiterCallback(id: Int, logger: KLogger): T? {
logger.trace { "[RM] get-callback: $id" }
fun <T> removeWaiterCallback(id: Int, logger: Logger): T? {
if (logger.isTraceEnabled) {
logger.trace("[RM] get-callback: [$id]")
}
val previous = pendingLock.write {
val previous = pending[id]
@ -115,8 +116,9 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
val result = previous.result
// always return this to the cache!
previous.result = null
waiterCache.put(previous)
rmiWaitersInUse.getAndDecrement()
responseWaitersInUse.getAndDecrement()
return result as T
}
@ -129,12 +131,14 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
*
* We ONLY care about the ID to get the correct response info. If there is no response, the ID can be ignored.
*/
suspend fun prep(logger: KLogger): ResponseWaiter {
fun prep(logger: Logger): ResponseWaiter {
val waiter = waiterCache.take()
rmiWaitersInUse.getAndIncrement()
logger.trace { "[RM] prep in-use: ${rmiWaitersInUse.value}" }
responseWaitersInUse.getAndIncrement()
if (logger.isTraceEnabled) {
logger.trace("[RM] prep in-use: [${waiter.id}] ${responseWaitersInUse.value}")
}
// this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
// this will initialize the result
waiter.prep()
pendingLock.write {
@ -149,12 +153,14 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
*
* We ONLY care about the ID to get the correct response info. If there is no response, the ID can be ignored.
*/
suspend fun prepWithCallback(logger: KLogger, function: Any): Int {
fun prepWithCallback(logger: Logger, function: Any): Int {
val waiter = waiterCache.take()
rmiWaitersInUse.getAndIncrement()
logger.trace { "[RM] prep in-use: ${rmiWaitersInUse.value}" }
responseWaitersInUse.getAndIncrement()
if (logger.isTraceEnabled) {
logger.trace("[RM] prep in-use: [${waiter.id}] ${responseWaitersInUse.value}")
}
// this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
// this will initialize the result
waiter.prep()
// assign the callback that will be notified when the return message is received
@ -170,81 +176,20 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
}
/**
* Cancels the RMI request in the given timeout, the callback is executed inside the read lock
*/
suspend fun cancelRequest(timeoutMillis: Long, id: Int, logger: KLogger, onCancelled: ResponseWaiter.() -> Unit) {
EventDispatcher.RESPONSE_MANAGER.launch {
delay(timeoutMillis) // this will always wait. if this job is cancelled, this will immediately stop waiting
// check if we have a result or not
pendingLock.read {
val maybeResult = pending[id]
if (maybeResult is ResponseWaiter) {
logger.trace { "[RM] timeout ($timeoutMillis) with callback cancel: $id" }
maybeResult.cancel()
onCancelled(maybeResult)
}
}
}
}
/**
* We only wait for a reply if we are SYNC.
*
* ASYNC does not send a response
* ASYNC does not send a response and does not call this method
*
* @return the result (can be null) or timeout exception
*/
suspend fun waitForReply(
responseWaiter: ResponseWaiter,
timeoutMillis: Long,
logger: KLogger,
connection: Connection
): Any? {
fun getReply(responseWaiter: ResponseWaiter, timeoutMillis: Long, logger: Logger): Any? {
val id = RmiUtils.unpackUnsignedRight(responseWaiter.id)
logger.trace { "[RM] wait: $id" }
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
if (timeoutMillis > 0) {
val responseTimeoutJob = EventDispatcher.RESPONSE_MANAGER.launch {
delay(timeoutMillis) // this will always wait. if this job is cancelled, this will immediately stop waiting
// check if we have a result or not
val maybeResult = pendingLock.read { pending[id] }
if (maybeResult is ResponseWaiter) {
logger.trace { "[RM] timeout ($timeoutMillis) cancel: $id" }
maybeResult.cancel()
}
}
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will suspend and either
// A) get response
// B) timeout
responseWaiter.doWait()
// always cancel the timeout
responseTimeoutJob.cancel()
} else {
// wait for the response --- THIS WAITS FOREVER (there is no timeout)!
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will suspend and
// A) get response
responseWaiter.doWait()
if (logger.isTraceEnabled) {
logger.trace("[RM] get: [$id]")
}
// deletes the entry in the map
val resultOrWaiter = pendingLock.write {
val previous = pending[id]
@ -253,29 +198,50 @@ internal class ResponseManager(maxValuesInCache: Int = 65534, minimumValue: Int
}
// always return the waiter to the cache
responseWaiter.result = null
waiterCache.put(responseWaiter)
rmiWaitersInUse.getAndDecrement()
responseWaitersInUse.getAndDecrement()
if (resultOrWaiter is ResponseWaiter) {
logger.trace { "[RM] timeout cancel ($timeoutMillis): $id" }
return if (connection.isClosed() || connection.isClosedViaAeron()) {
null
} else {
TIMEOUT_EXCEPTION
if (logger.isTraceEnabled) {
logger.trace("[RM] timeout cancel: [$id] ($timeoutMillis)")
}
// always throw an exception if we timeout. EVEN if the connection is closed, we want to make sure to raise awareness!
return TIMEOUT_EXCEPTION
}
return resultOrWaiter
}
suspend fun close() {
fun abort(responseWaiter: ResponseWaiter, logger: Logger) {
val id = RmiUtils.unpackUnsignedRight(responseWaiter.id)
if (logger.isTraceEnabled) {
logger.trace("[RM] abort: [$id]")
}
// deletes the entry in the map
pendingLock.write {
pending[id] = null
}
// always return the waiter to the cache
responseWaiter.result = null
waiterCache.put(responseWaiter)
responseWaitersInUse.getAndDecrement()
}
// This is only closed when shutting down the client/server.
fun close(logger: Logger) {
// technically, this isn't closing it, so much as it's cleaning it out
logger.debug { "Closing the response manager for RMI" }
if (logger.isDebugEnabled) {
logger.debug("Closing the response manager")
}
// wait for responses, or wait for timeouts!
while (rmiWaitersInUse.value > 0) {
delay(100)
while (responseWaitersInUse.value > 0) {
Thread.sleep(50)
}
pendingLock.write {

View File

@ -15,57 +15,73 @@
*/
package dorkbox.network.rmi
import kotlinx.coroutines.channels.Channel
import kotlinx.atomicfu.locks.withLock
import java.util.concurrent.*
import java.util.concurrent.locks.*
data class ResponseWaiter(val id: Int) {
// this is bi-directional waiting. The method names to not reflect this, however there is no possibility of race conditions w.r.t. waiting
// https://stackoverflow.com/questions/55421710/how-to-suspend-kotlin-coroutine-until-notified
// https://kotlinlang.org/docs/reference/coroutines/channels.html
// "receive' suspends until another coroutine invokes "send"
// and
// "send" suspends until another coroutine invokes "receive".
//
// these are wrapped in a try/catch, because cancel will cause exceptions to be thrown (which we DO NOT want)
@Volatile
var channel: Channel<Unit> = Channel(Channel.RENDEZVOUS)
private val lock = ReentrantLock()
private val condition = lock.newCondition()
@Volatile
var isCancelled = false
private var signalled = false
// holds the RMI result or callback. This is ALWAYS accessed from within a lock (so no synchronize/volatile/etc necessary)!
@Volatile
var result: Any? = null
/**
* this will replace the waiter if it was cancelled (waiters are not valid if cancelled)
* this will set the result to null
*/
fun prep() {
if (isCancelled) {
isCancelled = false
channel = Channel(0)
}
result = null
signalled = false
}
suspend fun doNotify() {
/**
* Waits until another thread invokes "doWait"
*/
fun doNotify() {
try {
channel.send(Unit)
lock.withLock {
signalled = true
condition.signal()
}
} catch (ignored: Throwable) {
}
}
suspend fun doWait() {
/**
* Waits a specific amount of time until another thread invokes "doNotify"
*/
fun doWait() {
try {
channel.receive()
lock.withLock {
if (signalled) {
return
}
condition.await()
}
} catch (ignored: Throwable) {
}
}
fun cancel() {
try {
isCancelled = true
channel.cancel()
/**
* Waits a specific amount of time until another thread invokes "doNotify"
*/
fun doWait(timeoutMs: Long): Boolean {
return try {
lock.withLock {
if (signalled) {
true
} else {
condition.await(timeoutMs, TimeUnit.MILLISECONDS)
}
}
} catch (ignored: Throwable) {
// we were interrupted BEFORE the timeout, so technically, the timeout did not elapse.
true
}
}
}

View File

@ -15,13 +15,15 @@
*/
package dorkbox.network.rmi
import com.conversantmedia.util.collection.FixedStack
import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.rmi.ResponseManager.Companion.TIMEOUT_EXCEPTION
import dorkbox.network.rmi.messages.MethodRequest
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import mu.KLogger
import kotlinx.coroutines.yield
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.util.*
@ -74,56 +76,28 @@ internal class RmiClient(val isGlobal: Boolean,
@Suppress("UNCHECKED_CAST")
private val EMPTY_ARRAY: Array<Any> = Collections.EMPTY_LIST.toTypedArray() as Array<Any>
private val safeAsyncState: ThreadLocal<Boolean?> = ThreadLocal.withInitial {
null
private val safeAsyncStack: ThreadLocal<FixedStack<Boolean?>> = ThreadLocal.withInitial {
FixedStack(64)
}
private const val charPrim = 0.toChar()
private const val shortPrim = 0.toShort()
private const val bytePrim = 0.toByte()
private fun returnAsyncOrSync(isAsync: Boolean, method: Method, returnValue: Any?): Any? {
if (isAsync) {
// if we are async then we return immediately.
// If you want the response value, disable async!
val returnType = method.returnType
if (returnType.isPrimitive) {
return when (returnType) {
Boolean::class.javaPrimitiveType -> java.lang.Boolean.FALSE
Int::class.javaPrimitiveType -> 0
Float::class.javaPrimitiveType -> 0.0f
Char::class.javaPrimitiveType -> charPrim
Long::class.javaPrimitiveType -> 0L
Short::class.javaPrimitiveType -> shortPrim
Byte::class.javaPrimitiveType -> bytePrim
Double::class.javaPrimitiveType -> 0.0
else -> null // void type
}
}
return null
}
else {
return returnValue
}
}
@Suppress("UNCHECKED_CAST")
private fun syncMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>) {
val action = args[0] as Any.() -> Unit
val prev = safeAsyncState.get()
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
// the docs cover that (and say, `don't do this`)
safeAsyncState.set(false)
safeAsyncStack.get().push(isAsync)
// the `sync` method is always a unit function - we want to execute that unit function directly - this way we can control
// exactly how sync state is preserved.
try {
action(proxy)
} finally {
if (prev != isAsync) {
safeAsyncState.remove()
}
safeAsyncStack.get().pop()
}
}
@ -131,8 +105,6 @@ internal class RmiClient(val isGlobal: Boolean,
private fun syncSuspendMethodAction(isAsync: Boolean, proxy: RemoteObject<*>, args: Array<Any>): Any? {
val action = args[0] as suspend Any.() -> Unit
val prev = safeAsyncState.get()
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
// We will use this for our coroutine context instead of running on a new coroutine
val suspendCoroutineArg = args.last()
@ -141,20 +113,20 @@ internal class RmiClient(val isGlobal: Boolean,
val suspendFunction: suspend () -> Any? = {
// the sync state is treated as a stack. Manually changing the state via `.async` field setter can cause problems, but
// the docs cover that (and say, `don't do this`)
withContext(safeAsyncState.asContextElement(isAsync)) {
withContext(safeAsyncStack.asContextElement()) {
yield() // must have an actually suspending call here!
safeAsyncStack.get().push(isAsync)
action(proxy)
}
}
// function suspension works differently !!
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
val result = (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
Continuation(continuation.context) {
val any = try {
it.getOrNull()
} finally {
if (prev != isAsync) {
safeAsyncState.remove()
}
safeAsyncStack.get().pop()
}
when (any) {
is Exception -> {
@ -168,6 +140,10 @@ internal class RmiClient(val isGlobal: Boolean,
}
}
})
runBlocking(safeAsyncStack.asContextElement()) {}
return result
}
}
@ -177,54 +153,13 @@ internal class RmiClient(val isGlobal: Boolean,
@Volatile private var enableHashCode = false
@Volatile private var enableEquals = false
// if we are ASYNC, then this method immediately returns
private suspend fun sendRequest(isAsync: Boolean, invokeMethod: MethodRequest, logger: KLogger): Any? {
// there is a STRANGE problem, where if we DO NOT respond/reply to method invocation, and immediate invoke multiple methods --
// the "server" side can have out-of-order method invocation. There are 2 ways to solve this
// 1) make the "server" side single threaded
// 2) make the "client" side wait for execution response (from the "server"). <--- this is what we are using.
//
// Because we have to ALWAYS make the client wait (unless 'isAsync' is true), we will always be returning, and will always have a
// response (even if it is a void response). This simplifies our response mask, and lets us use more bits for storing the
// response ID
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
invokeMethod.isGlobal = isGlobal
return if (isAsync) {
// If we are async, we ignore the response (don't invoke the response manager at all)....
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, RemoteObjectStorage.ASYNC_RMI)
connection.send(invokeMethod)
null
} else {
// The response, even if there is NOT one (ie: not void) will always return a thing (so our code execution is in lockstep -- unless it is ASYNC)
val rmiWaiter = responseManager.prep(logger)
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, rmiWaiter.id)
connection.send(invokeMethod)
responseManager.waitForReply(rmiWaiter, timeoutMillis, logger, connection)
}
}
@Suppress("DuplicatedCode", "UNCHECKED_CAST")
/**
* @throws Exception
*/
override fun invoke(proxy: Any, method: Method, args: Array<Any>?): Any? {
val localAsync =
safeAsyncState.get() // value set via obj.sync {}
safeAsyncStack.get().peek() // value set via obj.sync {}
?:
isAsync // the value was set via obj.sync = xyz
@ -283,7 +218,7 @@ internal class RmiClient(val isGlobal: Boolean,
return null
}
else -> throw Exception("Invocation handler could not find RemoteObject method for ${method.name}")
else -> throw RmiException("Invocation handler could not find RemoteObject method for ${method.name}")
}
} else {
when (method) {
@ -304,6 +239,8 @@ internal class RmiClient(val isGlobal: Boolean,
}
}
val connection = connection
// setup the RMI request
val invokeMethod = MethodRequest()
@ -314,6 +251,67 @@ internal class RmiClient(val isGlobal: Boolean,
// this should be accessed via the KRYO class ID + method index (both are SHORT, and can be packed)
invokeMethod.cachedMethod = cachedMethods.first { it.method == method }
// there is a STRANGE problem, where if we DO NOT respond/reply to method invocation, and immediate invoke multiple methods --
// the "server" side can have out-of-order method invocation. There are 2 ways to solve this
// 1) make the "server" side single threaded
// 2) make the "client" side wait for execution response (from the "server"). <--- this is what we are using.
//
// Because we have to ALWAYS make the client wait (unless 'isAsync' is true), we will always be returning, and will always have a
// response (even if it is a void response). This simplifies our response mask, and lets us use more bits for storing the
// response ID
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
invokeMethod.isGlobal = isGlobal
if (localAsync) {
// If we are async, we ignore the response (don't invoke the response manager at all)....
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, RemoteObjectStorage.ASYNC_RMI)
val success = connection.send(invokeMethod)
if (!success) {
throw RmiException("Unable to send async message, an error occurred during the send process")
}
// if we are async then we return immediately (but must return the correct type!)
// If you want the response value, disable async!
val returnType = method.returnType
if (returnType.isPrimitive) {
return when (returnType) {
Boolean::class.javaPrimitiveType -> java.lang.Boolean.FALSE
Int::class.javaPrimitiveType -> 0
Float::class.javaPrimitiveType -> 0.0f
Char::class.javaPrimitiveType -> charPrim
Long::class.javaPrimitiveType -> 0L
Short::class.javaPrimitiveType -> shortPrim
Byte::class.javaPrimitiveType -> bytePrim
Double::class.javaPrimitiveType -> 0.0
else -> null // void type
}
}
return null
}
val logger = connection.logger
//
// this is all SYNC code
//
// The response, even if there is NOT one (ie: not void) will always return a thing (so our code execution is in lockstep -- unless it is ASYNC)
val responseWaiter = responseManager.prep(logger)
invokeMethod.packedId = RmiUtils.packShorts(rmiObjectId, responseWaiter.id)
val success = connection.send(invokeMethod)
if (!success) {
responseManager.abort(responseWaiter, logger)
throw RmiException("Unable to send message, an error occurred during the send process")
}
// if a 'suspend' function is called, then our last argument is a 'Continuation' object
// We will use this for our coroutine context instead of running on a new coroutine
@ -324,15 +322,47 @@ internal class RmiClient(val isGlobal: Boolean,
val continuation = suspendCoroutineArg as Continuation<Any?>
val suspendFunction: suspend () -> Any? = {
sendRequest(localAsync, invokeMethod, connection.logger)
// NOTE: once something ELSE is suspending, we can remove the `yield`
yield() // if this is not here, it will not work (something must actually suspend!)
// NOTE: this is blocking!
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
if (timeoutMillis > 0) {
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for:
// A) get response
// B) timeout
if (!responseWaiter.doWait(timeoutMillis)) {
// if we timeout, it doesn't matter since we'll be removing the waiter from the array anyways,
// so no signal can occur, or a signal won't matter
responseManager.abort(responseWaiter, logger)
TIMEOUT_EXCEPTION
} else {
responseManager.getReply(responseWaiter, timeoutMillis, logger)
}
} else {
// wait for the response --- THIS WAITS FOREVER (there is no timeout)!
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for one
// A) get response
responseWaiter.doWait()
responseManager.getReply(responseWaiter, timeoutMillis, logger)
}
}
// function suspension works differently !!
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(
Continuation(continuation.context) {
// function suspension works differently. THIS IS A TRAMPOLINE TO CALL SUSPEND !!
return (suspendFunction as Function1<Continuation<Any?>, Any?>).invoke(Continuation(continuation.context) {
val any = it.getOrNull()
when (any) {
ResponseManager.TIMEOUT_EXCEPTION -> {
TIMEOUT_EXCEPTION -> {
val fancyName = RmiUtils.makeFancyMethodName(method)
val exception = TimeoutException("Response timed out: $fancyName")
// from top down, clean up the coroutine stack
@ -340,7 +370,7 @@ internal class RmiClient(val isGlobal: Boolean,
continuation.resumeWithException(exception)
}
is Exception -> {
is Throwable -> {
// for co-routines, it's impossible to get a legit stacktrace without impacting general performance,
// so we just don't do it.
// RmiUtils.cleanStackTraceForProxy(Exception(), any)
@ -348,17 +378,43 @@ internal class RmiClient(val isGlobal: Boolean,
}
else -> {
continuation.resume(returnAsyncOrSync(localAsync, method, any))
continuation.resume(any)
}
}
})
} else {
val any = runBlocking {
sendRequest(localAsync, invokeMethod, connection.logger)
// NOTE: this is blocking!
// NOTE: we ALWAYS send a response from the remote end (except when async).
//
// 'async' -> DO NOT WAIT (no response)
// 'timeout > 0' -> WAIT w/ TIMEOUT
// 'timeout == 0' -> WAIT FOREVER
if (timeoutMillis > 0) {
// wait for the response.
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for:
// A) get response
// B) timeout
if (!responseWaiter.doWait(timeoutMillis)) {
// if we timeout, it doesn't matter since we'll be removing the waiter from the array anyways,
// so no signal can occur, or a signal won't matter
responseManager.abort(responseWaiter, logger)
throw TIMEOUT_EXCEPTION
}
} else {
// wait for the response --- THIS WAITS FOREVER (there is no timeout)!
//
// If the response is ALREADY here, the doWait() returns instantly (with result)
// if no response yet, it will wait for one
// A) get response
responseWaiter.doWait()
}
val any = responseManager.getReply(responseWaiter, timeoutMillis, logger)
when (any) {
ResponseManager.TIMEOUT_EXCEPTION -> {
TIMEOUT_EXCEPTION -> {
val fancyName = RmiUtils.makeFancyMethodName(method)
val exception = TimeoutException("Response timed out: $fancyName")
// from top down, clean up the coroutine stack
@ -366,7 +422,7 @@ internal class RmiClient(val isGlobal: Boolean,
throw exception
}
is Exception -> {
is Throwable -> {
// reconstruct the stack trace, so the calling method knows where the method invocation happened, and can trace the call
// this stack will ALWAYS run up to this method (so we remove from the top->down, to get to the call site)
RmiUtils.cleanStackTraceForProxy(Exception(), any)
@ -374,8 +430,7 @@ internal class RmiClient(val isGlobal: Boolean,
}
else -> {
// attempt to return a proper value
return returnAsyncOrSync(localAsync, method, any)
return any
}
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi
/**
* Thrown when there is a generic RMI error (for example, if the RMI message could not be sent, or there is an action on an RMI object that is invalid
*/
class RmiException : Exception {
constructor() : super() {}
constructor(message: String?, cause: Throwable?) : super(message, cause) {}
constructor(message: String?) : super(message) {}
constructor(cause: Throwable?) : super(cause) {}
}

View File

@ -15,6 +15,7 @@
*/
package dorkbox.network.rmi
import dorkbox.classUtil.ClassHelper
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ListenerManager
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
@ -22,13 +23,11 @@ import dorkbox.network.exceptions.RMIException
import dorkbox.network.rmi.messages.ConnectionObjectCreateRequest
import dorkbox.network.rmi.messages.ConnectionObjectCreateResponse
import dorkbox.network.rmi.messages.ConnectionObjectDeleteRequest
import dorkbox.network.rmi.messages.ConnectionObjectDeleteResponse
import dorkbox.network.serialization.Serialization
import dorkbox.util.classes.ClassHelper
import mu.KLogger
import org.slf4j.Logger
class RmiManagerConnections<CONNECTION: Connection> internal constructor(
private val logger: KLogger,
private val logger: Logger,
private val responseManager: ResponseManager,
private val listenerManager: ListenerManager<CONNECTION>,
private val serialization: Serialization<CONNECTION>,
@ -38,7 +37,7 @@ class RmiManagerConnections<CONNECTION: Connection> internal constructor(
/**
* called on "server"
*/
suspend fun onConnectionObjectCreateRequest(serialization: Serialization<CONNECTION>, connection: CONNECTION, message: ConnectionObjectCreateRequest) {
fun onConnectionObjectCreateRequest(serialization: Serialization<CONNECTION>, connection: CONNECTION, message: ConnectionObjectCreateRequest) {
val callbackId = RmiUtils.unpackLeft(message.packedIds)
val kryoId = RmiUtils.unpackRight(message.packedIds)
val objectParameters = message.objectParameters
@ -54,13 +53,20 @@ class RmiManagerConnections<CONNECTION: Connection> internal constructor(
listenerManager.notifyError(connection, newException)
ConnectionObjectCreateResponse(RmiUtils.packShorts(callbackId, RemoteObjectStorage.INVALID_RMI))
} else {
val rmiId = connection.rmi.saveImplObject(implObject)
if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val newException = RMIException("Unable to create RMI object, invalid RMI ID")
listenerManager.notifyError(connection, newException)
}
try {
val rmiId = connection.rmi.saveImplObject(implObject)
if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val newException = RMIException("Unable to create RMI object, invalid RMI ID")
listenerManager.notifyError(connection, newException)
}
ConnectionObjectCreateResponse(RmiUtils.packShorts(callbackId, rmiId))
ConnectionObjectCreateResponse(RmiUtils.packShorts(callbackId, rmiId))
}
catch (e: Exception) {
val newException = RMIException("Error saving the RMI implementation object!", e)
listenerManager.notifyError(connection, newException)
ConnectionObjectCreateResponse(RmiUtils.packShorts(callbackId, RemoteObjectStorage.INVALID_RMI))
}
}
// we send the message ALWAYS, because the client needs to know it worked or not
@ -70,7 +76,7 @@ class RmiManagerConnections<CONNECTION: Connection> internal constructor(
/**
* called on "client"
*/
suspend fun onConnectionObjectCreateResponse(connection: CONNECTION, message: ConnectionObjectCreateResponse) {
fun onConnectionObjectCreateResponse(connection: CONNECTION, message: ConnectionObjectCreateResponse) {
val callbackId = RmiUtils.unpackLeft(message.packedIds)
val rmiId = RmiUtils.unpackRight(message.packedIds)
@ -85,15 +91,14 @@ class RmiManagerConnections<CONNECTION: Connection> internal constructor(
val rmi = connection.rmi as RmiSupportConnection<CONNECTION>
val callback = rmi.removeCallback(callbackId)
val interfaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(RemoteObjectCallback::class.java, callback.javaClass, 0)
val interfaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(RemoteObjectCallback::class.java, callback.javaClass, 0) ?: callback.javaClass
// create the client-side proxy object, if possible. This MUST be an object that is saved for the connection
val proxyObject = rmi.getProxyObject(false, connection, rmiId, interfaceClass)
// this should be executed on a NEW coroutine!
try {
callback(proxyObject)
} catch (exception: Exception) {
callback(proxyObject, rmiId)
} catch (exception: Throwable) {
exception.cleanStackTrace()
val newException = RMIException(exception)
listenerManager.notifyError(connection, newException)
@ -103,7 +108,7 @@ class RmiManagerConnections<CONNECTION: Connection> internal constructor(
/**
* called on "client" or "server"
*/
suspend fun onConnectionObjectDeleteRequest(connection: CONNECTION, message: ConnectionObjectDeleteRequest) {
fun onConnectionObjectDeleteRequest(connection: CONNECTION, message: ConnectionObjectDeleteRequest) {
val rmiId = message.rmiId
// we only delete the impl object if the RMI id is valid!
@ -116,28 +121,6 @@ class RmiManagerConnections<CONNECTION: Connection> internal constructor(
// it DOESN'T matter which "side" we are, just delete both (RMI id's must always represent the same object on both sides)
connection.rmi.removeProxyObject(rmiId)
connection.rmi.removeImplObject<Any?>(rmiId)
// tell the "other side" to delete the proxy/impl object
connection.send(ConnectionObjectDeleteResponse(rmiId))
}
/**
* called on "client" or "server"
*/
fun onConnectionObjectDeleteResponse(connection: CONNECTION, message: ConnectionObjectDeleteResponse) {
val rmiId = message.rmiId
// we only create the proxy + execute the callback if the RMI id is valid!
if (rmiId == RemoteObjectStorage.INVALID_RMI) {
val newException = RMIException("Unable to create RMI object, invalid RMI ID")
listenerManager.notifyError(connection, newException)
return
}
// it DOESN'T matter which "side" we are, just delete both (RMI id's must always represent the same object on both sides)
connection.rmi.removeProxyObject(rmiId)
connection.rmi.removeImplObject<Any?>(rmiId)
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,18 +16,14 @@
package dorkbox.network.rmi
import dorkbox.network.connection.Connection
import dorkbox.network.rmi.messages.ConnectionObjectCreateRequest
import dorkbox.network.rmi.messages.ConnectionObjectCreateResponse
import dorkbox.network.rmi.messages.ConnectionObjectDeleteRequest
import dorkbox.network.rmi.messages.ConnectionObjectDeleteResponse
import dorkbox.network.rmi.messages.MethodRequest
import dorkbox.network.rmi.messages.MethodResponse
import dorkbox.network.rmi.messages.*
import dorkbox.network.serialization.Serialization
import mu.KLogger
import kotlinx.coroutines.runBlocking
import org.slf4j.Logger
import java.lang.reflect.Proxy
import java.util.*
internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiObjectCache(logger) {
internal class RmiManagerGlobal<CONNECTION: Connection>(logger: Logger) : RmiObjectCache(logger) {
companion object {
/**
@ -93,13 +89,13 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
* Manages ALL OF THE RMI SCOPES
*/
@Suppress("DuplicatedCode")
suspend fun processMessage(
fun processMessage(
serialization: Serialization<CONNECTION>,
connection: CONNECTION,
message: Any,
rmiConnectionSupport: RmiManagerConnections<CONNECTION>,
responseManager: ResponseManager,
logger: KLogger
logger: Logger
) {
when (message) {
is ConnectionObjectCreateRequest -> {
@ -120,12 +116,6 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
*/
rmiConnectionSupport.onConnectionObjectDeleteRequest(connection, message)
}
is ConnectionObjectDeleteResponse -> {
/**
* called on "client" or "server"
*/
rmiConnectionSupport.onConnectionObjectDeleteResponse(connection, message)
}
is MethodRequest -> {
/**
* Invokes the method on the object and, sends the result back to the connection that made the invocation request.
@ -142,7 +132,9 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
val args = message.args
val sendResponse = rmiId != RemoteObjectStorage.ASYNC_RMI // async is always with a '1', and we should NOT send a message back if it is '1'
logger.trace { "RMI received: $rmiId" }
if (logger.isTraceEnabled) {
logger.trace("RMI received: $rmiId")
}
val implObject: Any? = if (isGlobal) {
getImplObject(rmiObjectId)
@ -164,10 +156,18 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
return
}
logger.trace {
if (logger.isTraceEnabled) {
var argString = ""
if (args != null) {
argString = Arrays.deepToString(args)
if (!args.isNullOrEmpty()) {
// long byte arrays have SERIOUS problems!
argString = Arrays.deepToString(args.map {
when (it) {
is ByteArray -> { "${it::class.java.simpleName}(length=${it.size})"}
is Array<*> -> { "${it::class.java.simpleName}(length=${it.size})"}
is Collection<*> -> { "${it::class.java.simpleName}(length=${it.size})"}
else -> { it }
}
}.toTypedArray())
argString = argString.substring(1, argString.length - 1)
}
@ -181,7 +181,9 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
// did we override our cached method? THIS IS NOT COMMON.
stringBuilder.append(" [Connection method override]")
}
stringBuilder.toString()
logger.trace(stringBuilder.toString())
}
var result: Any?
@ -189,47 +191,48 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
if (isCoroutine) {
// https://stackoverflow.com/questions/47654537/how-to-run-suspend-method-via-reflection
// https://discuss.kotlinlang.org/t/calling-coroutines-suspend-functions-via-reflection/4672
var suspendResult = kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn<Any?> { cont ->
// if we are a coroutine, we have to replace the LAST arg with the coroutine object
// we KNOW this is OK, because a continuation arg will always be there!
args!![args.size - 1] = cont
runBlocking {
var suspendResult = kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn<Any?> { cont ->
// if we are a coroutine, we have to replace the LAST arg with the coroutine object
// we KNOW this is OK, because a continuation arg will always be there!
args!![args.size - 1] = cont
var insideResult: Any?
try {
// args!! is safe to do here (even though it doesn't make sense)
insideResult = cachedMethod.invoke(connection, implObject, args)
} catch (ex: Exception) {
insideResult = ex.cause
// added to prevent a stack overflow when references is false, (because 'cause' == "this").
// See:
// https://groups.google.com/forum/?fromgroups=#!topic/kryo-users/6PDs71M1e9Y
if (insideResult == null) {
insideResult = ex
var insideResult: Any?
try {
insideResult = cachedMethod.invoke(connection, implObject, args)
} catch (ex: Throwable) {
insideResult = ex.cause
// added to prevent a stack overflow when references is false, (because 'cause' == "this").
// See:
// https://groups.google.com/forum/?fromgroups=#!topic/kryo-users/6PDs71M1e9Y
if (insideResult == null) {
insideResult = ex
}
else {
insideResult.initCause(null)
}
}
else {
insideResult.initCause(null)
}
}
insideResult
}
if (suspendResult === kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED) {
// we were suspending, and the stack will resume when possible, then it will call the response below
}
else {
if (suspendResult === Unit) {
// kotlin suspend returns, that DO NOT have a return value, REALLY return kotlin.Unit. This means there is no
// return value!
suspendResult = null
} else if (suspendResult is Exception) {
RmiUtils.cleanStackTraceForImpl(suspendResult, true)
logger.error("Connection ${connection.id}", suspendResult)
insideResult
}
if (sendResponse) {
val rmiMessage = returnRmiMessage(message, suspendResult, logger)
connection.send(rmiMessage)
if (suspendResult === kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED) {
// we were suspending, and the stack will resume when possible, then it will call the response below
}
else {
if (suspendResult === Unit) {
// kotlin suspend returns, that DO NOT have a return value, REALLY return kotlin.Unit. This means there is no
// return value!
suspendResult = null
} else if (suspendResult is Throwable) {
RmiUtils.cleanStackTraceForImpl(suspendResult, true)
logger.error("Connection ${connection.id}", suspendResult)
}
if (sendResponse) {
val rmiMessage = returnRmiMessage(message, suspendResult, logger)
connection.send(rmiMessage)
}
}
}
}
@ -271,8 +274,10 @@ internal class RmiManagerGlobal<CONNECTION: Connection>(logger: KLogger) : RmiOb
}
}
private fun returnRmiMessage(message: MethodRequest, result: Any?, logger: KLogger): MethodResponse {
logger.trace { "RMI return. Send: ${RmiUtils.unpackUnsignedRight(message.packedId)}" }
private fun returnRmiMessage(message: MethodRequest, result: Any?, logger: Logger): MethodResponse {
if (logger.isTraceEnabled) {
logger.trace("RMI return. Send: ${RmiUtils.unpackUnsignedRight(message.packedId)}")
}
val rmiMessage = MethodResponse()
rmiMessage.packedId = message.packedId

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
*/
package dorkbox.network.rmi
import mu.KLogger
import org.slf4j.Logger
/**
* Cache for implementation and proxy objects.
@ -23,11 +23,15 @@ import mu.KLogger
* The impl/proxy objects CANNOT be stored in the same data structure, because their IDs are not tied to the same ID source (and there
* would be conflicts in the data structure)
*/
open class RmiObjectCache(logger: KLogger) {
open class RmiObjectCache(val logger: Logger) {
private val implObjects = RemoteObjectStorage(logger)
/**
* This object will be saved again if we send the object "over the wire", automatically!
*
* So if we DELETE the object (on side A), and then later on side A sends the object to side B, then side A will save it again when it sends.
*
* @return the newly registered RMI ID for this object. [RemoteObjectStorage.INVALID_RMI] means it was invalid (an error log will be emitted)
*/
internal fun saveImplObject(rmiObject: Any): Int {
@ -50,7 +54,7 @@ open class RmiObjectCache(logger: KLogger) {
}
/**
* Removes the object using the ID registered.
* Removes the object using the registered ID.
*
* @return the object or null if not found
*/
@ -61,7 +65,26 @@ open class RmiObjectCache(logger: KLogger) {
/**
* @return the ID registered for the specified object, or INVALID_RMI if not found.
*/
internal fun <T> getId(implObject: T): Int {
internal fun <T: Any> getId(implObject: T): Int {
return implObjects.getId(implObject)
}
/**
* @return all the saved RMI implementation objects along with their RMI ID. This is used by session management in order to preserve RMI functionality.
*/
internal fun getAllImplObjects(): List<Pair<Int, Any>> {
return implObjects.getAll()
}
/**
* all the saved RMI implementation objects along with their RMI ID. This is used by session management in order to preserve RMI functionality.
*/
internal fun restoreImplObjects(implObjects: List<Pair<Int, Any>>) {
this.implObjects.restoreAll(implObjects)
}
internal open fun clear() {
this.implObjects.clear()
}
}

View File

@ -16,14 +16,15 @@
package dorkbox.network.rmi
import dorkbox.classUtil.ClassHelper
import dorkbox.collections.LockFreeIntMap
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import dorkbox.network.rmi.messages.ConnectionObjectCreateRequest
import dorkbox.network.rmi.messages.ConnectionObjectDeleteRequest
import dorkbox.network.serialization.Serialization
import dorkbox.util.classes.ClassHelper
import mu.KLogger
import org.slf4j.Logger
import java.lang.reflect.Proxy
/**
* Only the server can create or delete a global object
@ -34,20 +35,32 @@ import mu.KLogger
*
* Connection scope objects can be remotely created or deleted by either end of the connection. Only the server can create/delete a global scope object
*/
class RmiSupportConnection<CONNECTION: Connection> internal constructor(
private val logger: KLogger,
private val connection: CONNECTION,
private val responseManager: ResponseManager,
private val serialization: Serialization<CONNECTION>,
class RmiSupportConnection<CONNECTION: Connection> : RmiObjectCache {
private val connection: CONNECTION
private val responseManager: ResponseManager
val serialization: Serialization<CONNECTION>
private val getGlobalAction: (connection: CONNECTION, objectId: Int, interfaceClass: Class<*>) -> Any
) : RmiObjectCache(logger) {
internal constructor(
logger: Logger,
connection: CONNECTION,
responseManager: ResponseManager,
serialization: Serialization<CONNECTION>,
getGlobalAction: (connection: CONNECTION, objectId: Int, interfaceClass: Class<*>) -> Any
) : super(logger) {
this.connection = connection
this.responseManager = responseManager
this.serialization = serialization
this.getGlobalAction = getGlobalAction
this.proxyObjects = LockFreeIntMap<RemoteObject<*>>()
this.remoteObjectCreationCallbacks = RemoteObjectStorage(logger)
}
// It is critical that all of the RMI proxy objects are unique, and are saved/cached PER CONNECTION. These cannot be shared between connections!
private val proxyObjects = LockFreeIntMap<RemoteObject<*>>()
private val proxyObjects: LockFreeIntMap<RemoteObject<*>>
// callbacks for when a REMOTE object has been created
private val remoteObjectCreationCallbacks = RemoteObjectStorage(logger)
private val remoteObjectCreationCallbacks: RemoteObjectStorage
/**
* Removes a proxy object from the system
@ -66,15 +79,56 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
proxyObjects.put(rmiId, remoteObject)
}
private fun <Iface> registerCallback(callback: suspend Iface.() -> Unit): Int {
private fun <Iface> registerCallback(callback: Iface.(Int) -> Unit): Int {
return remoteObjectCreationCallbacks.register(callback)
}
internal fun removeCallback(callbackId: Int): suspend Any.() -> Unit {
internal fun removeCallback(callbackId: Int): Any.(Int) -> Unit {
// callback's area always correct, because we track them ourselves.
return remoteObjectCreationCallbacks.remove(callbackId)!!
}
internal fun getAllCallbacks(): List<Pair<Int, Any.(Int) -> Unit>> {
@Suppress("UNCHECKED_CAST")
return remoteObjectCreationCallbacks.getAll() as List<Pair<Int, Any.(Int) -> Unit>>
}
internal fun restoreCallbacks(oldProxyCallbacks: List<Pair<Int, Any.(Int) -> Unit>>) {
remoteObjectCreationCallbacks.restoreAll(oldProxyCallbacks)
}
/**
* @return all the RMI proxy objects used by this connection. This is used by session management in order to preserve RMI functionality.
*/
internal fun getAllProxyObjects(): List<RemoteObject<*>> {
return proxyObjects.values.toList()
}
/**
* Recreate all the proxy objects for this connection. This is used by session management in order to preserve RMI functionality.
*/
internal fun recreateProxyObjects(oldProxyObjects: List<RemoteObject<*>>) {
oldProxyObjects.forEach {
// the interface we care about is ALWAYS the second one!
val iface = it.javaClass.interfaces[1]
val kryoId = connection.endPoint.serialization.getKryoIdForRmiClient(iface)
val rmiClient = Proxy.getInvocationHandler(it) as RmiClient
val rmiId = rmiClient.rmiObjectId
val proxyObject = RmiManagerGlobal.createProxyObject(
rmiClient.isGlobal,
connection,
serialization,
responseManager,
kryoId, rmiId,
iface
)
saveProxyObject(rmiId, proxyObject)
}
}
/**
@ -134,7 +188,8 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
/**
* Creates create a new proxy object where the implementation exists in a remote connection.
*
* The callback will be notified when the remote object has been created.
* We use a callback to notify us when the object is ready. We can't "create this on the fly" because we
* have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here.
*
* Methods that return a value will throw [TimeoutException] if the response is not received with the response timeout [RemoteObject.responseTimeout].
*
@ -146,20 +201,26 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
*
* @see RemoteObject
*/
suspend fun <Iface> create(vararg objectParameters: Any?, callback: suspend Iface.() -> Unit) {
val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function1::class.java, callback.javaClass, 0)
fun <Iface> create(vararg objectParameters: Any?, callback: Iface.(rmiId: Int) -> Unit) {
val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function1::class.java, callback.javaClass, 0) ?: callback.javaClass
val kryoId = serialization.getKryoIdForRmiClient(iFaceClass)
@Suppress("UNCHECKED_CAST")
objectParameters as Array<Any?>
createRemoteObject(connection, kryoId, objectParameters, callback)
val callbackId = registerCallback(callback)
// There is no rmiID yet, because we haven't created it!
val message = ConnectionObjectCreateRequest(RmiUtils.packShorts(callbackId, kryoId), objectParameters)
connection.send(message)
}
/**
* Creates create a new proxy object where the implementation exists in a remote connection.
*
* The callback will be notified when the remote object has been created.
* We use a callback to notify us when the object is ready. We can't "create this on the fly" because we
* have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here.
*
* NOTE:: Methods can throw [TimeoutException] if the response is not received with the response timeout [RemoteObject.responseTimeout].
*
@ -167,15 +228,20 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
* will have the proxy object replaced with the registered (non-proxy) object.
*
* If one wishes to change the default behavior, cast the object to access the different methods.
* ie: `val remoteObject = test as RemoteObject`
* ie: `val remoteObject = RemoteObject.cast(obj)`
*
* @see RemoteObject
*/
suspend fun <Iface> create(callback: suspend Iface.() -> Unit) {
val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function1::class.java, callback.javaClass, 0)
fun <Iface> create(callback: Iface.(rmiId: Int) -> Unit) {
val iFaceClass = ClassHelper.getGenericParameterAsClassForSuperClass(Function1::class.java, callback.javaClass, 0) ?: callback.javaClass
val kryoId = serialization.getKryoIdForRmiClient(iFaceClass)
createRemoteObject(connection, kryoId, null, callback)
val callbackId = registerCallback(callback)
// There is no rmiID yet, because we haven't created it!
val message = ConnectionObjectCreateRequest(RmiUtils.packShorts(callbackId, kryoId), null)
connection.send(message)
}
/**
@ -185,7 +251,7 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
*
* Future '.get' requests will succeed, as they do not check the existence of the implementation object (methods called on it will fail)
*/
suspend fun delete(rmiObjectId: Int) {
fun delete(rmiObjectId: Int) {
// we only create the proxy + execute the callback if the RMI id is valid!
if (rmiObjectId == RemoteObjectStorage.INVALID_RMI) {
val exception = Exception("Unable to delete RMI object!")
@ -194,6 +260,12 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
return
}
// it DOESN'T matter which "side" we are, just delete both (RMI id's must always represent the same object on both sides)
removeProxyObject(rmiObjectId)
removeImplObject<Any?>(rmiObjectId)
// ALWAYS send a message because we don't know if we are the "client" or the "server" - and we want ALL sides cleaned up
connection.send(ConnectionObjectDeleteRequest(rmiObjectId))
}
@ -241,6 +313,14 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
}
/**
* Casts this remote object (specified by it's RMI ID) to the "RemoteObject" type, so that those methods can more easily be called
*/
inline fun <reified T> cast(rmiId: Int): RemoteObject<T> {
val obj = get<T>(rmiId)
@Suppress("UNCHECKED_CAST")
return obj as RemoteObject<T>
}
/**
@ -294,25 +374,9 @@ class RmiSupportConnection<CONNECTION: Connection> internal constructor(
return proxyObject as Iface
}
/**
* on the "client" to create a connection-specific remote object (that exists on the server)
*/
private suspend fun <Iface> createRemoteObject(connection: CONNECTION, kryoId: Int, objectParameters: Array<Any?>?, callback: suspend Iface.() -> Unit) {
val callbackId = registerCallback(callback)
// There is no rmiID yet, because we haven't created it!
val message = ConnectionObjectCreateRequest(RmiUtils.packShorts(callbackId, kryoId), objectParameters)
// We use a callback to notify us when the object is ready. We can't "create this on the fly" because we
// have to wait for the object to be created + ID to be assigned on the remote system BEFORE we can create the proxy instance here.
// this means we are creating a NEW object on the server
connection.send(message)
}
internal fun clear() {
override fun clear() {
super.clear()
proxyObjects.clear()
remoteObjectCreationCallbacks.close()
remoteObjectCreationCallbacks.clear()
}
}

View File

@ -18,7 +18,7 @@ package dorkbox.network.rmi
import dorkbox.network.connection.Connection
import dorkbox.network.connection.ListenerManager.Companion.cleanStackTrace
import mu.KLogger
import org.slf4j.Logger
/**
* Only the server can create or delete a global object
@ -30,7 +30,7 @@ import mu.KLogger
* Connection scope objects can be remotely created or deleted by either end of the connection. Only the server can create/delete a global scope object
*/
class RmiSupportServer<CONNECTION : Connection> internal constructor(
private val logger: KLogger,
private val logger: Logger,
private val rmiGlobalSupport: RmiManagerGlobal<CONNECTION>
) {
/**

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,9 +18,9 @@ package dorkbox.network.rmi
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.reflectasm.MethodAccess
import dorkbox.classUtil.ClassHelper
import dorkbox.network.connection.Connection
import dorkbox.util.classes.ClassHelper
import mu.KLogger
import org.slf4j.Logger
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.util.*
@ -75,7 +75,7 @@ object RmiUtils {
throw RuntimeException("Two methods with same signature! ('$o1Name', '$o2Name'")
}
private fun getReflectAsmMethod(logger: KLogger, clazz: Class<*>): MethodAccess? {
private fun getReflectAsmMethod(logger: Logger, clazz: Class<*>): MethodAccess? {
return try {
val methodAccess = MethodAccess.get(clazz)
@ -95,7 +95,7 @@ object RmiUtils {
* @param iFace this is never null.
* @param impl this is NULL on the rmi "client" side. This is NOT NULL on the "server" side (where the object lives)
*/
fun getCachedMethods(logger: KLogger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array<CachedMethod> {
fun getCachedMethods(logger: Logger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array<CachedMethod> {
var ifaceAsmMethodAccess: MethodAccess? = null
var implAsmMethodAccess: MethodAccess? = null
@ -495,7 +495,7 @@ object RmiUtils {
*
* We do this because these stack frames are not useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
*/
fun cleanStackTraceForProxy(localException: Exception, remoteException: Exception? = null) {
fun cleanStackTraceForProxy(localException: Throwable, remoteException: Throwable? = null) {
val myClassName = RmiClient::class.java.name
val stackTrace = localException.stackTrace
var newStartIndex = 0
@ -553,7 +553,7 @@ object RmiUtils {
*
* Neither of these are useful in resolving exception handling from a users perspective, and only clutter the stacktrace.
*/
fun cleanStackTraceForImpl(exception: Exception, isSuspendFunction: Boolean) {
fun cleanStackTraceForImpl(exception: Throwable, isSuspendFunction: Boolean) {
val packageName = RmiUtils::class.java.packageName
val stackTrace = exception.stackTrace
@ -578,7 +578,8 @@ object RmiUtils {
// step 2: starting at newEndIndex -> 0, find the start of reflection information (we are java11+ ONLY, so this is easy)
for (i in newEndIndex downTo 0) {
// this will be either JAVA reflection or ReflectASM reflection
val stackModule = stackTrace[i].moduleName
val stackTraceElement: StackTraceElement = stackTrace[i]
val stackModule = stackTraceElement.moduleName
if (stackModule == "java.base") {
newEndIndex--
} else {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import kotlin.coroutines.Continuation
class ContinuationSerializer() : Serializer<Continuation<*>>() {
internal class ContinuationSerializer() : Serializer<Continuation<*>>() {
init {
isImmutable = true
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@ package dorkbox.network.rmi.messages
import dorkbox.network.rmi.CachedMethod
import dorkbox.network.rmi.RmiUtils
import java.util.*
/**
* Internal message to invoke methods remotely.
@ -47,6 +48,21 @@ class MethodRequest : RmiMessage {
var args: Array<Any>? = null
override fun toString(): String {
return "MethodRequest(isGlobal=$isGlobal, rmiObjectId=${RmiUtils.unpackLeft(packedId)}, rmiId=${RmiUtils.unpackRight(packedId)}, cachedMethod=$cachedMethod, args=${args?.contentToString()})"
var argString = ""
val args1 = args
if (!args1.isNullOrEmpty()) {
// long byte arrays have SERIOUS problems!
argString = Arrays.deepToString(args1.map {
when (it) {
is ByteArray -> { "${it::class.java.simpleName}(length=${it.size})"}
is Array<*> -> { "${it::class.java.simpleName}(length=${it.size})"}
is Collection<*> -> { "${it::class.java.simpleName}(length=${it.size})"}
else -> { it }
}
}.toTypedArray())
argString = argString.substring(1, argString.length - 1)
}
return "MethodRequest(isGlobal=$isGlobal, rmiObjectId=${RmiUtils.unpackLeft(packedId)}, rmiId=${RmiUtils.unpackRight(packedId)}, cachedMethod=$cachedMethod, args=${argString})"
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -12,7 +12,8 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
@ -42,7 +43,7 @@ import com.esotericsoftware.kryo.io.Output
import dorkbox.network.connection.Connection
import dorkbox.network.rmi.CachedMethod
import dorkbox.network.rmi.RmiUtils
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.KryoReader
import org.agrona.collections.Int2ObjectHashMap
import java.lang.reflect.Method
@ -50,7 +51,7 @@ import java.lang.reflect.Method
* Internal message to invoke methods remotely.
*/
@Suppress("ConstantConditionIf")
class MethodRequestSerializer<CONNECTION: Connection>(private val methodCache: Int2ObjectHashMap<Array<CachedMethod>>) : Serializer<MethodRequest>() {
internal class MethodRequestSerializer<CONNECTION: Connection>(private val methodCache: Int2ObjectHashMap<Array<CachedMethod>>) : Serializer<MethodRequest>() {
override fun write(kryo: Kryo, output: Output, methodRequest: MethodRequest) {
val method = methodRequest.cachedMethod
@ -83,7 +84,7 @@ class MethodRequestSerializer<CONNECTION: Connection>(private val methodCache: I
val methodIndex = RmiUtils.unpackRight(methodInfo)
val isGlobal = input.readBoolean()
kryo as KryoExtra<CONNECTION>
kryo as KryoReader<CONNECTION>
val cachedMethod = try {
methodCache[methodClassId][methodIndex]

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,7 +20,7 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
class MethodResponseSerializer() : Serializer<MethodResponse>() {
internal class MethodResponseSerializer() : Serializer<MethodResponse>() {
override fun write(kryo: Kryo, output: Output, response: MethodResponse) {
output.writeInt(response.packedId)
kryo.writeClassAndObject(output, response.result)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,7 +23,7 @@ import dorkbox.network.connection.Connection
import dorkbox.network.connection.EndPoint
import dorkbox.network.rmi.RmiClient
import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.KryoReader
import java.lang.reflect.Proxy
/**
@ -56,7 +56,7 @@ import java.lang.reflect.Proxy
* If the impl object 'lives' on the SERVER, then the server must tell the client about the iface ID
*/
@Suppress("UNCHECKED_CAST")
class RmiClientSerializer<CONNECTION: Connection>: Serializer<Any>() {
internal class RmiClientSerializer<CONNECTION: Connection>: Serializer<Any>() {
override fun write(kryo: Kryo, output: Output, proxyObject: Any) {
val handler = Proxy.getInvocationHandler(proxyObject) as RmiClient
output.writeBoolean(handler.isGlobal)
@ -67,7 +67,7 @@ class RmiClientSerializer<CONNECTION: Connection>: Serializer<Any>() {
val isGlobal = input.readBoolean()
val objectId = input.readInt(true)
kryo as KryoExtra<CONNECTION>
kryo as KryoReader<CONNECTION>
val endPoint: EndPoint<CONNECTION> = kryo.connection.endPoint as EndPoint<CONNECTION>
return if (isGlobal) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2020 dorkbox, llc
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -12,7 +12,8 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
* Copyright (c) 2008, Nathan Sweet
* All rights reserved.
*
@ -41,7 +42,8 @@ import com.esotericsoftware.kryo.io.Output
import dorkbox.network.connection.Connection
import dorkbox.network.rmi.RemoteObjectStorage
import dorkbox.network.rmi.RmiSupportConnection
import dorkbox.network.serialization.KryoExtra
import dorkbox.network.serialization.KryoReader
import dorkbox.network.serialization.KryoWriter
/**
* This is to manage serializing RMI objects across the wire...
@ -73,10 +75,10 @@ import dorkbox.network.serialization.KryoExtra
* If the impl object 'lives' on the SERVER, then the server must tell the client about the iface ID
*/
@Suppress("UNCHECKED_CAST")
class RmiServerSerializer<CONNECTION: Connection> : Serializer<Any>(false) {
internal class RmiServerSerializer<CONNECTION: Connection> : Serializer<Any>(false) {
override fun write(kryo: Kryo, output: Output, `object`: Any) {
val kryoExtra = kryo as KryoExtra<CONNECTION>
val kryoExtra = kryo as KryoWriter<CONNECTION>
val connection = kryoExtra.connection
val rmi = connection.rmi
// have to write what the rmi ID is ONLY. A remote object sent via a connection IS ONLY a connection-scope object!
@ -96,7 +98,7 @@ class RmiServerSerializer<CONNECTION: Connection> : Serializer<Any>(false) {
}
override fun read(kryo: Kryo, input: Input, interfaceClass: Class<*>): Any? {
val kryoExtra = kryo as KryoExtra<CONNECTION>
val kryoExtra = kryo as KryoReader<CONNECTION>
val rmiId = input.readInt(true)
val connection = kryoExtra.connection

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi.messages;

View File

@ -0,0 +1,17 @@
/*
* Copyright 2023 dorkbox, llc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi;

Some files were not shown because too many files have changed in this diff Show More