Skip to content

Commit

Permalink
fix: #202 fix oauth2 login for error
Browse files Browse the repository at this point in the history
  • Loading branch information
KartVen committed Dec 7, 2024
1 parent e3ee211 commit 9ca0492
Show file tree
Hide file tree
Showing 21 changed files with 641 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.OAuth2ReactiveAuthorizationManager;
import reactor.core.publisher.Mono;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.handler.AuthenticationFailureHandler;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.handler.AuthenticationSuccessHandler;

import java.util.function.Function;

Expand All @@ -38,32 +38,22 @@ public SecurityWebFilterChain springSecurityFilterChain(
ServerHttpSecurity http,
ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver,
@Value("${app.security.oauth2.endpoint.callback}") String callbackEndpoint,
OAuth2ReactiveAuthorizationManager reactiveAuthenticationManager
OAuth2ReactiveAuthorizationManager reactiveAuthenticationManager,
AuthenticationSuccessHandler authenticationSuccessHandler,
AuthenticationFailureHandler authenticationFailureHandler
) {
http
.authorizeExchange(auth -> auth.anyExchange().permitAll())
.oauth2Login(oauth2 -> oauth2
.authorizationRequestResolver(authorizationRequestResolver)
.authenticationMatcher(callbackMatcher(callbackEndpoint))
.authenticationManager(reactiveAuthenticationManager)
.authenticationSuccessHandler(authenticationSuccessHandler())
// TODO check if need change
//.authenticationFailureHandler(authenticationFailureHandler())
.authenticationSuccessHandler(authenticationSuccessHandler)
.authenticationFailureHandler(authenticationFailureHandler)
);
return http.build();
}

private ServerAuthenticationSuccessHandler authenticationSuccessHandler() {
return (webFilterExchange, authentication) -> webFilterExchange
.getChain()
.filter(webFilterExchange.getExchange())
.and(Mono.empty());
}

/*private ServerAuthenticationFailureHandler authenticationFailureHandler() {
return (webFilterExchange, exception) -> Mono.empty();
}*/

@Bean
public ServerOAuth2AuthorizationRequestResolver oAuth2AuthorizationRequestResolver(
ReactiveClientRegistrationRepository clientRegistrationRepository,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package pl.sknikod.kodemygateway.infrastructure.module.oauth2;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.lang.NonNull;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyExtractors;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.model.AuthorizeResponse;
import reactor.core.publisher.Mono;

import java.util.Map;

public class OAuth2AuthorizeResponseBodyExtractor implements BodyExtractor<Mono<AuthorizeResponse>, ReactiveHttpInputMessage> {
private static final BodyExtractor<Mono<Map<String, Object>>, ReactiveHttpInputMessage> DELEGATE;

static {
DELEGATE = BodyExtractors.toMono(new ParameterizedTypeReference<>() {
});
}

@Override
@NonNull
public Mono<AuthorizeResponse> extract(@NonNull ReactiveHttpInputMessage inputMessage, @NonNull Context context) {
return DELEGATE.extract(inputMessage, context).map(this::parse).flatMap(this::validate);
}

private AuthorizeResponse parse(Map<String, Object> jsonMap) {
try {
return AuthorizeResponse.parse(jsonMap);
} catch (RuntimeException ex) {
OAuth2Error oAuth2Error = new OAuth2Error(
"invalid_token_response", "An error occurred parsing the Authorize response: " + ex.getMessage(), null
);
throw new OAuth2AuthorizationException(oAuth2Error, ex);
}
}

private Mono<AuthorizeResponse.Success> validate(AuthorizeResponse response) {
if (response.hasError()) {
AuthorizeResponse.Error errorResponse = (AuthorizeResponse.Error) response;
return Mono.error(new OAuth2AuthorizationException(new OAuth2Error(
errorResponse.getError(), errorResponse.getErrorDescription(), null))
);
}
return Mono.just(response).cast(AuthorizeResponse.Success.class);
}
}
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
package pl.sknikod.kodemygateway.infrastructure.module.oauth2;

import lombok.RequiredArgsConstructor;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.*;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.stereotype.Component;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.model.AuthorizeResponse;
import reactor.core.publisher.Mono;

import java.time.Duration;
import java.time.Instant;
import java.util.Map;

@Component
@RequiredArgsConstructor
public class OAuth2ReactiveAuthorizationManager implements ReactiveAuthenticationManager {
private final ReactiveOAuth2AuthorizeClient client;

@Override
public Mono<Authentication> authenticate(Authentication authentication) {
return Mono.defer(() -> {
Expand All @@ -29,8 +30,13 @@ public Mono<Authentication> authenticate(Authentication authentication) {
if (!isStateEqually(exchange)) {
return Mono.error(new OAuth2AuthorizationException(new OAuth2Error("invalid_state_parameter")));
}
return Mono.just(toOAuth2LoginAuthenticationToken(token)).onErrorMap(OAuth2AuthorizationException.class,
(e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString(), e));
return this.client.authorize(new ReactiveOAuth2AuthorizeClient.Request(
token.getClientRegistration().getRegistrationId(), exchange.getAuthorizationResponse().getCode()
))
.cast(AuthorizeResponse.Success.class)
.map(response -> onSuccess(response, token))
.onErrorMap(OAuth2AuthorizationException.class,
(e) -> new OAuth2AuthenticationException(e.getError(), e.getMessage(), e));
});
}

Expand All @@ -39,14 +45,15 @@ private boolean isStateEqually(OAuth2AuthorizationExchange exchange) {
.equals(exchange.getAuthorizationResponse().getState());
}

private Authentication toOAuth2LoginAuthenticationToken(OAuth2AuthorizationCodeAuthenticationToken token) {
private OAuth2LoginAuthenticationToken onSuccess(AuthorizeResponse.Success response, OAuth2AuthorizationCodeAuthenticationToken token) {
Instant issuedAt = Instant.now();
Instant expiresAt = Instant.from(issuedAt).plus(Duration.ofMinutes(10));
var accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, response.getAccessToken(), issuedAt, expiresAt);
var refreshToken = new OAuth2RefreshToken(response.getRefreshToken(), issuedAt, expiresAt);
return new OAuth2LoginAuthenticationToken(
token.getClientRegistration(),
token.getAuthorizationExchange(),
new DefaultOAuth2User(token.getAuthorities(), Map.of("name", token.getName()), "name"),
token.getAuthorities(),
new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "tokenValue", Instant.now(), Instant.now()),
token.getRefreshToken()
token.getClientRegistration(), token.getAuthorizationExchange(),
new ReactiveOAuth2User(token.getName(), accessToken, refreshToken),
token.getAuthorities(), accessToken, refreshToken
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package pl.sknikod.kodemygateway.infrastructure.module.oauth2;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClient;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.model.AuthorizeResponse;
import reactor.core.publisher.Mono;

import java.util.function.BiFunction;

@Component
public class ReactiveOAuth2AuthorizeClient {
private final WebClient webClient = WebClient.builder().build();
private final OAuth2AuthorizeResponseBodyExtractor BODY_EXTRACTOR = new OAuth2AuthorizeResponseBodyExtractor();
private final BiFunction<String, String, String> authorizeUriFunction;

public ReactiveOAuth2AuthorizeClient(
@Value("${service.baseUrl.auth}") String authBaseUrl,
@Value("${app.security.oauth2.endpoint.authorize}") String authorizeEndpoint
) {
this.authorizeUriFunction = (registrationId, code) ->
authBaseUrl + authorizeEndpoint + "/" + registrationId + "?code=" + code;
}

public Mono<AuthorizeResponse> authorize(Request request) {
return this.webClient.get()
.uri(authorizeUriFunction.apply(request.getRegistrationId(), request.getCode()))
.exchangeToMono(clientResponse -> clientResponse.body(BODY_EXTRACTOR));
}

@lombok.Value
public static class Request {
String registrationId;
String code;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package pl.sknikod.kodemygateway.infrastructure.module.oauth2;

import lombok.Getter;
import lombok.Value;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.user.OAuth2User;

import java.io.Serial;
import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;

@Getter
public class ReactiveOAuth2User implements OAuth2User, Serializable {
@Serial
private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
private final String name;
private final OAuth2AccessToken accessToken;
private final OAuth2RefreshToken refreshToken;

public ReactiveOAuth2User(String name, OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken) {
this.name = name;
this.accessToken = accessToken;
this.refreshToken = refreshToken;
}

@Override
public Map<String, Object> getAttributes() {
return Collections.emptyMap();
}

@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return Collections.emptyList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package pl.sknikod.kodemygateway.infrastructure.module.oauth2.handler;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;

import java.net.URI;

@Component
public class AuthenticationFailureHandler implements ServerAuthenticationFailureHandler {
private final String frontBaseUrl;

public AuthenticationFailureHandler(@Value("${service.baseUrl.front}") String frontBaseUrl) {
this.frontBaseUrl = frontBaseUrl;
}

@Override
public Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange, AuthenticationException exception) {
ServerHttpResponse response = webFilterExchange.getExchange().getResponse();
performRedirect(response, exception);
return Mono.empty();
}

private void performRedirect(ServerHttpResponse response, AuthenticationException exception) {
URI location = UriComponentsBuilder.fromUriString(frontBaseUrl)
.queryParams(createParams(exception))
.build().toUri();

response.setStatusCode(HttpStatus.FOUND);
response.getHeaders().setLocation(location);
}

private MultiValueMap<String, String> createParams(AuthenticationException exception) {
MultiValueMap<String, String> params = new LinkedMultiValueMap<>();
params.add("auth", "failure");
if (exception instanceof OAuth2AuthenticationException oauth2Exception) {
OAuth2Error error = oauth2Exception.getError();
if (error != null) {
params.add("error", error.getErrorCode());
params.add("details", error.getDescription());
return params;
}
}
params.add("error", OAuth2ErrorCodes.SERVER_ERROR);
params.add("details", "Unknown authorization error");
return params;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package pl.sknikod.kodemygateway.infrastructure.module.oauth2.handler;

import lombok.NonNull;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseCookie;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.stereotype.Component;
import org.springframework.web.util.UriComponentsBuilder;
import pl.sknikod.kodemygateway.infrastructure.module.oauth2.ReactiveOAuth2User;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;

@Component
public class AuthenticationSuccessHandler implements ServerAuthenticationSuccessHandler {
private static final String ACCESS_TOKEN_COOKIE = "AUTH_CONTEXT";
private static final String REFRESH_TOKEN_COOKIE = "AUTH_PERSIST";
private final String frontBaseUrl;

public AuthenticationSuccessHandler(@Value("${service.baseUrl.front}") String frontBaseUrl) {
this.frontBaseUrl = frontBaseUrl;
}

@Override
public Mono<Void> onAuthenticationSuccess(WebFilterExchange webFilterExchange, Authentication authentication) {
ServerHttpResponse response = webFilterExchange.getExchange().getResponse();
modifyHeaders(response, authentication);
performRedirect(response);
return Mono.empty();
}

private void modifyHeaders(ServerHttpResponse response, Authentication token) {
ReactiveOAuth2User user = (ReactiveOAuth2User) ((OAuth2AuthenticationToken) token).getPrincipal();
Instant now = Instant.now();
var accessToken = createCookie(
ACCESS_TOKEN_COOKIE, user.getAccessToken().getTokenValue(),
Duration.between(now, user.getAccessToken().getExpiresAt())
);
var refreshToken = createCookie(
REFRESH_TOKEN_COOKIE, user.getRefreshToken().getTokenValue(),
Duration.between(now, user.getRefreshToken().getExpiresAt())
);
response.getHeaders()
.addAll(HttpHeaders.SET_COOKIE, List.of(accessToken.toString(), refreshToken.toString()));
}

private ResponseCookie createCookie(@NonNull String name, @NonNull String value, @NonNull Duration age) {
return ResponseCookie.from(name, value)
.path("/")
.httpOnly(true)
.sameSite("Lax")
.maxAge(age)
.build();
}

private void performRedirect(ServerHttpResponse response) {
URI location = UriComponentsBuilder.fromUriString(frontBaseUrl)
.queryParam("auth", "success")
.build().toUri();

response.setStatusCode(HttpStatus.FOUND);
response.getHeaders().setLocation(location);
}
}
Loading

0 comments on commit 9ca0492

Please sign in to comment.