Skip to content

Commit

Permalink
feat: clearer compile exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
yusshu committed Jan 1, 2024
1 parent b4a8d2d commit d426539
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 112 deletions.
10 changes: 10 additions & 0 deletions src/main/java/team/unnamed/mocha/runtime/JavaTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
*/
package team.unnamed.mocha.runtime;

import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtPrimitiveType;
import javassist.NotFoundException;
import javassist.bytecode.Bytecode;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -94,6 +96,14 @@ private JavaTypes() {
}
}

public static @NotNull CtClass getClassUnchecked(final @NotNull ClassPool cp, final @NotNull Class<?> javaClass) {
try {
return cp.get(javaClass.getName());
} catch (final NotFoundException e) {
throw new IllegalStateException("CtClass not found for Java class: " + javaClass, e);
}
}

public static boolean isWrapper(final @NotNull CtClass type) {
requireNonNull(type, "type");
return WRAPPER_TYPE_NAMES.contains(type.getName());
Expand Down
253 changes: 141 additions & 112 deletions src/main/java/team/unnamed/mocha/runtime/MolangCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
*/
package team.unnamed.mocha.runtime;

import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtConstructor;
import javassist.CtField;
import javassist.CtMethod;
import javassist.NotFoundException;
import javassist.bytecode.BadBytecode;
import javassist.bytecode.Bytecode;
import javassist.bytecode.Descriptor;
import javassist.bytecode.MethodInfo;
Expand All @@ -42,7 +44,9 @@
import team.unnamed.mocha.runtime.compiled.Named;
import team.unnamed.mocha.util.CaseInsensitiveStringHashMap;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
Expand Down Expand Up @@ -139,148 +143,173 @@ public void postCompile(final @Nullable Consumer<byte @NotNull []> postCompile)
}
}

final CtClass interfaceCtClass = JavaTypes.getClassUnchecked(classPool, clazz);
final String scriptClassName = getClass().getPackage().getName() + ".MolangFunctionImpl_" + clazz.getSimpleName() + "_" + implementedMethod.getName()
+ "_" + Long.toHexString(System.currentTimeMillis()) + "_" + Integer.toHexString(RANDOM.nextInt(2024));

try {
CtClass scriptCtClass = classPool.makeClass(scriptClassName);
scriptCtClass.addInterface(classPool.get(clazz.getName()));
scriptCtClass.setModifiers(Modifier.FINAL | Modifier.PUBLIC);

final Class<?> returnType = implementedMethod.getReturnType();
final CtClass returnCtType = classPool.get(returnType.getName());

final Bytecode bytecode = new Bytecode(scriptCtClass.getClassFile().getConstPool());

final FunctionCompileState compileState = new FunctionCompileState(
this,
classPool,
scriptCtClass,
bytecode,
implementedMethod,
scope,
argumentParameterIndexes
);

// compute initial max locals
{
int maxLocals = 1; // 1: this
for (final CtClass paramType : ctParameters) {
if (paramType == CtClass.doubleType || paramType == CtClass.longType) {
maxLocals += 2; // doubles and longs take 2 places
} else {
maxLocals++;
}
final CtClass scriptCtClass = classPool.makeClass(scriptClassName);
scriptCtClass.addInterface(interfaceCtClass);
scriptCtClass.setModifiers(Modifier.PUBLIC);

final Class<?> returnType = implementedMethod.getReturnType();
final CtClass returnCtType = JavaTypes.getClassUnchecked(classPool, returnType);

final Bytecode bytecode = new Bytecode(scriptCtClass.getClassFile().getConstPool());
final FunctionCompileState compileState = new FunctionCompileState(this, classPool, scriptCtClass, bytecode, implementedMethod, scope, argumentParameterIndexes);

// compute initial max locals
{
int maxLocals = 1; // 1: this
for (final CtClass paramType : ctParameters) {
if (paramType == CtClass.doubleType || paramType == CtClass.longType) {
maxLocals += 2; // doubles and longs take 2 places
} else {
maxLocals++;
}
compileState.maxLocals(maxLocals);
}
compileState.maxLocals(maxLocals);
}

if (expressions.isEmpty()) {
// add only a "return 0", "return" or "return null" instruction
bytecode.addConstZero(returnCtType);
bytecode.addReturn(returnCtType);
} else {
final MolangCompilingVisitor visitor = new MolangCompilingVisitor(compileState);
CompileVisitResult lastVisitResult = null;
if (expressions.isEmpty()) {
// add only a "return 0", "return" or "return null" instruction
bytecode.addConstZero(returnCtType);
bytecode.addReturn(returnCtType);
} else {
final MolangCompilingVisitor compiler = new MolangCompilingVisitor(compileState);
CompileVisitResult lastVisitResult = null;

for (final Expression expression : expressions) {
lastVisitResult = expression.visit(new ExpressionInliner(new ExpressionInterpreter<>(null, scope), scope)).visit(visitor);
}
final ExpressionInliner inliner = new ExpressionInliner(new ExpressionInterpreter<>(null, scope), scope);

if (lastVisitResult == null || !lastVisitResult.returned()) {
if (lastVisitResult == null || lastVisitResult.lastPushedType() != returnCtType) {
JavaTypes.addCast(
bytecode,
lastVisitResult == null ? CtClass.doubleType : lastVisitResult.lastPushedType(),
returnCtType
);
}
for (final Expression expression : expressions) {
lastVisitResult = expression.visit(inliner).visit(compiler);
}

visitor.endVisit();
if (lastVisitResult == null || !lastVisitResult.returned()) {
if (lastVisitResult == null || lastVisitResult.lastPushedType() != returnCtType) {
JavaTypes.addCast(
bytecode,
lastVisitResult == null ? CtClass.doubleType : lastVisitResult.lastPushedType(),
returnCtType
);
}

compiler.endVisit();
}
}

bytecode.setMaxLocals(compileState.maxLocals());

final MethodInfo method = new MethodInfo(
scriptCtClass.getClassFile().getConstPool(),
implementedMethod.getName(),
Descriptor.ofMethod(
classPool.get(returnType.getName()),
ctParameters
)
);
method.setAccessFlags(Modifier.PUBLIC | Modifier.FINAL);
method.setCodeAttribute(bytecode.toCodeAttribute());
bytecode.setMaxLocals(compileState.maxLocals());

final MethodInfo method = new MethodInfo(scriptCtClass.getClassFile().getConstPool(), implementedMethod.getName(), Descriptor.ofMethod(returnCtType, ctParameters));
method.setAccessFlags(Modifier.PUBLIC | Modifier.FINAL);
method.setCodeAttribute(bytecode.toCodeAttribute());
final StackMapTable stackMapTable;

try {
method.getCodeAttribute().computeMaxStack();
stackMapTable = MapMaker.make(classPool, method);
} catch (final BadBytecode e) {
throw new IllegalStateException("Generated bad bytecode, open an issue at https://github.com/unnamed/mocha/issues", e);
}

final StackMapTable stackMapTable = MapMaker.make(classPool, method);
if (stackMapTable != null) {
method.getCodeAttribute().setAttribute(stackMapTable);
}
if (stackMapTable != null) {
method.getCodeAttribute().setAttribute(stackMapTable);
}

try {
scriptCtClass.addMethod(CtMethod.make(method, scriptCtClass));
} catch (final CannotCompileException e) {
throw new IllegalStateException("Couldn't compile main function method", e);
}

final Map<String, Object> requirements = compileState.requirements();
final Map<String, Object> requirements = compileState.requirements();

// add fields for the requirements
for (final Map.Entry<String, Object> entry : requirements.entrySet()) {
final String fieldName = entry.getKey();
final Object fieldValue = entry.getValue();
final CtClass fieldType = classPool.get(fieldValue.getClass().getName());
// add fields for the requirements
for (final Map.Entry<String, Object> entry : requirements.entrySet()) {
final String fieldName = entry.getKey();
final Object fieldValue = entry.getValue();
final CtClass fieldType = JavaTypes.getClassUnchecked(classPool, fieldValue.getClass());
try {
scriptCtClass.addField(new CtField(fieldType, fieldName, scriptCtClass));
} catch (final CannotCompileException e) {
throw new IllegalStateException("Couldn't compile field " + fieldName + " with type " + fieldType.getName(), e);
}
}

// add constructor that needs requirements and initializes them
final CtClass[] constructorParameterCtTypes = new CtClass[requirements.size()];
int j = 0;
for (final Map.Entry<String, Object> entry : requirements.entrySet()) {
constructorParameterCtTypes[j] = classPool.get(entry.getValue().getClass().getName());
++j;
}
// add constructor that needs requirements and initializes them
final CtClass[] constructorParameterCtTypes = new CtClass[requirements.size()];
int j = 0;
for (final Map.Entry<String, Object> entry : requirements.entrySet()) {
constructorParameterCtTypes[j] = JavaTypes.getClassUnchecked(classPool, entry.getValue().getClass());
++j;
}

{
final CtConstructor ctConstructor = new CtConstructor(constructorParameterCtTypes, scriptCtClass);
final Bytecode constructorBytecode = new Bytecode(scriptCtClass.getClassFile().getConstPool());
{
final CtConstructor ctConstructor = new CtConstructor(constructorParameterCtTypes, scriptCtClass);
final Bytecode constructorBytecode = new Bytecode(scriptCtClass.getClassFile().getConstPool());
constructorBytecode.addAload(0); // load this
constructorBytecode.addInvokespecial(JavaTypes.getClassUnchecked(classPool, Object.class), "<init>", "()V"); // invoke superclass constructor
// put!
int parameterIndex = 0;
for (final Map.Entry<String, Object> entry : requirements.entrySet()) {
final String fieldName = entry.getKey();
final Object fieldValue = entry.getValue();
constructorBytecode.addAload(0); // load this
constructorBytecode.addInvokespecial(classPool.get(Object.class.getName()), "<init>", "()V"); // invoke superclass constructor
// put!
int parameterIndex = 0;
for (final Map.Entry<String, Object> entry : requirements.entrySet()) {
final String fieldName = entry.getKey();
final Object fieldValue = entry.getValue();
constructorBytecode.addAload(0); // load this
constructorBytecode.addAload(parameterIndex + 1); // load parameter
constructorBytecode.addPutfield(scriptCtClass, fieldName, Descriptor.of(classPool.get(fieldValue.getClass().getName()))); // set!
parameterIndex++;
}
constructorBytecode.addReturn(null); // return
ctConstructor.getMethodInfo().setCodeAttribute(constructorBytecode.toCodeAttribute());
constructorBytecode.addAload(parameterIndex + 1); // load parameter
constructorBytecode.addPutfield(scriptCtClass, fieldName, Descriptor.of(JavaTypes.getClassUnchecked(classPool, fieldValue.getClass()))); // set!
parameterIndex++;
}
constructorBytecode.addReturn(null); // return
ctConstructor.getMethodInfo().setCodeAttribute(constructorBytecode.toCodeAttribute());
try {
ctConstructor.getMethodInfo().getCodeAttribute().computeMaxStack();
ctConstructor.getMethodInfo().getCodeAttribute().setMaxLocals(constructorParameterCtTypes.length + 1);
} catch (final BadBytecode e) {
throw new IllegalStateException("Generated bad bytecode, open an issue at https://github.com/unnamed/mocha/issues", e);
}

ctConstructor.getMethodInfo().getCodeAttribute().setMaxLocals(constructorParameterCtTypes.length + 1);
try {
scriptCtClass.addConstructor(ctConstructor);
} catch (final CannotCompileException e) {
throw new IllegalStateException("Couldn't compile script constructor", e);
}
}

if (postCompile != null) {
if (postCompile != null) {
try {
postCompile.accept(scriptCtClass.toBytecode());
} catch (IOException | CannotCompileException e) {
throw new IllegalStateException("Couldn't collect script bytecode", e);
}
final Class<?> compiledClass = classPool.toClass(scriptCtClass, getClass(), classLoader, null);

// find the constructor with the requirements
final Class<?>[] constructorParameterTypes = new Class[requirements.size()];
final Object[] constructorArguments = new Object[requirements.size()];
int i = 0;
for (final Object requirement : requirements.values()) {
constructorParameterTypes[i] = requirement.getClass();
constructorArguments[i] = requirement;
++i;
}
}
final Class<?> compiledClass;
try {
compiledClass = classPool.toClass(scriptCtClass, getClass(), classLoader, null);
} catch (final CannotCompileException e) {
throw new IllegalStateException("Couldn't compile script class", e);
}

final Constructor<?> constructor = compiledClass.getDeclaredConstructor(constructorParameterTypes);
final Object instance = constructor.newInstance(constructorArguments);
return clazz.cast(instance);
} catch (Exception e) {
throw new RuntimeException(e);
// find the constructor with the requirements
final Class<?>[] constructorParameterTypes = new Class[requirements.size()];
final Object[] constructorArguments = new Object[requirements.size()];
int i = 0;
for (final Object requirement : requirements.values()) {
constructorParameterTypes[i] = requirement.getClass();
constructorArguments[i] = requirement;
++i;
}

final Constructor<?> constructor;
try {
constructor = compiledClass.getDeclaredConstructor(constructorParameterTypes);
} catch (final NoSuchMethodException e) {
throw new IllegalStateException("Couldn't find constructor with parameters " + requirements.keySet(), e);
}
final Object instance;
try {
instance = constructor.newInstance(constructorArguments);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
throw new IllegalStateException("Couldn't instantiate script class", e);
}
return clazz.cast(instance);
}
}

0 comments on commit d426539

Please sign in to comment.