Skip to content

Commit

Permalink
Add support for intersection types in casts (#405)
Browse files Browse the repository at this point in the history
* Fix intersection casts for direct methods

* Add support for variable assignments

* Use var for variables that represent intersection types and update tests

* Add comments and remove println

* Add support for when an intersection type is casted back to it's
original type
  • Loading branch information
coehlrich authored Jun 23, 2024
1 parent 883a7aa commit 4ecf1dd
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ public static RootStatement codeToJava(StructClass cl, StructMethod mt, MethodDe
continue;
}

if (IntersectionCastProcessor.makeIntersectionCasts(root)) {
decompileRecord.add("intersectionCasts", root);
continue;
}

if (DecompilerContext.getOption(IFernflowerPreferences.PATTERN_MATCHING)) {
if (cl.getVersion().hasIfPatternMatching()) {
if (IfPatternMatchProcessor.matchInstanceof(root)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
package org.jetbrains.java.decompiler.modules.decompiler;

import org.jetbrains.java.decompiler.code.CodeConstants;
import org.jetbrains.java.decompiler.main.DecompilerContext;
import org.jetbrains.java.decompiler.modules.decompiler.exps.*;
import org.jetbrains.java.decompiler.modules.decompiler.exps.FunctionExprent.FunctionType;
import org.jetbrains.java.decompiler.modules.decompiler.stats.RootStatement;
import org.jetbrains.java.decompiler.modules.decompiler.stats.Statement;
import org.jetbrains.java.decompiler.struct.StructClass;
import org.jetbrains.java.decompiler.struct.StructMethod;
import org.jetbrains.java.decompiler.struct.gen.CodeType;
import org.jetbrains.java.decompiler.struct.gen.TypeFamily;
import org.jetbrains.java.decompiler.struct.gen.VarType;
import org.jetbrains.java.decompiler.struct.gen.generics.GenericMethodDescriptor;
import org.jetbrains.java.decompiler.util.Pair;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class IntersectionCastProcessor {

public static boolean makeIntersectionCasts(RootStatement root) {
return makeIntersectionCastsRec(root, root);
}

private static boolean makeIntersectionCastsRec(Statement stat, RootStatement root) {
boolean result = false;
if (stat.getExprents() != null) {
for (Exprent e : stat.getExprents()) {
result |= makeIntersectionCasts(e, root);
}
} else {
for (Object o : stat.getSequentialObjects()) {
if (o instanceof Statement s) {
result |= makeIntersectionCastsRec(s, root);
} else if (o instanceof Exprent e) {
result |= makeIntersectionCasts(e, root);
}
}
}
return result;
}

private static boolean makeIntersectionCasts(Exprent exp, RootStatement root) {
if (exp instanceof InvocationExprent inv) {
if (handleInvocation(inv, root)) {
return true;
}
} else if (exp instanceof AssignmentExprent assignment) {
if (handleAssignment(assignment, root)) {
return true;
}
}
boolean result = false;
for (Exprent sub : exp.getAllExprents()) {
result |= makeIntersectionCasts(sub, root);
}
return result;
}

private static boolean handleInvocation(InvocationExprent exp, RootStatement root) {
List<Exprent> lstParameters = exp.getLstParameters();
boolean result = false;
for (int i = 0; i < lstParameters.size(); i++) {
Exprent parameter = lstParameters.get(i);
if (parameter instanceof FunctionExprent cast && isValidCast(cast)) {
Pair<List<Exprent>, Exprent> casts = getCasts(cast);
List<Exprent> types = casts.a;
Exprent inner = casts.b;
// Checks for any bounds not supported by the current list of casts
List<VarType> bounds = getBounds(exp, i).stream()
.filter(bound -> !types
.stream()
.anyMatch(constant -> DecompilerContext.getStructContext().instanceOf(constant.getExprType().value, bound.value)))
.toList();

// Checks if the original type supports the remaining bounds
if (!bounds.isEmpty() && bounds.stream().allMatch(bound -> DecompilerContext.getStructContext().instanceOf(inner.getExprType().value, bound.value))) {
types.add(new ConstExprent(inner.getExprType(), null, null));
}
result |= replaceCasts(cast, types, inner);
}
}
return result;
}

private static boolean handleAssignment(AssignmentExprent exp, RootStatement root) {
if (exp.getLeft() instanceof VarExprent varExp) {
Exprent assigned = exp.getRight();
if (assigned instanceof FunctionExprent cast && isValidCast(cast)) {
Pair<List<Exprent>, Exprent> casts = getCasts(cast);
List<Exprent> types = casts.a;
Exprent inner = casts.b;
List<VariablePosition> references = findReferences(varExp, root);

// Convert the variable references into a set of bounds
Set<VarType> bounds = new HashSet<>();
for (VariablePosition position : references) {
bounds.addAll(switch (position.position) {
case METHOD_PARAMETER -> getBounds((InvocationExprent) position.exp, position.index);
case CASTED -> {
FunctionExprent func = (FunctionExprent) position.exp;
if (func.getLstOperands().size() == 2) {
yield List.of(func.getLstOperands().get(1).getExprType());
} else {
yield List.of();
}
}
});
}

// Checks for any bounds not supported by the current list of casts
bounds = bounds.stream()
.filter(bound -> !types
.stream()
.anyMatch(constant -> DecompilerContext.getStructContext().instanceOf(constant.getExprType().value, bound.value)))
.collect(Collectors.toSet());

// Checks if the original type supports the remaining bounds
if (!bounds.isEmpty() && bounds.stream().anyMatch(bound -> DecompilerContext.getStructContext().instanceOf(inner.getExprType().value, bound.value))) {
types.add(new ConstExprent(inner.getExprType(), null, null));
}
if (replaceCasts(cast, types, inner)) {
// If the casts were replaced make sure that the variable uses "var" instead of
// a type
varExp.setIntersectionType(true);
return true;
}
}
}
return false;
}

private static List<VarType> getBounds(InvocationExprent exp, int parameter) {
// Gets the bounds of a type parameter of a parameter of a method
StructMethod method = exp.getDesc();
GenericMethodDescriptor gmd = method != null ? method.getSignature() : null;
int start = gmd != null && DecompilerContext.getStructContext().getClass(method.getClassQualifiedName()).hasModifier(CodeConstants.ACC_ENUM) && method.getName().equals(CodeConstants.INIT_NAME) ? 2 : 0;
if (gmd != null) {
int index = parameter - start;
VarType type = gmd.parameterTypes.get(index);
if (type.type == CodeType.GENVAR) {
int typeParameterIndex = gmd.typeParameters.indexOf(type.value);
if (typeParameterIndex != -1) {
return gmd.typeParameterBounds.get(typeParameterIndex);
}
}
}
return List.of();
}

/**
* Searches for where a variable is referenced and returns the context
*/
private static List<VariablePosition> findReferences(VarExprent varExp, RootStatement root) {
List<VariablePosition> list = new ArrayList<>();
findReferencesRec(varExp, root, root, list);
return list;
}

private static void findReferencesRec(VarExprent varExp, Statement stat, RootStatement root, List<VariablePosition> list) {
if (stat.getExprents() != null) {
for (Exprent e : stat.getExprents()) {
findReferences(varExp, e, root, list);
}
} else {
for (Object o : stat.getSequentialObjects()) {
if (o instanceof Statement s) {
findReferencesRec(varExp, s, root, list);
} else if (o instanceof Exprent e) {
findReferences(varExp, e, root, list);
}
}
}
}

private static void findReferences(VarExprent varExp, Exprent exp, RootStatement root, List<VariablePosition> list) {
if (exp instanceof InvocationExprent inv) {
findReferences(varExp, inv, list);
} else if (exp instanceof FunctionExprent func && func.getFuncType() == FunctionType.CAST) {
if (func.getLstOperands().get(0) instanceof VarExprent otherVar && varExp.getVarVersionPair().equals(otherVar.getVarVersionPair())) {
list.add(new VariablePosition(VariablePositionEnum.CASTED, exp, -1));
}
}
for (Exprent sub : exp.getAllExprents()) {
findReferences(varExp, sub, root, list);
}
}

private static void findReferences(VarExprent varExp, InvocationExprent inv, List<VariablePosition> list) {
List<Exprent> lstParameters = inv.getLstParameters();
for (int i = 0; i < lstParameters.size(); i++) {
Exprent parameter = lstParameters.get(i);
if (parameter instanceof VarExprent otherVar && varExp.getVarVersionPair().equals(otherVar.getVarVersionPair())) {
list.add(new VariablePosition(VariablePositionEnum.METHOD_PARAMETER, inv, i));
}
}
}

private static Pair<List<Exprent>, Exprent> getCasts(Exprent exp) {
// Gets the list of casts done and gets the original exprent
List<Exprent> types = new ArrayList<>();
Exprent inner = exp;
while (inner instanceof FunctionExprent cast && isValidCast(cast)) {
types.add(cast.getLstOperands().get(1));
inner = cast.getLstOperands().get(0);
}
return Pair.of(types, inner);
}

private static boolean isValidCast(FunctionExprent cast) {
if (cast.getFuncType() == FunctionType.CAST && cast.getLstOperands().size() == 2) {
VarType type = cast.getLstOperands().get(1).getExprType();
// Intersection casts cannot include arrays
return type.typeFamily == TypeFamily.OBJECT && type.arrayDim == 0;
}
return false;
}

private static boolean replaceCasts(FunctionExprent cast, List<Exprent> types, Exprent inner) {
if (types.size() > 1) {
// Reorders the list of types to make sure that the class is always first
Exprent nonInterface = null;
for (Exprent type : types) {
StructClass clazz = DecompilerContext.getStructContext().getClass(type.getExprType().value);
if (clazz != null && !clazz.hasModifier(CodeConstants.ACC_INTERFACE)) {
if (nonInterface == null) {
nonInterface = type;
} else {
return false;
}
}
}
if (nonInterface != null) {
types.remove(types.indexOf(nonInterface));
types.add(0, nonInterface);
}
// Replaces the operands of the cast with the casted exprent and the list of needed casts
cast.getLstOperands().clear();
cast.getLstOperands().add(inner);
cast.getLstOperands().addAll(types);
return true;
}
return false;
}

private static record VariablePosition(VariablePositionEnum position, Exprent exp, int index) {

}

private static enum VariablePositionEnum {
METHOD_PARAMETER,
CASTED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ public VarType getInferredExprType(VarType upperBound) {
if (funcType == FunctionType.CAST) {
this.needsCast = true;
VarType right = lstOperands.get(0).getInferredExprType(upperBound);
VarType cast = lstOperands.get(1).getExprType();
List<VarType> cast = lstOperands.subList(1, lstOperands.size()).stream().map(Exprent::getExprType).toList();

if (upperBound != null && (upperBound.isGeneric() || right.isGeneric())) {
Map<VarType, List<VarType>> names = this.getNamedGenerics();
Expand All @@ -258,12 +258,8 @@ public VarType getInferredExprType(VarType upperBound) {
}

if (types != null) {
boolean anyMatch = false; //TODO: allMatch instead of anyMatch?
for (VarType type : types) {
anyMatch |= DecompilerContext.getStructContext().instanceOf(type.value, cast.value);
}

if (anyMatch) {
List<VarType> finalTypes = types;
if (cast.stream().allMatch(castType -> finalTypes.stream().anyMatch(type -> DecompilerContext.getStructContext().instanceOf(type.value, castType.value)))) {
this.needsCast = false;
}
} else {
Expand All @@ -278,7 +274,8 @@ public VarType getInferredExprType(VarType upperBound) {
return right;
}
} else { //TODO: Capture generics to make cast better?
this.needsCast = right.type == CodeType.NULL || !DecompilerContext.getStructContext().instanceOf(right.value, cast.value) || right.arrayDim != cast.arrayDim;
final VarType finalRight = right;
this.needsCast = right.type == CodeType.NULL || cast.stream().anyMatch(castType -> !DecompilerContext.getStructContext().instanceOf(finalRight.value, castType.value)) || cast.stream().anyMatch(castType -> finalRight.arrayDim != castType.arrayDim);
}

return getExprType();
Expand Down Expand Up @@ -606,7 +603,13 @@ else if (left instanceof ConstExprent) {
if (!needsCast) {
return buf.append(lstOperands.get(0).toJava(indent));
}
return buf.append(lstOperands.get(1).toJava(indent)).encloseWithParens().append(wrapOperandString(lstOperands.get(0), true, indent));
for (int i = 1; i < lstOperands.size(); i++) {
if (i > 1) {
buf.append(" & ");
}
buf.append(lstOperands.get(i).toJava(indent));
}
return buf.encloseWithParens().append(wrapOperandString(lstOperands.get(0), true, indent));
case ARRAY_LENGTH:
Exprent arr = lstOperands.get(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,10 @@ public Map<VarType, VarType> getGenericsMap() {
}

public StructMethod getDesc() {
if (desc == null) {
StructClass cl = DecompilerContext.getStructContext().getClass(classname);
desc = cl != null ? cl.getMethodRecursive(name, stringDescriptor) : null;
}
return desc;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class VarExprent extends Exprent {
private Instruction backing = null;
private boolean isEffectivelyFinal = false;
private VarType boundType;
private boolean isIntersectionType = false;

public VarExprent(int index, VarType varType, VarProcessor processor) {
this(index, varType, processor, null);
Expand Down Expand Up @@ -131,7 +132,7 @@ public TextBuffer toJava(int indent) {
}
VarType definitionType = getDefinitionVarType();
String name = ExprProcessor.getCastTypeName(definitionType);
if (name.equals(ExprProcessor.UNREPRESENTABLE_TYPE_STRING)) {
if (name.equals(ExprProcessor.UNREPRESENTABLE_TYPE_STRING) || isIntersectionType) {
buffer.append("var");
} else {
buffer.appendCastTypeName(definitionType);
Expand Down Expand Up @@ -514,6 +515,14 @@ public String toString() {
return "VarExprent[" + index + ',' + version + (definition ? " Def" : "") + "]: {" + super.toString() + "}";
}

public void setIntersectionType(boolean intersection) {
this.isIntersectionType = intersection;
}

public boolean isIntersectionType() {
return this.isIntersectionType;
}

// *****************************************************************************
// IMatchable implementation
// *****************************************************************************
Expand Down
1 change: 1 addition & 0 deletions test/org/jetbrains/java/decompiler/SingleClassesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ private void registerDefault() {
register(JAVA_8, "TestInnerClassesJ8");
register(JAVA_8, "TestSwitchInTry");
register(JAVA_21, "TestSwitchPatternMatchingJ21");
register(JAVA_21, "TestCastIntersectionJ21");
}

private void registerEntireClassPath() {
Expand Down
Loading

0 comments on commit 4ecf1dd

Please sign in to comment.