Skip to content

Commit

Permalink
Update AssertJ recipes to current recipe code style (#633)
Browse files Browse the repository at this point in the history
* Update AssertJ recipes to current recipe code style

* Apply formatter and remove unused imports

* Apply suggestions from code review

---------

Co-authored-by: Tim te Beek <[email protected]>
  • Loading branch information
jevanlingen and timtebeek authored Nov 5, 2024
1 parent 671addc commit 909f0eb
Show file tree
Hide file tree
Showing 15 changed files with 719 additions and 1,041 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesType;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
Expand All @@ -32,11 +32,14 @@
import java.util.List;

public class JUnitAssertArrayEqualsToAssertThat extends Recipe {
private static final String JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.junit.jupiter.api.Assertions";

private static final String JUNIT = "org.junit.jupiter.api.Assertions";
private static final String ASSERTJ = "org.assertj.core.api.Assertions";
private static final MethodMatcher ASSERT_ARRAY_EQUALS_MATCHER = new MethodMatcher(JUNIT + " assertArrayEquals(..)", true);

@Override
public String getDisplayName() {
return "JUnit `assertArrayEquals` To AssertJ";
return "JUnit `assertArrayEquals` to assertJ";
}

@Override
Expand All @@ -46,93 +49,72 @@ public String getDescription() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(new UsesType<>(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME, false), new AssertArrayEqualsToAssertThatVisitor());
}

public static class AssertArrayEqualsToAssertThatVisitor extends JavaIsoVisitor<ExecutionContext> {
private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertArrayEquals(..)");

private JavaParser.Builder<?, ?> assertionsParser;

private JavaParser.Builder<?, ?> assertionsParser(ExecutionContext ctx) {
if (assertionsParser == null) {
assertionsParser = JavaParser.fromJavaVersion()
.classpathFromResources(ctx, "assertj-core-3.24");
}
return assertionsParser;
}


@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
if (!JUNIT_ASSERT_EQUALS.matches(method)) {
return method;
}

List<Expression> args = method.getArguments();
Expression expected = args.get(0);
Expression actual = args.get(1);

// Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);

if (args.size() == 2) {
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()});")
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, expected);
} else if (args.size() == 3 && !isFloatingPointType(args.get(2))) {
Expression message = args.get(2);
JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()});") :
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()});");
return template
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, message, expected);
} else if (args.size() == 3) {
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
// assert is using floating points with a delta and no message.
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));")
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(assertionsParser(ctx))
return Preconditions.check(new UsesMethod<>(ASSERT_ARRAY_EQUALS_MATCHER), new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
J.MethodInvocation md = super.visitMethodInvocation(method, ctx);
if (!ASSERT_ARRAY_EQUALS_MATCHER.matches(md)) {
return md;
}

maybeAddImport(ASSERTJ, "assertThat", false);
maybeRemoveImport(JUNIT);

List<Expression> args = md.getArguments();
Expression expected = args.get(0);
Expression actual = args.get(1);
if (args.size() == 2) {
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()});")
.staticImports(ASSERTJ + ".assertThat")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), md.getCoordinates().replace(), actual, expected);
}
if (args.size() == 3 && isFloatingPointType(args.get(2))) {
maybeAddImport(ASSERTJ, "within", false);
// assert is using floating points with a delta and no message.
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));")
.staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), md.getCoordinates().replace(), actual, expected, args.get(2));
}
if (args.size() == 3) {
Expression message = args.get(2);
return JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any()}).containsExactly(#{anyArray()});")
.staticImports(ASSERTJ + ".assertThat")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), md.getCoordinates().replace(), actual, message, expected);
}

maybeAddImport(ASSERTJ, "within", false);

// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);
return JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any()}).containsExactly(#{anyArray()}, within(#{}));")
.staticImports(ASSERTJ + ".assertThat", ASSERTJ + ".within")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2));
.apply(getCursor(), md.getCoordinates().replace(), actual, message, expected, args.get(2));
}

// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);
maybeAddImport("org.assertj.core.api.Assertions", "within", false);

JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()}, within(#{any()}));") :
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()}, within(#{}));");
return template
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, message, expected, args.get(2));
}

/**
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
*
* @param expression The expression parsed from the original AST.
* @return true if the type is a floating point number.
*/
private static boolean isFloatingPointType(Expression expression) {

JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
if (fullyQualified != null) {
String typeName = fullyQualified.getFullyQualifiedName();
return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName);
/**
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
*
* @param expression The expression parsed from the original AST.
* @return true if the type is a floating point number.
*/
private boolean isFloatingPointType(Expression expression) {
JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
if (fullyQualified != null) {
String typeName = fullyQualified.getFullyQualifiedName();
return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName);
}

JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}

JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}
});
}
}
Loading

0 comments on commit 909f0eb

Please sign in to comment.