Skip to content

Commit

Permalink
Merge pull request #3 from Sinytra/tests
Browse files Browse the repository at this point in the history
Add a testing system
  • Loading branch information
Su5eD authored Jan 21, 2024
2 parents d04f00e + 9780230 commit 20962f8
Show file tree
Hide file tree
Showing 14 changed files with 570 additions and 37 deletions.
16 changes: 15 additions & 1 deletion definition/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ java {
withSourcesJar()
}

val testClasses = sourceSets.create("testClasses")

repositories {
mavenCentral()
maven {
Expand All @@ -38,7 +40,7 @@ dependencies {
implementation(group = "com.mojang", name = "logging", version = "1.1.1")
implementation(group = "com.google.guava", "guava", version = "32.1.2-jre")
implementation(group = "org.slf4j", "slf4j-api", "2.0.0")
implementation(group = "net.fabricmc", name = "sponge-mixin", version = "0.12.5+mixin.0.8.5")
"testClassesImplementation"(implementation(group = "net.fabricmc", name = "sponge-mixin", version = "0.12.5+mixin.0.8.5"))
compileOnly(group = "org.jetbrains", name = "annotations", version = "24.0.1")
implementation(group = "io.github.llamalad7", name = "mixinextras-common", version = "0.3.1")

Expand All @@ -51,6 +53,9 @@ dependencies {

testImplementation(platform("org.junit:junit-bom:5.9.1"))
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core:3.25.1")

"testRuntimeOnly"(testClasses.output)
}

tasks {
Expand All @@ -63,6 +68,15 @@ tasks {
test {
useJUnitPlatform()
systemProperty("adapter.definition.paramdiff.debug", true)
outputs.upToDateWhen { false }
}

named("compileTestClassesJava", JavaCompile::class.java) {
options.compilerArgs = listOf("-parameters")
}

named("testClasses") {
dependsOn("compileTestClassesJava")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package dev.su5ed.sinytra.adapter.patch.transformer;

import dev.su5ed.sinytra.adapter.patch.util.AdapterUtil;
import dev.su5ed.sinytra.adapter.patch.util.SingleValueHandle;
import it.unimi.dsi.fastutil.ints.Int2IntArrayMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.jetbrains.annotations.NotNull;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.MethodNode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

public class LVTSnapshot {
private final List<LocalVar> locals;
private final int[] vars;

public LVTSnapshot(List<LocalVar> locals, int[] vars) {
this.locals = locals;
this.vars = vars;
}

public void applyDifference(MethodNode newNode) {
final LVTSnapshot newLVT = take(newNode);

// A new local var was removed
// Shift all other vars after it
final List<LocalVar> removed = new ArrayList<>(this.locals);
removed.removeIf(newLVT.locals::contains);
int[] newVars = Arrays.copyOf(this.vars, this.vars.length);
for (LocalVar local : removed) {
for (int i = 0; i < newVars.length; i++) {
if (newVars[i] > local.index) {
newVars[i] -= local.desc.getSize();
}
}
newNode.localVariables.forEach(node -> {
if (node.index > local.index) {
node.index -= local.desc.getSize();
}
});
}

// A new local var was added
// Shift all other vars after it, including the one that was replaced
final List<LocalVar> added = new ArrayList<>(newLVT.locals);
added.removeIf(this.locals::contains);
for (LocalVar local : added) {
for (int i = 0; i < newVars.length; i++) {
if (newVars[i] >= local.index) {
newVars[i] += local.desc.getSize();
}
}
newNode.localVariables.forEach(node -> {
if (!node.name.equals(local.name) && node.index >= local.index) {
node.index += local.desc.getSize();
}
});
}

final Int2IntMap old2New = new Int2IntArrayMap();
for (int i = 0; i < this.vars.length; i++) {
old2New.put(this.vars[i], newVars[i]);
}

newNode.instructions.forEach(insn -> {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
final int idx = handle.get();
handle.set(old2New.getOrDefault(idx, idx));
}
});
}

public static LVTSnapshot take(MethodNode node) {
final List<LocalVar> locals = new ArrayList<>();
final IntSet vars = new IntArraySet();
node.localVariables.forEach(local -> locals.add(new LocalVar(local.name, Type.getType(local.desc), local.index)));
node.instructions.forEach(insn -> {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
vars.add((int) handle.get());
}
});
final int[] varsArray = vars.toIntArray();
Arrays.sort(varsArray);
Collections.sort(locals);
return new LVTSnapshot(locals, varsArray);
}

public record LocalVar(String name, Type desc, int index) implements Comparable<LocalVar> {
@Override
public int compareTo(@NotNull LVTSnapshot.LocalVar o) {
return Comparator.<Integer>naturalOrder().compare(this.index, o.index);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,21 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
continue;
}

int lvtOrdinal = offset + index;
int lvtIndex;
if (index > offset) {
List<LocalVariableNode> lvt = methodNode.localVariables.stream().sorted(Comparator.comparingInt(lvn -> lvn.index)).toList();
lvtIndex = lvt.get(lvtOrdinal).index;
} else {
lvtIndex = lvtOrdinal;
}
final LVTSnapshot snapshot = LVTSnapshot.take(methodNode);

int lvtIndex = calculateLVTIndex(newParameterTypes, isNonStatic, (needsOffset ? 1 : 0) + index);
int paramOrdinal = isNonStatic && needsOffset ? index + 1 : index;
ParameterNode newParameter = new ParameterNode(null, Opcodes.ACC_SYNTHETIC);
ParameterNode newParameter = new ParameterNode("adapter_injected_" + paramOrdinal, Opcodes.ACC_SYNTHETIC);
newParameterTypes.add(paramOrdinal, type);
methodNode.parameters.add(paramOrdinal, newParameter);

int varOffset = AdapterUtil.getLVTOffsetForType(type);
offsetLVT(methodNode, lvtIndex, varOffset);
offsetParameters(methodNode, paramOrdinal);

offsetSwaps.replaceAll(integerIntegerPair -> integerIntegerPair.mapFirst(j -> j >= paramOrdinal ? j + 1 : j));
offsetMoves.replaceAll(integerIntegerPair -> integerIntegerPair.mapFirst(j -> j >= paramOrdinal ? j + 1 : j).mapSecond(j -> j >= paramOrdinal ? j + 1 : j));

methodNode.localVariables.add(new LocalVariableNode("adapter_injected_" + paramOrdinal, type.getDescriptor(), null, self.start, self.end, lvtIndex));
methodNode.localVariables.add(paramOrdinal + (isNonStatic ? 1 : 0), new LocalVariableNode(newParameter.name, type.getDescriptor(), null, self.start, self.end, lvtIndex));
snapshot.applyDifference(methodNode);
}
LocalVariableLookup lvtLookup = new LocalVariableLookup(methodNode.localVariables);
BytecodeFixerUpper bfu = context.environment().bytecodeFixerUpper();
Expand Down Expand Up @@ -185,19 +179,23 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
this.context.substitutes.forEach(pair -> {
int paramIndex = pair.getFirst();
int substituteParamIndex = pair.getSecond();
int localIndex = offset + paramIndex;
int substituteIndex = offset + substituteParamIndex;
int localIndex = calculateLVTIndex(newParameterTypes, isNonStatic, paramIndex);
LVTSnapshot lvtSnapshot = LVTSnapshot.take(methodNode);
if (methodNode.parameters.size() > paramIndex) {
LOGGER.info("Substituting parameter {} for {} in {}.{}", paramIndex, substituteParamIndex, classNode.name, methodNode.name);
methodNode.parameters.remove(paramIndex);
newParameterTypes.remove(paramIndex);
int substituteIndex = calculateLVTIndex(newParameterTypes, isNonStatic, substituteParamIndex);
methodNode.localVariables.removeIf(lvn -> lvn.index == localIndex);
for (AbstractInsnNode insn : methodNode.instructions) {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null && handle.get() == localIndex) {
if (handle == null) continue;

if (handle.get() == localIndex) {
handle.set(substituteIndex);
}
}
lvtSnapshot.applyDifference(methodNode);
}
});
for (Pair<Integer, Integer> swapPair : offsetSwaps) {
Expand All @@ -206,20 +204,24 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
ParameterNode fromNode = methodNode.parameters.get(from);
ParameterNode toNode = methodNode.parameters.get(to);

final String fromName = fromNode.name;
fromNode.name = toNode.name;
toNode.name = fromName;
int fromOldLVT = calculateLVTIndex(newParameterTypes, isNonStatic, from);
int toOldLVT = calculateLVTIndex(newParameterTypes, isNonStatic, to);

methodNode.parameters.set(from, toNode);
methodNode.parameters.set(to, fromNode);
Type fromType = newParameterTypes.get(from);
Type toType = newParameterTypes.get(to);
newParameterTypes.set(from, toType);
newParameterTypes.set(to, fromType);
LOGGER.info(MIXINPATCH, "Swapped parameters at positions {} and {}", from, to);
LOGGER.info(MIXINPATCH, "Swapped parameters at positions {}({}) and {}({}) in {}.{}", from, fromNode.name, to, toNode.name, classNode.name, methodNode.name);

int fromNewLVT = calculateLVTIndex(newParameterTypes, isNonStatic, from);
int toNewLVT = calculateLVTIndex(newParameterTypes, isNonStatic, to);

swapLVT(methodNode, offset, to, from)
.andThen(swapLVT(methodNode, offset, from, to))
// Account for "big" LVT variables (like longs and doubles)
// Uses of the old parameter need to be the new parameter and vice versa
swapLVT(methodNode, fromOldLVT, toNewLVT)
.andThen(swapLVT(methodNode, toOldLVT, fromNewLVT))
.accept(null);
}

Expand Down Expand Up @@ -278,19 +280,28 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
return this.context.shouldComputeFrames() ? Result.COMPUTE_FRAMES : Result.APPLY;
}

private Consumer<Void> swapLVT(MethodNode methodNode, int offset, int from, int to) {
private int calculateLVTIndex(List<Type> parameters, boolean nonStatic, int index) {
int lvt = nonStatic ? 1 : 0;
for (int i = 0; i < index; i++) {
lvt += parameters.get(i).getSize();
}
return lvt;
}

private Consumer<Void> swapLVT(MethodNode methodNode, int from, int to) {
Consumer<Void> r = v -> {};
for (LocalVariableNode lvn : methodNode.localVariables) {
if (lvn.index == offset + from) {
r = r.andThen(v -> lvn.index = offset + to);
if (lvn.index == from) {
r = r.andThen(v -> lvn.index = to);
}
}

for (AbstractInsnNode insn : methodNode.instructions) {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
if (handle.get() == offset + from) {
r = r.andThen(v -> handle.set(offset + to));
if (handle.get() == from) {
LOGGER.info("Swapping in LVT: {} to {}", from, to);
r = r.andThen(v -> handle.set(to));
}
}
}
Expand All @@ -299,29 +310,22 @@ private Consumer<Void> swapLVT(MethodNode methodNode, int offset, int from, int
}

private static Pair<@Nullable ParameterNode, @Nullable LocalVariableNode> removeLocalVariable(MethodNode methodNode, int paramIndex, int lvtOffset, int replaceIndex, List<Type> newParameterTypes) {
final LVTSnapshot snapshot = LVTSnapshot.take(methodNode);
ParameterNode parameter = paramIndex < methodNode.parameters.size() ? methodNode.parameters.remove(paramIndex) : null;
methodNode.localVariables.sort(Comparator.comparingInt(lvn -> lvn.index));
LocalVariableNode lvn = methodNode.localVariables.remove(paramIndex + lvtOffset);
if (lvn != null) {
int varOffset = AdapterUtil.getLVTOffsetForType(Type.getType(lvn.desc));
for (LocalVariableNode local : methodNode.localVariables) {
if (local.index > lvn.index) {
local.index -= varOffset;
}
}
for (AbstractInsnNode insn : methodNode.instructions) {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
if (handle.get() == lvn.index) {
handle.set(replaceIndex);
}
if (handle.get() > lvn.index) {
handle.set(handle.get() - varOffset);
}
}
}
}
newParameterTypes.remove(paramIndex);
snapshot.applyDifference(methodNode);
return Pair.of(parameter, lvn);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public void testMethodParameterNameComparison() throws IOException {
System.out.println("Insertions:");
diff.insertions().forEach(param -> System.out.println("AT " + param.getFirst() + " TYPE " + param.getSecond()));
assertEquals(1, diff.insertions().size());
assertEquals(7, diff.insertions().get(0).getFirst());
assertEquals(9, diff.insertions().get(0).getFirst());
assertEquals(Type.FLOAT_TYPE, diff.insertions().get(0).getSecond());

System.out.println("Replacements:");
Expand Down
Loading

0 comments on commit 20962f8

Please sign in to comment.