From 7208cd671b53a84a7a06b6c07e9ee74e307e9838 Mon Sep 17 00:00:00 2001 From: Carter Kozak Date: Tue, 26 Nov 2024 10:45:35 -0500 Subject: [PATCH] Safety propagation takes into account known subtypes (#2703) Safety propagation takes into account known subtypes --- .../baseline/errorprone/MoreASTHelpers.java | 40 ++++++ .../errorprone/SafeLoggingPropagation.java | 26 +++- .../errorprone/safety/SafetyAnnotations.java | 44 ++++++ .../SafeLoggingPropagationTest.java | 134 ++++++++++++++++++ changelog/@unreleased/pr-2703.v2.yml | 5 + 5 files changed, 242 insertions(+), 7 deletions(-) create mode 100644 changelog/@unreleased/pr-2703.v2.yml diff --git a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/MoreASTHelpers.java b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/MoreASTHelpers.java index 9432e7931..98a1fd8d4 100644 --- a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/MoreASTHelpers.java +++ b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/MoreASTHelpers.java @@ -22,13 +22,19 @@ import com.google.errorprone.fixes.SuggestedFixes; import com.google.errorprone.util.ASTHelpers; import com.sun.source.tree.CatchTree; +import com.sun.source.tree.ClassTree; import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.NewArrayTree; import com.sun.source.tree.Tree; import com.sun.source.tree.TryTree; import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Symbol.ClassSymbol; import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.Types; +import java.lang.reflect.InvocationTargetException; +import java.util.Collections; import java.util.Comparator; +import java.util.List; import java.util.Optional; import javax.annotation.Nullable; @@ -123,5 +129,39 @@ public static Type getResultType(Tree tree) { : ASTHelpers.getType(tree); } + public static boolean isSealed(@Nullable ClassSymbol classSymbol) { + if (classSymbol == null) { + return false; + } + long flags = classSymbol.flags(); + return (flags & (1L << 62)) != 0; + } + + public static List getPermitsClause(@Nullable ClassTree classTree) { + if (classTree == null) { + return Collections.emptyList(); + } + try { + return (List) + ClassTree.class.getMethod("getPermitsClause").invoke(classTree); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Failed to extract permitted classes", e); + } catch (NoSuchMethodException e) { + // This is expected on older JDKs which do not support the permits clause + return Collections.emptyList(); + } + } + + public static List unwrapArray(@Nullable ExpressionTree expressionTree) { + if (expressionTree == null) { + return Collections.emptyList(); + } + if (expressionTree instanceof NewArrayTree) { + NewArrayTree tree = (NewArrayTree) expressionTree; + return tree.getInitializers(); + } + return Collections.singletonList(expressionTree); + } + private MoreASTHelpers() {} } diff --git a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/SafeLoggingPropagation.java b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/SafeLoggingPropagation.java index f288048a7..e4c2351bb 100644 --- a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/SafeLoggingPropagation.java +++ b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/SafeLoggingPropagation.java @@ -92,6 +92,8 @@ public Description matchClass(ClassTree classTree, VisitorState state) { if (classSymbol == null || classSymbol.isAnonymous()) { return Description.NO_MATCH; } + TypeSymbol tsym = classSymbol.type.tsym; + tsym.getModifiers(); if (ASTHelpers.isRecord(classSymbol)) { return matchRecord(classTree, classSymbol, state); } else { @@ -146,6 +148,7 @@ private static boolean immutablesDefaultAsDefault(Attribute.Compound styleAnnota private Description matchRecord(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) { Safety existingClassSafety = SafetyAnnotations.getSafety(classTree, state); Safety safety = SafetyAnnotations.getTypeSafetyFromAncestors(classTree, state); + safety = safety.leastUpperBound(SafetyAnnotations.getTypeSafetyFromKnownSubtypes(classTree, state)); for (VarSymbol recordComponent : Records.getRecordComponents(classSymbol)) { Safety symbolSafety = SafetyAnnotations.getSafety(recordComponent, state); Safety typeSafety = SafetyAnnotations.getSafety(recordComponent.type, state); @@ -160,7 +163,7 @@ private Description matchClassOrInterface(ClassTree classTree, ClassSymbol class if (ASTHelpers.hasAnnotation(classSymbol, "org.immutables.value.Value.Immutable", state)) { return matchImmutables(classTree, classSymbol, state); } - return matchBasedOnToString(classTree, classSymbol, state); + return matchArbitraryObject(classTree, classSymbol, state); } private static boolean isImmutablesField( @@ -239,6 +242,7 @@ private static Safety scanSymbolMethods(ClassSymbol begin, VisitorState state, b private Description matchImmutables(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) { Safety existingClassSafety = SafetyAnnotations.getAnnotatedSafety(classTree, state); Safety safety = SafetyAnnotations.getTypeSafetyFromAncestors(classTree, state); + safety = safety.leastUpperBound(SafetyAnnotations.getTypeSafetyFromKnownSubtypes(classTree, state)); boolean isJson = hasJacksonAnnotation(classSymbol, state); ClassSymbol symbol = ASTHelpers.getSymbol(classTree); Safety scanned = scanSymbolMethods(symbol, state, isJson); @@ -246,15 +250,23 @@ private Description matchImmutables(ClassTree classTree, ClassSymbol classSymbol return handleSafety(classTree, classTree.getModifiers(), state, existingClassSafety, safety); } - private Description matchBasedOnToString(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) { + private Safety getToStringSafety(ClassSymbol classSymbol, VisitorState state) { MethodSymbol toStringSymbol = ASTHelpers.resolveExistingMethod( state, classSymbol, TO_STRING_NAME.get(state), ImmutableList.of(), ImmutableList.of()); - if (toStringSymbol == null) { - return Description.NO_MATCH; - } + return SafetyAnnotations.getSafety(toStringSymbol, state); + } + + private Description matchArbitraryObject(ClassTree classTree, ClassSymbol classSymbol, VisitorState state) { + Safety toStringSafety = getToStringSafety(classSymbol, state); + Safety subtypeSafety = SafetyAnnotations.getTypeSafetyFromKnownSubtypes(classTree, state); + Safety ancestorSafety = SafetyAnnotations.getTypeSafetyFromAncestors(classTree, state); Safety existingClassSafety = SafetyAnnotations.getSafety(classTree, state); - Safety symbolSafety = SafetyAnnotations.getSafety(toStringSymbol, state); - return handleSafety(classTree, classTree.getModifiers(), state, existingClassSafety, symbolSafety); + return handleSafety( + classTree, + classTree.getModifiers(), + state, + existingClassSafety, + Safety.mergeAssumingUnknownIsSame(toStringSafety, subtypeSafety, ancestorSafety)); } private Description handleSafety( diff --git a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/safety/SafetyAnnotations.java b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/safety/SafetyAnnotations.java index 3da4491bf..573555339 100644 --- a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/safety/SafetyAnnotations.java +++ b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/safety/SafetyAnnotations.java @@ -18,8 +18,12 @@ import com.google.common.collect.Multimap; import com.google.errorprone.VisitorState; +import com.google.errorprone.matchers.AnnotationMatcherUtils; +import com.google.errorprone.matchers.AnnotationType; +import com.google.errorprone.matchers.Matcher; import com.google.errorprone.suppliers.Suppliers; import com.google.errorprone.util.ASTHelpers; +import com.palantir.baseline.errorprone.MoreASTHelpers; import com.sun.source.tree.AnnotationTree; import com.sun.source.tree.ClassTree; import com.sun.source.tree.ExpressionTree; @@ -61,6 +65,11 @@ public final class SafetyAnnotations { private static final com.google.errorprone.suppliers.Supplier throwableSupplier = Suppliers.typeFromClass(Throwable.class); + private static final Matcher JSON_TYPE_INFO_MATCHER = + new AnnotationType("com.fasterxml.jackson.annotation.JsonTypeInfo"); + private static final Matcher JSON_SUBTYPES_MATCHER = + new AnnotationType("com.fasterxml.jackson.annotation.JsonSubTypes"); + private static final TypeArgumentHandlers SAFETY_IS_COMBINATION_OF_TYPE_ARGUMENTS = new TypeArgumentHandlers( new TypeArgumentHandler(Iterable.class), new TypeArgumentHandler(Iterator.class), @@ -152,6 +161,41 @@ public static Safety getTypeSafetyFromAncestors(ClassTree classTree, VisitorStat return safety; } + public static Safety getTypeSafetyFromKnownSubtypes(ClassTree classTree, VisitorState state) { + Safety safety = Safety.UNKNOWN; + ClassSymbol symbol = ASTHelpers.getSymbol(classTree); + if (MoreASTHelpers.isSealed(symbol)) { + for (ExpressionTree permitted : MoreASTHelpers.getPermitsClause(classTree)) { + safety = Safety.mergeAssumingUnknownIsSame(safety, SafetyAnnotations.getSafety(permitted, state)); + } + } + for (AnnotationTree annotationTree : classTree.getModifiers().getAnnotations()) { + if (JSON_TYPE_INFO_MATCHER.matches(annotationTree, state)) { + ExpressionTree expressionTree = AnnotationMatcherUtils.getArgument(annotationTree, "defaultImpl"); + if (expressionTree != null) { + Safety defaultImplSafety = + SafetyAnnotations.getSafety(ASTHelpers.getReceiver(expressionTree), state); + safety = Safety.mergeAssumingUnknownIsSame(safety, defaultImplSafety); + } + } else if (JSON_SUBTYPES_MATCHER.matches(annotationTree, state)) { + ExpressionTree tree = AnnotationMatcherUtils.getArgument(annotationTree, "value"); + for (ExpressionTree subtype : MoreASTHelpers.unwrapArray(tree)) { + if (subtype instanceof AnnotationTree) { + ExpressionTree value = AnnotationMatcherUtils.getArgument((AnnotationTree) subtype, "value"); + if (value != null) { + Safety subtypeSafety = SafetyAnnotations.getSafety(ASTHelpers.getReceiver(value), state); + safety = Safety.mergeAssumingUnknownIsSame(safety, subtypeSafety); + } + } + } + + safety = Safety.mergeAssumingUnknownIsSame(safety, safety); + } + } + + return safety; + } + public static Safety getDirectSafety(@Nullable Symbol symbol, VisitorState state) { if (symbol != null) { if (containsAttributeNamed(symbol, doNotLogName.get(state))) { diff --git a/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/SafeLoggingPropagationTest.java b/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/SafeLoggingPropagationTest.java index 5b76a3c56..11053270e 100644 --- a/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/SafeLoggingPropagationTest.java +++ b/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/SafeLoggingPropagationTest.java @@ -1051,6 +1051,28 @@ void testSafetyAnnotatedReturnTypeDoesNotAnnotateMethod() { .doTest(); } + @Test + void testAddsAnnotation_sealedTypes() { + fix().addInputLines( + "Test.java", + "import com.palantir.logsafe.*;", + "class Test {", + " sealed interface Base permits Dnl {}", + " @DoNotLog", + " final class Dnl implements Base {}", + "}") + .addOutputLines( + "Test.java", + "import com.palantir.logsafe.*;", + "class Test {", + " @DoNotLog", + " sealed interface Base permits Dnl {}", + " @DoNotLog", + " final class Dnl implements Base {}", + "}") + .doTest(); + } + @Test void testSafetyAnnotatedArrayTypeDoesNotAnnotateMethod() { fix().addInputLines( @@ -1065,6 +1087,46 @@ void testSafetyAnnotatedArrayTypeDoesNotAnnotateMethod() { .doTest(); } + @Test + void testAddsAnnotation_jacksonSubTypes_defaultImpl() { + fix().addInputLines( + "Test.java", + "import com.fasterxml.jackson.annotation.*;", + "import com.palantir.logsafe.*;", + "class Test {", + " @JsonTypeInfo(", + " use = JsonTypeInfo.Id.NAME,", + " property = \"type\",", + " defaultImpl = Dnl.class)", + " @JsonSubTypes(value = {", + " @JsonSubTypes.Type(value = Unmarked.class, name = \"u\")", + " })", + " interface Base {}", + " @DoNotLog", + " class Dnl implements Base {}", + " class Unmarked implements Base {}", + "}") + .addOutputLines( + "Test.java", + "import com.fasterxml.jackson.annotation.*;", + "import com.palantir.logsafe.*;", + "class Test {", + " @DoNotLog", + " @JsonTypeInfo(", + " use = JsonTypeInfo.Id.NAME,", + " property = \"type\",", + " defaultImpl = Dnl.class)", + " @JsonSubTypes(value = {", + " @JsonSubTypes.Type(value = Unmarked.class, name = \"u\")", + " })", + " interface Base {}", + " @DoNotLog", + " class Dnl implements Base {}", + " class Unmarked implements Base {}", + "}") + .doTest(); + } + @Test void testSafetyAnnotatedCollectionTypeDoesNotAnnotateMethod() { fix().addInputLines( @@ -1080,6 +1142,78 @@ void testSafetyAnnotatedCollectionTypeDoesNotAnnotateMethod() { .doTest(); } + @Test + void testAddsAnnotation_jacksonSubTypes_subtypes_array() { + fix().addInputLines( + "Test.java", + "import com.fasterxml.jackson.annotation.*;", + "import com.palantir.logsafe.*;", + "class Test {", + " @JsonTypeInfo(", + " use = JsonTypeInfo.Id.NAME,", + " property = \"type\")", + " @JsonSubTypes(value = {", + " @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")", + " })", + " interface Base {}", + " @DoNotLog", + " class Dnl implements Base {}", + "}") + .addOutputLines( + "Test.java", + "import com.fasterxml.jackson.annotation.*;", + "import com.palantir.logsafe.*;", + "class Test {", + " @DoNotLog", + " @JsonTypeInfo(", + " use = JsonTypeInfo.Id.NAME,", + " property = \"type\")", + " @JsonSubTypes(value = {", + " @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")", + " })", + " interface Base {}", + " @DoNotLog", + " class Dnl implements Base {}", + "}") + .doTest(); + } + + @Test + void testAddsAnnotation_jacksonSubTypes_subtypes_implicitArray() { + fix().addInputLines( + "Test.java", + "import com.fasterxml.jackson.annotation.*;", + "import com.palantir.logsafe.*;", + "class Test {", + " @JsonTypeInfo(", + " use = JsonTypeInfo.Id.NAME,", + " property = \"type\")", + " @JsonSubTypes(", + " @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")", + " )", + " interface Base {}", + " @DoNotLog", + " class Dnl implements Base {}", + "}") + .addOutputLines( + "Test.java", + "import com.fasterxml.jackson.annotation.*;", + "import com.palantir.logsafe.*;", + "class Test {", + " @DoNotLog", + " @JsonTypeInfo(", + " use = JsonTypeInfo.Id.NAME,", + " property = \"type\")", + " @JsonSubTypes(", + " @JsonSubTypes.Type(value = Dnl.class, name = \"dnl\")", + " )", + " interface Base {}", + " @DoNotLog", + " class Dnl implements Base {}", + "}") + .doTest(); + } + private RefactoringValidator fix(String... args) { return RefactoringValidator.of(SafeLoggingPropagation.class, getClass(), args); } diff --git a/changelog/@unreleased/pr-2703.v2.yml b/changelog/@unreleased/pr-2703.v2.yml new file mode 100644 index 000000000..824a45cd3 --- /dev/null +++ b/changelog/@unreleased/pr-2703.v2.yml @@ -0,0 +1,5 @@ +type: improvement +improvement: + description: Safety propagation takes into account known subtypes + links: + - https://github.com/palantir/gradle-baseline/pull/2703