diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/AuthenticatedActionsHandler.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/AuthenticatedActionsHandler.java index 2b218733fb..f86a68bdb3 100644 --- a/http/oidc/src/main/java/org/wildfly/security/http/oidc/AuthenticatedActionsHandler.java +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/AuthenticatedActionsHandler.java @@ -36,6 +36,8 @@ * @author Farah Juma */ public class AuthenticatedActionsHandler { + + private static LogoutHandler logoutHandler = new LogoutHandler(); private OidcClientConfiguration deployment; private OidcHttpFacade facade; @@ -52,6 +54,11 @@ public boolean handledRequest() { queryBearerToken(); return true; } + + if (logoutHandler.tryLogout(facade)) { + return true; + } + return false; } diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/IDToken.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/IDToken.java index d40be6bfce..2c733cdacd 100644 --- a/http/oidc/src/main/java/org/wildfly/security/http/oidc/IDToken.java +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/IDToken.java @@ -53,6 +53,7 @@ public class IDToken extends JsonWebToken { public static final String CLAIMS_LOCALES = "claims_locales"; public static final String ACR = "acr"; public static final String S_HASH = "s_hash"; + public static final String SID = "sid"; /** * Construct a new instance. @@ -228,4 +229,12 @@ public String getAcr() { return getClaimValueAsString(ACR); } + /** + * Get the sid claim. + * + * @return the sid claim + */ + public String getSid() { + return getClaimValueAsString(SID); + } } diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/LogoutHandler.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/LogoutHandler.java new file mode 100644 index 0000000000..e8fd9269a0 --- /dev/null +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/LogoutHandler.java @@ -0,0 +1,267 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2021 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wildfly.security.http.oidc; + +import static java.util.Collections.synchronizedMap; +import static org.wildfly.security.http.oidc.ElytronMessages.log; + +import java.net.URISyntaxException; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.http.HttpStatus; +import org.apache.http.client.utils.URIBuilder; +import org.jose4j.jwt.JwtClaims; +import org.wildfly.security.http.HttpConstants; +import org.wildfly.security.http.oidc.OidcHttpFacade.Request; + +/** + * @author Pedro Igor + */ +final class LogoutHandler { + + private static final String POST_LOGOUT_REDIRECT_URI_PARAM = "post_logout_redirect_uri"; + private static final String ID_TOKEN_HINT_PARAM = "id_token_hint"; + private static final String LOGOUT_TOKEN_PARAM = "logout_token"; + private static final String LOGOUT_TOKEN_TYPE = "Logout"; + private static final String SID = "sid"; + private static final String ISS = "iss"; + + /** + * A bounded map to store sessions marked for invalidation after receiving logout requests through the back-channel + */ + private Map sessionsMarkedForInvalidation = synchronizedMap(new LinkedHashMap(16, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + boolean remove = sessionsMarkedForInvalidation.size() > eldest.getValue().getLogoutSessionWaitingLimit(); + + if (remove) { + log.debugf("Limit [%s] reached for sessions waiting [%s] for logout", eldest.getValue().getLogoutSessionWaitingLimit(), sessionsMarkedForInvalidation.size()); + } + + return remove; + } + }); + + boolean tryLogout(OidcHttpFacade facade) { + RefreshableOidcSecurityContext securityContext = getSecurityContext(facade); + + if (securityContext == null) { + // no active session + return false; + } + + if (isSessionMarkedForInvalidation(facade)) { + // session marked for invalidation, invalidate it + log.debug("Invalidating pending logout session"); + facade.getTokenStore().logout(false); + return true; + } + + if (isRpInitiatedLogoutUri(facade)) { + redirectEndSessionEndpoint(facade); + return true; + } + + if (isLogoutCallbackUri(facade)) { + handleLogoutRequest(facade); + return true; + } + + return false; + } + + private boolean isSessionMarkedForInvalidation(OidcHttpFacade facade) { + RefreshableOidcSecurityContext securityContext = getSecurityContext(facade); + IDToken idToken = securityContext.getIDToken(); + + if (idToken == null) { + return false; + } + + return sessionsMarkedForInvalidation.remove(idToken.getSid()) != null; + } + + private void redirectEndSessionEndpoint(OidcHttpFacade facade) { + RefreshableOidcSecurityContext securityContext = getSecurityContext(facade); + OidcClientConfiguration clientConfiguration = securityContext.getOidcClientConfiguration(); + String logoutUri; + + try { + URIBuilder redirectUriBuilder = new URIBuilder(clientConfiguration.getEndSessionEndpointUrl()) + .addParameter(ID_TOKEN_HINT_PARAM, securityContext.getIDTokenString()); + String postLogoutUri = clientConfiguration.getPostLogoutUri(); + + if (postLogoutUri != null) { + redirectUriBuilder.addParameter(POST_LOGOUT_REDIRECT_URI_PARAM, getRedirectUri(facade) + postLogoutUri); + } + + logoutUri = redirectUriBuilder.build().toString(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + + log.debugf("Sending redirect to the end_session_endpoint: %s", logoutUri); + facade.getResponse().setStatus(HttpStatus.SC_MOVED_TEMPORARILY); + facade.getResponse().setHeader(HttpConstants.LOCATION, logoutUri); + } + + private void handleLogoutRequest(OidcHttpFacade facade) { + if (isFrontChannel(facade)) { + handleFrontChannelLogoutRequest(facade); + } else if (isBackChannel(facade)) { + handleBackChannelLogoutRequest(facade); + } else { + // logout requests should arrive either as a HTTP GET or POST + facade.getResponse().setStatus(HttpStatus.SC_METHOD_NOT_ALLOWED); + facade.authenticationFailed(); + } + } + + private void handleBackChannelLogoutRequest(OidcHttpFacade facade) { + RefreshableOidcSecurityContext securityContext = getSecurityContext(facade); + String logoutToken = facade.getRequest().getFirstParam(LOGOUT_TOKEN_PARAM); + TokenValidator tokenValidator = TokenValidator.builder(securityContext.getOidcClientConfiguration()) + .setSkipExpirationValidator() + .setTokenType(LOGOUT_TOKEN_TYPE) + .build(); + JwtClaims claims; + + try { + claims = tokenValidator.verify(logoutToken); + } catch (Exception cause) { + log.debug("Unexpected error when verifying logout token", cause); + facade.getResponse().setStatus(HttpStatus.SC_BAD_REQUEST); + facade.authenticationFailed(); + return; + } + + if (!isSessionRequiredOnLogout(facade)) { + log.warn("Back-channel logout request received but can not infer sid from logout token to mark it for invalidation"); + facade.getResponse().setStatus(HttpStatus.SC_BAD_REQUEST); + facade.authenticationFailed(); + return; + } + + String sessionId = claims.getClaimValueAsString(SID); + + if (sessionId == null) { + facade.getResponse().setStatus(HttpStatus.SC_BAD_REQUEST); + facade.authenticationFailed(); + return; + } + + log.debug("Marking session for invalidation during back-channel logout"); + sessionsMarkedForInvalidation.put(sessionId, securityContext.getOidcClientConfiguration()); + } + + private void handleFrontChannelLogoutRequest(OidcHttpFacade facade) { + if (isSessionRequiredOnLogout(facade)) { + Request request = facade.getRequest(); + String sessionId = request.getQueryParamValue(SID); + + if (sessionId == null) { + facade.getResponse().setStatus(HttpStatus.SC_BAD_REQUEST); + facade.authenticationFailed(); + return; + } + + RefreshableOidcSecurityContext context = getSecurityContext(facade); + IDToken idToken = context.getIDToken(); + String issuer = request.getQueryParamValue(ISS); + + if (idToken == null || !sessionId.equals(idToken.getSid()) || !idToken.getIssuer().equals(issuer)) { + facade.getResponse().setStatus(HttpStatus.SC_BAD_REQUEST); + facade.authenticationFailed(); + return; + } + } + + log.debug("Invalidating session during front-channel logout"); + facade.getTokenStore().logout(false); + } + + private String getRedirectUri(OidcHttpFacade facade) { + String uri = facade.getRequest().getURI(); + + if (uri.indexOf('?') != -1) { + uri = uri.substring(0, uri.indexOf('?')); + } + + int logoutPathIndex = uri.indexOf(getLogoutUri(facade)); + + if (logoutPathIndex != -1) { + uri = uri.substring(0, logoutPathIndex); + } + + return uri; + } + + private boolean isLogoutCallbackUri(OidcHttpFacade facade) { + String path = facade.getRequest().getRelativePath(); + return path.endsWith(getLogoutCallbackUri(facade)); + } + + private boolean isRpInitiatedLogoutUri(OidcHttpFacade facade) { + String path = facade.getRequest().getRelativePath(); + return path.endsWith(getLogoutUri(facade)); + } + + private boolean isSessionRequiredOnLogout(OidcHttpFacade facade) { + return getOidcClientConfiguration(facade).isSessionRequiredOnLogout(); + } + + private OidcClientConfiguration getOidcClientConfiguration(OidcHttpFacade facade) { + RefreshableOidcSecurityContext securityContext = getSecurityContext(facade); + + if (securityContext == null) { + return null; + } + + return securityContext.getOidcClientConfiguration(); + } + + private RefreshableOidcSecurityContext getSecurityContext(OidcHttpFacade facade) { + RefreshableOidcSecurityContext securityContext = (RefreshableOidcSecurityContext) facade.getSecurityContext(); + + if (securityContext == null) { + facade.getResponse().setStatus(HttpStatus.SC_UNAUTHORIZED); + facade.authenticationFailed(); + return null; + } + + return securityContext; + } + + private String getLogoutUri(OidcHttpFacade facade) { + return getOidcClientConfiguration(facade).getLogoutUri(); + } + + private String getLogoutCallbackUri(OidcHttpFacade facade) { + return getOidcClientConfiguration(facade).getLogoutCallbackUrl(); + } + + private boolean isBackChannel(OidcHttpFacade facade) { + return "post".equalsIgnoreCase(facade.getRequest().getMethod()); + } + + private boolean isFrontChannel(OidcHttpFacade facade) { + return "get".equalsIgnoreCase(facade.getRequest().getMethod()); + } +} diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientConfiguration.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientConfiguration.java index ca56da2863..be5806c9d5 100644 --- a/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientConfiguration.java +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientConfiguration.java @@ -75,7 +75,7 @@ public enum RelativeUrlsUsed { protected String providerUrl; protected String authUrl; protected String tokenUrl; - protected String logoutUrl; + protected String endSessionEndpointUrl; protected String accountUrl; protected String registerNodeUrl; protected String unregisterNodeUrl; @@ -144,6 +144,12 @@ public enum RelativeUrlsUsed { protected String requestObjectSigningKeyStoreType; protected JWKEncPublicKeyLocator encryptionPublicKeyLocator; + private String postLogoutUri; + private boolean sessionRequiredOnLogout = true; + private String logoutUri = "/logout"; + private String logoutCallbackUrl = "/logout/callback"; + private int logoutSessionWaitingLimit = 100; + public OidcClientConfiguration() { } @@ -202,7 +208,7 @@ public void setAuthServerBaseUrl(OidcJsonConfiguration config) { protected void resetUrls() { authUrl = null; tokenUrl = null; - logoutUrl = null; + endSessionEndpointUrl = null; accountUrl = null; registerNodeUrl = null; unregisterNodeUrl = null; @@ -238,7 +244,7 @@ protected void resolveUrls() { authUrl = config.getAuthorizationEndpoint(); issuerUrl = config.getIssuer(); tokenUrl = config.getTokenEndpoint(); - logoutUrl = config.getLogoutEndpoint(); + endSessionEndpointUrl = config.getLogoutEndpoint(); jwksUrl = config.getJwksUri(); requestParameterSupported = config.getRequestParameterSupported(); requestObjectSigningAlgValuesSupported = config.getRequestObjectSigningAlgValuesSupported(); @@ -323,9 +329,13 @@ public String getTokenUrl() { return tokenUrl; } - public String getLogoutUrl() { + public String getEndSessionEndpointUrl() { resolveUrls(); - return logoutUrl; + return endSessionEndpointUrl; + } + + public String getLogoutUri() { + return logoutUri; } public String getAccountUrl() { @@ -779,4 +789,39 @@ public void setEncryptionPublicKeyLocator(JWKEncPublicKeyLocator publicKeySetExt public JWKEncPublicKeyLocator getEncryptionPublicKeyLocator() { return this.encryptionPublicKeyLocator; } + + public void setPostLogoutUri(String postLogoutUri) { + this.postLogoutUri = postLogoutUri; + } + + public String getPostLogoutUri() { + return postLogoutUri; + } + + public boolean isSessionRequiredOnLogout() { + return sessionRequiredOnLogout; + } + + public void setSessionRequiredOnLogout(boolean sessionRequiredOnLogout) { + this.sessionRequiredOnLogout = sessionRequiredOnLogout; + } + + public void setLogoutUri(String logoutUri) { + this.logoutUri = logoutUri; + } + + public String getLogoutCallbackUrl() { + return logoutCallbackUrl; + } + + public void setLogoutCallbackUrl(String logoutCallbackUrl) { + this.logoutCallbackUrl = logoutCallbackUrl; + } + public int getLogoutSessionWaitingLimit() { + return logoutSessionWaitingLimit; + } + + public void setLogoutSessionWaitingLimit(int logoutSessionWaitingLimit) { + this.logoutSessionWaitingLimit = logoutSessionWaitingLimit; + } } diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientContext.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientContext.java index f5d930bd52..eed194681a 100644 --- a/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientContext.java +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/OidcClientContext.java @@ -136,8 +136,8 @@ public String getTokenUrl() { } @Override - public String getLogoutUrl() { - return (this.logoutUrl != null) ? this.logoutUrl : delegate.getLogoutUrl(); + public String getEndSessionEndpointUrl() { + return (this.endSessionEndpointUrl != null) ? this.endSessionEndpointUrl : delegate.getEndSessionEndpointUrl(); } @Override diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/ServerRequest.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/ServerRequest.java index 3a203541ee..8af9d75f32 100644 --- a/http/oidc/src/main/java/org/wildfly/security/http/oidc/ServerRequest.java +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/ServerRequest.java @@ -102,7 +102,7 @@ public static AccessAndIDTokenResponse invokeRefresh(OidcClientConfiguration dep public static void invokeLogout(OidcClientConfiguration deployment, String refreshToken) throws IOException, HttpFailure { HttpClient client = deployment.getClient(); - String uri = deployment.getLogoutUrl(); + String uri = deployment.getEndSessionEndpointUrl(); List formparams = new ArrayList<>(); formparams.add(new BasicNameValuePair(Oidc.REFRESH_TOKEN, refreshToken)); diff --git a/http/oidc/src/main/java/org/wildfly/security/http/oidc/TokenValidator.java b/http/oidc/src/main/java/org/wildfly/security/http/oidc/TokenValidator.java index 746318043f..1c2a0e9108 100644 --- a/http/oidc/src/main/java/org/wildfly/security/http/oidc/TokenValidator.java +++ b/http/oidc/src/main/java/org/wildfly/security/http/oidc/TokenValidator.java @@ -69,10 +69,12 @@ public Boolean run() { private static final int HEADER_INDEX = 0; private JwtConsumerBuilder jwtConsumerBuilder; private OidcClientConfiguration clientConfiguration; + private String tokenType; private TokenValidator(Builder builder) { this.jwtConsumerBuilder = builder.jwtConsumerBuilder; this.clientConfiguration = builder.clientConfiguration; + this.tokenType = builder.tokenType; } /** @@ -110,11 +112,17 @@ public VerifiedTokens parseAndVerifyToken(final String idToken, final String acc * @throws OidcException if the bearer token is invalid */ public AccessToken parseAndVerifyToken(final String bearerToken) throws OidcException { + return new AccessToken(verify(bearerToken)); + } + + public JwtClaims verify(String bearerToken) throws OidcException { + JwtClaims jwtClaims; + try { JwtContext jwtContext = setVerificationKey(bearerToken, jwtConsumerBuilder); jwtConsumerBuilder.setRequireSubject(); if (! DISABLE_TYP_CLAIM_VALIDATION_PROPERTY) { - jwtConsumerBuilder.registerValidator(new TypeValidator("Bearer")); + jwtConsumerBuilder.registerValidator(new TypeValidator(tokenType)); } if (clientConfiguration.isVerifyTokenAudience()) { jwtConsumerBuilder.setExpectedAudience(clientConfiguration.getResourceName()); @@ -123,15 +131,15 @@ public AccessToken parseAndVerifyToken(final String bearerToken) throws OidcExce } // second pass to validate jwtConsumerBuilder.build().processContext(jwtContext); - JwtClaims jwtClaims = jwtContext.getJwtClaims(); + jwtClaims = jwtContext.getJwtClaims(); if (jwtClaims == null) { throw log.invalidBearerTokenClaims(); } - return new AccessToken(jwtClaims); } catch (InvalidJwtException e) { log.tracef("Problem parsing bearer token: " + bearerToken, e); throw log.invalidBearerToken(e); } + return jwtClaims; } private JwtContext setVerificationKey(final String token, final JwtConsumerBuilder jwtConsumerBuilder) throws InvalidJwtException { @@ -164,6 +172,8 @@ public static Builder builder(OidcClientConfiguration clientConfiguration) { } public static class Builder { + + public String tokenType = "Bearer"; private OidcClientConfiguration clientConfiguration; private String expectedIssuer; private String clientId; @@ -171,6 +181,7 @@ public static class Builder { private PublicKeyLocator publicKeyLocator; private SecretKey clientSecretKey; private JwtConsumerBuilder jwtConsumerBuilder; + private boolean skipExpirationValidator; /** * Construct a new uninitialized instance. @@ -213,11 +224,24 @@ public TokenValidator build() throws IllegalArgumentException { jwtConsumerBuilder = new JwtConsumerBuilder() .setExpectedIssuer(expectedIssuer) .setJwsAlgorithmConstraints( - new AlgorithmConstraints(AlgorithmConstraints.ConstraintType.PERMIT, expectedJwsAlgorithm)) - .setRequireExpirationTime(); + new AlgorithmConstraints(AlgorithmConstraints.ConstraintType.PERMIT, expectedJwsAlgorithm)); + + if (!skipExpirationValidator) { + jwtConsumerBuilder.setRequireExpirationTime(); + } return new TokenValidator(this); } + + public Builder setSkipExpirationValidator() { + this.skipExpirationValidator = true; + return this; + } + + public Builder setTokenType(String tokenType) { + this.tokenType = tokenType; + return this; + } } private static class AzpValidator implements ErrorCodeValidator { diff --git a/http/oidc/src/test/java/org/wildfly/security/http/oidc/AbstractLogoutTest.java b/http/oidc/src/test/java/org/wildfly/security/http/oidc/AbstractLogoutTest.java new file mode 100644 index 0000000000..ab0bb8341b --- /dev/null +++ b/http/oidc/src/test/java/org/wildfly/security/http/oidc/AbstractLogoutTest.java @@ -0,0 +1,217 @@ +package org.wildfly.security.http.oidc; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assume.assumeTrue; +import static org.wildfly.security.http.oidc.Oidc.OIDC_NAME; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import io.restassured.RestAssured; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.keycloak.representations.idm.ClientRepresentation; +import org.keycloak.representations.idm.RealmRepresentation; +import org.wildfly.security.http.HttpAuthenticationException; +import org.wildfly.security.http.HttpConstants; +import org.wildfly.security.http.HttpScope; +import org.wildfly.security.http.HttpServerAuthenticationMechanism; +import org.wildfly.security.http.Scope; + +/** + * @author Pedro Igor + */ +public abstract class AbstractLogoutTest extends OidcBaseTest { + + private ElytronDispatcher dispatcher; + private OidcClientConfiguration clientConfig; + + @BeforeClass + public static void onBeforeClass() { + assumeTrue("Docker isn't available, OIDC tests will be skipped", isDockerAvailable()); + KEYCLOAK_CONTAINER = new KeycloakContainer(); + KEYCLOAK_CONTAINER.start(); + System.setProperty("oidc.provider.url", KEYCLOAK_CONTAINER.getAuthServerUrl() + "/realms/" + TEST_REALM); + } + + @AfterClass + public static void onAfterClass() { + System.clearProperty("oidc.provider.url"); + } + + @AfterClass + public static void generalCleanup() { + // no-op + } + + @Before + public void onBefore() throws Exception { + OidcBaseTest.client = new MockWebServer(); + OidcBaseTest.client.start(new InetSocketAddress(0).getAddress(), CLIENT_PORT); + configureDispatcher(); + RealmRepresentation realm = KeycloakConfiguration.getRealmRepresentation(TEST_REALM, CLIENT_ID, CLIENT_SECRET, CLIENT_HOST_NAME, CLIENT_PORT, CLIENT_APP, false); + + realm.setAccessTokenLifespan(100); + realm.setSsoSessionMaxLifespan(100); + + ClientRepresentation client = realm.getClients().get(0); + + client.setAttributes(new HashMap<>()); + + doConfigureClient(client); + + List redirectUris = new ArrayList<>(client.getRedirectUris()); + + redirectUris.add("*"); + + client.setRedirectUris(redirectUris); + + sendRealmCreationRequest(realm); + } + + @After + public void onAfter() throws IOException { + client.shutdown(); + RestAssured + .given() + .auth().oauth2(KeycloakConfiguration.getAdminAccessToken(KEYCLOAK_CONTAINER.getAuthServerUrl())) + .when() + .delete(KEYCLOAK_CONTAINER.getAuthServerUrl() + "/admin/realms/" + TEST_REALM).then().statusCode(204); + } + + protected void doConfigureClient(ClientRepresentation client) { + } + + protected OidcJsonConfiguration getClientConfiguration() { + OidcJsonConfiguration config = new OidcJsonConfiguration(); + + config.setRealm(TEST_REALM); + config.setResource(CLIENT_ID); + config.setPublicClient(false); + config.setAuthServerUrl(KEYCLOAK_CONTAINER.getAuthServerUrl()); + config.setSslRequired("EXTERNAL"); + config.setCredentials(new HashMap<>()); + config.getCredentials().put("secret", CLIENT_SECRET); + + return config; + } + + protected TestingHttpServerRequest getCurrentRequest() { + return dispatcher.getCurrentRequest(); + } + + protected HttpScope getCurrentSession() { + return getCurrentRequest().getScope(Scope.SESSION); + } + + protected OidcClientConfiguration getClientConfig() { + return clientConfig; + } + + protected TestingHttpServerResponse getCurrentResponse() { + try { + return dispatcher.getCurrentRequest().getResponse(); + } catch (HttpAuthenticationException e) { + throw new RuntimeException(e); + } + } + + class ElytronDispatcher extends Dispatcher { + + volatile TestingHttpServerRequest currentRequest; + + private final HttpServerAuthenticationMechanism mechanism; + private Dispatcher beforeDispatcher; + private HttpScope sessionScope; + + public ElytronDispatcher(HttpServerAuthenticationMechanism mechanism, Dispatcher beforeDispatcher) { + this.mechanism = mechanism; + this.beforeDispatcher = beforeDispatcher; + } + + @Override + public MockResponse dispatch(RecordedRequest serverRequest) throws InterruptedException { + if (beforeDispatcher != null) { + MockResponse response = beforeDispatcher.dispatch(serverRequest); + + if (response != null) { + return response; + } + } + + MockResponse mockResponse = new MockResponse(); + + try { + currentRequest = new TestingHttpServerRequest(serverRequest, sessionScope); + + mechanism.evaluateRequest(currentRequest); + + TestingHttpServerResponse response = currentRequest.getResponse(); + + if (Status.COMPLETE.equals(currentRequest.getResult())) { + mockResponse.setBody("Welcome, authenticated user"); + sessionScope = currentRequest.getScope(Scope.SESSION); + } else { + boolean statusSet = response.getStatusCode() > 0; + + if (statusSet) { + mockResponse.setResponseCode(response.getStatusCode()); + + if (response.getLocation() != null) { + mockResponse.setHeader(HttpConstants.LOCATION, response.getLocation()); + } + } else { + mockResponse.setResponseCode(201); + mockResponse.setBody("from " + serverRequest.getPath()); + } + } + } catch (Exception cause) { + cause.printStackTrace(); + mockResponse.setResponseCode(500); + } + + return mockResponse; + } + + public TestingHttpServerRequest getCurrentRequest() { + return currentRequest; + } + } + + protected void configureDispatcher() { + configureDispatcher(OidcClientConfigurationBuilder.build(getClientConfiguration()), null); + } + + protected void configureDispatcher(OidcClientConfiguration clientConfig, Dispatcher beforeDispatch) { + this.clientConfig = clientConfig; + OidcClientContext oidcClientContext = new OidcClientContext(clientConfig); + oidcFactory = new OidcMechanismFactory(oidcClientContext); + HttpServerAuthenticationMechanism mechanism; + try { + mechanism = oidcFactory.createAuthenticationMechanism(OIDC_NAME, Collections.emptyMap(), getCallbackHandler()); + } catch (HttpAuthenticationException e) { + throw new RuntimeException(e); + } + dispatcher = new ElytronDispatcher(mechanism, beforeDispatch); + client.setDispatcher(dispatcher); + } + + protected void assertUserNotAuthenticated() { + assertNull(getCurrentSession().getAttachment(OidcAccount.class.getName())); + } + + protected void assertUserAuthenticated() { + assertNotNull(getCurrentSession().getAttachment(OidcAccount.class.getName())); + } +} diff --git a/http/oidc/src/test/java/org/wildfly/security/http/oidc/BackChannelLogoutTest.java b/http/oidc/src/test/java/org/wildfly/security/http/oidc/BackChannelLogoutTest.java new file mode 100644 index 0000000000..fd04b6e7d1 --- /dev/null +++ b/http/oidc/src/test/java/org/wildfly/security/http/oidc/BackChannelLogoutTest.java @@ -0,0 +1,78 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2021 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wildfly.security.http.oidc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.net.InetAddress; +import java.net.URI; +import java.net.UnknownHostException; +import java.util.List; + +import com.gargoylesoftware.htmlunit.Page; +import com.gargoylesoftware.htmlunit.WebClient; +import org.apache.http.HttpStatus; +import org.junit.Test; +import org.keycloak.representations.idm.ClientRepresentation; + +public class BackChannelLogoutTest extends AbstractLogoutTest { + + @Override + protected void doConfigureClient(ClientRepresentation client) { + List redirectUris = client.getRedirectUris(); + String redirectUri = redirectUris.get(0); + + client.setFrontchannelLogout(false); + client.getAttributes().put("backchannel.logout.session.required", "true"); + client.getAttributes().put("backchannel.logout.url", rewriteHost(redirectUri) + "/logout/callback"); + } + + private static String rewriteHost(String redirectUri) { + try { + return redirectUri.replace("localhost", InetAddress.getLocalHost().getHostAddress()); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + } + + @Test + public void testRPInitiatedLogout() throws Exception { + URI requestUri = new URI(getClientUrl()); + WebClient webClient = getWebClient(); + webClient.getPage(getClientUrl()); + TestingHttpServerResponse response = getCurrentResponse(); + assertEquals(HttpStatus.SC_MOVED_TEMPORARILY, response.getStatusCode()); + assertEquals(Status.NO_AUTH, getCurrentRequest().getResult()); + + webClient = getWebClient(); + Page page = loginToKeycloak(webClient, KeycloakConfiguration.ALICE, KeycloakConfiguration.ALICE_PASSWORD, + requestUri, response.getLocation(), + response.getCookies()) + .click(); + assertTrue(page.getWebResponse().getContentAsString().contains("Welcome, authenticated user")); + + // logged out after finishing the redirections during frontchannel logout + assertUserAuthenticated(); + webClient.getPage(getClientUrl() + "/logout"); + assertUserAuthenticated(); + webClient.getPage(getClientUrl()); + assertUserNotAuthenticated(); + } +} \ No newline at end of file diff --git a/http/oidc/src/test/java/org/wildfly/security/http/oidc/FrontChannelLogoutTest.java b/http/oidc/src/test/java/org/wildfly/security/http/oidc/FrontChannelLogoutTest.java new file mode 100644 index 0000000000..7979a9bd43 --- /dev/null +++ b/http/oidc/src/test/java/org/wildfly/security/http/oidc/FrontChannelLogoutTest.java @@ -0,0 +1,127 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2021 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wildfly.security.http.oidc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.net.URI; +import java.util.List; + +import com.gargoylesoftware.htmlunit.Page; +import com.gargoylesoftware.htmlunit.TextPage; +import com.gargoylesoftware.htmlunit.WebClient; +import com.gargoylesoftware.htmlunit.html.HtmlForm; +import com.gargoylesoftware.htmlunit.html.HtmlPage; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.QueueDispatcher; +import okhttp3.mockwebserver.RecordedRequest; +import org.apache.http.HttpStatus; +import org.junit.Test; +import org.keycloak.representations.idm.ClientRepresentation; + +/** + * Tests for the OpenID Connect authentication mechanism. + * + * @author Farah Juma + */ +public class FrontChannelLogoutTest extends AbstractLogoutTest { + + @Override + protected void doConfigureClient(ClientRepresentation client) { + client.setFrontchannelLogout(true); + List redirectUris = client.getRedirectUris(); + String redirectUri = redirectUris.get(0); + + client.getAttributes().put("frontchannel.logout.url", redirectUri + "/logout/callback"); + } + + @Test + public void testRPInitiatedLogout() throws Exception { + URI requestUri = new URI(getClientUrl()); + WebClient webClient = getWebClient(); + webClient.getPage(getClientUrl()); + TestingHttpServerResponse response = getCurrentResponse(); + assertEquals(HttpStatus.SC_MOVED_TEMPORARILY, response.getStatusCode()); + assertEquals(Status.NO_AUTH, getCurrentRequest().getResult()); + + webClient = getWebClient(); + Page page = loginToKeycloak(webClient, KeycloakConfiguration.ALICE, KeycloakConfiguration.ALICE_PASSWORD, + requestUri, response.getLocation(), + response.getCookies()) + .click(); + assertTrue(page.getWebResponse().getContentAsString().contains("Welcome, authenticated user")); + + // logged out after finishing the redirections during frontchannel logout + assertUserAuthenticated(); + webClient.getPage(getClientUrl() + "/logout"); + assertUserNotAuthenticated(); + } + + @Test + public void testRPInitiatedLogoutWithPostLogoutUri() throws Exception { + OidcClientConfiguration oidcClientConfiguration = getClientConfig(); + oidcClientConfiguration.setPostLogoutUri("/post-logout"); + configureDispatcher(oidcClientConfiguration, new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + if (request.getPath().contains("/post-logout")) { + return new MockResponse() + .setBody("you are logged out from app"); + } + return null; + } + }); + + URI requestUri = new URI(getClientUrl()); + WebClient webClient = getWebClient(); + webClient.getPage(getClientUrl()); + TestingHttpServerResponse response = getCurrentResponse(); + Page page = loginToKeycloak(webClient, KeycloakConfiguration.ALICE, KeycloakConfiguration.ALICE_PASSWORD, requestUri, response.getLocation(), + response.getCookies()).click(); + assertTrue(page.getWebResponse().getContentAsString().contains("Welcome, authenticated user")); + + assertUserAuthenticated(); + HtmlPage continueLogout = webClient.getPage(getClientUrl() + "/logout"); + page = continueLogout.getElementById("continue").click(); + assertUserNotAuthenticated(); + assertTrue(page.getWebResponse().getContentAsString().contains("you are logged out from app")); + } + + @Test + public void testFrontChannelLogout() throws Exception { + try { + URI requestUri = new URI(getClientUrl()); + WebClient webClient = getWebClient(); + webClient.getPage(getClientUrl()); + TextPage page = loginToKeycloak(webClient, KeycloakConfiguration.ALICE, KeycloakConfiguration.ALICE_PASSWORD, requestUri, getCurrentResponse().getLocation(), + getCurrentResponse().getCookies()).click(); + assertTrue(page.getContent().contains("Welcome, authenticated user")); + + HtmlPage logoutPage = webClient.getPage(getClientConfig().getEndSessionEndpointUrl() + "?client_id=" + CLIENT_ID); + HtmlForm form = logoutPage.getForms().get(0); + assertUserAuthenticated(); + form.getInputByName("confirmLogout").click(); + assertUserNotAuthenticated(); + } finally { + client.setDispatcher(new QueueDispatcher()); + } + } +} \ No newline at end of file diff --git a/http/oidc/src/test/java/org/wildfly/security/http/oidc/OidcBaseTest.java b/http/oidc/src/test/java/org/wildfly/security/http/oidc/OidcBaseTest.java index 6eb698160a..bbcdf50102 100644 --- a/http/oidc/src/test/java/org/wildfly/security/http/oidc/OidcBaseTest.java +++ b/http/oidc/src/test/java/org/wildfly/security/http/oidc/OidcBaseTest.java @@ -279,6 +279,7 @@ static WebClient getWebClient() { WebClient webClient = new WebClient(); webClient.setCssErrorHandler(new SilentCssErrorHandler()); webClient.setJavaScriptErrorListener(new SilentJavaScriptErrorListener()); + webClient.getOptions().setMaxInMemory(50000 * 1024); return webClient; } @@ -291,7 +292,10 @@ protected static String getClientUrlForTenant(String tenant) { } protected HtmlInput loginToKeycloak(String username, String password, URI requestUri, String location, List cookies) throws IOException { - WebClient webClient = getWebClient(); + return loginToKeycloak(getWebClient(), username, password, requestUri, location, cookies); + } + + protected HtmlInput loginToKeycloak(WebClient webClient, String username, String password, URI requestUri, String location, List cookies) throws IOException { if (cookies != null) { for (HttpServerCookie cookie : cookies) { webClient.addCookie(getCookieString(cookie), requestUri.toURL(), null); diff --git a/tests/base/src/test/java/org/wildfly/security/http/impl/AbstractBaseHttpTest.java b/tests/base/src/test/java/org/wildfly/security/http/impl/AbstractBaseHttpTest.java index 7b8308fd8c..146cc785ff 100644 --- a/tests/base/src/test/java/org/wildfly/security/http/impl/AbstractBaseHttpTest.java +++ b/tests/base/src/test/java/org/wildfly/security/http/impl/AbstractBaseHttpTest.java @@ -51,6 +51,8 @@ import javax.security.auth.x500.X500Principal; import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.RealmCallback; + +import okhttp3.mockwebserver.RecordedRequest; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.junit.Assert; @@ -144,6 +146,8 @@ protected enum Status { protected static class TestingHttpServerRequest implements HttpServerRequest { + private String contentType; + private String body; private Status result; private HttpServerMechanismsResponder responder; private String remoteUser; @@ -153,6 +157,7 @@ protected static class TestingHttpServerRequest implements HttpServerRequest { private Map> requestHeaders = new HashMap<>(); private X500Principal testPrincipal = null; private Map sessionScopeAttachments = new HashMap<>(); + private HttpScope sessionScope; public TestingHttpServerRequest(String[] authorization) { if (authorization != null) { @@ -221,6 +226,14 @@ public TestingHttpServerRequest(String[] authorization, URI requestURI, String c } } + public TestingHttpServerRequest(RecordedRequest request, HttpScope sessionScope) { + this(new String[0], request.getRequestUrl().uri(), request.getHeader("Cookie")); + this.requestMethod = request.getMethod(); + this.body = request.getBody().readUtf8(); + this.contentType = request.getHeader("Content-Type"); + this.sessionScope = sessionScope; + } + public Status getResult() { return result; } @@ -292,7 +305,7 @@ public URI getRequestURI() { } public String getRequestPath() { - throw new IllegalStateException(); + return requestURI.getPath(); } public Map> getParameters() { @@ -308,6 +321,19 @@ public List getParameterValues(String name) { } public String getFirstParameterValue(String name) { + if ("application/x-www-form-urlencoded".equals(contentType)) { + if (body == null) { + return null; + } + + for (String keyValue : body.split("&")) { + String key = keyValue.substring(0, keyValue.indexOf('=')); + + if (key.equals(name)) { + return keyValue.substring(keyValue.indexOf('=') + 1); + } + } + } throw new IllegalStateException(); } @@ -334,46 +360,48 @@ public boolean resumeRequest() { public HttpScope getScope(Scope scope) { if (scope.equals(Scope.SSL_SESSION)) { return null; - } else { - return new HttpScope() { + } else if (sessionScope != null) { + return sessionScope; + } - @Override - public boolean exists() { - return true; - } + return new HttpScope() { - @Override - public boolean create() { - return false; - } + @Override + public boolean exists() { + return true; + } - @Override - public boolean supportsAttachments() { - return true; - } + @Override + public boolean create() { + return false; + } - @Override - public boolean supportsInvalidation() { - return false; - } + @Override + public boolean supportsAttachments() { + return true; + } - @Override - public void setAttachment(String key, Object value) { - if (scope.equals(Scope.SESSION)) { - sessionScopeAttachments.put(key, value); - } + @Override + public boolean supportsInvalidation() { + return false; + } + + @Override + public void setAttachment(String key, Object value) { + if (scope.equals(Scope.SESSION)) { + sessionScopeAttachments.put(key, value); } + } - @Override - public Object getAttachment(String key) { - if (scope.equals(Scope.SESSION)) { - return sessionScopeAttachments.get(key); - } else { - return null; - } + @Override + public Object getAttachment(String key) { + if (scope.equals(Scope.SESSION)) { + return sessionScopeAttachments.get(key); + } else { + return null; } - }; - } + } + }; } public Collection getScopeIds(Scope scope) {