Network/src/dorkbox/network/rmi/RmiUtils.kt

419 lines
17 KiB
Kotlin

/*
* Copyright 2019 dorkbox, llc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dorkbox.network.rmi
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.reflectasm.MethodAccess
import dorkbox.network.connection.Connection
import dorkbox.util.classes.ClassHelper
import org.slf4j.Logger
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.util.*
/**
* Utility methods for creating a method cache for a class or interface.
*
* Additionally, this will override methods on the implementation so that methods can be called with a [Connection] parameter as the
* first parameter, with all other parameters being equal to the interface.
*
* This is to support calling RMI methods from an interface (that does pass the connection reference) to
* an implType, that DOES pass the connection reference. The remote side (that initiates the RMI calls), MUST use
* the interface, and the implType may override the method, so that we add the connection as the first in
* the list of parameters.
*
* for example:
* Interface: foo(String x)
*
* Impl: foo(String x) -> not called
* Impl: foo(Connection c, String x) -> this is called instead
*
* The implType (if it exists, with the same name, and with the same signature + connection parameter) will be called from the interface
* instead of the method that would NORMALLY be called.
*/
object RmiUtils {
private val METHOD_COMPARATOR = Comparator<Method> { o1, o2 -> // Methods are sorted so they can be represented as an index.
val o1Name = o1.name
val o2Name = o2.name
var diff = o1Name.compareTo(o2Name)
if (diff != 0) {
return@Comparator diff
}
val argTypes1 = o1.parameterTypes
val argTypes2 = o2.parameterTypes
if (argTypes1.size > argTypes2.size) {
return@Comparator 1
}
if (argTypes1.size < argTypes2.size) {
return@Comparator -1
}
for (i in argTypes1.indices) {
diff = argTypes1[i].name
.compareTo(argTypes2[i].name)
if (diff != 0) {
return@Comparator diff
}
}
throw RuntimeException("Two methods with same signature! ('$o1Name', '$o2Name'")
}
private fun getReflectAsmMethod(logger: Logger, clazz: Class<*>): MethodAccess? {
return try {
val methodAccess = MethodAccess.get(clazz)
if (methodAccess.methodNames.size == 0 && methodAccess.parameterTypes.size == 0 && methodAccess.returnTypes.size == 0) {
// there was NOTHING that reflectASM found, so trying to use it doesn't do us any good
null
} else methodAccess
} catch (e: Exception) {
logger.error("Unable to create ReflectASM method access", e)
null
}
}
/**
* @param iFace this is never null.
* @param impl this is NULL on the rmi "client" side. This is NOT NULL on the "server" side (where the object lives)
*/
fun getCachedMethods(logger: Logger, kryo: Kryo, asmEnabled: Boolean, iFace: Class<*>, impl: Class<*>?, classId: Int): Array<CachedMethod> {
var ifaceAsmMethodAccess: MethodAccess? = null
var implAsmMethodAccess: MethodAccess? = null
// RMI is **ALWAYS** based upon an interface, so we must always make sure to get the methods of the interface, instead of the
// implementation, otherwise we will have the wrong order of methods, so invoking a method by it's index will fail.
val methods = getMethods(iFace)
val size = methods.size
val cachedMethods = arrayOfNulls<CachedMethod>(size)
val implMethods: Array<Method>?
if (impl != null) {
require(!impl.isInterface) { "Cannot have type as an interface, it must be an implementation" }
implMethods = getMethods(impl)
// reflectASM
// doesn't work on android (set correctly by the serialization manager)
// can't get any method from the 'Object' object (we get from the interface, which is NOT 'Object')
// and it MUST be public (iFace is always public)
if (asmEnabled) {
implAsmMethodAccess = getReflectAsmMethod(logger, impl)
}
} else {
implMethods = null
}
// reflectASM
// doesn't work on android (set correctly by the serialization manager)
// can't get any method from the 'Object' object (we get from the interface, which is NOT 'Object')
// and it MUST be public (iFace is always public)
if (asmEnabled) {
ifaceAsmMethodAccess = getReflectAsmMethod(logger, iFace)
}
for (i in 0 until size) {
val method = methods[i]
val declaringClass = method.declaringClass
val parameterTypes = method.parameterTypes
// Store the serializer for each final parameter.
// this is ONLY for the ORIGINAL method, not the overridden one.
val serializers = arrayOfNulls<Serializer<*>?>(parameterTypes.size)
var ii = 0
val nn = parameterTypes.size
while (ii < nn) {
if (kryo.isFinal(parameterTypes[ii])) {
serializers[ii] = kryo.getSerializer(parameterTypes[ii])
}
ii++
}
@Suppress("UNCHECKED_CAST") /// we know this is correct, so it is safe to suppress the warning
serializers as Array<Serializer<*>>
// copy because they can be overridden
var cachedMethod: CachedMethod? = null
@Suppress("LocalVariableName")
var iface_OR_ImplMethodAccess = ifaceAsmMethodAccess
// reflectAsm doesn't like "Object" class methods
val canUseAsm = asmEnabled && method.declaringClass != Any::class.java
var overwrittenMethod: Method? = null
// this is how we detect if the method has been changed from the interface -> implementation + connection parameter
if (implMethods != null) {
overwrittenMethod = getOverwriteMethodWithConnectionParam(implMethods, method)
if (overwrittenMethod != null) {
if (logger.isTraceEnabled) {
logger.trace("Overridden method: {}.{}", impl, method.name)
}
// still might be null!
iface_OR_ImplMethodAccess = implAsmMethodAccess
}
}
if (canUseAsm) {
try {
val index = if (overwrittenMethod != null) {
// have to take into account the overwritten method's first parameter will ALWAYS be "Connection"
iface_OR_ImplMethodAccess!!.getIndex(method.name, *overwrittenMethod.parameterTypes)
} else {
iface_OR_ImplMethodAccess!!.getIndex(method.name, *parameterTypes)
}
cachedMethod = CachedAsmMethod(
methodAccessIndex = index,
methodAccess = iface_OR_ImplMethodAccess,
name = method.name,
method = method,
methodIndex = i,
methodClassId = classId,
serializers = serializers)
} catch (e: Exception) {
logger.trace("Unable to use ReflectAsm for {}.{} (using java reflection instead)", declaringClass, method.name, e)
}
}
if (cachedMethod == null) {
cachedMethod = CachedMethod(
method = method,
methodIndex = i,
methodClassId = classId,
serializers = serializers)
}
// this MIGHT be null, but if it is not, this is the method we will invoke INSTEAD of the "normal" method
cachedMethod.overriddenMethod = overwrittenMethod
cachedMethods[i] = cachedMethod
}
// force the type, because we KNOW it is ok to do so
@Suppress("UNCHECKED_CAST")
return cachedMethods as Array<CachedMethod>
}
/**
* This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above)
* NOTE: does not null check
*
* @param implMethods methods from the implementation
* @param origMethods methods from the interface
*/
private fun overwriteMethodsWithConnectionParam(implMethods: Array<Method>, origMethods: Array<Method>) {
var i = 0
val origMethodsSize = origMethods.size
while (i < origMethodsSize) {
val origMethod = origMethods[i]
val overwriteMethodsWithConnectionParam = getOverwriteMethodWithConnectionParam(implMethods, origMethod)
if (overwriteMethodsWithConnectionParam != null) {
origMethods[i] = overwriteMethodsWithConnectionParam
}
i++
}
}
/**
* This will overwrite an original (iface based) method with a method from the implementation ONLY if there is the extra 'Connection' parameter (as per above)
* NOTE: does not null check
*
* @param implMethods methods from the implementation
* @param origMethod original method from the interface
*/
private fun getOverwriteMethodWithConnectionParam(implMethods: Array<Method>, origMethod: Method): Method? {
val origName = origMethod.name
val origTypes = origMethod.parameterTypes
val origLength = origTypes.size + 1
for (implMethod in implMethods) {
val implName = implMethod.name
val implTypes = implMethod.parameterTypes
val implLength = implTypes.size
if (origLength != implLength || origName != implName) {
continue
}
// checkLength > 0
val shouldBeConnectionType = implTypes[0]
if (ClassHelper.hasInterface(Connection::class.java, shouldBeConnectionType)) {
// now we check to see if our "check" method is equal to our "cached" method + Connection
if (implLength == 1) {
// we only have "Connection" as a parameter
return implMethod
} else {
var found = true
for (k in 1 until implLength) {
if (origTypes[k - 1] != implTypes[k]) {
// make sure all the parameters match. Cannot use arrays.equals(*), because one will have "Connection" as
// a parameter - so we check that the rest match
found = false
break
}
}
if (found) {
return implMethod
}
}
}
}
return null
}
/**
* This will methods from an interface (for RMI), and from an implementation (for "connection" overriding the method signature).
*
* @return an array list of all found methods for this class
*/
fun getMethods(type: Class<*>): Array<Method> {
val allMethods = ArrayList<Method>()
val accessibleMethods: MutableMap<String, ArrayList<Array<Class<*>>>> = HashMap()
val classes = LinkedList<Class<*>>()
classes.add(type)
// explicitly add Object.class because that can always be called, because it is common to 100% of all java objects (and it's methods
// are not added via parentClass.getMethods()
classes.add(Any::class.java)
var nextClass: Class<*>
while (!classes.isEmpty()) {
nextClass = classes.removeFirst()
val methods = nextClass.methods
for (i in methods.indices) {
val method = methods[i]
// static and private methods cannot be called via RMI.
val modifiers = method.modifiers
if (Modifier.isStatic(modifiers)) {
continue
}
if (Modifier.isPrivate(modifiers)) {
continue
}
if (method.isSynthetic) {
continue
}
// methods that have been over-ridden by another method cannot be called.
// the first one in the map, is the "highest" level method, and is what can be called.
val name = method.name
val types = method.parameterTypes // length 0 if there are no parameters
var existingTypes = accessibleMethods[name]
if (existingTypes != null) {
var found = false
for (existingType in existingTypes) {
if (Arrays.equals(types, existingType)) {
found = true
break
}
}
if (found) {
// the method is overridden, so it should not be called.
continue
}
}
if (existingTypes == null) {
existingTypes = ArrayList()
}
existingTypes.add(types)
// add to the map for checking later
accessibleMethods[name] = existingTypes
// safe to add this method to the list of recognized methods
allMethods.add(method)
}
// add all interfaces from our class (if any)
classes.addAll(listOf(*nextClass.interfaces))
// If we are an interface, one CANNOT call any methods NOT defined by the interface!
// also, interfaces don't have a super-class.
val superclass = nextClass.superclass
if (superclass != null) {
classes.add(superclass)
}
}
accessibleMethods.clear()
Collections.sort(allMethods, METHOD_COMPARATOR)
return allMethods.toTypedArray()
}
fun resolveSerializerInstance(k: Kryo, superClass: Class<*>, serializerClass: Class<out Serializer<*>>): Serializer<*> {
return try {
try {
serializerClass.getConstructor(Kryo::class.java, Class::class.java).newInstance(k, superClass)
} catch (ex1: NoSuchMethodException) {
try {
serializerClass.getConstructor(Kryo::class.java).newInstance(k)
} catch (ex2: NoSuchMethodException) {
try {
serializerClass.getConstructor(Class::class.java).newInstance(superClass)
} catch (ex3: NoSuchMethodException) {
serializerClass.newInstance()
}
}
}
} catch (ex: Exception) {
throw IllegalArgumentException(
"Unable to create serializer \"" + serializerClass.name + "\" for class: " + superClass.name, ex)
}
}
fun getHierarchy(clazz: Class<*>): ArrayList<Class<*>> {
val allClasses = ArrayList<Class<*>>()
val parseClasses = LinkedList<Class<*>>()
parseClasses.add(clazz)
var nextClass: Class<*>
while (!parseClasses.isEmpty()) {
nextClass = parseClasses.removeFirst()
allClasses.add(nextClass)
// add all interfaces from our class (if any)
parseClasses.addAll(Arrays.asList(*nextClass.interfaces))
val superclass = nextClass.superclass
if (superclass != null) {
parseClasses.add(superclass)
}
}
// remove the first class, because we don't need it
allClasses.remove(clazz)
return allClasses
}
private const val RIGHT = 0xFFFF
fun packShorts(left: Int, right: Int): Int {
return left shl 16 or (right and RIGHT)
}
fun unpackLeft(packedInt: Int): Int {
return packedInt ushr 16 // >>> operator 0-fills from left
}
fun unpackRight(packedInt: Int): Int {
return packedInt and RIGHT
}
}