From 909f0ebbc5f5ad3f599476306efcd3af0476ff09 Mon Sep 17 00:00:00 2001 From: Jacob van Lingen Date: Tue, 5 Nov 2024 17:34:10 +0100 Subject: [PATCH] Update AssertJ recipes to current recipe code style (#633) * Update AssertJ recipes to current recipe code style * Apply formatter and remove unused imports * Apply suggestions from code review --------- Co-authored-by: Tim te Beek --- .../AdoptAssertJDurationAssertions.java | 369 +++++++++--------- .../JUnitAssertArrayEqualsToAssertThat.java | 156 ++++---- .../JUnitAssertEqualsToAssertThat.java | 156 +++----- .../assertj/JUnitAssertFalseToAssertThat.java | 85 ++-- .../JUnitAssertInstanceOfToAssertThat.java | 18 +- .../JUnitAssertNotEqualsToAssertThat.java | 163 +++----- .../JUnitAssertNotNullToAssertThat.java | 88 ++--- .../assertj/JUnitAssertNullToAssertThat.java | 86 ++-- .../assertj/JUnitAssertSameToAssertThat.java | 90 ++--- ...UnitAssertThrowsToAssertExceptionType.java | 70 ++-- .../assertj/JUnitAssertTrueToAssertThat.java | 87 ++--- .../assertj/JUnitFailToAssertJFail.java | 162 +++----- .../assertj/SimplifyAssertJAssertion.java | 54 ++- .../SimplifyChainedAssertJAssertion.java | 165 ++++---- ...AssertThrowsToAssertExceptionTypeTest.java | 11 +- 15 files changed, 719 insertions(+), 1041 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/assertj/AdoptAssertJDurationAssertions.java b/src/main/java/org/openrewrite/java/testing/assertj/AdoptAssertJDurationAssertions.java index ef26a84bd..b84ca6133 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/AdoptAssertJDurationAssertions.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/AdoptAssertJDurationAssertions.java @@ -32,17 +32,48 @@ import java.util.*; -public class AdoptAssertJDurationAssertions extends Recipe { - - static final String DURATION_ASSERT_HAS_LONG = "org.assertj.core.api.AbstractDurationAssert has*(long)"; +import static org.openrewrite.Preconditions.or; - static final String INTEGER_ASSERT_IS_EQUAL_TO = "org.assertj.core.api.AbstractIntegerAssert isEqualTo(..)"; - static final String INTEGER_ASSERT_IS_GREATER_THAN = "org.assertj.core.api.AbstractIntegerAssert isGreaterThan(..)"; - static final String INTEGER_ASSERT_IS_LESS_THAN = "org.assertj.core.api.AbstractIntegerAssert isLessThan(..)"; +public class AdoptAssertJDurationAssertions extends Recipe { - static final String LONG_ASSERT_IS_LESS_THAN = "org.assertj.core.api.AbstractLongAssert isLessThan(..)"; - static final String LONG_ASSERT_IS_GREATER_THAN = "org.assertj.core.api.AbstractLongAssert isGreaterThan(..)"; - static final String LONG_ASSERT_IS_EQUAL_TO = "org.assertj.core.api.AbstractLongAssert isEqualTo(..)"; + private static final String DURATION_ASSERT_HAS_LONG = "org.assertj.core.api.AbstractDurationAssert has*(long)"; + private static final String INTEGER_ASSERT_IS_EQUAL_TO = "org.assertj.core.api.AbstractIntegerAssert isEqualTo(..)"; + private static final String INTEGER_ASSERT_IS_GREATER_THAN = "org.assertj.core.api.AbstractIntegerAssert isGreaterThan(..)"; + private static final String INTEGER_ASSERT_IS_LESS_THAN = "org.assertj.core.api.AbstractIntegerAssert isLessThan(..)"; + private static final String LONG_ASSERT_IS_LESS_THAN = "org.assertj.core.api.AbstractLongAssert isLessThan(..)"; + private static final String LONG_ASSERT_IS_GREATER_THAN = "org.assertj.core.api.AbstractLongAssert isGreaterThan(..)"; + private static final String LONG_ASSERT_IS_EQUAL_TO = "org.assertj.core.api.AbstractLongAssert isEqualTo(..)"; + + private static final MethodMatcher ASSERT_THAT_MATCHER = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)"); + private static final MethodMatcher GET_NANO_MATCHER = new MethodMatcher("java.time.Duration getNano()"); + private static final MethodMatcher GET_SECONDS_MATCHER = new MethodMatcher("java.time.Duration getSeconds()"); + private static final MethodMatcher AS_MATCHER = new MethodMatcher("org.assertj.core.api.AbstractObjectAssert as(..)"); + private static final MethodMatcher TIME_UNIT_MATCHERS = new MethodMatcher(DURATION_ASSERT_HAS_LONG, true); + + private static final List IS_MATCHERS = Arrays.asList( + new MethodMatcher(INTEGER_ASSERT_IS_EQUAL_TO, true), + new MethodMatcher(INTEGER_ASSERT_IS_GREATER_THAN, true), + new MethodMatcher(INTEGER_ASSERT_IS_LESS_THAN, true), + + new MethodMatcher(LONG_ASSERT_IS_EQUAL_TO, true), + new MethodMatcher(LONG_ASSERT_IS_GREATER_THAN, true), + new MethodMatcher(LONG_ASSERT_IS_LESS_THAN, true) + ); + + private static final Map METHOD_MAP = new HashMap() {{ + put("getSeconds", "hasSeconds"); + put("getNano", "hasNanos"); + + put("hasNanos", "hasMillis"); + put("hasMillis", "hasSeconds"); + put("hasSeconds", "hasMinutes"); + put("hasMinutes", "hasHours"); + put("hasHours", "hasDays"); + + put("isGreaterThan", "isPositive"); + put("isLessThan", "isNegative"); + put("isEqualTo", "isZero"); + }}; @Override public String getDisplayName() { @@ -56,190 +87,156 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(Preconditions.or( + return Preconditions.check( + or( new UsesMethod<>(DURATION_ASSERT_HAS_LONG, true), - new UsesMethod<>(INTEGER_ASSERT_IS_EQUAL_TO, true), new UsesMethod<>(INTEGER_ASSERT_IS_GREATER_THAN, true), new UsesMethod<>(INTEGER_ASSERT_IS_LESS_THAN, true), - new UsesMethod<>(LONG_ASSERT_IS_EQUAL_TO, true), new UsesMethod<>(LONG_ASSERT_IS_GREATER_THAN, true), new UsesMethod<>(LONG_ASSERT_IS_LESS_THAN, true) - ), new AdoptAssertJDurationAssertionsVisitor() - ); - } - - @SuppressWarnings("DataFlowIssue") - private static class AdoptAssertJDurationAssertionsVisitor extends JavaIsoVisitor { - private static final MethodMatcher ASSERT_THAT_MATCHER = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)"); - private static final MethodMatcher GET_NANO_MATCHER = new MethodMatcher("java.time.Duration getNano()"); - private static final MethodMatcher GET_SECONDS_MATCHER = new MethodMatcher("java.time.Duration getSeconds()"); - private static final MethodMatcher AS_MATCHER = new MethodMatcher("org.assertj.core.api.AbstractObjectAssert as(..)"); - private static final MethodMatcher TIME_UNIT_MATCHERS = new MethodMatcher(DURATION_ASSERT_HAS_LONG, true); - private static final List IS_MATCHERS = Arrays.asList( - new MethodMatcher(INTEGER_ASSERT_IS_EQUAL_TO, true), - new MethodMatcher(INTEGER_ASSERT_IS_GREATER_THAN, true), - new MethodMatcher(INTEGER_ASSERT_IS_LESS_THAN, true), - - new MethodMatcher(LONG_ASSERT_IS_EQUAL_TO, true), - new MethodMatcher(LONG_ASSERT_IS_GREATER_THAN, true), - new MethodMatcher(LONG_ASSERT_IS_LESS_THAN, true) - ); - private static final Map METHOD_MAP = new HashMap() {{ - put("getSeconds", "hasSeconds"); - put("getNano", "hasNanos"); - - put("hasNanos", "hasMillis"); - put("hasMillis", "hasSeconds"); - put("hasSeconds", "hasMinutes"); - put("hasMinutes", "hasHours"); - put("hasHours", "hasDays"); - - put("isGreaterThan", "isPositive"); - put("isLessThan", "isNegative"); - put("isEqualTo", "isZero"); - }}; - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); - if (TIME_UNIT_MATCHERS.matches(mi)) { - return simplifyTimeUnits(mi, ctx); - } else if (IS_MATCHERS.stream().anyMatch(matcher -> matcher.matches(mi))) { - return simplifyMultipleAssertions(mi, ctx); - } - return mi; - } - - private J.MethodInvocation simplifyMultipleAssertions(J.MethodInvocation m, ExecutionContext ctx) { - Expression isEqualToArg = m.getArguments().get(0); - Expression select = m.getSelect(); - List templateParameters = new ArrayList<>(); - templateParameters.add(null); - Expression asDescription = null; - - if (AS_MATCHER.matches(select)) { - asDescription = ((J.MethodInvocation) select).getArguments().get(0); - select = ((J.MethodInvocation) select).getSelect(); - templateParameters.add(asDescription); - } - - if (!ASSERT_THAT_MATCHER.matches(select)) { - return m; - } - - Expression assertThatArgumentExpr = ((J.MethodInvocation) select).getArguments().get(0); - if (!(assertThatArgumentExpr instanceof J.MethodInvocation)) { - return m; - } - J.MethodInvocation assertThatArg = (J.MethodInvocation) assertThatArgumentExpr; - - if (isZero(isEqualToArg) && checkIfRelatedToDuration(assertThatArg)) { - String formatted_template = formatTemplate("assertThat(#{any()}).%s();", m.getSimpleName(), asDescription); - templateParameters.set(0, assertThatArg); - return applyTemplate(ctx, m, formatted_template, templateParameters.toArray()); - } - - if (GET_NANO_MATCHER.matches(assertThatArg) || GET_SECONDS_MATCHER.matches(assertThatArg)) { - Expression assertThatArgSelect = assertThatArg.getSelect(); - String methodName = assertThatArg.getSimpleName(); - String formatted_template = formatTemplate("assertThat(#{any()}).%s(#{any()});", methodName, asDescription); - templateParameters.set(0, assertThatArgSelect); - templateParameters.add(isEqualToArg); - - return applyTemplate(ctx, m, formatted_template, templateParameters.toArray()); - } - - return m; - } - - private boolean isZero(Expression isEqualToArg) { - if (isEqualToArg instanceof J.Literal) { - J.Literal literal = (J.Literal) isEqualToArg; - return literal.getValue() instanceof Number && ((Number) literal.getValue()).longValue() == 0; - } - return false; - } - - private J.MethodInvocation simplifyTimeUnits(J.MethodInvocation m, ExecutionContext ctx) { - Expression arg = m.getArguments().get(0); - Long argValue = SimplifyDurationCreationUnits.getConstantIntegralValue(arg); - if (argValue == null) { - return m; - } - - List unitInfo = getUnitInfo(m.getSimpleName(), Math.toIntExact(argValue)); - String methodName = (String) unitInfo.get(0); - int methodArg = (int) unitInfo.get(1); - if (!(m.getSimpleName().equals(methodName))) { - // update method invocation with new name and arg - String template = String.format("#{any()}.%s(%d)", methodName, methodArg); - return applyTemplate(ctx, m, template, m.getSelect()); - } - - return m; - } - - private static List getUnitInfo(String name, int argValue) { - final int timeLength; - if (name.equals("hasSeconds") || name.equals("hasMinutes")) { - timeLength = 60; - } else if (name.equals("hasNanos") || name.equals("hasMillis")) { - timeLength = 1000; - } else if (name.equals("hasHours")) { - timeLength = 24; - } else { - return Arrays.asList(name, argValue); - } - - if (argValue % timeLength == 0) { - String newName = METHOD_MAP.get(name); - return getUnitInfo(newName, argValue / timeLength); - } else { - // returning name, newArg - return Arrays.asList(name, argValue); - } - } - - private J.MethodInvocation applyTemplate(ExecutionContext ctx, J.MethodInvocation m, String template, Object... parameters) { - J.MethodInvocation invocation = JavaTemplate.builder(template) - .contextSensitive() - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) - .build() - .apply(getCursor(), m.getCoordinates().replace(), parameters); - - // retain whitespace formatting - if (invocation.getPadding().getSelect() != null && m.getPadding().getSelect() != null) { - return invocation.getPadding() - .withSelect( - invocation.getPadding().getSelect() - .withAfter(m.getPadding().getSelect().getAfter()) - ); - } - return invocation; - } - - private boolean checkIfRelatedToDuration(J.MethodInvocation argument) { - // assertThat(.).isEqual(0) - if (argument.getSelect() != null) { - if (argument.getSelect() instanceof J.MethodInvocation) { - J.MethodInvocation selectMethod = (J.MethodInvocation) argument.getSelect(); - return TypeUtils.isOfType(selectMethod.getType(), JavaType.buildType("java.time.Duration")); + ), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (TIME_UNIT_MATCHERS.matches(mi)) { + return simplifyTimeUnits(mi, ctx); + } else if (IS_MATCHERS.stream().anyMatch(matcher -> matcher.matches(mi))) { + return simplifyMultipleAssertions(mi, ctx); + } + return mi; + } + + private J.MethodInvocation simplifyMultipleAssertions(J.MethodInvocation m, ExecutionContext ctx) { + Expression isEqualToArg = m.getArguments().get(0); + Expression select = m.getSelect(); + List templateParameters = new ArrayList<>(); + templateParameters.add(null); + Expression asDescription = null; + + if (AS_MATCHER.matches(select)) { + asDescription = ((J.MethodInvocation) select).getArguments().get(0); + select = ((J.MethodInvocation) select).getSelect(); + templateParameters.add(asDescription); + } + + if (!ASSERT_THAT_MATCHER.matches(select)) { + return m; + } + + Expression assertThatArgumentExpr = ((J.MethodInvocation) select).getArguments().get(0); + if (!(assertThatArgumentExpr instanceof J.MethodInvocation)) { + return m; + } + J.MethodInvocation assertThatArg = (J.MethodInvocation) assertThatArgumentExpr; + + if (isZero(isEqualToArg) && checkIfRelatedToDuration(assertThatArg)) { + String formatted_template = formatTemplate("assertThat(#{any()}).%s();", m.getSimpleName(), asDescription); + templateParameters.set(0, assertThatArg); + return applyTemplate(ctx, m, formatted_template, templateParameters.toArray()); + } + + if (GET_NANO_MATCHER.matches(assertThatArg) || GET_SECONDS_MATCHER.matches(assertThatArg)) { + Expression assertThatArgSelect = assertThatArg.getSelect(); + String methodName = assertThatArg.getSimpleName(); + String formatted_template = formatTemplate("assertThat(#{any()}).%s(#{any()});", methodName, asDescription); + templateParameters.set(0, assertThatArgSelect); + templateParameters.add(isEqualToArg); + + return applyTemplate(ctx, m, formatted_template, templateParameters.toArray()); + } + + return m; + } + + private boolean isZero(Expression isEqualToArg) { + if (isEqualToArg instanceof J.Literal) { + J.Literal literal = (J.Literal) isEqualToArg; + return literal.getValue() instanceof Number && ((Number) literal.getValue()).longValue() == 0; + } + return false; + } + + private J.MethodInvocation simplifyTimeUnits(J.MethodInvocation m, ExecutionContext ctx) { + Expression arg = m.getArguments().get(0); + Long argValue = SimplifyDurationCreationUnits.getConstantIntegralValue(arg); + if (argValue == null) { + return m; + } + + List unitInfo = getUnitInfo(m.getSimpleName(), Math.toIntExact(argValue)); + String methodName = (String) unitInfo.get(0); + int methodArg = (int) unitInfo.get(1); + if (!(m.getSimpleName().equals(methodName))) { + // update method invocation with new name and arg + String template = String.format("#{any()}.%s(%d)", methodName, methodArg); + return applyTemplate(ctx, m, template, m.getSelect()); + } + + return m; + } + + private List getUnitInfo(String name, int argValue) { + final int timeLength; + if (name.equals("hasSeconds") || name.equals("hasMinutes")) { + timeLength = 60; + } else if (name.equals("hasNanos") || name.equals("hasMillis")) { + timeLength = 1000; + } else if (name.equals("hasHours")) { + timeLength = 24; + } else { + return Arrays.asList(name, argValue); + } + + if (argValue % timeLength == 0) { + String newName = METHOD_MAP.get(name); + return getUnitInfo(newName, argValue / timeLength); + } else { + // returning name, newArg + return Arrays.asList(name, argValue); + } + } + + private J.MethodInvocation applyTemplate(ExecutionContext ctx, J.MethodInvocation m, String template, Object... parameters) { + J.MethodInvocation invocation = JavaTemplate.builder(template) + .contextSensitive() + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), m.getCoordinates().replace(), parameters); + + // retain whitespace formatting + if (invocation.getPadding().getSelect() != null && m.getPadding().getSelect() != null) { + return invocation.getPadding() + .withSelect( + invocation.getPadding().getSelect() + .withAfter(m.getPadding().getSelect().getAfter()) + ); + } + return invocation; + } + + private boolean checkIfRelatedToDuration(J.MethodInvocation argument) { + if (argument.getSelect() != null) { + if (argument.getSelect() instanceof J.MethodInvocation) { + J.MethodInvocation selectMethod = (J.MethodInvocation) argument.getSelect(); + return TypeUtils.isOfType(selectMethod.getType(), JavaType.buildType("java.time.Duration")); + } + } + return false; + } + + @SuppressWarnings("ConstantValue") + private String formatTemplate(String template, String methodName, Object asDescriptionArg) { + String replacementMethod = METHOD_MAP.get(methodName); + if (asDescriptionArg == null) { + return String.format(template, replacementMethod); + } + StringBuilder newTemplate = new StringBuilder(template); + newTemplate.insert(newTemplate.indexOf(").") + 1, ".as(#{any()})"); + return String.format(newTemplate.toString(), replacementMethod); + } } - } - return false; - } - - @SuppressWarnings("ConstantValue") - private String formatTemplate(String template, String methodName, Object asDescriptionArg) { - String replacementMethod = METHOD_MAP.get(methodName); - if (asDescriptionArg == null) { - return String.format(template, replacementMethod); - } - StringBuilder newTemplate = new StringBuilder(template); - newTemplate.insert(newTemplate.indexOf(").") + 1, ".as(#{any()})"); - return String.format(newTemplate.toString(), replacementMethod); - } + ); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java index ecdde2479..2c636e98f 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java @@ -23,7 +23,7 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; @@ -32,11 +32,14 @@ import java.util.List; public class JUnitAssertArrayEqualsToAssertThat extends Recipe { - private static final String JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.junit.jupiter.api.Assertions"; + + private static final String JUNIT = "org.junit.jupiter.api.Assertions"; + private static final String ASSERTJ = "org.assertj.core.api.Assertions"; + private static final MethodMatcher ASSERT_ARRAY_EQUALS_MATCHER = new MethodMatcher(JUNIT + " assertArrayEquals(..)", true); @Override public String getDisplayName() { - return "JUnit `assertArrayEquals` To AssertJ"; + return "JUnit `assertArrayEquals` to assertJ"; } @Override @@ -46,93 +49,72 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME, false), new AssertArrayEqualsToAssertThatVisitor()); - } - - public static class AssertArrayEqualsToAssertThatVisitor extends JavaIsoVisitor { - private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertArrayEquals(..)"); - - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_EQUALS.matches(method)) { - return method; - } - - List args = method.getArguments(); - Expression expected = args.get(0); - Expression actual = args.get(1); - - // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME); - - if (args.size() == 2) { - return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()});") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply(getCursor(), method.getCoordinates().replace(), actual, expected); - } else if (args.size() == 3 && !isFloatingPointType(args.get(2))) { - Expression message = args.get(2); - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()});") : - JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()});"); - return template - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply(getCursor(), method.getCoordinates().replace(), actual, message, expected); - } else if (args.size() == 3) { - maybeAddImport("org.assertj.core.api.Assertions", "within", false); - // assert is using floating points with a delta and no message. - return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));") - .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(assertionsParser(ctx)) + return Preconditions.check(new UsesMethod<>(ASSERT_ARRAY_EQUALS_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation md = super.visitMethodInvocation(method, ctx); + if (!ASSERT_ARRAY_EQUALS_MATCHER.matches(md)) { + return md; + } + + maybeAddImport(ASSERTJ, "assertThat", false); + maybeRemoveImport(JUNIT); + + List args = md.getArguments(); + Expression expected = args.get(0); + Expression actual = args.get(1); + if (args.size() == 2) { + return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()});") + .staticImports(ASSERTJ + ".assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), md.getCoordinates().replace(), actual, expected); + } + if (args.size() == 3 && isFloatingPointType(args.get(2))) { + maybeAddImport(ASSERTJ, "within", false); + // assert is using floating points with a delta and no message. + return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));") + .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), md.getCoordinates().replace(), actual, expected, args.get(2)); + } + if (args.size() == 3) { + Expression message = args.get(2); + return JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any()}).containsExactly(#{anyArray()});") + .staticImports(ASSERTJ + ".assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), md.getCoordinates().replace(), actual, message, expected); + } + + maybeAddImport(ASSERTJ, "within", false); + + // The assertEquals is using a floating point with a delta argument and a message. + Expression message = args.get(3); + return JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any()}).containsExactly(#{anyArray()}, within(#{}));") + .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2)); + .apply(getCursor(), md.getCoordinates().replace(), actual, message, expected, args.get(2)); } - // The assertEquals is using a floating point with a delta argument and a message. - Expression message = args.get(3); - maybeAddImport("org.assertj.core.api.Assertions", "within", false); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()}, within(#{any()}));") : - JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()}, within(#{}));"); - return template - .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(assertionsParser(ctx)) - .build() - .apply(getCursor(), method.getCoordinates().replace(), actual, message, expected, args.get(2)); - } - - /** - * Returns true if the expression's type is either a primitive float/double or their object forms Float/Double - * - * @param expression The expression parsed from the original AST. - * @return true if the type is a floating point number. - */ - private static boolean isFloatingPointType(Expression expression) { - - JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType()); - if (fullyQualified != null) { - String typeName = fullyQualified.getFullyQualifiedName(); - return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName); + /** + * Returns true if the expression's type is either a primitive float/double or their object forms Float/Double + * + * @param expression The expression parsed from the original AST. + * @return true if the type is a floating point number. + */ + private boolean isFloatingPointType(Expression expression) { + JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType()); + if (fullyQualified != null) { + String typeName = fullyQualified.getFullyQualifiedName(); + return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName); + } + + JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); + return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; } - - JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); - return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java index 276c3870d..cbfad48c2 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java @@ -23,7 +23,7 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; @@ -33,6 +33,10 @@ public class JUnitAssertEqualsToAssertThat extends Recipe { + private static final String JUNIT = "org.junit.jupiter.api.Assertions"; + private static final String ASSERTJ = "org.assertj.core.api.Assertions"; + private static final MethodMatcher ASSERT_EQUALS_MATCHER = new MethodMatcher(JUNIT + " assertEquals(..)", true); + @Override public String getDisplayName() { return "JUnit `assertEquals` to AssertJ"; @@ -45,105 +49,67 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertEqualsToAssertThatVisitor()); - } - - public static class AssertEqualsToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertEquals(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_EQUALS.matches(method)) { - return method; - } - - List args = method.getArguments(); - Expression expected = args.get(0); - Expression actual = args.get(1); - - //always add the import (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); - - if (args.size() == 2) { - return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply(getCursor(), method.getCoordinates().replace(), actual, expected); - } else if (args.size() == 3 && !isFloatingPointType(args.get(2))) { - Expression message = args.get(2); - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isEqualTo(#{any()});") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isEqualTo(#{any()});"); - return template - .staticImports("org.assertj.core.api.Assertions.assertThat") + return Preconditions.check(new UsesMethod<>(ASSERT_EQUALS_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_EQUALS_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport(ASSERTJ, "assertThat", false); + maybeRemoveImport(JUNIT); + + List args = mi.getArguments(); + Expression expected = args.get(0); + Expression actual = args.get(1); + if (args.size() == 2) { + return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});") + .staticImports(ASSERTJ + ".assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, expected); + } + if (args.size() == 3 && !isFloatingPointType(args.get(2))) { + Expression message = args.get(2); + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isEqualTo(#{any()});") + .staticImports(ASSERTJ + ".assertThat") + .imports("java.util.function.Supplier") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected); + } + if (args.size() == 3) { + maybeAddImport(ASSERTJ, "within", false); + return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));") + .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, expected, args.get(2)); + } + + maybeAddImport(ASSERTJ, "within", false); + + // The assertEquals is using a floating point with a delta argument and a message. + Expression message = args.get(3); + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isCloseTo(#{any()}, within(#{any()}));") + .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") .imports("java.util.function.Supplier") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message, - expected - ); - } else if (args.size() == 3) { - //always add the import (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "within", false); - return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));") - .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2)); - + .apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected, args.get(2)); } - // The assertEquals is using a floating point with a delta argument and a message. - Expression message = args.get(3); - - //always add the import (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "within", false); - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isCloseTo(#{any()}, within(#{any()}));") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isCloseTo(#{any()}, within(#{any()}));"); - return template - .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .imports("java.util.function.Supplier") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message, - expected, - args.get(2) - ); - } + private boolean isFloatingPointType(Expression expression) { + JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType()); + if (fullyQualified != null) { + String typeName = fullyQualified.getFullyQualifiedName(); + return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName); + } - private static boolean isFloatingPointType(Expression expression) { - - JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType()); - if (fullyQualified != null) { - String typeName = fullyQualified.getFullyQualifiedName(); - return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName); + JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); + return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; } - - JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); - return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java index 149e5b3df..84d1462a1 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java @@ -23,15 +23,16 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.TypeUtils; import java.util.List; public class JUnitAssertFalseToAssertThat extends Recipe { + private static final MethodMatcher ASSERT_FALSE_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertFalse(boolean, ..)", true); + @Override public String getDisplayName() { return "JUnit `assertFalse` to AssertJ"; @@ -44,66 +45,34 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertFalseToAssertThatVisitor()); - } - - public static class AssertFalseToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_FALSE = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertFalse(boolean, ..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_FALSE.matches(method)) { - return method; - } + return Preconditions.check(new UsesMethod<>(ASSERT_FALSE_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_FALSE_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + maybeRemoveImport("org.junit.jupiter.api.Assertions"); + + List args = mi.getArguments(); + Expression actual = args.get(0); + if (args.size() == 1) { + return JavaTemplate.builder("assertThat(#{any(boolean)}).isFalse();") + .staticImports("org.assertj.core.api.Assertions.assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual); + } - List args = method.getArguments(); - Expression actual = args.get(0); - - if (args.size() == 1) { - method = JavaTemplate.builder("assertThat(#{any(boolean)}).isFalse();") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual - ); - } else { Expression message = args.get(1); - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any(boolean)}).as(#{any(String)}).isFalse();") : - JavaTemplate.builder("assertThat(#{any(boolean)}).as(#{any(java.util.function.Supplier)}).isFalse();"); - - method = template + return JavaTemplate.builder("assertThat(#{any(boolean)}).as(#{any()}).isFalse();") .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message - ); + .apply(getCursor(), mi.getCoordinates().replace(), actual, message); } - - //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); - - return method; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertInstanceOfToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertInstanceOfToAssertThat.java index 6a727a403..f5921f80c 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertInstanceOfToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertInstanceOfToAssertThat.java @@ -46,30 +46,30 @@ public TreeVisitor getVisitor() { return Preconditions.check(new UsesMethod<>(ASSERT_INSTANCE_OF_MATCHER), new JavaIsoVisitor() { @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - J.MethodInvocation md = super.visitMethodInvocation(method, ctx); - if (!ASSERT_INSTANCE_OF_MATCHER.matches(md)) { - return md; + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_INSTANCE_OF_MATCHER.matches(mi)) { + return mi; } maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); maybeRemoveImport("org.junit.jupiter.api.Assertions"); - Expression expectedType = md.getArguments().get(0); - Expression actualValue = md.getArguments().get(1); - if (md.getArguments().size() == 2) { + Expression expected = mi.getArguments().get(0); + Expression actual = mi.getArguments().get(1); + if (mi.getArguments().size() == 2) { return JavaTemplate.builder("assertThat(#{any()}).isInstanceOf(#{any()});") .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply(getCursor(), method.getCoordinates().replace(), actualValue, expectedType); + .apply(getCursor(), method.getCoordinates().replace(), actual, expected); } - Expression messageOrSupplier = md.getArguments().get(2); + Expression messageOrSupplier = mi.getArguments().get(2); return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isInstanceOf(#{any()});") .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply(getCursor(), method.getCoordinates().replace(), actualValue, messageOrSupplier, expectedType); + .apply(getCursor(), method.getCoordinates().replace(), actual, messageOrSupplier, expected); } }); } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java index d78506810..f8dfce9e7 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java @@ -23,7 +23,7 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; @@ -33,6 +33,10 @@ public class JUnitAssertNotEqualsToAssertThat extends Recipe { + private static final String JUNIT = "org.junit.jupiter.api.Assertions"; + private static final String ASSERTJ = "org.assertj.core.api.Assertions"; + private static final MethodMatcher ASSERT_NOT_EQUALS_MATCHER = new MethodMatcher(JUNIT + " assertNotEquals(..)", true); + @Override public String getDisplayName() { return "JUnit `assertNotEquals` to AssertJ"; @@ -45,117 +49,64 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertNotEqualsToAssertThatVisitor()); - } - - public static class AssertNotEqualsToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertNotEquals(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_EQUALS.matches(method)) { - return method; - } - - List args = method.getArguments(); - - Expression expected = args.get(0); - Expression actual = args.get(1); - - if (args.size() == 2) { - method = JavaTemplate.builder("assertThat(#{any()}).isNotEqualTo(#{any()});") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - expected - ); - } else if (args.size() == 3 && !isFloatingPointType(args.get(2))) { - Expression message = args.get(2); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isNotEqualTo(#{any()});") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotEqualTo(#{any()});"); - + return Preconditions.check(new UsesMethod<>(ASSERT_NOT_EQUALS_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_NOT_EQUALS_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport(ASSERTJ, "assertThat", false); + maybeRemoveImport(JUNIT); + + List args = mi.getArguments(); + Expression expected = args.get(0); + Expression actual = args.get(1); + if (args.size() == 2) { + return JavaTemplate.builder("assertThat(#{any()}).isNotEqualTo(#{any()});") + .staticImports(ASSERTJ + ".assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, expected); + } + if (args.size() == 3 && isFloatingPointType(args.get(2))) { + maybeAddImport(ASSERTJ, "within", false); + return JavaTemplate.builder("assertThat(#{any()}).isNotCloseTo(#{any()}, within(#{any()}));") + .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, expected, args.get(2)); + } + if (args.size() == 3) { + Expression message = args.get(2); + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isNotEqualTo(#{any()});") + .staticImports(ASSERTJ + ".assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected); + } + + maybeAddImport(ASSERTJ, "within", false); - method = template - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message, - expected - ); - } else if (args.size() == 3) { - method = JavaTemplate.builder("assertThat(#{any()}).isNotCloseTo(#{any()}, within(#{any()}));") - .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - expected, - args.get(2) - ); - maybeAddImport("org.assertj.core.api.Assertions", "within", false); - } else { Expression message = args.get(3); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isNotCloseTo(#{any()}, within(#{any()}));") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotCloseTo(#{any()}, within(#{any()}));"); - - method = template - .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(assertionsParser(ctx)) + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isNotCloseTo(#{any()}, within(#{any()}));") + .staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message, - expected, - args.get(2) - ); - - maybeAddImport("org.assertj.core.api.Assertions", "within", false); + .apply(getCursor(), method.getCoordinates().replace(), actual, message, expected, args.get(2)); } - //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); + private boolean isFloatingPointType(Expression expression) { + JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType()); + if (fullyQualified != null) { + String typeName = fullyQualified.getFullyQualifiedName(); + return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName); + } - return method; - } - - private static boolean isFloatingPointType(Expression expression) { - JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType()); - if (fullyQualified != null) { - String typeName = fullyQualified.getFullyQualifiedName(); - return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName); + JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); + return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; } - - JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType()); - return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java index dea4c6ffe..f6d4696ad 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java @@ -23,15 +23,16 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.TypeUtils; import java.util.List; public class JUnitAssertNotNullToAssertThat extends Recipe { + private static final MethodMatcher ASSERT_NOT_NULL_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertNotNull(..)", true); + @Override public String getDisplayName() { return "JUnit `assertNotNull` to AssertJ"; @@ -44,68 +45,35 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertNotNullToAssertThatVisitor()); - } - - public static class AssertNotNullToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_NOT_NULL_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertNotNull(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_NOT_NULL_MATCHER.matches(method)) { - return method; - } - - List args = method.getArguments(); - Expression actual = args.get(0); + return Preconditions.check(new UsesMethod<>(ASSERT_NOT_NULL_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_NOT_NULL_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + maybeRemoveImport("org.junit.jupiter.api.Assertions"); + + List args = mi.getArguments(); + Expression actual = args.get(0); + if (args.size() == 1) { + return JavaTemplate.builder("assertThat(#{any()}).isNotNull();") + .staticImports("org.assertj.core.api.Assertions.assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual); + + } - if (args.size() == 1) { - method = JavaTemplate.builder("assertThat(#{any()}).isNotNull();") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual - ); - - } else { Expression message = args.get(1); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isNotNull();") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotNull();"); - - method = template + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isNotNull();") .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message - ); + .apply(getCursor(), mi.getCoordinates().replace(), actual, message); } - - //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - //And if there are no longer references to the JUnit assertions class, we can remove the import. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); - - return method; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java index c2c916f20..7dc9586e8 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java @@ -23,15 +23,16 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.TypeUtils; import java.util.List; public class JUnitAssertNullToAssertThat extends Recipe { + private static final MethodMatcher ASSERT_NULL_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertNull(..)", true); + @Override public String getDisplayName() { return "JUnit `assertNull` to AssertJ"; @@ -44,67 +45,34 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertNullToAssertThatVisitor()); - } - - public static class AssertNullToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_NULL_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertNull(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_NULL_MATCHER.matches(method)) { - return method; - } + return Preconditions.check(new UsesMethod<>(ASSERT_NULL_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_NULL_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + maybeRemoveImport("org.junit.jupiter.api.Assertions"); + + List args = mi.getArguments(); + Expression actual = args.get(0); + if (args.size() == 1) { + return JavaTemplate.builder("assertThat(#{any()}).isNull();") + .staticImports("org.assertj.core.api.Assertions.assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual); + } - List args = method.getArguments(); - Expression actual = args.get(0); - - if (args.size() == 1) { - method = JavaTemplate.builder("assertThat(#{any()}).isNull();") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual - ); - } else { Expression message = args.get(1); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isNull();") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNull();"); - - method = template + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isNull();") .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message - ); + .apply(getCursor(), mi.getCoordinates().replace(), actual, message); } - - // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); - - return method; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java index 1241d70e0..88a3a584e 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java @@ -23,15 +23,16 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.TypeUtils; import java.util.List; public class JUnitAssertSameToAssertThat extends Recipe { + private static final MethodMatcher ASSERT_SAME_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertSame(..)", true); + @Override public String getDisplayName() { return "JUnit `assertSame` to AssertJ"; @@ -44,70 +45,35 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertSameToAssertThatVisitor()); - } - - public static class AssertSameToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_SAME_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertSame(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_SAME_MATCHER.matches(method)) { - return method; - } + return Preconditions.check(new UsesMethod<>(ASSERT_SAME_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_SAME_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + maybeRemoveImport("org.junit.jupiter.api.Assertions"); + + List args = mi.getArguments(); + Expression expected = args.get(0); + Expression actual = args.get(1); + if (args.size() == 2) { + return JavaTemplate.builder("assertThat(#{any()}).isSameAs(#{any()});") + .staticImports("org.assertj.core.api.Assertions.assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual, expected); + } - List args = method.getArguments(); - Expression expected = args.get(0); - Expression actual = args.get(1); - - if (args.size() == 2) { - method = JavaTemplate.builder("assertThat(#{any()}).isSameAs(#{any()});") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - expected - ); - } else { Expression message = args.get(2); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isSameAs(#{any()});") : - JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isSameAs(#{any()});"); - - method = template + return JavaTemplate.builder("assertThat(#{any()}).as(#{any()}).isSameAs(#{any()});") .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message, - expected - ); + .apply(getCursor(), mi.getCoordinates().replace(), actual, message, expected); } - - // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); - - return method; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java index 6a880f816..0960100d0 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java @@ -30,6 +30,9 @@ public class JUnitAssertThrowsToAssertExceptionType extends Recipe { + private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertThrows(..)"); + private static final JavaType THROWING_CALLABLE_TYPE = JavaType.buildType("org.assertj.core.api.ThrowableAssert.ThrowingCallable"); + @Override public String getDisplayName() { return "JUnit AssertThrows to AssertJ exceptionType"; @@ -42,47 +45,38 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesMethod<>("org.junit.jupiter.api.Assertions assertThrows(..)"), new AssertExceptionTypeVisitor()); - } - - private static class AssertExceptionTypeVisitor extends JavaIsoVisitor { - private static final MethodMatcher ASSERT_THROWS_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertThrows(..)"); - private static final JavaType THROWING_CALLABLE_TYPE = JavaType.buildType("org.assertj.core.api.ThrowableAssert.ThrowingCallable"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); - if (ASSERT_THROWS_MATCHER.matches(mi) && - mi.getArguments().size() == 2 && - getCursor().getParentTreeCursor().getValue() instanceof J.Block) { - J executable = mi.getArguments().get(1); - if (executable instanceof J.Lambda) { - executable = ((J.Lambda) executable).withType(THROWING_CALLABLE_TYPE); - } else if (executable instanceof J.MemberReference) { - executable = ((J.MemberReference) executable).withType(THROWING_CALLABLE_TYPE); - } else { - executable = null; - } + return Preconditions.check(new UsesMethod<>(ASSERT_THROWS_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (ASSERT_THROWS_MATCHER.matches(mi) && + mi.getArguments().size() == 2 && + getCursor().getParentTreeCursor().getValue() instanceof J.Block) { + J executable = mi.getArguments().get(1); + if (executable instanceof J.Lambda) { + executable = ((J.Lambda) executable).withType(THROWING_CALLABLE_TYPE); + } else if (executable instanceof J.MemberReference) { + executable = ((J.MemberReference) executable).withType(THROWING_CALLABLE_TYPE); + } else { + executable = null; + } - if (executable != null) { - mi = JavaTemplate - .builder("assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) - .staticImports("org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType") - .build() - .apply( - getCursor(), - mi.getCoordinates().replace(), - mi.getArguments().get(0), executable - ); - maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType", false); - maybeRemoveImport("org.junit.jupiter.api.Assertions.assertThrows"); - maybeRemoveImport("org.junit.jupiter.api.Assertions"); + if (executable != null) { + mi = JavaTemplate + .builder("assertThatExceptionOfType(#{any(java.lang.Class)}).isThrownBy(#{any(org.assertj.core.api.ThrowableAssert.ThrowingCallable)})") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .staticImports("org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType") + .build() + .apply(getCursor(), mi.getCoordinates().replace(), mi.getArguments().get(0), executable); + maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType", false); + maybeRemoveImport("org.junit.jupiter.api.Assertions.assertThrows"); + maybeRemoveImport("org.junit.jupiter.api.Assertions"); - doAfterVisit(new LambdaBlockToExpression().getVisitor()); + doAfterVisit(new LambdaBlockToExpression().getVisitor()); + } } + return mi; } - return mi; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java index b478716dc..8541a4f2c 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java @@ -23,14 +23,16 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.TypeUtils; import java.util.List; public class JUnitAssertTrueToAssertThat extends Recipe { + + private static final MethodMatcher ASSERT_TRUE_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertTrue(boolean, ..)"); + @Override public String getDisplayName() { return "JUnit `assertTrue` to AssertJ"; @@ -43,67 +45,34 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new AssertTrueToAssertThatVisitor()); - } - - public static class AssertTrueToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_TRUE = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertTrue(boolean, ..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!JUNIT_ASSERT_TRUE.matches(method)) { - return method; - } + return Preconditions.check(new UsesMethod<>(ASSERT_TRUE_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!ASSERT_TRUE_MATCHER.matches(mi)) { + return mi; + } + + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + maybeRemoveImport("org.junit.jupiter.api.Assertions"); + + List args = mi.getArguments(); + Expression actual = args.get(0); + if (args.size() == 1) { + return JavaTemplate.builder("assertThat(#{any(boolean)}).isTrue();") + .staticImports("org.assertj.core.api.Assertions.assertThat") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), actual); + } - List args = method.getArguments(); - Expression actual = args.get(0); - - if (args.size() == 1) { - method = JavaTemplate.builder("assertThat(#{any(boolean)}).isTrue();") - .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual - ); - } else { Expression message = args.get(1); - - JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? - JavaTemplate.builder("assertThat(#{any(boolean)}).as(#{any(String)}).isTrue();") : - JavaTemplate.builder("assertThat(#{any(boolean)}).as(#{any(java.util.function.Supplier)}).isTrue();"); - - method = template + return JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isTrue();") .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - actual, - message - ); + .apply(getCursor(), mi.getCoordinates().replace(), actual, message); } - - //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); - - // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. - maybeRemoveImport("org.junit.jupiter.api.Assertions"); - - return method; - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java index b93ee3c9e..aef7ddfde 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java @@ -20,14 +20,20 @@ import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.java.*; -import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.TypeUtils; +import java.util.Collections; import java.util.List; public class JUnitFailToAssertJFail extends Recipe { + + private static final String JUNIT = "org.junit.jupiter.api.Assertions"; + private static final String ASSERTJ = "org.assertj.core.api.Assertions"; + private static final MethodMatcher FAIL_MATCHER = new MethodMatcher(JUNIT + " fail(..)"); + @Override public String getDisplayName() { return "JUnit fail to AssertJ"; @@ -40,118 +46,70 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.junit.jupiter.api.Assertions", false), new JUnitFailToAssertJFailVisitor()); - } - - public static class JUnitFailToAssertJFailVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_FAIL_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions" + " fail(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - J.MethodInvocation m = method; - - if (!JUNIT_FAIL_MATCHER.matches(m)) { - return m; - } - - List args = m.getArguments(); + return Preconditions.check(new UsesMethod<>(FAIL_MATCHER), new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = method; + if (!FAIL_MATCHER.matches(mi)) { + return mi; + } - if (args.size() == 1) { - // fail(), fail(String), fail(Supplier), fail(Throwable) - if (args.get(0) instanceof J.Empty) { - m = JavaTemplate.builder("org.assertj.core.api.Assertions.fail(\"\");") - .javaParser(assertionsParser(ctx)) - .build() - .apply(getCursor(), m.getCoordinates().replace()); - } else if (args.get(0) instanceof J.Literal || - TypeUtils.isAssignableTo("java.lang.String", args.get(0).getType())) { - m = JavaTemplate.builder("org.assertj.core.api.Assertions.fail(#{any()});") - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - m.getCoordinates().replace(), - args.get(0) - ); + List args = mi.getArguments(); + if (args.size() == 1) { + // fail(), fail(String), fail(Supplier), fail(Throwable) + if (args.get(0) instanceof J.Empty) { + mi = JavaTemplate.builder(ASSERTJ + ".fail(\"\");") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace()); + } else if (args.get(0) instanceof J.Literal || + TypeUtils.isAssignableTo("java.lang.String", args.get(0).getType())) { + mi = JavaTemplate.builder(ASSERTJ + ".fail(#{any()});") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), args.get(0)); + } else { + mi = JavaTemplate.builder(ASSERTJ + ".fail(\"\", #{any()});") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), args.get(0)); + } } else { - m = JavaTemplate.builder("org.assertj.core.api.Assertions.fail(\"\", #{any()});") - .javaParser(assertionsParser(ctx)) + // fail(String, Throwable) + String anyArgs = String.join(",", Collections.nCopies(args.size(), "#{any()}")); + mi = JavaTemplate.builder(ASSERTJ + ".fail(" + anyArgs + ");") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() - .apply( - getCursor(), - m.getCoordinates().replace(), - args.get(0) - ); + .apply(getCursor(), mi.getCoordinates().replace(), args.toArray()); } - } else { - // fail(String, Throwable) - StringBuilder templateBuilder = new StringBuilder("org.assertj.core.api.Assertions.fail("); - for (int i = 0; i < args.size(); i++) { - templateBuilder.append("#{any()}"); - if (i < args.size() - 1) { - templateBuilder.append(", "); - } - } - templateBuilder.append(");"); - m = JavaTemplate.builder(templateBuilder.toString()) - .javaParser(assertionsParser(ctx)) - .build() - .apply( - getCursor(), - m.getCoordinates().replace(), - args.toArray() - ); + doAfterVisit(new RemoveUnusedImports().getVisitor()); + doAfterVisit(new UnqualifiedMethodInvocations()); + return mi; } - doAfterVisit(new RemoveUnusedImports().getVisitor()); - doAfterVisit(new UnqualifiedMethodInvocations()); - return m; - } + class UnqualifiedMethodInvocations extends JavaIsoVisitor { + private final MethodMatcher INTERNAL_FAIL_MATCHER = new MethodMatcher(ASSERTJ + " fail(..)"); - private static class UnqualifiedMethodInvocations extends JavaIsoVisitor { - private static final MethodMatcher ASSERTJ_FAIL_MATCHER = new MethodMatcher("org.assertj.core.api.Assertions" + " fail(..)"); + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); + if (!INTERNAL_FAIL_MATCHER.matches(mi)) { + return mi; + } - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - if (!ASSERTJ_FAIL_MATCHER.matches(method)) { - return method; - } + maybeAddImport(ASSERTJ, "fail", false); + maybeRemoveImport(JUNIT + ".fail"); - StringBuilder templateBuilder = new StringBuilder("fail("); - List arguments = method.getArguments(); - for (int i = 0; i < arguments.size(); i++) { - templateBuilder.append("#{any()}"); - if (i < arguments.size() - 1) { - templateBuilder.append(", "); - } + List arguments = mi.getArguments(); + String anyArgs = String.join(",", Collections.nCopies(arguments.size(), "#{any()}")); + return JavaTemplate.builder("fail(" + anyArgs + ");") + .staticImports(ASSERTJ + ".fail") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), arguments.toArray()); } - templateBuilder.append(");"); - - method = JavaTemplate.builder(templateBuilder.toString()) - .staticImports("org.assertj.core.api.Assertions" + ".fail") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) - .build() - .apply( - getCursor(), - method.getCoordinates().replace(), - arguments.toArray() - ); - //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) - maybeAddImport("org.assertj.core.api.Assertions", "fail", false); - maybeRemoveImport("org.junit.jupiter.api.Assertions.fail"); - return super.visitMethodInvocation(method, ctx); } - } + }); } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyAssertJAssertion.java b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyAssertJAssertion.java index 1d3685955..de1d6be55 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyAssertJAssertion.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyAssertJAssertion.java @@ -33,6 +33,8 @@ @NoArgsConstructor public class SimplifyAssertJAssertion extends Recipe { + private static final MethodMatcher ASSERT_THAT_MATCHER = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)"); + @Option(displayName = "AssertJ assertion", description = "The assertion method that should be replaced.", example = "hasSize", @@ -67,38 +69,34 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return new ShorthenChainedAssertJAssertionsVisitor(); - } - - private class ShorthenChainedAssertJAssertionsVisitor extends JavaIsoVisitor { - private final MethodMatcher ASSERT_THAT_MATCHER = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)"); - private final MethodMatcher ASSERT_TO_REPLACE = new MethodMatcher("org.assertj.core.api.* " + assertToReplace + "(..)"); + final MethodMatcher assertToReplace = new MethodMatcher("org.assertj.core.api.* " + this.assertToReplace + "(..)"); + return new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(method, ctx); - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) { - J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx); + // Match the end of the chain first, then the select to avoid matching the wrong method chain + if (!assertToReplace.matches(mi) || !ASSERT_THAT_MATCHER.matches(mi.getSelect())) { + return mi; + } - // Match the end of the chain first, then the select to avoid matching the wrong method chain - if (!ASSERT_TO_REPLACE.matches(mi) || !ASSERT_THAT_MATCHER.matches(mi.getSelect())) { - return mi; - } + // Compare argument with passed in literal + if (!(mi.getArguments().get(0) instanceof J.Literal) || + !literalArgument.equals(((J.Literal) mi.getArguments().get(0)).getValueSource())) { // Implies "null" is `null` + return mi; + } - // Compare argument with passed in literal - if (!(mi.getArguments().get(0) instanceof J.Literal) || - !literalArgument.equals(((J.Literal) mi.getArguments().get(0)).getValueSource())) { // Implies "null" is `null` - return mi; - } + // Check argument type of assertThat + if (!TypeUtils.isAssignableTo(requiredType, ((J.MethodInvocation) mi.getSelect()).getArguments().get(0).getType())) { + return mi; + } - // Check argument type of assertThat - if (!TypeUtils.isAssignableTo(requiredType, ((J.MethodInvocation) mi.getSelect()).getArguments().get(0).getType())) { - return mi; + // Assume zero argument replacement method + return JavaTemplate.builder("#{any()}." + dedicatedAssertion + "()") + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), mi.getSelect()); } - - // Assume zero argument replacement method - return JavaTemplate.builder("#{any()}." + dedicatedAssertion + "()") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) - .build() - .apply(getCursor(), mi.getCoordinates().replace(), mi.getSelect()); - } + }; } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java index ba431a43c..e6d37912e 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java @@ -35,6 +35,7 @@ @AllArgsConstructor @NoArgsConstructor public class SimplifyChainedAssertJAssertion extends Recipe { + @Option(displayName = "AssertJ chained assertion", description = "The chained AssertJ assertion to move to dedicated assertion.", example = "equals", @@ -81,95 +82,89 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return new SimplifyChainedAssertJAssertionsVisitor(); - } - - private class SimplifyChainedAssertJAssertionsVisitor extends JavaIsoVisitor { - private final MethodMatcher ASSERT_THAT_MATCHER = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)"); - private final MethodMatcher CHAINED_ASSERT_MATCHER = new MethodMatcher("java..* " + chainedAssertion + "(..)"); - private final MethodMatcher ASSERT_TO_REPLACE = new MethodMatcher("org.assertj.core.api.* " + assertToReplace + "(..)"); - - @Override - public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) { - J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx); - - // assert has correct assertion - if (!ASSERT_TO_REPLACE.matches(mi) || mi.getArguments().size() != 1) { - return mi; + MethodMatcher assertThatMatcher = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)"); + MethodMatcher chainedAssertMatcher = new MethodMatcher("java..* " + chainedAssertion + "(..)"); + MethodMatcher assertToReplace = new MethodMatcher("org.assertj.core.api.* " + this.assertToReplace + "(..)"); + + return new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocation, ExecutionContext ctx) { + J.MethodInvocation mi = super.visitMethodInvocation(methodInvocation, ctx); + + // assert has correct assertion + if (!assertToReplace.matches(mi) || mi.getArguments().size() != 1) { + return mi; + } + + // assertThat has method call + J.MethodInvocation assertThat = (J.MethodInvocation) mi.getSelect(); + if (!assertThatMatcher.matches(assertThat) || !(assertThat.getArguments().get(0) instanceof J.MethodInvocation)) { + return mi; + } + + J.MethodInvocation assertThatArg = (J.MethodInvocation) assertThat.getArguments().get(0); + if (!chainedAssertMatcher.matches(assertThatArg)) { + return mi; + } + + // Extract the actual argument for the new assertThat call + Expression actual = assertThatArg.getSelect() != null ? assertThatArg.getSelect() : assertThatArg; + if (!TypeUtils.isAssignableTo(requiredType, actual.getType())) { + return mi; + } + List arguments = new ArrayList<>(); + arguments.add(actual); + + String template = getStringTemplateAndAppendArguments(assertThatArg, mi, arguments); + return JavaTemplate.builder(String.format(template, dedicatedAssertion)) + .contextSensitive() + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5.9", "assertj-core-3.24")) + .build() + .apply(getCursor(), mi.getCoordinates().replace(), arguments.toArray()); } - // assertThat has method call - J.MethodInvocation assertThat = (J.MethodInvocation) mi.getSelect(); - if (!ASSERT_THAT_MATCHER.matches(assertThat) || !(assertThat.getArguments().get(0) instanceof J.MethodInvocation)) { - return mi; + private String getStringTemplateAndAppendArguments(J.MethodInvocation assertThatArg, J.MethodInvocation methodToReplace, List arguments) { + Expression assertThatArgument = assertThatArg.getArguments().get(0); + Expression methodToReplaceArgument = methodToReplace.getArguments().get(0); + boolean assertThatArgumentIsEmpty = assertThatArgument instanceof J.Empty; + boolean methodToReplaceArgumentIsEmpty = methodToReplaceArgument instanceof J.Empty; + + // If both arguments are empty, then the select is already added to the arguments list, and we use a minimal template + if (assertThatArgumentIsEmpty && methodToReplaceArgumentIsEmpty) { + return "assertThat(#{any()}).%s()"; + } + + // If both arguments are not empty, then we add both to the arguments to the arguments list, and return a template with two arguments + if (!assertThatArgumentIsEmpty && !methodToReplaceArgumentIsEmpty) { + // This should only happen for map assertions using a key and value + arguments.add(assertThatArgument); + arguments.add(methodToReplaceArgument); + return "assertThat(#{any()}).%s(#{any()}, #{any()})"; + } + + // If either argument is empty, we choose which one to add to the arguments list, and optionally extract the select + arguments.add(extractEitherArgument(assertThatArgumentIsEmpty, assertThatArgument, methodToReplaceArgument)); + + // Special case for Path.of() assertions + if ("java.nio.file.Path".equals(requiredType) && dedicatedAssertion.contains("Raw") && + TypeUtils.isAssignableTo("java.lang.String", assertThatArgument.getType())) { + maybeAddImport("java.nio.file.Path"); + return "assertThat(#{any()}).%s(Path.of(#{any()}))"; + } + + return "assertThat(#{any()}).%s(#{any()})"; } - J.MethodInvocation assertThatArg = (J.MethodInvocation) assertThat.getArguments().get(0); - if (!CHAINED_ASSERT_MATCHER.matches(assertThatArg)) { - return mi; - } - - // Extract the actual argument for the new assertThat call - Expression actual = assertThatArg.getSelect() != null ? assertThatArg.getSelect() : assertThatArg; - if (!TypeUtils.isAssignableTo(requiredType, actual.getType())) { - return mi; - } - List arguments = new ArrayList<>(); - arguments.add(actual); - - String template = getStringTemplateAndAppendArguments(assertThatArg, mi, arguments); - return applyTemplate(String.format(template, dedicatedAssertion), arguments, mi, ctx); - } - - private J.MethodInvocation applyTemplate(String formattedTemplate, List arguments, J.MethodInvocation mi, ExecutionContext ctx) { - return JavaTemplate.builder(formattedTemplate) - .contextSensitive() - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5.9", "assertj-core-3.24")) - .build() - .apply(getCursor(), mi.getCoordinates().replace(), arguments.toArray()); - } - - private String getStringTemplateAndAppendArguments(J.MethodInvocation assertThatArg, J.MethodInvocation methodToReplace, List arguments) { - Expression assertThatArgument = assertThatArg.getArguments().get(0); - Expression methodToReplaceArgument = methodToReplace.getArguments().get(0); - boolean assertThatArgumentIsEmpty = assertThatArgument instanceof J.Empty; - boolean methodToReplaceArgumentIsEmpty = methodToReplaceArgument instanceof J.Empty; - - // If both arguments are empty, then the select is already added to the arguments list, and we use a minimal template - if (assertThatArgumentIsEmpty && methodToReplaceArgumentIsEmpty) { - return "assertThat(#{any()}).%s()"; - } - - // If both arguments are not empty, then we add both to the arguments to the arguments list, and return a template with two arguments - if (!assertThatArgumentIsEmpty && !methodToReplaceArgumentIsEmpty) { - // This should only happen for map assertions using a key and value - arguments.add(assertThatArgument); - arguments.add(methodToReplaceArgument); - return "assertThat(#{any()}).%s(#{any()}, #{any()})"; - } - - // If either argument is empty, we choose which one to add to the arguments list, and optionally extract the select - arguments.add(extractEitherArgument(assertThatArgumentIsEmpty, assertThatArgument, methodToReplaceArgument)); - - // Special case for Path.of() assertions - if ("java.nio.file.Path".equals(requiredType) && dedicatedAssertion.contains("Raw") && - TypeUtils.isAssignableTo("java.lang.String", assertThatArgument.getType())) { - maybeAddImport("java.nio.file.Path"); - return "assertThat(#{any()}).%s(Path.of(#{any()}))"; - } - - return "assertThat(#{any()}).%s(#{any()})"; - } - - private Expression extractEitherArgument(boolean assertThatArgumentIsEmpty, Expression assertThatArgument, Expression methodToReplaceArgument) { - if (assertThatArgumentIsEmpty) { - return methodToReplaceArgument; - } - // Only on the assertThat argument do we possibly replace the argument with the select; such as list.size() -> list - if (CHAINED_ASSERT_MATCHER.matches(assertThatArgument)) { - return Objects.requireNonNull(((J.MethodInvocation) assertThatArgument).getSelect()); + private Expression extractEitherArgument(boolean assertThatArgumentIsEmpty, Expression assertThatArgument, Expression methodToReplaceArgument) { + if (assertThatArgumentIsEmpty) { + return methodToReplaceArgument; + } + // Only on the assertThat argument do we possibly replace the argument with the select; such as list.size() -> list + if (chainedAssertMatcher.matches(assertThatArgument)) { + return Objects.requireNonNull(((J.MethodInvocation) assertThatArgument).getSelect()); + } + return assertThatArgument; } - return assertThatArgument; - } + }; } } diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java index 1b5e7fd24..85b9f6a01 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionTypeTest.java @@ -44,12 +44,10 @@ void toAssertExceptionOfType() { java( """ import static org.junit.jupiter.api.Assertions.assertThrows; - + public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { - assertThrows(NullPointerException.class, () -> { - foo(); - }); + assertThrows(NullPointerException.class, () -> foo()); } void foo() { throw new NullPointerException(); @@ -58,11 +56,10 @@ void foo() { """, """ import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; - + public class SimpleExpectedExceptionTest { public void throwsExceptionWithSpecificType() { - assertThatExceptionOfType(NullPointerException.class).isThrownBy(() -> - foo()); + assertThatExceptionOfType(NullPointerException.class).isThrownBy(() -> foo()); } void foo() { throw new NullPointerException();