diff --git a/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNull.java b/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNull.java index fd138df1e..aa11dc38d 100644 --- a/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNull.java +++ b/src/main/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNull.java @@ -44,26 +44,14 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { return Preconditions.check(new UsesMethod<>(ASSERT_FALSE), new JavaVisitor() { - - JavaParser.Builder javaParser = null; - - private JavaParser.Builder javaParser(ExecutionContext ctx) { - if (javaParser == null) { - javaParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "junit-jupiter-api-5.9"); - } - return javaParser; - } - @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); - if (ASSERT_FALSE.matches(mi) && isEqualBinary(mi)) { - StringBuilder sb = new StringBuilder(); - + if (ASSERT_FALSE.matches(mi) && isEqualBinaryWithNull(mi)) { J.Binary binary = (J.Binary) mi.getArguments().get(0); Expression nonNullExpression = getNonNullExpression(binary); + StringBuilder sb = new StringBuilder(); if (mi.getSelect() == null) { maybeRemoveImport("org.junit.jupiter.api.Assertions"); maybeAddImport("org.junit.jupiter.api.Assertions", "assertNotNull"); @@ -71,6 +59,7 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) sb.append("Assertions."); } sb.append("assertNotNull(#{any(java.lang.Object)}"); + Object[] args; if (mi.getArguments().size() == 2) { sb.append(", #{any()}"); @@ -79,41 +68,38 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) args = new J[]{nonNullExpression}; } sb.append(")"); - JavaTemplate t; if (mi.getSelect() == null) { t = JavaTemplate.builder(sb.toString()) .contextSensitive() .staticImports("org.junit.jupiter.api.Assertions.assertNotNull") - .javaParser(javaParser(ctx)) + .javaParser(JavaParser.fromJavaVersion() + .classpathFromResources(ctx, "junit-jupiter-api-5.9")) .build(); } else { t = JavaTemplate.builder(sb.toString()) .contextSensitive() .imports("org.junit.jupiter.api.Assertions") - .javaParser(javaParser(ctx)) + .javaParser(JavaParser.fromJavaVersion() + .classpathFromResources(ctx, "junit-jupiter-api-5.9")) .build(); } - return t.apply(updateCursor(mi), mi.getCoordinates().replace(), args); + return t.apply(updateCursor(mi), mi.getCoordinates().replace(), args); } return mi; } - private Expression getNonNullExpression(J.Binary binary) { - if (binary.getRight() instanceof J.Literal) { boolean isNull = ((J.Literal) binary.getRight()).getValue() == null; if (isNull) { return binary.getLeft(); } } - return binary.getRight(); } - private boolean isEqualBinary(J.MethodInvocation method) { - + private boolean isEqualBinaryWithNull(J.MethodInvocation method) { if (method.getArguments().isEmpty()) { return false; } @@ -124,10 +110,12 @@ private boolean isEqualBinary(J.MethodInvocation method) { } J.Binary binary = (J.Binary) firstArgument; - J.Binary.Type operator = binary.getOperator(); - return operator.equals(J.Binary.Type.Equal); + if (binary.getOperator() != J.Binary.Type.Equal) { + return false; + } + return binary.getLeft() instanceof J.Literal && ((J.Literal) binary.getLeft()).getValue() == null || + binary.getRight() instanceof J.Literal && ((J.Literal) binary.getRight()).getValue() == null; } }); } - } diff --git a/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNull.java b/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNull.java index a35046f44..5d725947b 100644 --- a/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNull.java +++ b/src/main/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNull.java @@ -44,22 +44,10 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { return Preconditions.check(new UsesMethod<>(ASSERT_TRUE), new JavaVisitor() { - - JavaParser.Builder javaParser = null; - - private JavaParser.Builder javaParser(ExecutionContext ctx) { - if (javaParser == null) { - javaParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "junit-jupiter-api-5.9"); - } - return javaParser; - } - - @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); - if (ASSERT_TRUE.matches(mi) && isEqualBinary(mi)) { + if (ASSERT_TRUE.matches(mi) && isEqualBinaryWithNull(mi)) { J.Binary binary = (J.Binary) mi.getArguments().get(0); Expression nonNullExpression = getNonNullExpression(binary); @@ -75,9 +63,9 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) Object[] args; if (mi.getArguments().size() == 2) { sb.append(", #{any()}"); - args = new Object[]{nonNullExpression, mi.getArguments().get(1)}; + args = new J[]{nonNullExpression, mi.getArguments().get(1)}; } else { - args = new Object[]{nonNullExpression}; + args = new J[]{nonNullExpression}; } sb.append(")"); JavaTemplate t; @@ -85,35 +73,33 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) t = JavaTemplate.builder(sb.toString()) .contextSensitive() .staticImports("org.junit.jupiter.api.Assertions.assertNull") - .javaParser(javaParser(ctx)) + .javaParser(JavaParser.fromJavaVersion() + .classpathFromResources(ctx, "junit-jupiter-api-5.9")) .build(); } else { t = JavaTemplate.builder(sb.toString()) .contextSensitive() .imports("org.junit.jupiter.api.Assertions") - .javaParser(javaParser(ctx)) + .javaParser(JavaParser.fromJavaVersion() + .classpathFromResources(ctx, "junit-jupiter-api-5.9")) .build(); } - return t.apply(updateCursor(mi), mi.getCoordinates().replace(), args); + return t.apply(updateCursor(mi), mi.getCoordinates().replace(), args); } return mi; } - private Expression getNonNullExpression(J.Binary binary) { - if (binary.getRight() instanceof J.Literal) { boolean isNull = ((J.Literal) binary.getRight()).getValue() == null; if (isNull) { return binary.getLeft(); } } - return binary.getRight(); } - private boolean isEqualBinary(J.MethodInvocation method) { - + private boolean isEqualBinaryWithNull(J.MethodInvocation method) { if (method.getArguments().isEmpty()) { return false; } @@ -124,8 +110,11 @@ private boolean isEqualBinary(J.MethodInvocation method) { } J.Binary binary = (J.Binary) firstArgument; - J.Binary.Type operator = binary.getOperator(); - return operator.equals(J.Binary.Type.Equal); + if (binary.getOperator() != J.Binary.Type.Equal) { + return false; + } + return binary.getLeft() instanceof J.Literal && ((J.Literal) binary.getLeft()).getValue() == null || + binary.getRight() instanceof J.Literal && ((J.Literal) binary.getRight()).getValue() == null; } }); } diff --git a/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNullTest.java b/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNullTest.java index 13512e0cf..531bb8cb9 100644 --- a/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNullTest.java +++ b/src/test/java/org/openrewrite/java/testing/cleanup/AssertFalseNullToAssertNotNullTest.java @@ -114,4 +114,23 @@ void test() { ) ); } + + @Test + void comparableComparedToZero() { + rewriteRun( + //language=java + java( + """ + import static org.junit.jupiter.api.Assertions.assertFalse; + + public class Test { + void test() { + Integer a = 0; + assertFalse(a.compareTo(0) == 0); + } + } + """ + ) + ); + } } diff --git a/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNullTest.java b/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNullTest.java index 4aa078f74..c6d34c4d5 100644 --- a/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNullTest.java +++ b/src/test/java/org/openrewrite/java/testing/cleanup/AssertTrueNullToAssertNullTest.java @@ -42,7 +42,7 @@ void simplifyToAssertNull() { java( """ import static org.junit.jupiter.api.Assertions.assertTrue; - + public class Test { void test() { String a = null; @@ -57,7 +57,7 @@ void test() { """, """ import static org.junit.jupiter.api.Assertions.assertNull; - + public class Test { void test() { String a = null; @@ -83,7 +83,7 @@ void preserveStyleOfStaticImportOrNot() { java( """ import org.junit.jupiter.api.Assertions; - + public class Test { void test() { String a = null; @@ -98,7 +98,7 @@ void test() { """, """ import org.junit.jupiter.api.Assertions; - + public class Test { void test() { String a = null; @@ -114,4 +114,23 @@ void test() { ) ); } + + @Test + void comparableComparedToZero() { + rewriteRun( + //language=java + java( + """ + import static org.junit.jupiter.api.Assertions.assertTrue; + + public class Test { + void test() { + Integer a = 0; + assertTrue(a.compareTo(0) == 0); + } + } + """ + ) + ); + } }