From 9ba2435cb21a828dd2bccaabcac63dabf12d5dc8 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Fri, 27 Sep 2024 15:57:57 -0500 Subject: [PATCH] Support refresh token for Token Exchange Closes gh-15534 --- ...xchangeOAuth2AuthorizedClientProvider.java | 2 +- ...eactiveOAuth2AuthorizedClientProvider.java | 2 +- ...geOAuth2AuthorizedClientProviderTests.java | 25 +++++++++++++++---- ...veOAuth2AuthorizedClientProviderTests.java | 25 +++++++++++++++---- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java index 256ced675ab..ca22416af95 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java @@ -90,7 +90,7 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, grantRequest); return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), - tokenResponse.getAccessToken()); + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); } private OAuth2Token resolveSubjectToken(OAuth2AuthorizationContext context) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java index 43e0607d2ea..b3791a5da04 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java @@ -88,7 +88,7 @@ public Mono authorize(OAuth2AuthorizationContext context .onErrorMap(OAuth2AuthorizationException.class, (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex)) .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), - tokenResponse.getAccessToken())); + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken())); } private Mono resolveSubjectToken(OAuth2AuthorizationContext context) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java index 8cf3b0fdf0f..ddc9ead28df 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java @@ -213,7 +213,9 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { issuedAt, expiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), accessToken); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(accessTokenResponse); // @formatter:off @@ -228,6 +230,7 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); @@ -248,7 +251,9 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry // Shorten the lifespan of the access token by 90 seconds, which will ultimately // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(accessTokenResponse); // @formatter:off @@ -263,6 +268,7 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); @@ -285,7 +291,9 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotReso @Test public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() { - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(accessTokenResponse); // @formatter:off @@ -299,6 +307,7 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThe assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); @@ -312,7 +321,9 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { Function subjectTokenResolver = mock(Function.class); given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.subjectToken); this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(accessTokenResponse); TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); @@ -327,6 +338,7 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); verify(subjectTokenResolver).apply(authorizationContext); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); @@ -341,7 +353,9 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() { Function actorTokenResolver = mock(Function.class); given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.actorToken); this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(accessTokenResponse); // @formatter:off @@ -355,6 +369,7 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); verify(actorTokenResolver).apply(authorizationContext); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java index 2b7250911f7..99787f163e1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java @@ -215,7 +215,9 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { issuedAt, expiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), accessToken); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(Mono.just(accessTokenResponse)); // @formatter:off @@ -231,6 +233,7 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); @@ -251,7 +254,9 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry // Shorten the lifespan of the access token by 90 seconds, which will ultimately // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(Mono.just(accessTokenResponse)); // @formatter:off @@ -267,6 +272,7 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); @@ -289,7 +295,9 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotReso @Test public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() { - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(Mono.just(accessTokenResponse)); // @formatter:off @@ -303,6 +311,7 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThe assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); @@ -317,7 +326,9 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))) .willReturn(Mono.just(this.subjectToken)); this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(Mono.just(accessTokenResponse)); TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); @@ -332,6 +343,7 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); verify(subjectTokenResolver).apply(authorizationContext); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class); @@ -346,7 +358,9 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() { Function> actorTokenResolver = mock(Function.class); given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(Mono.just(this.actorToken)); this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh") + .build(); given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) .willReturn(Mono.just(accessTokenResponse)); // @formatter:off @@ -360,6 +374,7 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); verify(actorTokenResolver).apply(authorizationContext); ArgumentCaptor grantRequestCaptor = ArgumentCaptor .forClass(TokenExchangeGrantRequest.class);