Skip to content

Commit

Permalink
Add toDirectMethodHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum committed Dec 12, 2020
1 parent a55a379 commit 87a7cc5
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 17 deletions.
80 changes: 75 additions & 5 deletions src/main/java/io/airlift/bytecode/FastMethodHandleProxies.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic;
import static io.airlift.bytecode.expression.BytecodeExpressions.newArray;
import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance;
import static java.lang.invoke.MethodHandles.lookup;
import static java.lang.invoke.MethodType.methodType;
import static java.util.Arrays.stream;

public final class FastMethodHandleProxies
{
private static final String BASE_PACKAGE = FastMethodHandleProxies.class.getPackage().getName() + ".proxy";

private static final String DIRECT_METHOD_NAME = "invokeDirect";

private static final Method LOOKUP_FIND_VIRTUAL;
private static final Method LOOKUP_LOOKUP_CLASS;
private static final Method CLASS_GET_CLASSLOADER;
Expand Down Expand Up @@ -153,11 +156,7 @@ private static void defineProxyMethod(ClassDefinition classDefinition, Method ta
parameters);

BytecodeNode invocation = invokeDynamic(
new BootstrapMethod(
classDefinition.getType(),
"$bootstrap",
type(CallSite.class),
ImmutableList.of(type(Lookup.class), type(String.class), type(MethodType.class))),
getBootstrapMethod(classDefinition),
ImmutableList.of(),
target.getName(),
type(target.getReturnType()),
Expand Down Expand Up @@ -188,6 +187,68 @@ private static void defineProxyMethod(ClassDefinition classDefinition, Method ta
method.getBody().append(invocation);
}

public static MethodHandle toDirectMethodHandle(MethodHandle target, ClassLoader parentClassLoader)
{
String className = uniqueClassName(BASE_PACKAGE, "MethodHandle").getClassName();
return toDirectMethodHandle(className, target, parentClassLoader);
}

public static MethodHandle toDirectMethodHandle(String className, MethodHandle target, ClassLoader parentClassLoader)
{
try {
lookup().revealDirect(target);
return target;
}
catch (RuntimeException ignored) {
}

ClassDefinition classDefinition = new ClassDefinition(
a(PUBLIC, FINAL, SYNTHETIC),
typeFromJavaClassName(className),
type(Object.class));

classDefinition.declareDefaultConstructor(a(PUBLIC));

defineDirectMethod(classDefinition, target.type());

defineBootstrapMethod(classDefinition);

DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(parentClassLoader, ImmutableMap.of(0L, target));
Class<?> newClass = classGenerator(dynamicClassLoader).defineClass(classDefinition, Object.class);
try {
return lookup().findStatic(newClass, DIRECT_METHOD_NAME, target.type());
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}

private static void defineDirectMethod(ClassDefinition classDefinition, MethodType methodType)
{
List<Parameter> parameters = new ArrayList<>();
for (int i = 0; i < methodType.parameterCount(); i++) {
parameters.add(arg("arg" + i, methodType.parameterType(i)));
}

MethodDefinition method = classDefinition.declareMethod(
a(PUBLIC, STATIC),
DIRECT_METHOD_NAME,
type(methodType.returnType()),
parameters);

BytecodeExpression invocation = invokeDynamic(
getBootstrapMethod(classDefinition),
ImmutableList.of(),
DIRECT_METHOD_NAME,
type(methodType.returnType()),
parameters.stream()
.map(BytecodeExpression::getType)
.collect(toImmutableList()),
parameters);

method.getBody().append(invocation.ret());
}

private static void defineBootstrapMethod(ClassDefinition classDefinition)
{
Parameter callerLookup = arg("callerLookup", Lookup.class);
Expand Down Expand Up @@ -224,6 +285,15 @@ private static void defineBootstrapMethod(ClassDefinition classDefinition)
method.getBody().append(callSite.ret());
}

private static BootstrapMethod getBootstrapMethod(ClassDefinition classDefinition)
{
return new BootstrapMethod(
classDefinition.getType(),
"$bootstrap",
type(CallSite.class),
ImmutableList.of(type(Lookup.class), type(String.class), type(MethodType.class)));
}

public static <T> Method getSingleAbstractMethod(Class<T> type)
{
return stream(type.getMethods())
Expand Down
65 changes: 53 additions & 12 deletions src/test/java/io/airlift/bytecode/TestFastMethodHandleProxies.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandleProxies;
import java.lang.invoke.MethodType;
import java.lang.invoke.MutableCallSite;
import java.lang.reflect.Method;
import java.lang.reflect.UndeclaredThrowableException;
import java.util.function.Consumer;
import java.util.function.BiConsumer;
import java.util.function.IntSupplier;
import java.util.function.LongFunction;
import java.util.function.LongUnaryOperator;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.airlift.bytecode.FastMethodHandleProxies.getSingleAbstractMethod;
import static io.airlift.bytecode.FastMethodHandleProxies.toDirectMethodHandle;
import static java.lang.invoke.MethodHandles.lookup;
import static java.lang.invoke.MethodHandles.privateLookupIn;
import static java.lang.invoke.MethodType.methodType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
Expand All @@ -40,7 +48,7 @@ public void testBasic()
assertInterface(
LongUnaryOperator.class,
lookup().findStatic(getClass(), "increment", methodType(long.class, long.class)),
addOne -> assertEquals(addOne.applyAsLong(1), 2L));
(addOne, wrapped) -> assertEquals(addOne.applyAsLong(1), 2L));
}

private static long increment(long x)
Expand All @@ -55,7 +63,7 @@ public void testGeneric()
assertInterface(
LongFunction.class,
lookup().findStatic(getClass(), "incrementAndPrint", methodType(String.class, long.class)),
print -> assertEquals(print.apply(1), "2"));
(print, wrapped) -> assertEquals(print.apply(1), "2"));
}

private static String incrementAndPrint(long x)
Expand All @@ -70,7 +78,7 @@ public void testObjectAndDefaultMethods()
assertInterface(
StringLength.class,
lookup().findStatic(getClass(), "stringLength", methodType(int.class, String.class)),
length -> {
(length, wrapped) -> {
assertEquals(length.length("abc"), 3);
assertEquals(length.theAnswer(), 42);
});
Expand Down Expand Up @@ -101,7 +109,7 @@ public void testUncheckedException()
assertInterface(
Runnable.class,
lookup().findStatic(getClass(), "throwUncheckedException", methodType(void.class)),
runnable -> assertThatThrownBy(runnable::run)
(runnable, wrapped) -> assertThatThrownBy(runnable::run)
.isInstanceOf(VerifyException.class));
}

Expand All @@ -117,9 +125,17 @@ public void testCheckedException()
assertInterface(
Runnable.class,
lookup().findStatic(getClass(), "throwCheckedException", methodType(void.class)),
runnable -> assertThatThrownBy(runnable::run)
.isInstanceOf(UndeclaredThrowableException.class)
.hasCauseInstanceOf(IOException.class));
(runnable, wrapped) -> {
if (wrapped) {
assertThatThrownBy(runnable::run)
.isInstanceOf(UndeclaredThrowableException.class)
.hasCauseInstanceOf(IOException.class);
}
else {
assertThatThrownBy(runnable::run)
.isInstanceOf(IOException.class);
}
});
}

private static void throwCheckedException()
Expand All @@ -139,7 +155,7 @@ public void testMutableCallSite()
assertInterface(
IntSupplier.class,
callSite.dynamicInvoker(),
supplier -> {
(supplier, wrapped) -> {
callSite.setTarget(one);
assertEquals(supplier.getAsInt(), 1);
callSite.setTarget(two);
Expand All @@ -157,9 +173,34 @@ private static int two()
return 2;
}

private static <T> void assertInterface(Class<T> interfaceType, MethodHandle target, Consumer<T> consumer)
private static <T> void assertInterface(Class<T> interfaceType, MethodHandle target, BiConsumer<T, Boolean> consumer)
{
consumer.accept(MethodHandleProxies.asInterfaceInstance(interfaceType, target), true);
consumer.accept(FastMethodHandleProxies.asInterfaceInstance(interfaceType, target), true);
consumer.accept(toInterfaceInstance(interfaceType, target), false);
}

@SuppressWarnings("unchecked")
private static <T> T toInterfaceInstance(Class<T> type, MethodHandle target)
{
consumer.accept(MethodHandleProxies.asInterfaceInstance(interfaceType, target));
consumer.accept(FastMethodHandleProxies.asInterfaceInstance(interfaceType, target));
Method method = getSingleAbstractMethod(type);
target = toDirectMethodHandle(target, type.getClassLoader());
try {
MethodHandle handle = lookup().unreflect(method);
MethodType methodType = handle.type().dropParameterTypes(0, 1);
Class<?> targetClass = lookup().revealDirect(target).getDeclaringClass();
CallSite callSite = LambdaMetafactory.metafactory(
privateLookupIn(targetClass, lookup()),
method.getName(),
methodType(type),
methodType,
target,
methodType);
return (T) callSite.getTarget().invoke();
}
catch (Throwable t) {
throwIfUnchecked(t);
throw new IllegalArgumentException(t);
}
}
}

0 comments on commit 87a7cc5

Please sign in to comment.