Skip to content

Commit

Permalink
Support refresh token for Token Exchange
Browse files Browse the repository at this point in the history
Closes gh-15534
  • Loading branch information
sjohnr committed Sep 27, 2024
1 parent e11c188 commit 9ba2435
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, grantRequest);

return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken());
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
}

private OAuth2Token resolveSubjectToken(OAuth2AuthorizationContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
.onErrorMap(OAuth2AuthorizationException.class,
(ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken()));
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()));
}

private Mono<OAuth2Token> resolveSubjectToken(OAuth2AuthorizationContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() {
issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
// @formatter:off
Expand All @@ -228,6 +230,7 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() {
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture());
Expand All @@ -248,7 +251,9 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry
// Shorten the lifespan of the access token by 90 seconds, which will ultimately
// force it to expire on the client
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
// @formatter:off
Expand All @@ -263,6 +268,7 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture());
Expand All @@ -285,7 +291,9 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotReso

@Test
public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
// @formatter:off
Expand All @@ -299,6 +307,7 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThe
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture());
Expand All @@ -312,7 +321,9 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() {
Function<OAuth2AuthorizationContext, OAuth2Token> subjectTokenResolver = mock(Function.class);
given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.subjectToken);
this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
Expand All @@ -327,6 +338,7 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() {
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
verify(subjectTokenResolver).apply(authorizationContext);
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
Expand All @@ -341,7 +353,9 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() {
Function<OAuth2AuthorizationContext, OAuth2Token> actorTokenResolver = mock(Function.class);
given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.actorToken);
this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);
// @formatter:off
Expand All @@ -355,6 +369,7 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() {
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
verify(actorTokenResolver).apply(authorizationContext);
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() {
issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(Mono.just(accessTokenResponse));
// @formatter:off
Expand All @@ -231,6 +233,7 @@ public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() {
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture());
Expand All @@ -251,7 +254,9 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry
// Shorten the lifespan of the access token by 90 seconds, which will ultimately
// force it to expire on the client
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(Mono.just(accessTokenResponse));
// @formatter:off
Expand All @@ -267,6 +272,7 @@ public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiry
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture());
Expand All @@ -289,7 +295,9 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotReso

@Test
public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(Mono.just(accessTokenResponse));
// @formatter:off
Expand All @@ -303,6 +311,7 @@ public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThe
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture());
Expand All @@ -317,7 +326,9 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() {
given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class)))
.willReturn(Mono.just(this.subjectToken));
this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(Mono.just(accessTokenResponse));
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
Expand All @@ -332,6 +343,7 @@ public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() {
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
verify(subjectTokenResolver).apply(authorizationContext);
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
Expand All @@ -346,7 +358,9 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() {
Function<OAuth2AuthorizationContext, Mono<OAuth2Token>> actorTokenResolver = mock(Function.class);
given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(Mono.just(this.actorToken));
this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
.refreshToken("refresh")
.build();
given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(Mono.just(accessTokenResponse));
// @formatter:off
Expand All @@ -360,6 +374,7 @@ public void authorizeWhenCustomActorTokenResolverSetThenCalled() {
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
verify(actorTokenResolver).apply(authorizationContext);
ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
Expand Down

0 comments on commit 9ba2435

Please sign in to comment.