Skip to content

Commit

Permalink
Tests for swapping and substitution
Browse files Browse the repository at this point in the history
  • Loading branch information
Matyrobbrt committed Jan 21, 2024
1 parent d04f00e commit d9fc28f
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 26 deletions.
12 changes: 11 additions & 1 deletion definition/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ java {
withSourcesJar()
}

val testClasses = sourceSets.create("testClasses")

repositories {
mavenCentral()
maven {
Expand All @@ -38,7 +40,7 @@ dependencies {
implementation(group = "com.mojang", name = "logging", version = "1.1.1")
implementation(group = "com.google.guava", "guava", version = "32.1.2-jre")
implementation(group = "org.slf4j", "slf4j-api", "2.0.0")
implementation(group = "net.fabricmc", name = "sponge-mixin", version = "0.12.5+mixin.0.8.5")
"testClassesImplementation"(implementation(group = "net.fabricmc", name = "sponge-mixin", version = "0.12.5+mixin.0.8.5"))
compileOnly(group = "org.jetbrains", name = "annotations", version = "24.0.1")
implementation(group = "io.github.llamalad7", name = "mixinextras-common", version = "0.3.1")

Expand All @@ -51,6 +53,9 @@ dependencies {

testImplementation(platform("org.junit:junit-bom:5.9.1"))
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.assertj:assertj-core:3.25.1")

"testRuntimeOnly"(testClasses.output)
}

tasks {
Expand All @@ -63,6 +68,11 @@ tasks {
test {
useJUnitPlatform()
systemProperty("adapter.definition.paramdiff.debug", true)
outputs.upToDateWhen { false }
}

named("compileTestClassesJava", JavaCompile::class.java) {
options.compilerArgs = listOf("-parameters")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package dev.su5ed.sinytra.adapter.patch.transformer;

import dev.su5ed.sinytra.adapter.patch.util.AdapterUtil;
import dev.su5ed.sinytra.adapter.patch.util.SingleValueHandle;
import it.unimi.dsi.fastutil.ints.Int2IntArrayMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.jetbrains.annotations.NotNull;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.MethodNode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

public class LVTSnapshot {
private final List<LocalVar> locals;
private final int[] vars;

public LVTSnapshot(List<LocalVar> locals, int[] vars) {
this.locals = locals;
this.vars = vars;
}

public void applyDifference(MethodNode newNode) {
final LVTSnapshot newLVT = take(newNode);

// A new local var was removed
// Shift all other vars after it
final List<LocalVar> removed = new ArrayList<>(this.locals);
removed.removeIf(newLVT.locals::contains);
int[] newVars = Arrays.copyOf(this.vars, this.vars.length);
for (LocalVar local : removed) {
for (int i = 0; i < newVars.length; i++) {
if (newVars[i] > local.index) {
newVars[i] -= local.desc.getSize();
}
}
newNode.localVariables.forEach(node -> {
if (node.index > local.index) {
node.index -= local.desc.getSize();
}
});
}

// A new local var was added
// Shift all other vars after it, including the one that was replaced
final List<LocalVar> added = new ArrayList<>(newLVT.locals);
added.removeIf(this.locals::contains);
for (LocalVar local : added) {
for (int i = 0; i < newVars.length; i++) {
if (newVars[i] >= local.index) {
newVars[i] += local.desc.getSize();
}
}
newNode.localVariables.forEach(node -> {
if (node.index >= local.index) {
node.index += local.desc.getSize();
}
});
}

final Int2IntMap old2New = new Int2IntArrayMap();
for (int i = 0; i < this.vars.length; i++) {
old2New.put(this.vars[i], newVars[i]);
}

newNode.instructions.forEach(insn -> {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
final int idx = handle.get();
handle.set(old2New.getOrDefault(idx, idx));
}
});
}

public static LVTSnapshot take(MethodNode node) {
final List<LocalVar> locals = new ArrayList<>();
final IntSet vars = new IntArraySet();
node.localVariables.forEach(local -> locals.add(new LocalVar(local.name, Type.getType(local.desc), local.index)));
node.instructions.forEach(insn -> {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
vars.add((int) handle.get());
}
});
final int[] varsArray = vars.toIntArray();
Arrays.sort(varsArray);
Collections.sort(locals);
return new LVTSnapshot(locals, varsArray);
}

public record LocalVar(String name, Type desc, int index) implements Comparable<LocalVar> {
@Override
public int compareTo(@NotNull LVTSnapshot.LocalVar o) {
return Comparator.<Integer>naturalOrder().compare(this.index, o.index);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
continue;
}

final LVTSnapshot snapshot = LVTSnapshot.take(methodNode);

int lvtOrdinal = offset + index;
int lvtIndex;
if (index > offset) {
Expand All @@ -123,8 +125,8 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
newParameterTypes.add(paramOrdinal, type);
methodNode.parameters.add(paramOrdinal, newParameter);

int varOffset = AdapterUtil.getLVTOffsetForType(type);
offsetLVT(methodNode, lvtIndex, varOffset);
snapshot.applyDifference(methodNode);

offsetParameters(methodNode, paramOrdinal);

offsetSwaps.replaceAll(integerIntegerPair -> integerIntegerPair.mapFirst(j -> j >= paramOrdinal ? j + 1 : j));
Expand Down Expand Up @@ -185,19 +187,23 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
this.context.substitutes.forEach(pair -> {
int paramIndex = pair.getFirst();
int substituteParamIndex = pair.getSecond();
int localIndex = offset + paramIndex;
int substituteIndex = offset + substituteParamIndex;
int localIndex = calculateLVTIndex(newParameterTypes, isNonStatic, paramIndex);
LVTSnapshot lvtSnapshot = LVTSnapshot.take(methodNode);
if (methodNode.parameters.size() > paramIndex) {
LOGGER.info("Substituting parameter {} for {} in {}.{}", paramIndex, substituteParamIndex, classNode.name, methodNode.name);
methodNode.parameters.remove(paramIndex);
newParameterTypes.remove(paramIndex);
int substituteIndex = calculateLVTIndex(newParameterTypes, isNonStatic, substituteParamIndex);
methodNode.localVariables.removeIf(lvn -> lvn.index == localIndex);
for (AbstractInsnNode insn : methodNode.instructions) {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null && handle.get() == localIndex) {
if (handle == null) continue;

if (handle.get() == localIndex) {
handle.set(substituteIndex);
}
}
lvtSnapshot.applyDifference(methodNode);
}
});
for (Pair<Integer, Integer> swapPair : offsetSwaps) {
Expand All @@ -206,20 +212,24 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
ParameterNode fromNode = methodNode.parameters.get(from);
ParameterNode toNode = methodNode.parameters.get(to);

final String fromName = fromNode.name;
fromNode.name = toNode.name;
toNode.name = fromName;
int fromOldLVT = calculateLVTIndex(newParameterTypes, isNonStatic, from);
int toOldLVT = calculateLVTIndex(newParameterTypes, isNonStatic, to);

methodNode.parameters.set(from, toNode);
methodNode.parameters.set(to, fromNode);
Type fromType = newParameterTypes.get(from);
Type toType = newParameterTypes.get(to);
newParameterTypes.set(from, toType);
newParameterTypes.set(to, fromType);
LOGGER.info(MIXINPATCH, "Swapped parameters at positions {} and {}", from, to);
LOGGER.info(MIXINPATCH, "Swapped parameters at positions {}({}) and {}({}) in {}.{}", from, fromNode.name, to, toNode.name, classNode.name, methodNode.name);

swapLVT(methodNode, offset, to, from)
.andThen(swapLVT(methodNode, offset, from, to))
int fromNewLVT = calculateLVTIndex(newParameterTypes, isNonStatic, from);
int toNewLVT = calculateLVTIndex(newParameterTypes, isNonStatic, to);

// Account for "big" LVT variables (like longs and doubles)
// Uses of the old parameter need to be the new parameter and vice versa
swapLVT(methodNode, fromOldLVT, toNewLVT)
.andThen(swapLVT(methodNode, toOldLVT, fromNewLVT))
.accept(null);
}

Expand Down Expand Up @@ -278,19 +288,28 @@ public Result apply(ClassNode classNode, MethodNode methodNode, MethodContext me
return this.context.shouldComputeFrames() ? Result.COMPUTE_FRAMES : Result.APPLY;
}

private Consumer<Void> swapLVT(MethodNode methodNode, int offset, int from, int to) {
private int calculateLVTIndex(List<Type> parameters, boolean nonStatic, int index) {
int lvt = nonStatic ? 1 : 0;
for (int i = 0; i < index; i++) {
lvt += parameters.get(i).getSize();
}
return lvt;
}

private Consumer<Void> swapLVT(MethodNode methodNode, int from, int to) {
Consumer<Void> r = v -> {};
for (LocalVariableNode lvn : methodNode.localVariables) {
if (lvn.index == offset + from) {
r = r.andThen(v -> lvn.index = offset + to);
if (lvn.index == from) {
r = r.andThen(v -> lvn.index = to);
}
}

for (AbstractInsnNode insn : methodNode.instructions) {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
if (handle.get() == offset + from) {
r = r.andThen(v -> handle.set(offset + to));
if (handle.get() == from) {
LOGGER.info("Swapping in LVT: {} to {}", from, to);
r = r.andThen(v -> handle.set(to));
}
}
}
Expand All @@ -299,29 +318,22 @@ private Consumer<Void> swapLVT(MethodNode methodNode, int offset, int from, int
}

private static Pair<@Nullable ParameterNode, @Nullable LocalVariableNode> removeLocalVariable(MethodNode methodNode, int paramIndex, int lvtOffset, int replaceIndex, List<Type> newParameterTypes) {
final LVTSnapshot snapshot = LVTSnapshot.take(methodNode);
ParameterNode parameter = paramIndex < methodNode.parameters.size() ? methodNode.parameters.remove(paramIndex) : null;
methodNode.localVariables.sort(Comparator.comparingInt(lvn -> lvn.index));
LocalVariableNode lvn = methodNode.localVariables.remove(paramIndex + lvtOffset);
if (lvn != null) {
int varOffset = AdapterUtil.getLVTOffsetForType(Type.getType(lvn.desc));
for (LocalVariableNode local : methodNode.localVariables) {
if (local.index > lvn.index) {
local.index -= varOffset;
}
}
for (AbstractInsnNode insn : methodNode.instructions) {
SingleValueHandle<Integer> handle = AdapterUtil.handleLocalVarInsnValue(insn);
if (handle != null) {
if (handle.get() == lvn.index) {
handle.set(replaceIndex);
}
if (handle.get() > lvn.index) {
handle.set(handle.get() - varOffset);
}
}
}
}
newParameterTypes.remove(paramIndex);
snapshot.applyDifference(methodNode);
return Pair.of(parameter, lvn);
}

Expand Down
Loading

0 comments on commit d9fc28f

Please sign in to comment.