Skip to content

Commit

Permalink
Merge pull request #36634 from sberyozkin/oidc_client_request_customizer
Browse files Browse the repository at this point in the history
Introduce OidcClientRequestFilter
  • Loading branch information
sberyozkin authored Oct 23, 2023
2 parents 13ba9d0 + ee61fa7 commit 1e69886
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ quarkus.oidc.introspection-credentials.name=introspection-user-name
quarkus.oidc.introspection-credentials.secret=introspection-user-secret
----

[[oidc-client-filters]]
==== OIDC client request customization

You can customize OIDC client requests by registering one or more `OidcClientRequestFiler` implementations which can update or add new request headers, please see xref:security-openid-connect-client-reference#oidc-client-filters[Client request customization] for more information.

==== Redirecting to and from the OIDC provider

When a user is redirected to the OpenID Connect provider to authenticate, the redirect URL includes a `redirect_uri` query parameter, which indicates to the provider where the user has to be redirected to when the authentication is complete.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,42 @@ quarkus.log.category."io.quarkus.oidc.client.runtime.OidcClientRecorder".level=T
quarkus.log.category."io.quarkus.oidc.client.runtime.OidcClientRecorder".min-level=TRACE
----

[[oidc-client-filters]]
== Client request customization

You can customize OIDC client requests by registering one or more `OidcClientRequestFiler` implementations which can update or add new request headers, for example, a filter can analyze the request body and add its digest as a new header value:

[source,java]
----
package io.quarkus.it.keycloak;
import jakarta.enterprise.context.ApplicationScoped;
import io.quarkus.arc.Unremovable;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.vertx.core.http.HttpMethod;
import io.vertx.mutiny.core.buffer.Buffer;
import io.vertx.mutiny.ext.web.client.HttpRequest;
@ApplicationScoped
@Unremovable
public class OidcClientRequestCustomizer implements OidcClientRequestFilter {
@Override
public void filter(HttpRequest<Buffer> request, Buffer buffer) {
HttpMethod method = request.method();
String uri = request.uri();
if (method == HttpMethod.POST && uri.endsWith("/service") && buffer != null) {
request.putHeader("Digest", calculateDigest(buffer.toString()));
}
}
private String calculateDigest(String bodyString) {
// Apply the required digest algorithm to the body string
}
}
----

[[token-propagation-reactive]]
== Token Propagation Reactive

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.security.Key;
import java.time.Instant;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

Expand All @@ -16,6 +17,7 @@
import io.quarkus.oidc.client.OidcClientConfig;
import io.quarkus.oidc.client.OidcClientException;
import io.quarkus.oidc.client.Tokens;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.oidc.common.runtime.OidcConstants;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -44,17 +46,20 @@ public class OidcClientImpl implements OidcClient {
private final String clientSecretBasicAuthScheme;
private final Key clientJwtKey;
private final OidcClientConfig oidcConfig;
private final List<OidcClientRequestFilter> filters;
private volatile boolean closed;

public OidcClientImpl(WebClient client, String tokenRequestUri, String tokenRevokeUri, String grantType,
MultiMap tokenGrantParams, MultiMap commonRefreshGrantParams, OidcClientConfig oidcClientConfig) {
MultiMap tokenGrantParams, MultiMap commonRefreshGrantParams, OidcClientConfig oidcClientConfig,
List<OidcClientRequestFilter> filters) {
this.client = client;
this.tokenRequestUri = tokenRequestUri;
this.tokenRevokeUri = tokenRevokeUri;
this.tokenGrantParams = tokenGrantParams;
this.commonRefreshGrantParams = commonRefreshGrantParams;
this.grantType = grantType;
this.oidcConfig = oidcClientConfig;
this.filters = filters;
this.clientSecretBasicAuthScheme = OidcCommonUtils.initClientSecretBasicAuth(oidcClientConfig);
this.clientJwtKey = OidcCommonUtils.initClientJwtKey(oidcClientConfig);
}
Expand Down Expand Up @@ -159,7 +164,8 @@ private UniOnItem<HttpResponse<Buffer>> postRequest(HttpRequest<Buffer> request,
}
}
// Retry up to three times with a one-second delay between the retries if the connection is closed
Uni<HttpResponse<Buffer>> response = request.sendBuffer(OidcCommonUtils.encodeForm(body))
Buffer buffer = OidcCommonUtils.encodeForm(body);
Uni<HttpResponse<Buffer>> response = filter(request, buffer).sendBuffer(buffer)
.onFailure(ConnectException.class)
.retry()
.atMost(oidcConfig.connectionRetryCount)
Expand Down Expand Up @@ -252,4 +258,11 @@ private void checkClosed() {
throw new IllegalStateException("OidcClient " + oidcConfig.getId().get() + " is closed");
}
}

private HttpRequest<Buffer> filter(HttpRequest<Buffer> request, Buffer body) {
for (OidcClientRequestFilter filter : filters) {
filter.filter(request, body);
}
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand All @@ -16,6 +17,7 @@
import io.quarkus.oidc.client.OidcClientException;
import io.quarkus.oidc.client.OidcClients;
import io.quarkus.oidc.client.Tokens;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.oidc.common.runtime.OidcConstants;
import io.quarkus.runtime.TlsConfig;
Expand Down Expand Up @@ -120,6 +122,8 @@ protected static Uni<OidcClient> createOidcClientUni(OidcClientConfig oidcConfig

WebClient client = WebClient.create(new io.vertx.mutiny.core.Vertx(vertx.get()), options);

List<OidcClientRequestFilter> clientRequestFilters = OidcCommonUtils.getClientRequestCustomizer();

Uni<OidcConfigurationMetadata> tokenUrisUni = null;
if (OidcCommonUtils.isAbsoluteUrl(oidcConfig.tokenPath)) {
tokenUrisUni = Uni.createFrom().item(
Expand All @@ -133,7 +137,7 @@ protected static Uni<OidcClient> createOidcClientUni(OidcClientConfig oidcConfig
OidcCommonUtils.getOidcEndpointUrl(authServerUriString, oidcConfig.tokenPath),
OidcCommonUtils.getOidcEndpointUrl(authServerUriString, oidcConfig.revokePath)));
} else {
tokenUrisUni = discoverTokenUris(client, authServerUriString.toString(), oidcConfig);
tokenUrisUni = discoverTokenUris(client, clientRequestFilters, authServerUriString.toString(), oidcConfig);
}
}
return tokenUrisUni.onItemOrFailure()
Expand Down Expand Up @@ -188,7 +192,8 @@ public OidcClient apply(OidcConfigurationMetadata metadata, Throwable t) {
return new OidcClientImpl(client, metadata.tokenRequestUri, metadata.tokenRevokeUri, grantType,
tokenGrantParams,
commonRefreshGrantParams,
oidcConfig);
oidcConfig,
clientRequestFilters);
}

});
Expand All @@ -205,10 +210,11 @@ private static void setGrantClientParams(OidcClientConfig oidcConfig, MultiMap g
}
}

private static Uni<OidcConfigurationMetadata> discoverTokenUris(WebClient client, String authServerUrl,
OidcClientConfig oidcConfig) {
private static Uni<OidcConfigurationMetadata> discoverTokenUris(WebClient client,
List<OidcClientRequestFilter> clientRequestFilters,
String authServerUrl, OidcClientConfig oidcConfig) {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
return OidcCommonUtils.discoverMetadata(client, authServerUrl, connectionDelayInMillisecs)
return OidcCommonUtils.discoverMetadata(client, clientRequestFilters, authServerUrl, connectionDelayInMillisecs)
.onItem().transform(json -> new OidcConfigurationMetadata(json.getString("token_endpoint"),
json.getString("revocation_endpoint")));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkus.oidc.common;

import io.vertx.mutiny.core.buffer.Buffer;
import io.vertx.mutiny.ext.web.client.HttpRequest;

/**
* Request filter which can be used to customize OIDC client requests
*/
public interface OidcClientRequestFilter {
/**
* Filter OIDC client requests
*
* @param request HTTP request
* @param body request body, will be null for HTTP GET methods, may be null for other HTTP methods
*/
void filter(HttpRequest<Buffer> request, Buffer body);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@
import java.security.PrivateKey;
import java.time.Duration;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import javax.crypto.SecretKey;

import org.jboss.logging.Logger;

import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.credentials.CredentialsProvider;
import io.quarkus.credentials.runtime.CredentialsProviderFinder;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonConfig.Credentials;
import io.quarkus.oidc.common.runtime.OidcCommonConfig.Credentials.Provider;
import io.quarkus.oidc.common.runtime.OidcCommonConfig.Credentials.Secret;
Expand All @@ -45,6 +50,7 @@
import io.vertx.core.net.ProxyOptions;
import io.vertx.mutiny.core.MultiMap;
import io.vertx.mutiny.core.buffer.Buffer;
import io.vertx.mutiny.ext.web.client.HttpRequest;
import io.vertx.mutiny.ext.web.client.WebClient;

public class OidcCommonUtils {
Expand Down Expand Up @@ -421,9 +427,14 @@ public static Predicate<? super Throwable> oidcEndpointNotAvailable() {
|| (t instanceof OidcEndpointAccessException && ((OidcEndpointAccessException) t).getErrorStatus() == 404));
}

public static Uni<JsonObject> discoverMetadata(WebClient client, String authServerUrl, long connectionDelayInMillisecs) {
public static Uni<JsonObject> discoverMetadata(WebClient client, List<OidcClientRequestFilter> filters,
String authServerUrl, long connectionDelayInMillisecs) {
final String discoveryUrl = authServerUrl + OidcConstants.WELL_KNOWN_CONFIGURATION;
return client.getAbs(discoveryUrl).send().onItem().transform(resp -> {
HttpRequest<Buffer> request = client.getAbs(discoveryUrl);
for (OidcClientRequestFilter filter : filters) {
filter.filter(request, null);
}
return request.send().onItem().transform(resp -> {
if (resp.statusCode() == 200) {
return resp.bodyAsJsonObject();
} else {
Expand Down Expand Up @@ -466,4 +477,13 @@ private static byte[] doRead(InputStream is) throws IOException {
}
return out.toByteArray();
}

public static List<OidcClientRequestFilter> getClientRequestCustomizer() {
ArcContainer container = Arc.container();
if (container != null) {
return container.listAll(OidcClientRequestFilter.class).stream().map(handle -> handle.get())
.collect(Collectors.toList());
}
return List.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.ConnectException;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.util.List;
import java.util.Map;

import org.jboss.logging.Logger;
Expand All @@ -14,6 +15,7 @@
import io.quarkus.oidc.OidcTenantConfig;
import io.quarkus.oidc.TokenIntrospection;
import io.quarkus.oidc.UserInfo;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.oidc.common.runtime.OidcConstants;
import io.quarkus.oidc.common.runtime.OidcEndpointAccessException;
Expand Down Expand Up @@ -43,16 +45,19 @@ public class OidcProviderClient implements Closeable {
private final String clientSecretBasicAuthScheme;
private final String introspectionBasicAuthScheme;
private final Key clientJwtKey;
private final List<OidcClientRequestFilter> filters;

public OidcProviderClient(WebClient client,
OidcConfigurationMetadata metadata,
OidcTenantConfig oidcConfig) {
OidcTenantConfig oidcConfig,
List<OidcClientRequestFilter> filters) {
this.client = client;
this.metadata = metadata;
this.oidcConfig = oidcConfig;
this.clientSecretBasicAuthScheme = OidcCommonUtils.initClientSecretBasicAuth(oidcConfig);
this.clientJwtKey = OidcCommonUtils.initClientJwtKey(oidcConfig);
this.introspectionBasicAuthScheme = initIntrospectionBasicAuthScheme(oidcConfig);
this.filters = filters;
}

private static String initIntrospectionBasicAuthScheme(OidcTenantConfig oidcConfig) {
Expand All @@ -70,13 +75,13 @@ public OidcConfigurationMetadata getMetadata() {
}

public Uni<JsonWebKeySet> getJsonWebKeySet() {
return client.getAbs(metadata.getJsonWebKeySetUri()).send().onItem()
return filter(client.getAbs(metadata.getJsonWebKeySetUri()), null).send().onItem()
.transform(resp -> getJsonWebKeySet(resp));
}

public Uni<UserInfo> getUserInfo(String token) {
LOG.debugf("Get UserInfo on: %s auth: %s", metadata.getUserInfoUri(), OidcConstants.BEARER_SCHEME + " " + token);
return client.getAbs(metadata.getUserInfoUri())
return filter(client.getAbs(metadata.getUserInfoUri()), null)
.putHeader(AUTHORIZATION_HEADER, OidcConstants.BEARER_SCHEME + " " + token)
.send().onItem().transform(resp -> getUserInfo(resp));
}
Expand Down Expand Up @@ -157,7 +162,8 @@ private UniOnItem<HttpResponse<Buffer>> getHttpResponse(String uri, MultiMap for
}
LOG.debugf("Get token on: %s params: %s headers: %s", metadata.getTokenUri(), formBody, request.headers());
// Retry up to three times with a one-second delay between the retries if the connection is closed.
Uni<HttpResponse<Buffer>> response = request.sendBuffer(OidcCommonUtils.encodeForm(formBody))
Buffer buffer = OidcCommonUtils.encodeForm(formBody);
Uni<HttpResponse<Buffer>> response = filter(request, buffer).sendBuffer(buffer)
.onFailure(ConnectException.class)
.retry()
.atMost(oidcConfig.connectionRetryCount).onFailure().transform(t -> t.getCause());
Expand Down Expand Up @@ -212,4 +218,11 @@ public void close() {
public Key getClientJwtKey() {
return clientJwtKey;
}

private HttpRequest<Buffer> filter(HttpRequest<Buffer> request, Buffer body) {
for (OidcClientRequestFilter filter : filters) {
filter.filter(request, body);
}
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.quarkus.oidc.OidcTenantConfig.Roles.Source;
import io.quarkus.oidc.OidcTenantConfig.TokenStateManager.Strategy;
import io.quarkus.oidc.TenantConfigResolver;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonConfig;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.runtime.LaunchMode;
Expand Down Expand Up @@ -424,12 +425,15 @@ protected static Uni<OidcProviderClient> createOidcClientUni(OidcTenantConfig oi

WebClient client = WebClient.create(new io.vertx.mutiny.core.Vertx(vertx), options);

List<OidcClientRequestFilter> clientRequestFilters = OidcCommonUtils.getClientRequestCustomizer();

Uni<OidcConfigurationMetadata> metadataUni = null;
if (!oidcConfig.discoveryEnabled.orElse(true)) {
metadataUni = Uni.createFrom().item(createLocalMetadata(oidcConfig, authServerUriString));
} else {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
metadataUni = OidcCommonUtils.discoverMetadata(client, authServerUriString, connectionDelayInMillisecs)
metadataUni = OidcCommonUtils
.discoverMetadata(client, clientRequestFilters, authServerUriString, connectionDelayInMillisecs)
.onItem()
.transform(new Function<JsonObject, OidcConfigurationMetadata>() {
@Override
Expand Down Expand Up @@ -465,7 +469,8 @@ public Uni<OidcProviderClient> apply(OidcConfigurationMetadata metadata, Throwab
"UserInfo is required but the OpenID Provider UserInfo endpoint is not configured."
+ " Use 'quarkus.oidc.user-info-path' if the discovery is disabled."));
}
return Uni.createFrom().item(new OidcProviderClient(client, metadata, oidcConfig));
return Uni.createFrom()
.item(new OidcProviderClient(client, metadata, oidcConfig, clientRequestFilters));
}

});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.quarkus.it.keycloak;

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkus.arc.Unremovable;
import io.quarkus.oidc.common.OidcClientRequestFilter;
import io.vertx.mutiny.core.buffer.Buffer;
import io.vertx.mutiny.ext.web.client.HttpRequest;

@ApplicationScoped
@Unremovable
public class OidcRequestCustomizer implements OidcClientRequestFilter {

@Override
public void filter(HttpRequest<Buffer> request, Buffer buffer) {
String uri = request.uri();
if (uri.endsWith("/non-standard-tokens")) {
request.putHeader("GrantType", getGrantType(buffer.toString()));
}
}

private String getGrantType(String formString) {
for (String formValue : formString.split("&")) {
if (formValue.startsWith("grant_type=")) {
return formValue.substring("grant_type=".length());
}
}
return "";
}
}
Loading

0 comments on commit 1e69886

Please sign in to comment.