Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Deprecated Usages of RemoteJWKSet #16296

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,12 @@

package org.springframework.security.oauth2.jwt;

import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
Expand All @@ -26,6 +32,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
Expand All @@ -35,11 +42,8 @@

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
Expand Down Expand Up @@ -80,6 +84,7 @@
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
* @author Daeho Kwon
* @since 5.2
*/
public final class NimbusJwtDecoder implements JwtDecoder {
Expand Down Expand Up @@ -165,7 +170,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
.build();
// @formatter:on
}
catch (RemoteKeySourceException ex) {
catch (KeySourceException ex) {
this.logger.trace("Failed to retrieve JWK set", ex);
if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
Expand Down Expand Up @@ -377,11 +382,12 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
}

JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
if (this.cache == null) {
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource(toURL(jwkSetUri), jwkSetRetriever);
if(this.cache == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update this so that the member variable is set to NoOpCache. In this way, the null checking here and downstream is unnecessary.

return new SpringURLBasedJWKSource(urlBasedJWKSetSource);
}
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
SpringJWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new SpringURLBasedJWKSource<>(urlBasedJWKSetSource, jwkSetCache);
}

JWTProcessor<SecurityContext> processor() {
Expand Down Expand Up @@ -414,7 +420,80 @@ private static URL toURL(String url) {
}
}

private static final class SpringJWKSetCache implements JWKSetCache {
private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {

private final URLBasedJWKSetSource urlBasedJWKSetSource;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are doing our own implementation of JWKSource, it should be possible to use RestOperations and Cache directly instead of implementing additional Nimbus interfaces.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jzheaux, I have a couple of questions regarding the feedback:

  1. Regarding "Since we are doing our own implementation of JWKSource, it should be possible to use RestOperations and Cache directly instead of implementing additional Nimbus interfaces":

    Does this mean we should replace the custom SpringJWKSetCache implementation with direct usage of Spring's Cache?

  2. I am also considering whether to streamline the code by either:

    • Using RestOperations directly, or
    • Using RestOperationsResourceRetriever

Below are the two possible approaches:

  • Direct RestOperations:

    private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {
    
        private final RestOperations restOperations;
    
        // Implementation logic uses restOperations directly
    }
  • Using ResourceRetriever:

    private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {
    
        private final ResourceRetriever resourceRetriever;
    
        // Implementation logic continues to leverage resourceRetriever
    }

Would it be better to reuse the already implemented RestOperationsResourceRetriever, or to use RestOperations directly within the class?

If we go with the direct RestOperations approach, should the RestOperationsResourceRetriever implementation be removed entirely?

Thank you for your feedback!


private final SpringJWKSetCache jwkSetCache;

private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource) {
this(urlBasedJWKSetSource, null);
}

private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource, SpringJWKSetCache jwkSetCache) {
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
this.jwkSetCache = jwkSetCache;
}

@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
if (this.jwkSetCache != null) {
JWKSet jwkSet = this.jwkSetCache.get();
if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
synchronized (this) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please replace this synchronized statement with an ReentrantLock. Otherwise this could cause thread pinning with virtual threads.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holzerch Thank you for the feedback!
I’ll replace the synchronized block with a ReentrantLock as suggested.

jwkSet = fetchJWKSet();
this.jwkSetCache.put(jwkSet);
}
}
List<JWK> matches = jwkSelector.select(jwkSet);
if(!matches.isEmpty()) {
return matches;
}
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
if (soughtKeyID == null) {
return Collections.emptyList();
}
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
return Collections.emptyList();
}
synchronized (this) {
if(jwkSet == this.jwkSetCache.get()) {
jwkSet = fetchJWKSet();
this.jwkSetCache.put(jwkSet);
} else {
jwkSet = this.jwkSetCache.get();
}
}
if(jwkSet == null) {
return Collections.emptyList();
}
return jwkSelector.select(jwkSet);
}
return jwkSelector.select(fetchJWKSet());
}

private JWKSet fetchJWKSet() throws KeySourceException {
return this.urlBasedJWKSetSource.getJWKSet(JWKSetCacheRefreshEvaluator.noRefresh(),
System.currentTimeMillis(), null);
}

private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
Set<String> keyIDs = jwkMatcher.getKeyIDs();

if (keyIDs == null || keyIDs.isEmpty()) {
return null;
}

for (String id: keyIDs) {
if (id != null) {
return id;
}
}
return null;
}
}

private static final class SpringJWKSetCache {

private final String jwkSetUri;

Expand All @@ -440,20 +519,16 @@ private void updateJwkSetFromCache() {
}
}

// Note: Only called from inside a synchronized block in RemoteJWKSet.
@Override
// Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
public void put(JWKSet jwkSet) {
this.jwkSet = jwkSet;
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
}

@Override
public JWKSet get() {
return (!requiresRefresh()) ? this.jwkSet : null;

}

@Override
public boolean requiresRefresh() {
return this.cache.get(this.jwkSetUri) == null;
}
Expand Down