From 0beae056acd7765df13c63b649bd33dd57fd527f Mon Sep 17 00:00:00 2001 From: Gabriel Roldan Date: Sat, 20 Jul 2024 16:44:26 -0300 Subject: [PATCH] Preserve response headers when redirecting application error to gateway error pages Commit 37ff94b9 make the `ApplicationError` Gateway filter lose the original response headers when throwing a `ResponseStatusException` for the Gateway to show up the customized HTML error pages instead of the orignal (usually whitelabel) errors. This patch makes it so that the `ApplicationError` filter runs only when `text/html` is accepted by the request, and the request method is idempotent (e.g. GET, HEAD, etc.). Additionally, the original response headers are not lost, since the exception is thrown at `ServerHttpResponseDecorator.beforeCommit()`, and respecting the reactive chain. --- .../ApplicationErrorGatewayFilterFactory.java | 79 +++-- ...licationErrorGatewayFilterFactoryTest.java | 271 +++++++++++++----- gateway/src/test/resources/logback-test.xml | 3 +- 3 files changed, 263 insertions(+), 90 deletions(-) diff --git a/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java b/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java index 15aa2d0d..e56fa7ac 100644 --- a/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java +++ b/gateway/src/main/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactory.java @@ -18,19 +18,19 @@ */ package org.georchestra.gateway.filter.global; -import java.net.URI; +import java.util.function.Supplier; import org.springframework.cloud.gateway.filter.GatewayFilter; import org.springframework.cloud.gateway.filter.GatewayFilterChain; import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory; import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory; -import org.springframework.cloud.gateway.support.HttpStatusHolder; -import org.springframework.cloud.gateway.support.ServerWebExchangeUtils; import org.springframework.core.Ordered; +import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponseDecorator; -import org.springframework.lang.Nullable; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; @@ -39,7 +39,8 @@ /** * Filter to allow custom error pages to be used when an application behind the - * gateways returns an error. + * gateways returns an error, only for idempotent HTTP response status codes + * (i.e. GET, HEAD, OPTIONS). *

* {@link GatewayFilterFactory} providing a {@link GatewayFilter} that throws a * {@link ResponseStatusException} with the proxied response status code if the @@ -80,29 +81,59 @@ public GatewayFilter apply(final Object config) { return new ServiceErrorGatewayFilter(); } - private static class ServiceErrorGatewayFilter implements GatewayFilter, Ordered { - - public @Override Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { - - ApplicationErrorConveyorHttpResponse response; - response = new ApplicationErrorConveyorHttpResponse(exchange.getResponse()); - - exchange = exchange.mutate().response(response).build(); - return chain.filter(exchange); + private class ServiceErrorGatewayFilter implements GatewayFilter, Ordered { + /** + * @return {@link Ordered#HIGHEST_PRECEDENCE} or + * {@link ApplicationErrorConveyorHttpResponse#beforeCommit(Supplier)} + * won't be called + */ + @Override + public int getOrder() { + return Ordered.HIGHEST_PRECEDENCE; } + /** + * If the request method is idempotent and accepts {@literal text/html}, applies + * a filter that when the routed response receives an error status code, will + * throw a {@link ResponseStatusException} with the same status, for the gateway + * to apply the customized error template, also when the status code comes from + * a proxied service response + */ @Override - public int getOrder() { - return ResolveTargetGlobalFilter.ORDER + 1; + public Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { + if (canFilter(exchange.getRequest())) { + exchange = decorate(exchange); + } + return chain.filter(exchange); } + } + + ServerWebExchange decorate(ServerWebExchange exchange) { + var response = new ApplicationErrorConveyorHttpResponse(exchange.getResponse()); + exchange = exchange.mutate().response(response).build(); + return exchange; + } + + boolean canFilter(ServerHttpRequest request) { + return methodIsIdempotent(request.getMethod()) && acceptsHtml(request); + } + + boolean methodIsIdempotent(HttpMethod method) { + return switch (method) { + case GET, HEAD, OPTIONS, TRACE -> true; + default -> false; + }; + } + boolean acceptsHtml(ServerHttpRequest request) { + return request.getHeaders().getAccept().stream().anyMatch(MediaType.TEXT_HTML::isCompatibleWith); } /** * A response decorator that throws a {@link ResponseStatusException} at - * {@link #setStatusCode(HttpStatus)} if the status code is an error code, thus - * letting the gateway render the appropriate custom error page instead of the - * original application response body. + * {@link #beforeCommit} if the status code is an error code, thus letting the + * gateway render the appropriate custom error page instead of the original + * application response body. */ private static class ApplicationErrorConveyorHttpResponse extends ServerHttpResponseDecorator { @@ -111,12 +142,14 @@ public ApplicationErrorConveyorHttpResponse(ServerHttpResponse delegate) { } @Override - public boolean setStatusCode(@Nullable HttpStatus status) { - checkStatusCode(status); - return super.setStatusCode(status); + public void beforeCommit(Supplier> action) { + Mono checkStatus = Mono.fromRunnable(this::checkStatusCode); + Mono checkedAction = checkStatus.then(Mono.fromRunnable(action::get)); + super.beforeCommit(() -> checkedAction); } - private void checkStatusCode(HttpStatus statusCode) { + private void checkStatusCode() { + HttpStatus statusCode = getStatusCode(); log.debug("native status code: {}", statusCode); if (statusCode.is4xxClientError() || statusCode.is5xxServerError()) { log.debug("Conveying {} response status", statusCode); diff --git a/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java b/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java index a2061c96..a4d18bf0 100644 --- a/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java +++ b/gateway/src/test/java/org/georchestra/gateway/filter/global/ApplicationErrorGatewayFilterFactoryTest.java @@ -18,104 +18,243 @@ */ package org.georchestra.gateway.filter.global; +import static com.github.tomakehurst.wiremock.stubbing.StubMapping.buildFrom; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; -import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR; -import java.net.URI; -import java.util.List; +import java.util.Iterator; +import java.util.Map; +import java.util.stream.Stream; -import org.georchestra.gateway.model.HeaderMappings; -import org.georchestra.gateway.model.RoleBasedAccessRule; +import org.georchestra.gateway.app.GeorchestraGatewayApplication; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.springframework.cloud.gateway.filter.GatewayFilter; -import org.springframework.cloud.gateway.filter.GatewayFilterChain; -import org.springframework.cloud.gateway.handler.FilteringWebHandler; -import org.springframework.cloud.gateway.route.Route; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.boot.test.web.client.TestRestTemplate; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; -import org.springframework.http.server.reactive.ServerHttpResponse; -import org.springframework.mock.http.server.reactive.MockServerHttpRequest; -import org.springframework.mock.http.server.reactive.MockServerHttpResponse; -import org.springframework.mock.web.server.MockServerWebExchange; -import org.springframework.web.server.ResponseStatusException; -import org.springframework.web.server.ServerWebExchange; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; -import reactor.core.publisher.Mono; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.StubMapping; +@SpringBootTest(classes = GeorchestraGatewayApplication.class, // + webEnvironment = WebEnvironment.RANDOM_PORT, // + properties = { // + "server.error.whitelabel.enabled=false", // + "georchestra.gateway.global-access-rules[0].intercept-url=/**", // + "georchestra.gateway.global-access-rules[0].anonymous=true" // + }) +@WireMockTest class ApplicationErrorGatewayFilterFactoryTest { - private GatewayFilter filter; - private MockServerWebExchange exchange; + /** + * saved in {@link #setUpWireMock}, to be used on {@link #registerRoutes} + */ + private static WireMockRuntimeInfo wmRuntimeInfo; - final URI matchedURI = URI.create("http://fake.backend.com:8080"); - private Route matchedRoute; + /** + * Set up stub requests for the wiremock server. WireMock is running on a random + * port, so this method saves {@link #wmRuntimeInfo} for + * {@link #registerRoutes(DynamicPropertyRegistry)} + */ + @BeforeAll + static void saveWireMock(WireMockRuntimeInfo runtimeInfo) { + ApplicationErrorGatewayFilterFactoryTest.wmRuntimeInfo = runtimeInfo; + } - HeaderMappings defaultHeaders; - List defaultRules; + /** + * Set up a gateway route that proxies all requests to the wiremock server + */ + @DynamicPropertySource + static void registerRoutes(DynamicPropertyRegistry registry) { + String targetUrl = wmRuntimeInfo.getHttpBaseUrl(); - @BeforeEach - void setUp() throws Exception { - var factory = new ApplicationErrorGatewayFilterFactory(); - filter = factory.apply(factory.newConfig()); + registry.add("spring.cloud.gateway.routes[0].id", () -> "mockeduproute"); + registry.add("spring.cloud.gateway.routes[0].uri", () -> targetUrl); + registry.add("spring.cloud.gateway.routes[0].predicates[0]", () -> "Path=/**"); + } + + @Autowired + TestRestTemplate testRestTemplate; - matchedRoute = mock(Route.class); - when(matchedRoute.getUri()).thenReturn(matchedURI); + @SpyBean + ApplicationErrorGatewayFilterFactory factory; - MockServerHttpRequest request = MockServerHttpRequest.get("/test").build(); - exchange = MockServerWebExchange.from(request); - exchange.getAttributes().put(GATEWAY_ROUTE_ATTR, matchedRoute); - exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, matchedURI); + @BeforeEach + void setUp(WireMockRuntimeInfo runtimeInfo) throws Exception { + StubMapping defaultResponse = buildFrom(""" + { + "priority": 100, + "request": {"method": "ANY","urlPattern": ".*"}, + "response": { + "status": 418, + "jsonBody": { "status": "Error", "message": "I'm a teapot" }, + "headers": {"Content-Type": "application/json"} + } + } + """); + WireMock wireMock = runtimeInfo.getWireMock(); + wireMock.register(defaultResponse); } @Test - void testNotAnErrorResponse() { - GatewayFilterChain chain = mock(GatewayFilterChain.class); + void testNonIdempotentHttpMethodsIgnored(WireMockRuntimeInfo runtimeInfo) { + StubMapping mapping = buildFrom(""" + { + "priority": 1, + "request": { + "method": "POST", + "url": "/geonetwork", + "headers": { + "Accept": {"contains": "text/html"} + } + }, + "response": { + "status": 400, + "body": "Bad request from downstream", + "headers": { + "Content-Type": "text/plain", + "X-Frame-Options": "ALLOW-FROM *.test.com", + "X-Content-Type-Options": "nosniff", + "Referrer-Policy": "same-origin" + } + } + } + """); + runtimeInfo.getWireMock().register(mapping); - filter.filter(exchange, chain); + ResponseEntity response = testRestTemplate.postForEntity("/geonetwork", + withHeaders("Accept", "text/html"), String.class); - ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); - verify(chain).filter(captor.capture()); + verify(factory, times(1)).canFilter(any()); + verify(factory, never()).decorate(any()); - ServerWebExchange mutated = captor.getValue(); - ServerHttpResponse response = mutated.getResponse(); - response.setStatusCode(HttpStatus.CREATED); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); + Map headers = response.getHeaders().toSingleValueMap(); + Map expected = Map.of(// + "Content-Type", "text/plain", // + "X-Frame-Options", "ALLOW-FROM *.test.com", // + "X-Content-Type-Options", "nosniff", // + "Referrer-Policy", "same-origin"// - MockServerHttpResponse origResponse = exchange.getResponse(); - assertThat(origResponse.getStatusCode()).isEqualTo(HttpStatus.CREATED); + ); + assertThat(headers).as("response does not contain all original headers").containsAllEntriesOf(expected); + assertThat(response.getBody()).isEqualTo("Bad request from downstream"); } @Test - void test4xx() { - testApplicationError(HttpStatus.BAD_REQUEST); - testApplicationError(HttpStatus.UNAUTHORIZED); - testApplicationError(HttpStatus.FORBIDDEN); - testApplicationError(HttpStatus.NOT_FOUND); + void testNonHtmlAcceptRquestIgnored(WireMockRuntimeInfo runtimeInfo) { + StubMapping mapping = buildFrom(""" + { + "priority": 1, + "request": { + "method": "GET", + "url": "/geonetwork", + "headers": { + "Accept": {"contains": "application/json"} + } + }, + "response": { + "status": 500, + "body": "Internal server error from downstream", + "headers": { + "Content-Type": "text/plain", + "X-Frame-Options": "ALLOW-FROM *.test.com", + "X-Content-Type-Options": "nosniff", + "Referrer-Policy": "same-origin" + } + } + } + """); + runtimeInfo.getWireMock().register(mapping); + + RequestEntity req = RequestEntity.get("/geonetwork").header("Accept", "application/json").build(); + ResponseEntity response = testRestTemplate.exchange(req, String.class); + + verify(factory, times(1)).canFilter(any()); + verify(factory, never()).decorate(any()); + + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + Map headers = response.getHeaders().toSingleValueMap(); + Map expected = Map.of(// + "Content-Type", "text/plain", // + "X-Frame-Options", "ALLOW-FROM *.test.com", // + "X-Content-Type-Options", "nosniff", // + "Referrer-Policy", "same-origin"// + + ); + assertThat(headers).as("response does not contain all original headers").containsAllEntriesOf(expected); + assertThat(response.getBody()).isEqualTo("Internal server error from downstream"); } @Test - void test5xx() { - testApplicationError(HttpStatus.INTERNAL_SERVER_ERROR); - testApplicationError(HttpStatus.SERVICE_UNAVAILABLE); - testApplicationError(HttpStatus.BAD_GATEWAY); - } + void testApplicationErrorToCustomErrorPageMapping(WireMockRuntimeInfo runtimeInfo) { + runtimeInfo.getWireMock().register(buildFrom(""" + { + "priority": 1, + "request": { + "method": "GET", + "url": "/geonetwork", + "headers": { + "Accept": {"contains": "text/html"} + } + }, + "response": { + "status": 500, + "body": "Internal server error from downstream", + "headers": { + "Content-Type": "text/plain", + "X-Frame-Options": "ALLOW-FROM *.test.com", + "X-Content-Type-Options": "nosniff", + "Referrer-Policy": "same-origin" + } + } + } + """)); + + RequestEntity req = RequestEntity.get("/geonetwork").header("Accept", "text/html").build(); + ResponseEntity response = testRestTemplate.exchange(req, String.class); + + verify(factory, times(1)).canFilter(any()); + verify(factory, times(1)).decorate(any()); - private void testApplicationError(HttpStatus status) { - GatewayFilterChain chain = mock(GatewayFilterChain.class); - filter.filter(exchange, chain); - ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); - verify(chain).filter(captor.capture()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR); + assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_HTML)) + .as("Expected content type text/html").isTrue(); + + Map headers = response.getHeaders().toSingleValueMap(); + Map expected = Map.of(// + "X-Frame-Options", "ALLOW-FROM *.test.com", // + "X-Content-Type-Options", "nosniff", // + "Referrer-Policy", "same-origin"// + + ); + assertThat(headers).as("response does not contain all original headers").containsAllEntriesOf(expected); + } - ServerWebExchange mutated = captor.getValue(); - ServerHttpResponse response = mutated.getResponse(); - assertThrows(ResponseStatusException.class, () -> response.setStatusCode(status)); + private HttpEntity withHeaders(String... headersKvp) { + assertThat(headersKvp.length % 2).isZero(); + HttpHeaders headers = new HttpHeaders(); + Iterator it = Stream.of(headersKvp).iterator(); + while (it.hasNext()) { + headers.add(it.next(), it.next()); + } + return new HttpEntity<>(headers); } } diff --git a/gateway/src/test/resources/logback-test.xml b/gateway/src/test/resources/logback-test.xml index d3a7d2ff..3956e5b6 100644 --- a/gateway/src/test/resources/logback-test.xml +++ b/gateway/src/test/resources/logback-test.xml @@ -8,5 +8,6 @@ - + + \ No newline at end of file