Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: oauth 인증 클래스 리팩토링 #124

Merged
merged 4 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
Expand Down Expand Up @@ -88,7 +89,7 @@ public ResponseEntity<Void> deleteOrder(

@GetMapping("/payed")
public ResponseEntity<FindPayedOrdersResponse> findPayedOrders(
@RequestBody @Valid FindPayedOrdersRequest findPayedOrdersRequest,
@ModelAttribute @Valid FindPayedOrdersRequest findPayedOrdersRequest,
@LoginUser Long userId) {
FindPayedOrdersCommand command = FindPayedOrdersCommand.of(
userId,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,100 @@
package com.prgrms.nabmart.global.auth.oauth.client;

import com.prgrms.nabmart.domain.user.service.response.FindUserDetailResponse;
import com.prgrms.nabmart.global.auth.oauth.OAuthProvider;
import com.prgrms.nabmart.global.auth.oauth.dto.OAuthHttpMessage;
import com.prgrms.nabmart.global.infrastructure.ApiService;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.stereotype.Component;

public interface OAuthRestClient {
@Slf4j
@Component
@RequiredArgsConstructor
public class OAuthRestClient {

void callUnlinkOAuthUser(FindUserDetailResponse userDetailResponse);
private final ApiService apiService;
private final OAuth2AuthorizedClientService authorizedClientService;

void callRefreshAccessToken(FindUserDetailResponse userDetailResponse);
public void callUnlinkOAuthUser(final FindUserDetailResponse userDetailResponse) {
OAuthProvider oAuthProvider = OAuthProvider.getOAuthProvider(userDetailResponse.provider());
OAuthHttpMessageProvider oAuthHttpMessageProvider = oAuthProvider.getOAuthHttpMessageProvider();
OAuth2AuthorizedClient oAuth2AuthorizedClient = authorizedClientService.loadAuthorizedClient(
userDetailResponse.provider(),
userDetailResponse.providerId());

refreshAccessTokenIfNotValid(userDetailResponse, oAuth2AuthorizedClient);

OAuthHttpMessage unlinkHttpMessage = oAuthHttpMessageProvider.createUnlinkUserRequest(
userDetailResponse, oAuth2AuthorizedClient);

Map<String, Object> response = sendPostApiRequest(unlinkHttpMessage);
log.info("회원의 연결이 종료되었습니다. 회원 ID={}", response);

oAuthHttpMessageProvider.verifySuccessUnlinkUserRequest(response);
authorizedClientService.removeAuthorizedClient(
userDetailResponse.provider(),
userDetailResponse.provider());
}

private void refreshAccessTokenIfNotValid(FindUserDetailResponse userDetailResponse,
OAuth2AuthorizedClient oAuth2AuthorizedClient) {
Instant expiresAt = oAuth2AuthorizedClient.getAccessToken().getExpiresAt();
if(expiresAt.isBefore(Instant.now())) {
callRefreshAccessToken(userDetailResponse);
}
}

public void callRefreshAccessToken(final FindUserDetailResponse userDetailResponse) {
OAuthProvider oAuthProvider = OAuthProvider.getOAuthProvider(userDetailResponse.provider());
OAuthHttpMessageProvider oAuthHttpMessageProvider = oAuthProvider.getOAuthHttpMessageProvider();
OAuth2AuthorizedClient oAuth2AuthorizedClient = authorizedClientService.loadAuthorizedClient(
userDetailResponse.provider(),
userDetailResponse.providerId());
OAuthHttpMessage refreshAccessTokenRequest
= oAuthHttpMessageProvider.createRefreshAccessTokenRequest(oAuth2AuthorizedClient);

Map response = sendPostApiRequest(refreshAccessTokenRequest);

OAuth2AccessToken refreshedAccessToken
= oAuthHttpMessageProvider.extractAccessToken(response);
OAuth2RefreshToken refreshedRefreshToken
= oAuthHttpMessageProvider.extractRefreshToken(response)
.orElse(oAuth2AuthorizedClient.getRefreshToken());

OAuth2AuthorizedClient updatedAuthorizedClient = new OAuth2AuthorizedClient(
oAuth2AuthorizedClient.getClientRegistration(),
oAuth2AuthorizedClient.getPrincipalName(),
refreshedAccessToken,
refreshedRefreshToken);
Authentication authenticationForTokenRefresh
= getAuthenticationForTokenRefresh(updatedAuthorizedClient);
authorizedClientService.saveAuthorizedClient(
updatedAuthorizedClient,
authenticationForTokenRefresh);
}

private Authentication getAuthenticationForTokenRefresh(
OAuth2AuthorizedClient updatedAuthorizedClient) {
String principalName = updatedAuthorizedClient.getPrincipalName();
return UsernamePasswordAuthenticationToken.authenticated(
principalName, null, List.of());
}

private Map sendPostApiRequest(OAuthHttpMessage request) {
return apiService.getResult(
request.httpMessage(),
request.uri(),
Map.class,
request.uriVariables());
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.prgrms.nabmart.global.exception.ExternalApiException;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.HttpEntity;
import org.springframework.http.ResponseEntity;
Expand All @@ -21,7 +23,12 @@ public ApiService(RestTemplateBuilder restTemplateBuilder) {
}

public <T> T getResult(HttpEntity httpEntity, String url, Class<T> clazz) {
ResponseEntity<T> response = callExternalApi(url, httpEntity, clazz);
return getResult(httpEntity, url, clazz, Collections.emptyMap());
}

public <T> T getResult(HttpEntity httpEntity, String url, Class<T> clazz,
Map<String, ?> uriVariables) {
ResponseEntity<T> response = callExternalApi(url, httpEntity, clazz, uriVariables);
if (response.getStatusCode().isError()) {
throw new ExternalApiException("외부 API 호출 과정에서 오류가 발생했습니다");
}
Expand All @@ -32,9 +39,10 @@ public <T> T getResult(HttpEntity httpEntity, String url, Class<T> clazz) {
private <T> ResponseEntity<T> callExternalApi(
String url,
HttpEntity httpEntity,
Class<T> clazz) {
Class<T> clazz,
Map<String, ?> uriVariables) {
try {
return restTemplate.postForEntity(url, httpEntity, clazz);
return restTemplate.postForEntity(url, httpEntity, clazz, uriVariables);
} catch (Exception exception) {
throw new ExternalApiException("외부 API 호출 과정에서 오류가 발생했습니다");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import static org.springframework.restdocs.payload.JsonFieldType.NUMBER;
import static org.springframework.restdocs.payload.JsonFieldType.STRING;
import static org.springframework.restdocs.payload.PayloadDocumentation.fieldWithPath;
import static org.springframework.restdocs.payload.PayloadDocumentation.requestFields;
import static org.springframework.restdocs.payload.PayloadDocumentation.responseFields;
import static org.springframework.restdocs.request.RequestDocumentation.parameterWithName;
import static org.springframework.restdocs.request.RequestDocumentation.pathParameters;
Expand All @@ -32,12 +31,11 @@
import com.prgrms.nabmart.domain.coupon.UserCoupon;
import com.prgrms.nabmart.domain.order.Order;
import com.prgrms.nabmart.domain.order.controller.request.CreateOrderRequest;
import com.prgrms.nabmart.domain.order.controller.request.FindPayedOrdersRequest;
import com.prgrms.nabmart.domain.order.service.response.FindPayedOrdersResponse;
import com.prgrms.nabmart.domain.order.service.response.FindPayedOrdersResponse.FindPayedOrderResponse;
import com.prgrms.nabmart.domain.order.service.response.CreateOrderResponse;
import com.prgrms.nabmart.domain.order.service.response.FindOrderDetailResponse;
import com.prgrms.nabmart.domain.order.service.response.FindOrdersResponse;
import com.prgrms.nabmart.domain.order.service.response.FindPayedOrdersResponse;
import com.prgrms.nabmart.domain.order.service.response.FindPayedOrdersResponse.FindPayedOrderResponse;
import com.prgrms.nabmart.domain.order.service.response.UpdateOrderByCouponResponse;
import com.prgrms.nabmart.domain.user.User;
import java.util.List;
Expand Down Expand Up @@ -247,7 +245,6 @@ class FindPayedOrdersTest {
void findPayedOrders() throws Exception {
//given
int page = 0;
FindPayedOrdersRequest findPayedOrdersRequest = new FindPayedOrdersRequest(page);
FindPayedOrderResponse findPayedOrderResponse
= new FindPayedOrderResponse(1L, "비비고 왕교자 외 2개", 20000);
FindPayedOrdersResponse findPayedOrdersResponse = new FindPayedOrdersResponse(
Expand All @@ -258,17 +255,16 @@ void findPayedOrders() throws Exception {
//when
ResultActions resultActions = mockMvc.perform(get("/api/v1/orders/payed")
.header(AUTHORIZATION, accessToken)
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(findPayedOrdersRequest)));
.param("page", String.valueOf(page)));

//then
resultActions.andExpect(status().isOk())
.andDo(restDocs.document(
requestHeaders(
headerWithName(AUTHORIZATION).description("액세스 토큰")
),
requestFields(
fieldWithPath("page").type(NUMBER).description("페이지 번호")
queryParameters(
parameterWithName("page").description("페이지 번호")
),
responseFields(
fieldWithPath("orders").type(ARRAY).description("주문 목록"),
Expand Down
Loading