diff --git a/pom.xml b/pom.xml index 01eb04c..4dceb65 100644 --- a/pom.xml +++ b/pom.xml @@ -248,6 +248,13 @@ test + + org.wiremock + wiremock + 3.9.1 + test + + diff --git a/src/main/java/org/folio/rest/impl/SamlAPI.java b/src/main/java/org/folio/rest/impl/SamlAPI.java index d9bafb5..7eaa64f 100644 --- a/src/main/java/org/folio/rest/impl/SamlAPI.java +++ b/src/main/java/org/folio/rest/impl/SamlAPI.java @@ -257,7 +257,8 @@ private void doPostSamlCallback(String body, RoutingContext routingContext, Map< if (isLegacyResponse(tokenSignEndpoint)) { return redirectResponseLegacy(jsonResponse, stripesBaseUrl, originalUrl); } else { - return redirectResponse(jsonResponse, stripesBaseUrl, originalUrl); + var okapiPath = UrlUtil.getPathFromOkapiUrl(parsedHeaders.getUrl()); + return redirectResponse(jsonResponse, stripesBaseUrl, originalUrl, okapiPath); } }); }); @@ -321,7 +322,9 @@ private Response redirectResponseLegacy(JsonObject jsonObject, URI stripesBaseUr return PostSamlCallbackResponse.respond302(headers); } - private Response redirectResponse(JsonObject jsonObject, URI stripesBaseUrl, URI originalUrl) { + private Response redirectResponse(JsonObject jsonObject, + URI stripesBaseUrl, URI originalUrl, String okapiPath) { + String accessToken = jsonObject.getString(ACCESS_TOKEN); String refreshToken = jsonObject.getString(REFRESH_TOKEN); String accessTokenExpiration = jsonObject.getString(ACCESS_TOKEN_EXPIRATION); @@ -339,13 +342,14 @@ private Response redirectResponse(JsonObject jsonObject, URI stripesBaseUrl, URI // NOTE RMB doesn't support sending multiple headers with the same key so we // make our own response. - return Response.status(302).header(SET_COOKIE, accessTokenCookie(accessToken, accessTokenExpiration)) - .header(SET_COOKIE, refreshTokenCookie(refreshToken, refreshTokenExpiration)) + return Response.status(302) + .header(SET_COOKIE, accessTokenCookie(accessToken, accessTokenExpiration, okapiPath)) + .header(SET_COOKIE, refreshTokenCookie(refreshToken, refreshTokenExpiration, okapiPath)) .header(LOCATION, location) .build(); } - private String refreshTokenCookie(String refreshToken, String refreshTokenExpiration) { + private String refreshTokenCookie(String refreshToken, String refreshTokenExpiration, String okapiPath) { // The refresh token expiration is the time after which the token will be // considered expired. var exp = Instant.parse(refreshTokenExpiration).getEpochSecond(); @@ -360,7 +364,7 @@ private String refreshTokenCookie(String refreshToken, String refreshTokenExpira var rtCookie = Cookie.cookie(FOLIO_REFRESH_TOKEN, refreshToken) .setMaxAge(ttlSeconds) .setSecure(true) - .setPath("/authn") + .setPath(okapiPath + "/authn") .setHttpOnly(true) .setSameSite(CookieSameSiteConfig.get()) .setDomain(null) @@ -371,7 +375,7 @@ private String refreshTokenCookie(String refreshToken, String refreshTokenExpira return rtCookie; } - private String accessTokenCookie(String accessToken, String accessTokenExpiration) { + private String accessTokenCookie(String accessToken, String accessTokenExpiration, String okapiPath) { // The refresh token expiration is the time after which the token will be // considered expired. var exp = Instant.parse(accessTokenExpiration).getEpochSecond(); @@ -386,7 +390,7 @@ private String accessTokenCookie(String accessToken, String accessTokenExpiratio var atCookie = Cookie.cookie(FOLIO_ACCESS_TOKEN, accessToken) .setMaxAge(ttlSeconds) .setSecure(true) - .setPath("/") + .setPath(okapiPath + "/") .setHttpOnly(true) .setSameSite(CookieSameSiteConfig.get()) .encode(); diff --git a/src/main/java/org/folio/util/UrlUtil.java b/src/main/java/org/folio/util/UrlUtil.java index 526c118..0475a39 100644 --- a/src/main/java/org/folio/util/UrlUtil.java +++ b/src/main/java/org/folio/util/UrlUtil.java @@ -1,7 +1,13 @@ package org.folio.util; +import java.io.UncheckedIOException; import java.net.ConnectException; +import java.net.MalformedURLException; import java.net.URI; +import java.net.URL; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import io.vertx.core.Future; import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; @@ -12,11 +18,30 @@ * @author rsass */ public class UrlUtil { + private static final Logger logger = LogManager.getLogger(UrlUtil.class); private UrlUtil() { } + /** + * Return the path component of the okapiUrl without tailing '/', + * return empty string on parse error. + */ + public static String getPathFromOkapiUrl(String okapiUrl) { + try { + var path = new URL(okapiUrl).getPath(); + if (path.endsWith("/")) { + return path.substring(0, path.length() - 1); + } + return path; + } catch (MalformedURLException e) { + var message = "Malformed Okapi URL: " + okapiUrl; + logger.error("{}", message, e); + throw new UncheckedIOException(message, e); + } + } + public static URI parseBaseUrl(URI originalUrl) { return URI.create(originalUrl.getScheme() + "://" + originalUrl.getAuthority()); } diff --git a/src/test/java/org/folio/rest/impl/SamlAPITest.java b/src/test/java/org/folio/rest/impl/SamlAPITest.java index 27e9335..2417917 100644 --- a/src/test/java/org/folio/rest/impl/SamlAPITest.java +++ b/src/test/java/org/folio/rest/impl/SamlAPITest.java @@ -3,12 +3,16 @@ import static io.restassured.RestAssured.given; import static io.restassured.module.jsv.JsonSchemaValidator.matchesJsonSchemaInClasspath; import static org.folio.util.Base64AwareXsdMatcher.matchesBase64XsdInClasspath; +import static com.github.tomakehurst.wiremock.client.WireMock.any; +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertThrows; +import com.github.tomakehurst.wiremock.junit.WireMockClassRule; import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; @@ -34,6 +38,7 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; +import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; @@ -47,10 +52,10 @@ import org.pac4j.core.redirect.RedirectionActionBuilder; import org.pac4j.saml.client.SAML2Client; import org.w3c.dom.ls.LSResourceResolver; - import io.restassured.RestAssured; import io.restassured.http.ContentType; import io.restassured.http.Header; +import io.restassured.matcher.RestAssuredMatchers; import io.restassured.response.ExtractableResponse; import io.restassured.response.Response; import io.vertx.core.DeploymentOptions; @@ -79,11 +84,17 @@ public class SamlAPITest extends TestBase { public static final int IDP_MOCK_PORT = NetworkUtils.nextFreePort(); private static final int MOCK_SERVER_PORT = NetworkUtils.nextFreePort(); + private static final int OKAPI_PROXY_PORT = NetworkUtils.nextFreePort(); private static final Header OKAPI_URL_HEADER= new Header("X-Okapi-Url", "http://localhost:" + MOCK_SERVER_PORT); + private static final Header OKAPI_PROXY_URL_HEADER= + new Header("X-Okapi-Url", "http://localhost:" + OKAPI_PROXY_PORT + "/okapi"); private static final MockJsonExtended mock = new MockJsonExtended(); private DataMigrationHelper dataMigrationHelper = new DataMigrationHelper(TENANT_HEADER, TOKEN_HEADER, OKAPI_URL_HEADER); + @ClassRule + public static WireMockClassRule okapiProxy = new WireMockClassRule(OKAPI_PROXY_PORT); + @Rule public TestName testName = new TestName(); @@ -92,6 +103,11 @@ public static void setupOnce(TestContext context) { RestAssured.port = TestBase.modulePort; RestAssured.enableLoggingOfRequestAndResponseIfValidationFails(); + okapiProxy.stubFor(any(urlMatching("/okapi/.*")) + .willReturn(aResponse() + .proxiedFrom("http://localhost:" + MOCK_SERVER_PORT) + .withProxyUrlPrefixToRemove("/okapi"))); + DeploymentOptions idpOptions = new DeploymentOptions() .setConfig(new JsonObject().put("http.port", IDP_MOCK_PORT)); @@ -728,6 +744,20 @@ public void callbackEndpointTests(TestContext context) { samlResponse, TENANT_HEADER, TOKEN_HEADER, OKAPI_URL_HEADER); CookieSameSiteConfig.set(Map.of()); + log.info("=== Test - POST /saml/callback-with-expiry with okapi path ==="); + given() + .header(TENANT_HEADER) + .header(TOKEN_HEADER) + .header(OKAPI_PROXY_URL_HEADER) + .cookie(SamlAPI.RELAY_STATE, cookie) + .formParam("SAMLResponse", samlResponse) + .formParam("RelayState", relayState) + .post("/saml/callback-with-expiry") + .then() + .statusCode(302) + .cookie("folioRefreshToken", RestAssuredMatchers.detailedCookie().path("/okapi/authn")) + .cookie("folioAccessToken", RestAssuredMatchers.detailedCookie().path("/okapi/")); + log.info("=== Test - POST /saml/callback-with-expiry - failure (wrong cookie) ==="); given() .header(TENANT_HEADER) diff --git a/src/test/java/org/folio/util/UrlUtilTest.java b/src/test/java/org/folio/util/UrlUtilTest.java index 134619c..cd957e0 100644 --- a/src/test/java/org/folio/util/UrlUtilTest.java +++ b/src/test/java/org/folio/util/UrlUtilTest.java @@ -1,17 +1,22 @@ package org.folio.util; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.io.UncheckedIOException; +import java.net.MalformedURLException; import org.folio.rest.tools.utils.NetworkUtils; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.function.ThrowingRunnable; import org.junit.runner.RunWith; import io.vertx.core.DeploymentOptions; @@ -50,6 +55,23 @@ public static void afterOnce(TestContext context) { mockVertx.close(context.asyncAssertSuccess()); } + void assertMalformedUrlException(ThrowingRunnable runnable) { + var e = assertThrows(UncheckedIOException.class, runnable); + assertThat(e.getMessage(), startsWith("Malformed Okapi URL: ")); + assertThat(e.getCause(), instanceOf(MalformedURLException.class)); + } + + @Test + public void getPathFromOkapiUrl() { + assertMalformedUrlException(() -> UrlUtil.getPathFromOkapiUrl(null)); + assertMalformedUrlException(() -> UrlUtil.getPathFromOkapiUrl("")); + assertMalformedUrlException(() -> UrlUtil.getPathFromOkapiUrl(":")); + assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost"), is("")); + assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost/"), is("")); + assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost/okapi"), is("/okapi")); + assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost/okapi/"), is("/okapi")); + } + @Test public void checkIdpUrl(TestContext context) { UrlUtil.checkIdpUrl("http://localhost:" + MOCK_PORT + "/xml", vertx)