diff --git a/build.gradle.kts b/build.gradle.kts index 66255d735..fbf42e1b5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -158,6 +158,7 @@ dependencies { // for generating properties migration configurations testImplementation("io.github.classgraph:classgraph:latest.release") testImplementation("org.openrewrite:rewrite-java-17") + testImplementation("org.openrewrite:rewrite-kotlin:$rewriteVersion") testImplementation("org.openrewrite.recipe:rewrite-migrate-java:$rewriteVersion") testImplementation("org.openrewrite.recipe:rewrite-testing-frameworks:$rewriteVersion") diff --git a/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java b/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java index b129bef22..e256c1ea8 100644 --- a/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java +++ b/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java @@ -23,6 +23,7 @@ import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.search.UsesMethod; +import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; @@ -30,7 +31,7 @@ public class SimplifyWebTestClientCalls extends Recipe { - private static final MethodMatcher IS_EQUAL_TO_INT_MATCHER = new MethodMatcher("org.springframework.test.web.reactive.server.StatusAssertions isEqualTo(int)"); + private static final MethodMatcher IS_EQUAL_TO_MATCHER = new MethodMatcher("org.springframework.test.web.reactive.server.StatusAssertions isEqualTo(..)"); @Override public String getDisplayName() { @@ -44,12 +45,12 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesMethod<>(IS_EQUAL_TO_INT_MATCHER), new JavaIsoVisitor() { + return Preconditions.check(new UsesMethod<>(IS_EQUAL_TO_MATCHER), new JavaIsoVisitor() { @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { J.MethodInvocation m = super.visitMethodInvocation(method, ctx); - if (IS_EQUAL_TO_INT_MATCHER.matches(m.getMethodType())) { - int statusCode = (int) ((J.Literal) m.getArguments().get(0)).getValue(); + if (IS_EQUAL_TO_MATCHER.matches(m.getMethodType())) { + final int statusCode = extractStatusCode(m.getArguments().get(0)); switch (statusCode) { case 200: return replaceMethod(m, "isOk()"); @@ -82,11 +83,26 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu return m; } + private int extractStatusCode(Expression expression) { + if (expression instanceof J.Literal) { + Object raw = ((J.Literal) expression).getValue(); + if (raw instanceof Integer) { + return (int) raw; + } + } + return -1; // HttpStatus is not yet supported + } + private J.MethodInvocation replaceMethod(J.MethodInvocation method, String methodName) { - JavaTemplate template = JavaTemplate.builder(methodName).build(); - J.MethodInvocation methodInvocation = template.apply(getCursor(), method.getCoordinates().replaceMethod()); - JavaType.Method type = methodInvocation.getMethodType().withParameterNames(emptyList()).withParameterTypes(emptyList()); - return methodInvocation.withArguments(emptyList()).withMethodType(type).withName(methodInvocation.getName().withType(type)); + J.MethodInvocation methodInvocation = JavaTemplate.apply(methodName, getCursor(), method.getCoordinates().replaceMethod()); + JavaType.Method type = methodInvocation + .getMethodType() + .withParameterNames(emptyList()) + .withParameterTypes(emptyList()); + return methodInvocation + .withArguments(emptyList()) + .withMethodType(type) + .withName(methodInvocation.getName().withType(type)); } }); diff --git a/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java b/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java index b447146d7..7d64a2e28 100644 --- a/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java +++ b/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java @@ -22,10 +22,12 @@ import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; import org.openrewrite.java.JavaParser; +import org.openrewrite.kotlin.KotlinParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.kotlin.Assertions.kotlin; class SimplifyWebTestClientCallsTest implements RewriteTest { @@ -34,25 +36,27 @@ public void defaults(RecipeSpec spec) { spec .recipe(new SimplifyWebTestClientCalls()) .parser(JavaParser.fromJavaVersion() + .classpathFromResources(new InMemoryExecutionContext(), "spring-web-6", "spring-test-6")) + .parser(KotlinParser.builder() .classpathFromResources(new InMemoryExecutionContext(), "spring-web-6", "spring-test-6")); } @DocumentExample @ParameterizedTest @CsvSource({ - "200,isOk()", - "201,isCreated()", - "202,isAccepted()", - "204,isNoContent()", - "302,isFound()", - "303,isSeeOther()", - "304,isNotModified()", - "307,isTemporaryRedirect()", - "308,isPermanentRedirect()", - "400,isBadRequest()", - "401,isUnauthorized()", - "403,isForbidden()", - "404,isNotFound()" + "200,isOk()", + "201,isCreated()", + "202,isAccepted()", + "204,isNoContent()", + "302,isFound()", + "303,isSeeOther()", + "304,isNotModified()", + "307,isTemporaryRedirect()", + "308,isPermanentRedirect()", + "400,isBadRequest()", + "401,isUnauthorized()", + "403,isForbidden()", + "404,isNotFound()" }) void replacesAllIntStatusCodes(String httpStatus, String method) { rewriteRun( @@ -94,6 +98,47 @@ void someMethod() { ); } + @Test + void replaceKotlinInt() { + rewriteRun( + //language=kotlin + kotlin( + """ + import org.springframework.test.web.reactive.server.WebTestClient + + class Test { + val webClient: WebTestClient = WebTestClient.bindToServer().build() + fun someMethod() { + webClient + .post() + .uri("/some/url") + .bodyValue("someValue") + .exchange() + .expectStatus() + .isEqualTo(200) + } + } + """, + """ + import org.springframework.test.web.reactive.server.WebTestClient + + class Test { + val webClient: WebTestClient = WebTestClient.bindToServer().build() + fun someMethod() { + webClient + .post() + .uri("/some/url") + .bodyValue("someValue") + .exchange() + .expectStatus() + .isOk() + } + } + """ + ) + ); + } + @Test void doesNotReplaceUnspecificStatusCode() { rewriteRun(