diff --git a/pom.xml b/pom.xml
index 01eb04c..4dceb65 100644
--- a/pom.xml
+++ b/pom.xml
@@ -248,6 +248,13 @@
test
+
+ org.wiremock
+ wiremock
+ 3.9.1
+ test
+
+
diff --git a/src/main/java/org/folio/rest/impl/SamlAPI.java b/src/main/java/org/folio/rest/impl/SamlAPI.java
index d9bafb5..7eaa64f 100644
--- a/src/main/java/org/folio/rest/impl/SamlAPI.java
+++ b/src/main/java/org/folio/rest/impl/SamlAPI.java
@@ -257,7 +257,8 @@ private void doPostSamlCallback(String body, RoutingContext routingContext, Map<
if (isLegacyResponse(tokenSignEndpoint)) {
return redirectResponseLegacy(jsonResponse, stripesBaseUrl, originalUrl);
} else {
- return redirectResponse(jsonResponse, stripesBaseUrl, originalUrl);
+ var okapiPath = UrlUtil.getPathFromOkapiUrl(parsedHeaders.getUrl());
+ return redirectResponse(jsonResponse, stripesBaseUrl, originalUrl, okapiPath);
}
});
});
@@ -321,7 +322,9 @@ private Response redirectResponseLegacy(JsonObject jsonObject, URI stripesBaseUr
return PostSamlCallbackResponse.respond302(headers);
}
- private Response redirectResponse(JsonObject jsonObject, URI stripesBaseUrl, URI originalUrl) {
+ private Response redirectResponse(JsonObject jsonObject,
+ URI stripesBaseUrl, URI originalUrl, String okapiPath) {
+
String accessToken = jsonObject.getString(ACCESS_TOKEN);
String refreshToken = jsonObject.getString(REFRESH_TOKEN);
String accessTokenExpiration = jsonObject.getString(ACCESS_TOKEN_EXPIRATION);
@@ -339,13 +342,14 @@ private Response redirectResponse(JsonObject jsonObject, URI stripesBaseUrl, URI
// NOTE RMB doesn't support sending multiple headers with the same key so we
// make our own response.
- return Response.status(302).header(SET_COOKIE, accessTokenCookie(accessToken, accessTokenExpiration))
- .header(SET_COOKIE, refreshTokenCookie(refreshToken, refreshTokenExpiration))
+ return Response.status(302)
+ .header(SET_COOKIE, accessTokenCookie(accessToken, accessTokenExpiration, okapiPath))
+ .header(SET_COOKIE, refreshTokenCookie(refreshToken, refreshTokenExpiration, okapiPath))
.header(LOCATION, location)
.build();
}
- private String refreshTokenCookie(String refreshToken, String refreshTokenExpiration) {
+ private String refreshTokenCookie(String refreshToken, String refreshTokenExpiration, String okapiPath) {
// The refresh token expiration is the time after which the token will be
// considered expired.
var exp = Instant.parse(refreshTokenExpiration).getEpochSecond();
@@ -360,7 +364,7 @@ private String refreshTokenCookie(String refreshToken, String refreshTokenExpira
var rtCookie = Cookie.cookie(FOLIO_REFRESH_TOKEN, refreshToken)
.setMaxAge(ttlSeconds)
.setSecure(true)
- .setPath("/authn")
+ .setPath(okapiPath + "/authn")
.setHttpOnly(true)
.setSameSite(CookieSameSiteConfig.get())
.setDomain(null)
@@ -371,7 +375,7 @@ private String refreshTokenCookie(String refreshToken, String refreshTokenExpira
return rtCookie;
}
- private String accessTokenCookie(String accessToken, String accessTokenExpiration) {
+ private String accessTokenCookie(String accessToken, String accessTokenExpiration, String okapiPath) {
// The refresh token expiration is the time after which the token will be
// considered expired.
var exp = Instant.parse(accessTokenExpiration).getEpochSecond();
@@ -386,7 +390,7 @@ private String accessTokenCookie(String accessToken, String accessTokenExpiratio
var atCookie = Cookie.cookie(FOLIO_ACCESS_TOKEN, accessToken)
.setMaxAge(ttlSeconds)
.setSecure(true)
- .setPath("/")
+ .setPath(okapiPath + "/")
.setHttpOnly(true)
.setSameSite(CookieSameSiteConfig.get())
.encode();
diff --git a/src/main/java/org/folio/util/UrlUtil.java b/src/main/java/org/folio/util/UrlUtil.java
index 526c118..0475a39 100644
--- a/src/main/java/org/folio/util/UrlUtil.java
+++ b/src/main/java/org/folio/util/UrlUtil.java
@@ -1,7 +1,13 @@
package org.folio.util;
+import java.io.UncheckedIOException;
import java.net.ConnectException;
+import java.net.MalformedURLException;
import java.net.URI;
+import java.net.URL;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
@@ -12,11 +18,30 @@
* @author rsass
*/
public class UrlUtil {
+ private static final Logger logger = LogManager.getLogger(UrlUtil.class);
private UrlUtil() {
}
+ /**
+ * Return the path component of the okapiUrl without tailing '/',
+ * return empty string on parse error.
+ */
+ public static String getPathFromOkapiUrl(String okapiUrl) {
+ try {
+ var path = new URL(okapiUrl).getPath();
+ if (path.endsWith("/")) {
+ return path.substring(0, path.length() - 1);
+ }
+ return path;
+ } catch (MalformedURLException e) {
+ var message = "Malformed Okapi URL: " + okapiUrl;
+ logger.error("{}", message, e);
+ throw new UncheckedIOException(message, e);
+ }
+ }
+
public static URI parseBaseUrl(URI originalUrl) {
return URI.create(originalUrl.getScheme() + "://" + originalUrl.getAuthority());
}
diff --git a/src/test/java/org/folio/rest/impl/SamlAPITest.java b/src/test/java/org/folio/rest/impl/SamlAPITest.java
index 27e9335..2417917 100644
--- a/src/test/java/org/folio/rest/impl/SamlAPITest.java
+++ b/src/test/java/org/folio/rest/impl/SamlAPITest.java
@@ -3,12 +3,16 @@
import static io.restassured.RestAssured.given;
import static io.restassured.module.jsv.JsonSchemaValidator.matchesJsonSchemaInClasspath;
import static org.folio.util.Base64AwareXsdMatcher.matchesBase64XsdInClasspath;
+import static com.github.tomakehurst.wiremock.client.WireMock.any;
+import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
+import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
+import com.github.tomakehurst.wiremock.junit.WireMockClassRule;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
@@ -34,6 +38,7 @@
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
+import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
@@ -47,10 +52,10 @@
import org.pac4j.core.redirect.RedirectionActionBuilder;
import org.pac4j.saml.client.SAML2Client;
import org.w3c.dom.ls.LSResourceResolver;
-
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import io.restassured.http.Header;
+import io.restassured.matcher.RestAssuredMatchers;
import io.restassured.response.ExtractableResponse;
import io.restassured.response.Response;
import io.vertx.core.DeploymentOptions;
@@ -79,11 +84,17 @@ public class SamlAPITest extends TestBase {
public static final int IDP_MOCK_PORT = NetworkUtils.nextFreePort();
private static final int MOCK_SERVER_PORT = NetworkUtils.nextFreePort();
+ private static final int OKAPI_PROXY_PORT = NetworkUtils.nextFreePort();
private static final Header OKAPI_URL_HEADER= new Header("X-Okapi-Url", "http://localhost:" + MOCK_SERVER_PORT);
+ private static final Header OKAPI_PROXY_URL_HEADER=
+ new Header("X-Okapi-Url", "http://localhost:" + OKAPI_PROXY_PORT + "/okapi");
private static final MockJsonExtended mock = new MockJsonExtended();
private DataMigrationHelper dataMigrationHelper = new DataMigrationHelper(TENANT_HEADER, TOKEN_HEADER, OKAPI_URL_HEADER);
+ @ClassRule
+ public static WireMockClassRule okapiProxy = new WireMockClassRule(OKAPI_PROXY_PORT);
+
@Rule
public TestName testName = new TestName();
@@ -92,6 +103,11 @@ public static void setupOnce(TestContext context) {
RestAssured.port = TestBase.modulePort;
RestAssured.enableLoggingOfRequestAndResponseIfValidationFails();
+ okapiProxy.stubFor(any(urlMatching("/okapi/.*"))
+ .willReturn(aResponse()
+ .proxiedFrom("http://localhost:" + MOCK_SERVER_PORT)
+ .withProxyUrlPrefixToRemove("/okapi")));
+
DeploymentOptions idpOptions = new DeploymentOptions()
.setConfig(new JsonObject().put("http.port", IDP_MOCK_PORT));
@@ -728,6 +744,20 @@ public void callbackEndpointTests(TestContext context) {
samlResponse, TENANT_HEADER, TOKEN_HEADER, OKAPI_URL_HEADER);
CookieSameSiteConfig.set(Map.of());
+ log.info("=== Test - POST /saml/callback-with-expiry with okapi path ===");
+ given()
+ .header(TENANT_HEADER)
+ .header(TOKEN_HEADER)
+ .header(OKAPI_PROXY_URL_HEADER)
+ .cookie(SamlAPI.RELAY_STATE, cookie)
+ .formParam("SAMLResponse", samlResponse)
+ .formParam("RelayState", relayState)
+ .post("/saml/callback-with-expiry")
+ .then()
+ .statusCode(302)
+ .cookie("folioRefreshToken", RestAssuredMatchers.detailedCookie().path("/okapi/authn"))
+ .cookie("folioAccessToken", RestAssuredMatchers.detailedCookie().path("/okapi/"));
+
log.info("=== Test - POST /saml/callback-with-expiry - failure (wrong cookie) ===");
given()
.header(TENANT_HEADER)
diff --git a/src/test/java/org/folio/util/UrlUtilTest.java b/src/test/java/org/folio/util/UrlUtilTest.java
index 134619c..cd957e0 100644
--- a/src/test/java/org/folio/util/UrlUtilTest.java
+++ b/src/test/java/org/folio/util/UrlUtilTest.java
@@ -1,17 +1,22 @@
package org.folio.util;
+import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
+import java.io.UncheckedIOException;
+import java.net.MalformedURLException;
import org.folio.rest.tools.utils.NetworkUtils;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
import org.junit.runner.RunWith;
import io.vertx.core.DeploymentOptions;
@@ -50,6 +55,23 @@ public static void afterOnce(TestContext context) {
mockVertx.close(context.asyncAssertSuccess());
}
+ void assertMalformedUrlException(ThrowingRunnable runnable) {
+ var e = assertThrows(UncheckedIOException.class, runnable);
+ assertThat(e.getMessage(), startsWith("Malformed Okapi URL: "));
+ assertThat(e.getCause(), instanceOf(MalformedURLException.class));
+ }
+
+ @Test
+ public void getPathFromOkapiUrl() {
+ assertMalformedUrlException(() -> UrlUtil.getPathFromOkapiUrl(null));
+ assertMalformedUrlException(() -> UrlUtil.getPathFromOkapiUrl(""));
+ assertMalformedUrlException(() -> UrlUtil.getPathFromOkapiUrl(":"));
+ assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost"), is(""));
+ assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost/"), is(""));
+ assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost/okapi"), is("/okapi"));
+ assertThat(UrlUtil.getPathFromOkapiUrl("https://localhost/okapi/"), is("/okapi"));
+ }
+
@Test
public void checkIdpUrl(TestContext context) {
UrlUtil.checkIdpUrl("http://localhost:" + MOCK_PORT + "/xml", vertx)