Skip to content

Commit

Permalink
Fix variable used in switch head not being inlined
Browse files Browse the repository at this point in the history
  • Loading branch information
coehlrich committed May 20, 2024
1 parent 84cf6f0 commit 9c0f911
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.jetbrains.java.decompiler.main.collectors.CounterContainer;
import org.jetbrains.java.decompiler.main.extern.IFernflowerPreferences;
import org.jetbrains.java.decompiler.modules.decompiler.exps.*;
import org.jetbrains.java.decompiler.modules.decompiler.exps.FunctionExprent.FunctionType;
import org.jetbrains.java.decompiler.modules.decompiler.stats.*;
import org.jetbrains.java.decompiler.struct.consts.PooledConstant;
import org.jetbrains.java.decompiler.struct.consts.PrimitiveConstant;
Expand All @@ -15,6 +16,7 @@
import org.jetbrains.java.decompiler.util.Pair;

import java.util.*;
import java.util.stream.Stream;

public final class SwitchPatternMatchProcessor {
public static boolean processPatternMatching(Statement root) {
Expand Down Expand Up @@ -66,6 +68,7 @@ private static boolean processStatement(SwitchStatement stat, Statement root) {
Exprent realSelector = origParams.get(0);
boolean guarded = true;
boolean isEnumSwitch = value.getName().equals("enumSwitch");
boolean nullCase = false;
List<Pair<Statement, Exprent>> references = new ArrayList<>();
if (origParams.get(1) instanceof VarExprent) {
VarExprent var = (VarExprent) origParams.get(1);
Expand Down Expand Up @@ -134,6 +137,7 @@ private static boolean processStatement(SwitchStatement stat, Statement root) {

// -1 always means null
if (caseValue == -1) {
nullCase = true;
allCases.remove(caseExpr);
ConstExprent nullConst = new ConstExprent(VarType.VARTYPE_NULL, null, null);
// null can be shared with a pattern or default; put it at the end, but before default, to make sure it doesn't get
Expand Down Expand Up @@ -262,6 +266,41 @@ private static boolean processStatement(SwitchStatement stat, Statement root) {
}
}

Exprent oldSelector = realSelector;
// inline head
List<Exprent> basicHead = stat.getBasichead().getExprents();
if (realSelector instanceof VarExprent var && basicHead != null && basicHead.size() >= 1) {
if (basicHead.get(basicHead.size() - 1) instanceof AssignmentExprent assignment && assignment.getLeft() instanceof VarExprent assigned) {
if (var.equals(assigned) && !var.isVarReferenced(root, Stream.concat(Stream.of(assigned), stat.getCaseValues().stream().flatMap(List::stream).filter(exp -> exp instanceof FunctionExprent func && func.getFuncType() == FunctionType.INSTANCEOF && func.getLstOperands().get(0) instanceof VarExprent checked && checked.equals(var)).map(exp -> (VarExprent) ((FunctionExprent) exp).getLstOperands().get(0))).toArray(VarExprent[]::new))) {
realSelector = assignment.getRight();
basicHead.remove(basicHead.size() - 1);
}
}
}

// Check for non null
if (basicHead != null && basicHead.size() >= 1 && realSelector instanceof VarExprent var && !nullCase) {
Exprent last = basicHead.get(basicHead.size() - 1);
if (last instanceof InvocationExprent inv && inv.isStatic() && inv.getClassname().equals("java/util/Objects") && inv.getName().equals("requireNonNull") && inv.getStringDescriptor().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && var.equals(inv.getLstParameters().get(0))) {
basicHead.remove(basicHead.size() - 1);
// Check for other assignment
if (basicHead.size() >= 1 && var.isStack() && !nullCase) {
last = basicHead.get(basicHead.size() - 1);
if (last instanceof AssignmentExprent assignment && assignment.getLeft() instanceof VarExprent assigned && var.equals(assigned)) {
if (!var.isVarReferenced(root, assigned)) {
realSelector = assignment.getRight();
basicHead.remove(basicHead.size() - 1);
}
}
}
}
}

if (oldSelector != realSelector) {
Exprent finalSelector = realSelector;
stat.getCaseValues().stream().flatMap(List::stream).filter(Objects::nonNull).filter(exp -> exp instanceof FunctionExprent func && func.getFuncType() == FunctionType.INSTANCEOF && func.getLstOperands().get(0).equals(oldSelector)).forEach(exp -> ((FunctionExprent) exp).getLstOperands().set(0, finalSelector));
}

head.setValue(realSelector); // SwitchBootstraps.typeSwitch(o, var1) -> o

if (guarded && stat.getParent() instanceof DoStatement) {
Expand Down
18 changes: 18 additions & 0 deletions testData/src/java21/pkg/TestSwitchPatternMatchingJ21.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
package pkg;

import java.util.function.Supplier;

public class TestSwitchPatternMatchingJ21 {
public void test1(Object o) {
System.out.println(switch (o) {
case Integer i -> Integer.toString(i);
case null, default -> "null";
});
}

public String test2(Object o) {
return switch (o) {
case Integer i -> Integer.toString(i);
case String s -> s;
default -> "null";
};
}

public String test3(Supplier<Object> o) {
return switch (o.get()) {
case Integer i -> Integer.toString(i);
case String s -> s;
default -> "null";
};
}
}

0 comments on commit 9c0f911

Please sign in to comment.