From 4ecf1dd9acde50295788dbe2f4cab4ece9f7cdd1 Mon Sep 17 00:00:00 2001 From: coehlrich Date: Mon, 24 Jun 2024 04:15:59 +1200 Subject: [PATCH] Add support for intersection types in casts (#405) * Fix intersection casts for direct methods * Add support for variable assignments * Use var for variables that represent intersection types and update tests * Add comments and remove println * Add support for when an intersection type is casted back to it's original type --- .../decompiler/main/rels/MethodProcessor.java | 5 + .../decompiler/IntersectionCastProcessor.java | 258 ++++++++++++++++++ .../decompiler/exps/FunctionExprent.java | 21 +- .../decompiler/exps/InvocationExprent.java | 4 + .../modules/decompiler/exps/VarExprent.java | 11 +- .../java/decompiler/SingleClassesTest.java | 1 + .../results/pkg/TestCastIntersectionJ21.dec | 128 +++++++++ .../results/pkg/TestKotlinConstructorKt.dec | 5 +- .../java21/pkg/TestCastIntersectionJ21.java | 36 +++ 9 files changed, 455 insertions(+), 14 deletions(-) create mode 100644 src/org/jetbrains/java/decompiler/modules/decompiler/IntersectionCastProcessor.java create mode 100644 testData/results/pkg/TestCastIntersectionJ21.dec create mode 100644 testData/src/java21/pkg/TestCastIntersectionJ21.java diff --git a/src/org/jetbrains/java/decompiler/main/rels/MethodProcessor.java b/src/org/jetbrains/java/decompiler/main/rels/MethodProcessor.java index c7ef4f4033..4affc00bab 100644 --- a/src/org/jetbrains/java/decompiler/main/rels/MethodProcessor.java +++ b/src/org/jetbrains/java/decompiler/main/rels/MethodProcessor.java @@ -304,6 +304,11 @@ public static RootStatement codeToJava(StructClass cl, StructMethod mt, MethodDe continue; } + if (IntersectionCastProcessor.makeIntersectionCasts(root)) { + decompileRecord.add("intersectionCasts", root); + continue; + } + if (DecompilerContext.getOption(IFernflowerPreferences.PATTERN_MATCHING)) { if (cl.getVersion().hasIfPatternMatching()) { if (IfPatternMatchProcessor.matchInstanceof(root)) { diff --git a/src/org/jetbrains/java/decompiler/modules/decompiler/IntersectionCastProcessor.java b/src/org/jetbrains/java/decompiler/modules/decompiler/IntersectionCastProcessor.java new file mode 100644 index 0000000000..7ca02a01a3 --- /dev/null +++ b/src/org/jetbrains/java/decompiler/modules/decompiler/IntersectionCastProcessor.java @@ -0,0 +1,258 @@ +package org.jetbrains.java.decompiler.modules.decompiler; + +import org.jetbrains.java.decompiler.code.CodeConstants; +import org.jetbrains.java.decompiler.main.DecompilerContext; +import org.jetbrains.java.decompiler.modules.decompiler.exps.*; +import org.jetbrains.java.decompiler.modules.decompiler.exps.FunctionExprent.FunctionType; +import org.jetbrains.java.decompiler.modules.decompiler.stats.RootStatement; +import org.jetbrains.java.decompiler.modules.decompiler.stats.Statement; +import org.jetbrains.java.decompiler.struct.StructClass; +import org.jetbrains.java.decompiler.struct.StructMethod; +import org.jetbrains.java.decompiler.struct.gen.CodeType; +import org.jetbrains.java.decompiler.struct.gen.TypeFamily; +import org.jetbrains.java.decompiler.struct.gen.VarType; +import org.jetbrains.java.decompiler.struct.gen.generics.GenericMethodDescriptor; +import org.jetbrains.java.decompiler.util.Pair; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +public class IntersectionCastProcessor { + + public static boolean makeIntersectionCasts(RootStatement root) { + return makeIntersectionCastsRec(root, root); + } + + private static boolean makeIntersectionCastsRec(Statement stat, RootStatement root) { + boolean result = false; + if (stat.getExprents() != null) { + for (Exprent e : stat.getExprents()) { + result |= makeIntersectionCasts(e, root); + } + } else { + for (Object o : stat.getSequentialObjects()) { + if (o instanceof Statement s) { + result |= makeIntersectionCastsRec(s, root); + } else if (o instanceof Exprent e) { + result |= makeIntersectionCasts(e, root); + } + } + } + return result; + } + + private static boolean makeIntersectionCasts(Exprent exp, RootStatement root) { + if (exp instanceof InvocationExprent inv) { + if (handleInvocation(inv, root)) { + return true; + } + } else if (exp instanceof AssignmentExprent assignment) { + if (handleAssignment(assignment, root)) { + return true; + } + } + boolean result = false; + for (Exprent sub : exp.getAllExprents()) { + result |= makeIntersectionCasts(sub, root); + } + return result; + } + + private static boolean handleInvocation(InvocationExprent exp, RootStatement root) { + List lstParameters = exp.getLstParameters(); + boolean result = false; + for (int i = 0; i < lstParameters.size(); i++) { + Exprent parameter = lstParameters.get(i); + if (parameter instanceof FunctionExprent cast && isValidCast(cast)) { + Pair, Exprent> casts = getCasts(cast); + List types = casts.a; + Exprent inner = casts.b; + // Checks for any bounds not supported by the current list of casts + List bounds = getBounds(exp, i).stream() + .filter(bound -> !types + .stream() + .anyMatch(constant -> DecompilerContext.getStructContext().instanceOf(constant.getExprType().value, bound.value))) + .toList(); + + // Checks if the original type supports the remaining bounds + if (!bounds.isEmpty() && bounds.stream().allMatch(bound -> DecompilerContext.getStructContext().instanceOf(inner.getExprType().value, bound.value))) { + types.add(new ConstExprent(inner.getExprType(), null, null)); + } + result |= replaceCasts(cast, types, inner); + } + } + return result; + } + + private static boolean handleAssignment(AssignmentExprent exp, RootStatement root) { + if (exp.getLeft() instanceof VarExprent varExp) { + Exprent assigned = exp.getRight(); + if (assigned instanceof FunctionExprent cast && isValidCast(cast)) { + Pair, Exprent> casts = getCasts(cast); + List types = casts.a; + Exprent inner = casts.b; + List references = findReferences(varExp, root); + + // Convert the variable references into a set of bounds + Set bounds = new HashSet<>(); + for (VariablePosition position : references) { + bounds.addAll(switch (position.position) { + case METHOD_PARAMETER -> getBounds((InvocationExprent) position.exp, position.index); + case CASTED -> { + FunctionExprent func = (FunctionExprent) position.exp; + if (func.getLstOperands().size() == 2) { + yield List.of(func.getLstOperands().get(1).getExprType()); + } else { + yield List.of(); + } + } + }); + } + + // Checks for any bounds not supported by the current list of casts + bounds = bounds.stream() + .filter(bound -> !types + .stream() + .anyMatch(constant -> DecompilerContext.getStructContext().instanceOf(constant.getExprType().value, bound.value))) + .collect(Collectors.toSet()); + + // Checks if the original type supports the remaining bounds + if (!bounds.isEmpty() && bounds.stream().anyMatch(bound -> DecompilerContext.getStructContext().instanceOf(inner.getExprType().value, bound.value))) { + types.add(new ConstExprent(inner.getExprType(), null, null)); + } + if (replaceCasts(cast, types, inner)) { + // If the casts were replaced make sure that the variable uses "var" instead of + // a type + varExp.setIntersectionType(true); + return true; + } + } + } + return false; + } + + private static List getBounds(InvocationExprent exp, int parameter) { + // Gets the bounds of a type parameter of a parameter of a method + StructMethod method = exp.getDesc(); + GenericMethodDescriptor gmd = method != null ? method.getSignature() : null; + int start = gmd != null && DecompilerContext.getStructContext().getClass(method.getClassQualifiedName()).hasModifier(CodeConstants.ACC_ENUM) && method.getName().equals(CodeConstants.INIT_NAME) ? 2 : 0; + if (gmd != null) { + int index = parameter - start; + VarType type = gmd.parameterTypes.get(index); + if (type.type == CodeType.GENVAR) { + int typeParameterIndex = gmd.typeParameters.indexOf(type.value); + if (typeParameterIndex != -1) { + return gmd.typeParameterBounds.get(typeParameterIndex); + } + } + } + return List.of(); + } + + /** + * Searches for where a variable is referenced and returns the context + */ + private static List findReferences(VarExprent varExp, RootStatement root) { + List list = new ArrayList<>(); + findReferencesRec(varExp, root, root, list); + return list; + } + + private static void findReferencesRec(VarExprent varExp, Statement stat, RootStatement root, List list) { + if (stat.getExprents() != null) { + for (Exprent e : stat.getExprents()) { + findReferences(varExp, e, root, list); + } + } else { + for (Object o : stat.getSequentialObjects()) { + if (o instanceof Statement s) { + findReferencesRec(varExp, s, root, list); + } else if (o instanceof Exprent e) { + findReferences(varExp, e, root, list); + } + } + } + } + + private static void findReferences(VarExprent varExp, Exprent exp, RootStatement root, List list) { + if (exp instanceof InvocationExprent inv) { + findReferences(varExp, inv, list); + } else if (exp instanceof FunctionExprent func && func.getFuncType() == FunctionType.CAST) { + if (func.getLstOperands().get(0) instanceof VarExprent otherVar && varExp.getVarVersionPair().equals(otherVar.getVarVersionPair())) { + list.add(new VariablePosition(VariablePositionEnum.CASTED, exp, -1)); + } + } + for (Exprent sub : exp.getAllExprents()) { + findReferences(varExp, sub, root, list); + } + } + + private static void findReferences(VarExprent varExp, InvocationExprent inv, List list) { + List lstParameters = inv.getLstParameters(); + for (int i = 0; i < lstParameters.size(); i++) { + Exprent parameter = lstParameters.get(i); + if (parameter instanceof VarExprent otherVar && varExp.getVarVersionPair().equals(otherVar.getVarVersionPair())) { + list.add(new VariablePosition(VariablePositionEnum.METHOD_PARAMETER, inv, i)); + } + } + } + + private static Pair, Exprent> getCasts(Exprent exp) { + // Gets the list of casts done and gets the original exprent + List types = new ArrayList<>(); + Exprent inner = exp; + while (inner instanceof FunctionExprent cast && isValidCast(cast)) { + types.add(cast.getLstOperands().get(1)); + inner = cast.getLstOperands().get(0); + } + return Pair.of(types, inner); + } + + private static boolean isValidCast(FunctionExprent cast) { + if (cast.getFuncType() == FunctionType.CAST && cast.getLstOperands().size() == 2) { + VarType type = cast.getLstOperands().get(1).getExprType(); + // Intersection casts cannot include arrays + return type.typeFamily == TypeFamily.OBJECT && type.arrayDim == 0; + } + return false; + } + + private static boolean replaceCasts(FunctionExprent cast, List types, Exprent inner) { + if (types.size() > 1) { + // Reorders the list of types to make sure that the class is always first + Exprent nonInterface = null; + for (Exprent type : types) { + StructClass clazz = DecompilerContext.getStructContext().getClass(type.getExprType().value); + if (clazz != null && !clazz.hasModifier(CodeConstants.ACC_INTERFACE)) { + if (nonInterface == null) { + nonInterface = type; + } else { + return false; + } + } + } + if (nonInterface != null) { + types.remove(types.indexOf(nonInterface)); + types.add(0, nonInterface); + } + // Replaces the operands of the cast with the casted exprent and the list of needed casts + cast.getLstOperands().clear(); + cast.getLstOperands().add(inner); + cast.getLstOperands().addAll(types); + return true; + } + return false; + } + + private static record VariablePosition(VariablePositionEnum position, Exprent exp, int index) { + + } + + private static enum VariablePositionEnum { + METHOD_PARAMETER, + CASTED; + } +} diff --git a/src/org/jetbrains/java/decompiler/modules/decompiler/exps/FunctionExprent.java b/src/org/jetbrains/java/decompiler/modules/decompiler/exps/FunctionExprent.java index 788a98da35..82eba81d78 100644 --- a/src/org/jetbrains/java/decompiler/modules/decompiler/exps/FunctionExprent.java +++ b/src/org/jetbrains/java/decompiler/modules/decompiler/exps/FunctionExprent.java @@ -240,7 +240,7 @@ public VarType getInferredExprType(VarType upperBound) { if (funcType == FunctionType.CAST) { this.needsCast = true; VarType right = lstOperands.get(0).getInferredExprType(upperBound); - VarType cast = lstOperands.get(1).getExprType(); + List cast = lstOperands.subList(1, lstOperands.size()).stream().map(Exprent::getExprType).toList(); if (upperBound != null && (upperBound.isGeneric() || right.isGeneric())) { Map> names = this.getNamedGenerics(); @@ -258,12 +258,8 @@ public VarType getInferredExprType(VarType upperBound) { } if (types != null) { - boolean anyMatch = false; //TODO: allMatch instead of anyMatch? - for (VarType type : types) { - anyMatch |= DecompilerContext.getStructContext().instanceOf(type.value, cast.value); - } - - if (anyMatch) { + List finalTypes = types; + if (cast.stream().allMatch(castType -> finalTypes.stream().anyMatch(type -> DecompilerContext.getStructContext().instanceOf(type.value, castType.value)))) { this.needsCast = false; } } else { @@ -278,7 +274,8 @@ public VarType getInferredExprType(VarType upperBound) { return right; } } else { //TODO: Capture generics to make cast better? - this.needsCast = right.type == CodeType.NULL || !DecompilerContext.getStructContext().instanceOf(right.value, cast.value) || right.arrayDim != cast.arrayDim; + final VarType finalRight = right; + this.needsCast = right.type == CodeType.NULL || cast.stream().anyMatch(castType -> !DecompilerContext.getStructContext().instanceOf(finalRight.value, castType.value)) || cast.stream().anyMatch(castType -> finalRight.arrayDim != castType.arrayDim); } return getExprType(); @@ -606,7 +603,13 @@ else if (left instanceof ConstExprent) { if (!needsCast) { return buf.append(lstOperands.get(0).toJava(indent)); } - return buf.append(lstOperands.get(1).toJava(indent)).encloseWithParens().append(wrapOperandString(lstOperands.get(0), true, indent)); + for (int i = 1; i < lstOperands.size(); i++) { + if (i > 1) { + buf.append(" & "); + } + buf.append(lstOperands.get(i).toJava(indent)); + } + return buf.encloseWithParens().append(wrapOperandString(lstOperands.get(0), true, indent)); case ARRAY_LENGTH: Exprent arr = lstOperands.get(0); diff --git a/src/org/jetbrains/java/decompiler/modules/decompiler/exps/InvocationExprent.java b/src/org/jetbrains/java/decompiler/modules/decompiler/exps/InvocationExprent.java index 541d6a4744..87fd2fe2d9 100644 --- a/src/org/jetbrains/java/decompiler/modules/decompiler/exps/InvocationExprent.java +++ b/src/org/jetbrains/java/decompiler/modules/decompiler/exps/InvocationExprent.java @@ -1809,6 +1809,10 @@ public Map getGenericsMap() { } public StructMethod getDesc() { + if (desc == null) { + StructClass cl = DecompilerContext.getStructContext().getClass(classname); + desc = cl != null ? cl.getMethodRecursive(name, stringDescriptor) : null; + } return desc; } diff --git a/src/org/jetbrains/java/decompiler/modules/decompiler/exps/VarExprent.java b/src/org/jetbrains/java/decompiler/modules/decompiler/exps/VarExprent.java index 8d71deaa04..1e7086516c 100644 --- a/src/org/jetbrains/java/decompiler/modules/decompiler/exps/VarExprent.java +++ b/src/org/jetbrains/java/decompiler/modules/decompiler/exps/VarExprent.java @@ -57,6 +57,7 @@ public class VarExprent extends Exprent { private Instruction backing = null; private boolean isEffectivelyFinal = false; private VarType boundType; + private boolean isIntersectionType = false; public VarExprent(int index, VarType varType, VarProcessor processor) { this(index, varType, processor, null); @@ -131,7 +132,7 @@ public TextBuffer toJava(int indent) { } VarType definitionType = getDefinitionVarType(); String name = ExprProcessor.getCastTypeName(definitionType); - if (name.equals(ExprProcessor.UNREPRESENTABLE_TYPE_STRING)) { + if (name.equals(ExprProcessor.UNREPRESENTABLE_TYPE_STRING) || isIntersectionType) { buffer.append("var"); } else { buffer.appendCastTypeName(definitionType); @@ -514,6 +515,14 @@ public String toString() { return "VarExprent[" + index + ',' + version + (definition ? " Def" : "") + "]: {" + super.toString() + "}"; } + public void setIntersectionType(boolean intersection) { + this.isIntersectionType = intersection; + } + + public boolean isIntersectionType() { + return this.isIntersectionType; + } + // ***************************************************************************** // IMatchable implementation // ***************************************************************************** diff --git a/test/org/jetbrains/java/decompiler/SingleClassesTest.java b/test/org/jetbrains/java/decompiler/SingleClassesTest.java index 2d462889c2..9b9ca4703c 100644 --- a/test/org/jetbrains/java/decompiler/SingleClassesTest.java +++ b/test/org/jetbrains/java/decompiler/SingleClassesTest.java @@ -706,6 +706,7 @@ private void registerDefault() { register(JAVA_8, "TestInnerClassesJ8"); register(JAVA_8, "TestSwitchInTry"); register(JAVA_21, "TestSwitchPatternMatchingJ21"); + register(JAVA_21, "TestCastIntersectionJ21"); } private void registerEntireClassPath() { diff --git a/testData/results/pkg/TestCastIntersectionJ21.dec b/testData/results/pkg/TestCastIntersectionJ21.dec new file mode 100644 index 0000000000..a19ed62c35 --- /dev/null +++ b/testData/results/pkg/TestCastIntersectionJ21.dec @@ -0,0 +1,128 @@ +package pkg; + +public class TestCastIntersectionJ21 { + public void test1(TestCastIntersectionJ21.I1 i1) { + this.method((TestCastIntersectionJ21.I1 & TestCastIntersectionJ21.I2)i1);// 5 + }// 6 + + public void test2(TestCastIntersectionJ21.I2 i2) { + this.method((TestCastIntersectionJ21.I1 & TestCastIntersectionJ21.I2)i2);// 9 + }// 10 + + public void test3(TestCastIntersectionJ21.I1 i1) { + var i = (TestCastIntersectionJ21.I1 & TestCastIntersectionJ21.I2)i1;// 13 + this.method(i);// 14 + }// 15 + + public void test4(TestCastIntersectionJ21.I2 i2) { + var i = (TestCastIntersectionJ21.I1 & TestCastIntersectionJ21.I2)i2;// 18 + this.method(i);// 19 + }// 20 + + public void test5(TestCastIntersectionJ21.I2 i2) { + var i = (TestCastIntersectionJ21.I1 & TestCastIntersectionJ21.I2)i2;// 23 + ((TestCastIntersectionJ21.I2)i).method();// 24 + }// 25 + + public void method(I i) { + }// 28 + + private static class I1 { + } + + private interface I2 { + void method(); + } +} + +class 'pkg/TestCastIntersectionJ21' { + method 'test1 (Lpkg/TestCastIntersectionJ21$I1;)V' { + 0 4 + 1 4 + 5 4 + 6 4 + 7 4 + 8 4 + 9 4 + a 4 + b 5 + } + + method 'test2 (Lpkg/TestCastIntersectionJ21$I2;)V' { + 0 8 + 1 8 + 2 8 + 3 8 + 4 8 + 5 8 + 6 8 + 7 8 + 8 9 + } + + method 'test3 (Lpkg/TestCastIntersectionJ21$I1;)V' { + 0 12 + 4 12 + 5 12 + 6 12 + 7 12 + 8 13 + 9 13 + a 13 + b 13 + c 13 + d 14 + } + + method 'test4 (Lpkg/TestCastIntersectionJ21$I2;)V' { + 0 17 + 1 17 + 2 17 + 3 17 + 4 17 + 5 18 + 6 18 + 7 18 + 8 18 + 9 18 + a 19 + } + + method 'test5 (Lpkg/TestCastIntersectionJ21$I2;)V' { + 0 22 + 1 22 + 2 22 + 3 22 + 4 22 + 5 23 + 6 23 + 7 23 + 8 23 + 9 23 + a 23 + b 23 + c 23 + d 23 + e 24 + } + + method 'method (Lpkg/TestCastIntersectionJ21$I1;)V' { + 0 27 + } +} + +Lines mapping: +5 <-> 5 +6 <-> 6 +9 <-> 9 +10 <-> 10 +13 <-> 13 +14 <-> 14 +15 <-> 15 +18 <-> 18 +19 <-> 19 +20 <-> 20 +23 <-> 23 +24 <-> 24 +25 <-> 25 +28 <-> 28 diff --git a/testData/results/pkg/TestKotlinConstructorKt.dec b/testData/results/pkg/TestKotlinConstructorKt.dec index dca2b8ebf4..682b7ef365 100644 --- a/testData/results/pkg/TestKotlinConstructorKt.dec +++ b/testData/results/pkg/TestKotlinConstructorKt.dec @@ -17,7 +17,7 @@ public final class TestKotlinConstructorKt { private static final List foo(Collection list) { Iterable $this$map$iv = list; int $i$f$map = 0; - Collection destination$iv$iv = new ArrayList(CollectionsKt.collectionSizeOrDefault($this$map$iv, 10)); + var destination$iv$iv = new ArrayList(CollectionsKt.collectionSizeOrDefault($this$map$iv, 10)); int $i$f$mapTo = 0; for (Object item$iv$iv : $this$map$iv) {// 12 13 @@ -109,9 +109,6 @@ class 'pkg/TestKotlinConstructorKt' { 66 31 6b 34 6c 34 - 6d 34 - 6e 34 - 6f 34 71 34 72 34 73 34 diff --git a/testData/src/java21/pkg/TestCastIntersectionJ21.java b/testData/src/java21/pkg/TestCastIntersectionJ21.java new file mode 100644 index 0000000000..81f5eae5b1 --- /dev/null +++ b/testData/src/java21/pkg/TestCastIntersectionJ21.java @@ -0,0 +1,36 @@ +package pkg; + +public class TestCastIntersectionJ21 { + public void test1(I1 i1) { + method((I1 & I2) i1); + } + + public void test2(I2 i2) { + method((I1 & I2) i2); + } + + public void test3(I1 i1) { + var i = (I1 & I2) i1; + method(i); + } + + public void test4(I2 i2) { + var i = (I1 & I2) i2; + method(i); + } + + public void test5(I2 i2) { + var i = (I1 & I2) i2; + i.method(); + } + + public void method(I i) { + } + + private static class I1 { + } + + private static interface I2 { + public void method(); + } +}