diff --git a/src/dorkbox/network/rmi/RmiUtils.kt b/src/dorkbox/network/rmi/RmiUtils.kt index 76e08478..c46b8f9a 100644 --- a/src/dorkbox/network/rmi/RmiUtils.kt +++ b/src/dorkbox/network/rmi/RmiUtils.kt @@ -19,10 +19,11 @@ import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.reflectasm.MethodAccess import dorkbox.network.connection.Connection import dorkbox.util.classes.ClassHelper -import org.slf4j.Logger +import mu.KLogger import java.lang.reflect.Method import java.lang.reflect.Modifier import java.util.* +import kotlin.coroutines.Continuation /** * Utility methods for creating a method cache for a class or interface. @@ -73,14 +74,16 @@ object RmiUtils { throw RuntimeException("Two methods with same signature! ('$o1Name', '$o2Name'") } - private fun getReflectAsmMethod(logger: Logger, clazz: Class<*>): MethodAccess? { + private fun getReflectAsmMethod(logger: KLogger, clazz: Class<*>): MethodAccess? { return try { val methodAccess = MethodAccess.get(clazz) - if (methodAccess.methodNames.size == 0 && methodAccess.parameterTypes.size == 0 && methodAccess.returnTypes.size == 0) { + if (methodAccess.methodNames.isEmpty() && methodAccess.parameterTypes.isEmpty() && methodAccess.returnTypes.isEmpty()) { // there was NOTHING that reflectASM found, so trying to use it doesn't do us any good null - } else methodAccess + } else { + methodAccess + } } catch (e: Exception) { logger.error("Unable to create ReflectASM method access", e) null @@ -91,7 +94,7 @@ object RmiUtils { * @param iFace this is never null. * @param impl this is NULL on the rmi "client" side. This is NOT NULL on the "server" side (where the object lives) */ - fun getCachedMethods(logger: Logger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array { + fun getCachedMethods(logger: KLogger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array { var ifaceAsmMethodAccess: MethodAccess? = null var implAsmMethodAccess: MethodAccess? = null @@ -124,6 +127,9 @@ object RmiUtils { if (asmEnabled) { ifaceAsmMethodAccess = getReflectAsmMethod(logger, iFace) } + + val hasConnectionOverrideMethods = hasOverwriteMethodWithConnectionParam(implMethods) + for (i in 0 until size) { val method = methods[i] @@ -132,18 +138,14 @@ object RmiUtils { // Store the serializer for each final parameter. // this is ONLY for the ORIGINAL method, not the overridden one. - val serializers = arrayOfNulls?>(parameterTypes.size) - var ii = 0 - val nn = parameterTypes.size - while (ii < nn) { - if (kryo.isFinal(parameterTypes[ii])) { - serializers[ii] = kryo.getSerializer(parameterTypes[ii]) - } - ii++ - } - @Suppress("UNCHECKED_CAST") /// we know this is correct, so it is safe to suppress the warning - serializers as Array> + val serializers = arrayOfNulls>(parameterTypes.size) + parameterTypes.forEachIndexed { index, parameterType -> + val paramClazz = parameterTypes[index] + if (kryo.isFinal(paramClazz) || paramClazz === Continuation::class.java) { + serializers[index] = kryo.getSerializer(parameterType) + } + } // copy because they can be overridden var cachedMethod: CachedMethod? = null @@ -156,18 +158,20 @@ object RmiUtils { var overwrittenMethod: Method? = null // this is how we detect if the method has been changed from the interface -> implementation + connection parameter - if (implMethods != null) { + if (implMethods != null && hasConnectionOverrideMethods) { overwrittenMethod = getOverwriteMethodWithConnectionParam(implMethods, method) if (overwrittenMethod != null) { - if (logger.isTraceEnabled) { - logger.trace("Overridden method: {}.{}", impl, method.name) + logger.trace { + "Overridden method: $impl.${method.name}" } // still might be null! iface_OR_ImplMethodAccess = implAsmMethodAccess } } + + if (canUseAsm) { try { val index = if (overwrittenMethod != null) { @@ -209,6 +213,27 @@ object RmiUtils { return cachedMethods as Array } + /** + * Check to see if there are ANY methods in this class that start with a "Connection" parameter. + */ + private fun hasOverwriteMethodWithConnectionParam(implMethods: Array?): Boolean { + if (implMethods == null) { + return false + } + + // maybe there is a method that starts with a "Connection" parameter. + for (implMethod in implMethods) { + val implParameters = implMethod.parameterTypes + + // check if the FIRST parameter is "Connection" + if (implParameters.isNotEmpty() && ClassHelper.hasInterface(Connection::class.java, implParameters[0])) { + return true + } + } + + return false + } + /** * This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above) * NOTE: does not null check @@ -245,15 +270,15 @@ object RmiUtils { for (implMethod in implMethods) { val implName = implMethod.name - val implTypes = implMethod.parameterTypes - val implLength = implTypes.size + val implParameters = implMethod.parameterTypes + val implLength = implParameters.size + if (origLength != implLength || origName != implName) { continue } - // checkLength > 0 - val shouldBeConnectionType = implTypes[0] - if (ClassHelper.hasInterface(Connection::class.java, shouldBeConnectionType)) { + // check if the FIRST parameter is "Connection" + if (ClassHelper.hasInterface(Connection::class.java, implParameters[0])) { // now we check to see if our "check" method is equal to our "cached" method + Connection if (implLength == 1) { // we only have "Connection" as a parameter @@ -261,7 +286,7 @@ object RmiUtils { } else { var found = true for (k in 1 until implLength) { - if (origTypes[k - 1] != implTypes[k]) { + if (origTypes[k - 1] != implParameters[k]) { // make sure all the parameters match. Cannot use arrays.equals(*), because one will have "Connection" as // a parameter - so we check that the rest match found = false