diff --git a/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java b/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java index e256c1ea..6c46b9b4 100644 --- a/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java +++ b/src/main/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCalls.java @@ -27,6 +27,8 @@ import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; +import java.util.List; + import static java.util.Collections.emptyList; public class SimplifyWebTestClientCalls extends Recipe { @@ -84,16 +86,65 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu } private int extractStatusCode(Expression expression) { - if (expression instanceof J.Literal) { + if (expression instanceof J.FieldAccess) { + //isEqualTo(HttpStatus.OK) + J.FieldAccess fa = (J.FieldAccess) expression; + if (fa.getTarget() instanceof J.Identifier) { + if ("HttpStatus".equals(((J.Identifier) fa.getTarget()).getSimpleName())) { + switch (fa.getSimpleName()) { + case "OK": + return 200; + case "CREATED": + return 201; + case "ACCEPTED": + return 202; + case "NO_CONTENT": + return 204; + case "FOUND": + return 302; + case "SEE_OTHER": + return 303; + case "NOT_MODIFIED": + return 304; + case "TEMPORARY_REDIRECT": + return 307; + case "PERMANENT_REDIRECT": + return 308; + case "BAD_REQUEST": + return 400; + case "UNAUTHORIZED": + return 401; + case "FORBIDDEN": + return 403; + case "NOT_FOUND": + return 404; + } + } + } + } else if (expression instanceof J.Literal) { + //isEqualTo(200) Object raw = ((J.Literal) expression).getValue(); if (raw instanceof Integer) { return (int) raw; } + } else if (expression instanceof J.MethodInvocation) { + //isEqualTo(HttpStatus.valueOf(200)) + //isEqualTo(HttpStatusCode.valueOf(200)) + J.MethodInvocation methodInvocation = (J.MethodInvocation) expression; + List arguments = methodInvocation.getArguments(); + if (arguments.size() == 1 && arguments.get(0) instanceof J.Literal) { + Object raw = ((J.Literal) arguments.get(0)).getValue(); + if (raw instanceof Integer) { + return (int) raw; + } + } } - return -1; // HttpStatus is not yet supported + return -1; } private J.MethodInvocation replaceMethod(J.MethodInvocation method, String methodName) { + maybeRemoveImport("org.springframework.http.HttpStatus"); + maybeRemoveImport("org.springframework.http.HttpStatusCode"); J.MethodInvocation methodInvocation = JavaTemplate.apply(methodName, getCursor(), method.getCoordinates().replaceMethod()); JavaType.Method type = methodInvocation .getMethodType() @@ -103,7 +154,6 @@ private J.MethodInvocation replaceMethod(J.MethodInvocation method, String metho .withArguments(emptyList()) .withMethodType(type) .withName(methodInvocation.getName().withType(type)); - } }); } diff --git a/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java b/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java index 7d64a2e2..14059211 100644 --- a/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java +++ b/src/test/java/org/openrewrite/java/spring/http/SimplifyWebTestClientCallsTest.java @@ -15,7 +15,6 @@ */ package org.openrewrite.java.spring.http; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -165,14 +164,14 @@ void someMethod() { } @Test - @Disabled("Yet to be implemented") - void usesIsOkForHttpStatus200() { + void usesIsOkForHttpStatusValueOf200() { rewriteRun( //language=java java( """ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.http.HttpStatus; + class Test { private final WebTestClient webClient = WebTestClient.bindToServer().build(); void someMethod() { @@ -187,6 +186,47 @@ void someMethod() { """, """ import org.springframework.test.web.reactive.server.WebTestClient; + + class Test { + private final WebTestClient webClient = WebTestClient.bindToServer().build(); + void someMethod() { + webClient + .post() + .uri("/some/value") + .exchange() + .expectStatus() + .isOk(); + } + } + """ + ) + ); + } + + @Test + void usesIsOkForHttpStatusValueCodeOf200() { + rewriteRun( + //language=java + java( + """ + import org.springframework.test.web.reactive.server.WebTestClient; + import org.springframework.http.HttpStatusCode; + + class Test { + private final WebTestClient webClient = WebTestClient.bindToServer().build(); + void someMethod() { + webClient + .post() + .uri("/some/value") + .exchange() + .expectStatus() + .isEqualTo(HttpStatusCode.valueOf(200)); + } + } + """, + """ + import org.springframework.test.web.reactive.server.WebTestClient; + class Test { private final WebTestClient webClient = WebTestClient.bindToServer().build(); void someMethod() { @@ -203,6 +243,61 @@ void someMethod() { ); } + @ParameterizedTest + @CsvSource({ + "OK,isOk()", + "CREATED,isCreated()", + "ACCEPTED,isAccepted()", + "NO_CONTENT,isNoContent()", + "FOUND,isFound()", + "SEE_OTHER,isSeeOther()", + "NOT_MODIFIED,isNotModified()", + "TEMPORARY_REDIRECT,isTemporaryRedirect()", + "PERMANENT_REDIRECT,isPermanentRedirect()", + "BAD_REQUEST,isBadRequest()", + "UNAUTHORIZED,isUnauthorized()", + "FORBIDDEN,isForbidden()", + "NOT_FOUND,isNotFound()" + }) + void usesIsOkForHttpStatusValue(String httpStatus, String method) { + rewriteRun( + //language=java + java( + """ + import org.springframework.test.web.reactive.server.WebTestClient; + import org.springframework.http.HttpStatus; + + class Test { + private final WebTestClient webClient = WebTestClient.bindToServer().build(); + void someMethod() { + webClient + .post() + .uri("/some/value") + .exchange() + .expectStatus() + .isEqualTo(HttpStatus.%s); + } + } + """.formatted(httpStatus), + """ + import org.springframework.test.web.reactive.server.WebTestClient; + + class Test { + private final WebTestClient webClient = WebTestClient.bindToServer().build(); + void someMethod() { + webClient + .post() + .uri("/some/value") + .exchange() + .expectStatus() + .%s; + } + } + """.formatted(method) + ) + ); + } + @Test void doesNotUseIsOkForHttpStatus300() { rewriteRun(