Skip to content

Commit

Permalink
Safety propagation takes into account known subtypes (#2703)
Browse files Browse the repository at this point in the history
Safety propagation takes into account known subtypes
  • Loading branch information
carterkozak authored Nov 26, 2024
1 parent 9f6902e commit 7208cd6
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@
import com.google.errorprone.fixes.SuggestedFixes;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.CatchTree;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.NewArrayTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.TryTree;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symbol.ClassSymbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Types;
import java.lang.reflect.InvocationTargetException;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -123,5 +129,39 @@ public static Type getResultType(Tree tree) {
: ASTHelpers.getType(tree);
}

public static boolean isSealed(@Nullable ClassSymbol classSymbol) {
if (classSymbol == null) {
return false;
}
long flags = classSymbol.flags();
return (flags & (1L << 62)) != 0;
}

public static List<ExpressionTree> getPermitsClause(@Nullable ClassTree classTree) {
if (classTree == null) {
return Collections.emptyList();
}
try {
return (List<ExpressionTree>)
ClassTree.class.getMethod("getPermitsClause").invoke(classTree);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException("Failed to extract permitted classes", e);
} catch (NoSuchMethodException e) {
// This is expected on older JDKs which do not support the permits clause
return Collections.emptyList();
}
}

public static List<? extends ExpressionTree> unwrapArray(@Nullable ExpressionTree expressionTree) {
if (expressionTree == null) {
return Collections.emptyList();
}
if (expressionTree instanceof NewArrayTree) {
NewArrayTree tree = (NewArrayTree) expressionTree;
return tree.getInitializers();
}
return Collections.singletonList(expressionTree);
}

private MoreASTHelpers() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ public Description matchClass(ClassTree classTree, VisitorState state) {
if (classSymbol == null || classSymbol.isAnonymous()) {
return Description.NO_MATCH;
}
TypeSymbol tsym = classSymbol.type.tsym;
tsym.getModifiers();
if (ASTHelpers.isRecord(classSymbol)) {
return matchRecord(classTree, classSymbol, state);
} else {
Expand Down Expand Up @@ -146,6 +148,7 @@ private static boolean immutablesDefaultAsDefault(Attribute.Compound styleAnnota
private Description matchRecord(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) {
Safety existingClassSafety = SafetyAnnotations.getSafety(classTree, state);
Safety safety = SafetyAnnotations.getTypeSafetyFromAncestors(classTree, state);
safety = safety.leastUpperBound(SafetyAnnotations.getTypeSafetyFromKnownSubtypes(classTree, state));
for (VarSymbol recordComponent : Records.getRecordComponents(classSymbol)) {
Safety symbolSafety = SafetyAnnotations.getSafety(recordComponent, state);
Safety typeSafety = SafetyAnnotations.getSafety(recordComponent.type, state);
Expand All @@ -160,7 +163,7 @@ private Description matchClassOrInterface(ClassTree classTree, ClassSymbol class
if (ASTHelpers.hasAnnotation(classSymbol, "org.immutables.value.Value.Immutable", state)) {
return matchImmutables(classTree, classSymbol, state);
}
return matchBasedOnToString(classTree, classSymbol, state);
return matchArbitraryObject(classTree, classSymbol, state);
}

private static boolean isImmutablesField(
Expand Down Expand Up @@ -239,22 +242,31 @@ private static Safety scanSymbolMethods(ClassSymbol begin, VisitorState state, b
private Description matchImmutables(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) {
Safety existingClassSafety = SafetyAnnotations.getAnnotatedSafety(classTree, state);
Safety safety = SafetyAnnotations.getTypeSafetyFromAncestors(classTree, state);
safety = safety.leastUpperBound(SafetyAnnotations.getTypeSafetyFromKnownSubtypes(classTree, state));
boolean isJson = hasJacksonAnnotation(classSymbol, state);
ClassSymbol symbol = ASTHelpers.getSymbol(classTree);
Safety scanned = scanSymbolMethods(symbol, state, isJson);
safety = safety.leastUpperBound(scanned);
return handleSafety(classTree, classTree.getModifiers(), state, existingClassSafety, safety);
}

private Description matchBasedOnToString(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) {
private Safety getToStringSafety(ClassSymbol classSymbol, VisitorState state) {
MethodSymbol toStringSymbol = ASTHelpers.resolveExistingMethod(
state, classSymbol, TO_STRING_NAME.get(state), ImmutableList.of(), ImmutableList.of());
if (toStringSymbol == null) {
return Description.NO_MATCH;
}
return SafetyAnnotations.getSafety(toStringSymbol, state);
}

private Description matchArbitraryObject(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) {
Safety toStringSafety = getToStringSafety(classSymbol, state);
Safety subtypeSafety = SafetyAnnotations.getTypeSafetyFromKnownSubtypes(classTree, state);
Safety ancestorSafety = SafetyAnnotations.getTypeSafetyFromAncestors(classTree, state);
Safety existingClassSafety = SafetyAnnotations.getSafety(classTree, state);
Safety symbolSafety = SafetyAnnotations.getSafety(toStringSymbol, state);
return handleSafety(classTree, classTree.getModifiers(), state, existingClassSafety, symbolSafety);
return handleSafety(
classTree,
classTree.getModifiers(),
state,
existingClassSafety,
Safety.mergeAssumingUnknownIsSame(toStringSafety, subtypeSafety, ancestorSafety));
}

private Description handleSafety(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@

import com.google.common.collect.Multimap;
import com.google.errorprone.VisitorState;
import com.google.errorprone.matchers.AnnotationMatcherUtils;
import com.google.errorprone.matchers.AnnotationType;
import com.google.errorprone.matchers.Matcher;
import com.google.errorprone.suppliers.Suppliers;
import com.google.errorprone.util.ASTHelpers;
import com.palantir.baseline.errorprone.MoreASTHelpers;
import com.sun.source.tree.AnnotationTree;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.ExpressionTree;
Expand Down Expand Up @@ -61,6 +65,11 @@ public final class SafetyAnnotations {
private static final com.google.errorprone.suppliers.Supplier<Type> throwableSupplier =
Suppliers.typeFromClass(Throwable.class);

private static final Matcher<AnnotationTree> JSON_TYPE_INFO_MATCHER =
new AnnotationType("com.fasterxml.jackson.annotation.JsonTypeInfo");
private static final Matcher<AnnotationTree> JSON_SUBTYPES_MATCHER =
new AnnotationType("com.fasterxml.jackson.annotation.JsonSubTypes");

private static final TypeArgumentHandlers SAFETY_IS_COMBINATION_OF_TYPE_ARGUMENTS = new TypeArgumentHandlers(
new TypeArgumentHandler(Iterable.class),
new TypeArgumentHandler(Iterator.class),
Expand Down Expand Up @@ -152,6 +161,41 @@ public static Safety getTypeSafetyFromAncestors(ClassTree classTree, VisitorStat
return safety;
}

public static Safety getTypeSafetyFromKnownSubtypes(ClassTree classTree, VisitorState state) {
Safety safety = Safety.UNKNOWN;
ClassSymbol symbol = ASTHelpers.getSymbol(classTree);
if (MoreASTHelpers.isSealed(symbol)) {
for (ExpressionTree permitted : MoreASTHelpers.getPermitsClause(classTree)) {
safety = Safety.mergeAssumingUnknownIsSame(safety, SafetyAnnotations.getSafety(permitted, state));
}
}
for (AnnotationTree annotationTree : classTree.getModifiers().getAnnotations()) {
if (JSON_TYPE_INFO_MATCHER.matches(annotationTree, state)) {
ExpressionTree expressionTree = AnnotationMatcherUtils.getArgument(annotationTree, "defaultImpl");
if (expressionTree != null) {
Safety defaultImplSafety =
SafetyAnnotations.getSafety(ASTHelpers.getReceiver(expressionTree), state);
safety = Safety.mergeAssumingUnknownIsSame(safety, defaultImplSafety);
}
} else if (JSON_SUBTYPES_MATCHER.matches(annotationTree, state)) {
ExpressionTree tree = AnnotationMatcherUtils.getArgument(annotationTree, "value");
for (ExpressionTree subtype : MoreASTHelpers.unwrapArray(tree)) {
if (subtype instanceof AnnotationTree) {
ExpressionTree value = AnnotationMatcherUtils.getArgument((AnnotationTree) subtype, "value");
if (value != null) {
Safety subtypeSafety = SafetyAnnotations.getSafety(ASTHelpers.getReceiver(value), state);
safety = Safety.mergeAssumingUnknownIsSame(safety, subtypeSafety);
}
}
}

safety = Safety.mergeAssumingUnknownIsSame(safety, safety);
}
}

return safety;
}

public static Safety getDirectSafety(@Nullable Symbol symbol, VisitorState state) {
if (symbol != null) {
if (containsAttributeNamed(symbol, doNotLogName.get(state))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,28 @@ void testSafetyAnnotatedReturnTypeDoesNotAnnotateMethod() {
.doTest();
}

@Test
void testAddsAnnotation_sealedTypes() {
fix().addInputLines(
"Test.java",
"import com.palantir.logsafe.*;",
"class Test {",
" sealed interface Base permits Dnl {}",
" @DoNotLog",
" final class Dnl implements Base {}",
"}")
.addOutputLines(
"Test.java",
"import com.palantir.logsafe.*;",
"class Test {",
" @DoNotLog",
" sealed interface Base permits Dnl {}",
" @DoNotLog",
" final class Dnl implements Base {}",
"}")
.doTest();
}

@Test
void testSafetyAnnotatedArrayTypeDoesNotAnnotateMethod() {
fix().addInputLines(
Expand All @@ -1065,6 +1087,46 @@ void testSafetyAnnotatedArrayTypeDoesNotAnnotateMethod() {
.doTest();
}

@Test
void testAddsAnnotation_jacksonSubTypes_defaultImpl() {
fix().addInputLines(
"Test.java",
"import com.fasterxml.jackson.annotation.*;",
"import com.palantir.logsafe.*;",
"class Test {",
" @JsonTypeInfo(",
" use = JsonTypeInfo.Id.NAME,",
" property = \"type\",",
" defaultImpl = Dnl.class)",
" @JsonSubTypes(value = {",
" @JsonSubTypes.Type(value = Unmarked.class, name = \"u\")",
" })",
" interface Base {}",
" @DoNotLog",
" class Dnl implements Base {}",
" class Unmarked implements Base {}",
"}")
.addOutputLines(
"Test.java",
"import com.fasterxml.jackson.annotation.*;",
"import com.palantir.logsafe.*;",
"class Test {",
" @DoNotLog",
" @JsonTypeInfo(",
" use = JsonTypeInfo.Id.NAME,",
" property = \"type\",",
" defaultImpl = Dnl.class)",
" @JsonSubTypes(value = {",
" @JsonSubTypes.Type(value = Unmarked.class, name = \"u\")",
" })",
" interface Base {}",
" @DoNotLog",
" class Dnl implements Base {}",
" class Unmarked implements Base {}",
"}")
.doTest();
}

@Test
void testSafetyAnnotatedCollectionTypeDoesNotAnnotateMethod() {
fix().addInputLines(
Expand All @@ -1080,6 +1142,78 @@ void testSafetyAnnotatedCollectionTypeDoesNotAnnotateMethod() {
.doTest();
}

@Test
void testAddsAnnotation_jacksonSubTypes_subtypes_array() {
fix().addInputLines(
"Test.java",
"import com.fasterxml.jackson.annotation.*;",
"import com.palantir.logsafe.*;",
"class Test {",
" @JsonTypeInfo(",
" use = JsonTypeInfo.Id.NAME,",
" property = \"type\")",
" @JsonSubTypes(value = {",
" @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")",
" })",
" interface Base {}",
" @DoNotLog",
" class Dnl implements Base {}",
"}")
.addOutputLines(
"Test.java",
"import com.fasterxml.jackson.annotation.*;",
"import com.palantir.logsafe.*;",
"class Test {",
" @DoNotLog",
" @JsonTypeInfo(",
" use = JsonTypeInfo.Id.NAME,",
" property = \"type\")",
" @JsonSubTypes(value = {",
" @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")",
" })",
" interface Base {}",
" @DoNotLog",
" class Dnl implements Base {}",
"}")
.doTest();
}

@Test
void testAddsAnnotation_jacksonSubTypes_subtypes_implicitArray() {
fix().addInputLines(
"Test.java",
"import com.fasterxml.jackson.annotation.*;",
"import com.palantir.logsafe.*;",
"class Test {",
" @JsonTypeInfo(",
" use = JsonTypeInfo.Id.NAME,",
" property = \"type\")",
" @JsonSubTypes(",
" @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")",
" )",
" interface Base {}",
" @DoNotLog",
" class Dnl implements Base {}",
"}")
.addOutputLines(
"Test.java",
"import com.fasterxml.jackson.annotation.*;",
"import com.palantir.logsafe.*;",
"class Test {",
" @DoNotLog",
" @JsonTypeInfo(",
" use = JsonTypeInfo.Id.NAME,",
" property = \"type\")",
" @JsonSubTypes(",
" @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")",
" )",
" interface Base {}",
" @DoNotLog",
" class Dnl implements Base {}",
"}")
.doTest();
}

private RefactoringValidator fix(String... args) {
return RefactoringValidator.of(SafeLoggingPropagation.class, getClass(), args);
}
Expand Down
5 changes: 5 additions & 0 deletions changelog/@unreleased/pr-2703.v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: improvement
improvement:
description: Safety propagation takes into account known subtypes
links:
- https://github.com/palantir/gradle-baseline/pull/2703

0 comments on commit 7208cd6

Please sign in to comment.