Skip to content

Commit

Permalink
feat : refresh token 도입 (#231)
Browse files Browse the repository at this point in the history
* refactor : 회원가입, 로그인, 로그아웃 엔드포인트 수정

users -> auth로 수정

* refactor : 로그아웃된 토큰 예외 메세지 수정

* feat : 로그인 시 리프레시 토큰 추가

* test : 로그인 시 리프레시 토큰 테스트 수정

* feat : 토큰 관련 예외 클래스 추가

* feat : refresh token 관련 엔티티 세팅

* feat : 토큰 재발급 API 추가

* refactor : 토큰 관련 response DTO 이름 수정

* feat : 만료된 모든 refresh token 제거 스케줄러 추가
  • Loading branch information
rladmstn authored Dec 8, 2024
1 parent 9fe68e3 commit 0729081
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 72 deletions.
158 changes: 131 additions & 27 deletions src/main/java/com/gamzabat/algohub/common/jwt/TokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,36 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.UUID;
import java.util.stream.Collectors;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import com.amazonaws.util.StringUtils;
import com.gamzabat.algohub.common.jwt.domain.RefreshToken;
import com.gamzabat.algohub.common.jwt.dto.JwtDTO;
import com.gamzabat.algohub.common.jwt.exception.ExpiredTokenException;
import com.gamzabat.algohub.common.jwt.exception.TokenException;
import com.gamzabat.algohub.common.jwt.repository.RefreshTokenRepository;
import com.gamzabat.algohub.common.redis.RedisService;
import com.gamzabat.algohub.exception.JwtRequestException;
import com.gamzabat.algohub.feature.group.studygroup.exception.CannotFoundUserException;
import com.gamzabat.algohub.feature.user.dto.TokenResponse;
import com.gamzabat.algohub.feature.user.repository.UserRepository;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.SignatureAlgorithm;
Expand All @@ -38,36 +48,65 @@
@Component
@Getter
public class TokenProvider {
private final Key key;
private final Key accessTokenKey;
private final Key refreshTokenKey;
private final RedisService redisService;
private final RefreshTokenRepository refreshTokenRepository;
private final UserRepository userRepository;
@Value("${jwt_expiration_time}")
private long tokenExpiration;

public TokenProvider(@Value("${jwt_secret_key}") String secretKey, RedisService redisService) {
byte[] keyBytes = Decoders.BASE64URL.decode(secretKey);
this.key = Keys.hmacShaKeyFor(keyBytes);
private long accessTokenExpirationTime;
@Value("${refresh_token_expiration_time}")
private long refreshTokenExpirationTime;

public TokenProvider(@Value("${jwt_secret_key}") String accessTokenKey,
@Value("${jwt_refresh_secret_key}") String refreshTokenKey,
RedisService redisService, RefreshTokenRepository refreshTokenRepository,
UserRepository userRepository) {
byte[] accessKeyBytes = Decoders.BASE64URL.decode(accessTokenKey);
byte[] refreshKeyBytes = Decoders.BASE64URL.decode(refreshTokenKey);
this.accessTokenKey = Keys.hmacShaKeyFor(accessKeyBytes);
this.refreshTokenKey = Keys.hmacShaKeyFor(refreshKeyBytes);
this.redisService = redisService;
this.refreshTokenRepository = refreshTokenRepository;
this.userRepository = userRepository;
}

public JwtDTO generateTokens(Authentication authentication) {
String loginId = UUID.randomUUID().toString();
return JwtDTO.builder()
.grantType("Bearer")
.accessToken(generateAccessToken(loginId, authentication))
.refreshToken(generateRefreshToken(loginId, authentication))
.build();
}

public JwtDTO generateToken(Authentication authentication) {
public String generateAccessToken(String loginId, Authentication authentication) {
String authorities = authentication.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.joining(","));

long now = (new Date().getTime());
Date tokenExpireDate = new Date(now + this.accessTokenExpirationTime);
return createNewAccessToken(authentication.getName(), authorities, loginId, tokenExpireDate);
}

Date tokenExpireDate = new Date(now + this.tokenExpiration);
String token = Jwts.builder()
.setSubject(authentication.getName())
.claim("auth", authorities)
.setExpiration(tokenExpireDate)
.signWith(key, SignatureAlgorithm.HS256)
.compact();

return JwtDTO.builder()
.grantType("Bearer")
.token(token)
.build();
public String generateRefreshToken(String loginId, Authentication authentication) {
long now = (new Date().getTime());
Date expirationTime = new Date(now + this.refreshTokenExpirationTime);
String refreshToken = createNewRefreshToken(authentication.getName());

com.gamzabat.algohub.feature.user.domain.User user = userRepository.findByEmail(authentication.getName())
.orElseThrow(() -> new CannotFoundUserException(HttpStatus.NOT_FOUND.value(), "존재하지 않는 유저입니다."));

refreshTokenRepository.save(
RefreshToken.builder()
.refreshToken(refreshToken)
.user(user)
.loginId(loginId)
.expirationDateTime(expirationTime)
.build()
);
return refreshToken;
}

public Authentication getAuthentication(String token) {
Expand All @@ -86,13 +125,13 @@ public Authentication getAuthentication(String token) {
public boolean validateToken(String token) {
try {
if (logout(token))
throw new JwtRequestException(HttpStatus.UNAUTHORIZED.value(), "UNAUTHORIZED", "로그아웃 되었습니다.");
throw new JwtRequestException(HttpStatus.FORBIDDEN.value(), "FORBIDDEN", "로그아웃 된 토큰입니다.");
Jwts.parserBuilder()
.setSigningKey(key)
.setSigningKey(accessTokenKey)
.build().parseClaimsJws(token);
return true;
} catch (SecurityException | MalformedJwtException e) {
throw new JwtRequestException(HttpStatus.UNAUTHORIZED.value(), "UNAUTHORIZED", "검증되지 않은 토큰입니다.");
throw new JwtRequestException(HttpStatus.BAD_REQUEST.value(), "BAD_REQUEST", "검증되지 않은 토큰입니다.");
} catch (ExpiredJwtException e) {
throw new JwtRequestException(HttpStatus.UNAUTHORIZED.value(), "UNAUTHORIZED", "만료된 토큰 입니다.");
} catch (UnsupportedJwtException e) {
Expand All @@ -104,16 +143,14 @@ public boolean validateToken(String token) {

private Claims parseClaims(String token) {
return Jwts.parserBuilder()
.setSigningKey(key)
.setSigningKey(accessTokenKey)
.build()
.parseClaimsJws(token)
.getBody();
}

public String getUserEmail(String authToken) {
String token = authToken.replace("Bearer", "").trim();
Jws<Claims> claimsJws = Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(token);
return claimsJws.getBody().getSubject();
return getClaims(authToken).getSubject();
}

public String resolveToken(HttpServletRequest request) {
Expand All @@ -123,7 +160,74 @@ public String resolveToken(HttpServletRequest request) {
return null;
}

private Claims getClaims(String expiredToken) {
String token = expiredToken.replace("Bearer", "").trim();
return parseClaims(token);
}

@Transactional(propagation = Propagation.REQUIRES_NEW, noRollbackFor = ExpiredTokenException.class)
public TokenResponse reissueTokens(String expiredToken, String inputRefreshToken) {
Claims claims = getClaims(expiredToken);
String subject = claims.getSubject();
com.gamzabat.algohub.feature.user.domain.User user = userRepository.findByEmail(subject)
.orElseThrow(() -> new CannotFoundUserException(HttpStatus.NOT_FOUND.value(), "존재하지 않는 유저입니다."));

String loginId = (String)claims.get("loginId");
RefreshToken refreshToken = refreshTokenRepository.findByLoginIdAndUser(loginId, user)
.orElseThrow(() -> new TokenException(HttpStatus.UNAUTHORIZED.value(), "유효하지 않은 리프레시 토큰입니다. 재로그인이 필요합니다."));

validateTokenPair(inputRefreshToken, loginId, refreshToken);

long now = (new Date().getTime());
Date accessTokenExpireDate = new Date(now + this.accessTokenExpirationTime);
String newAccessToken = createNewAccessToken(
subject, claims.get("auth").toString(), loginId, accessTokenExpireDate
);

Date refreshTokenExpireDate = new Date(now + this.refreshTokenExpirationTime);
String newRefreshToken = createNewRefreshToken(subject);
refreshToken.updateRefreshToken(newRefreshToken, refreshTokenExpireDate);

return new TokenResponse(newAccessToken, newRefreshToken);
}

private void validateTokenPair(String inputRefreshToken, String loginId, RefreshToken refreshToken) {
if (!loginId.equals(refreshToken.getLoginId()) || !inputRefreshToken.equals(refreshToken.getRefreshToken())) {
throw new TokenException(HttpStatus.FORBIDDEN.value(), "토큰의 로그인 정보가 일치하지 않습니다.");
}

if (refreshToken.getExpirationDateTime().before(new Date())) {
refreshTokenRepository.delete(refreshToken);
log.info("success to delete refresh token");
throw new ExpiredTokenException(HttpStatus.UNAUTHORIZED.value(), "리프레시 토큰의 유효기간이 만료되었습니다. 재로그인이 필요합니다.");
}
}

private String createNewAccessToken(String subject, String authorities, String loginId, Date expirationDateTime) {
return Jwts.builder()
.setSubject(subject)
.setIssuedAt(new Date())
.claim("auth", authorities)
.claim("loginId", loginId)
.setExpiration(expirationDateTime)
.signWith(accessTokenKey, SignatureAlgorithm.HS256)
.compact();
}

private String createNewRefreshToken(String subject) {
return Jwts.builder()
.setSubject(subject)
.setIssuedAt(new Date())
.signWith(refreshTokenKey, SignatureAlgorithm.HS256)
.compact();
}

private boolean logout(String token) {
return redisService.getValues(token).equals("logout");
}

@Scheduled(cron = "0 0 0 * * *")
public void clearExpiredRefreshTokens() {
refreshTokenRepository.deleteExpiredRefreshTokens();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.gamzabat.algohub.common.jwt.domain;

import java.util.Date;

import org.hibernate.annotations.DynamicUpdate;

import com.gamzabat.algohub.feature.user.domain.User;

import jakarta.persistence.Entity;
import jakarta.persistence.FetchType;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;

@Entity
@Getter
@NoArgsConstructor
@DynamicUpdate
public class RefreshToken {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "user_id")
private User user;
private String refreshToken;
private String loginId;
private Date expirationDateTime;

@Builder
public RefreshToken(User user, String refreshToken, Date expirationDateTime, String loginId) {
this.user = user;
this.refreshToken = refreshToken;
this.loginId = loginId;
this.expirationDateTime = expirationDateTime;
}

public void updateRefreshToken(String refreshToken, Date expirationDateTime) {
this.refreshToken = refreshToken;
this.expirationDateTime = expirationDateTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
@AllArgsConstructor
public class JwtDTO {
private String grantType;
private String token;
private String accessToken;
private String refreshToken;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.gamzabat.algohub.common.jwt.dto;

public record ReissueTokenRequest(String expiredAccessToken, String refreshToken) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.gamzabat.algohub.common.jwt.exception;

import lombok.Getter;

@Getter
public class ExpiredTokenException extends TokenException {
public ExpiredTokenException(int code, String error) {
super(code, error);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.gamzabat.algohub.common.jwt.exception;

import lombok.Getter;

@Getter
public class TokenException extends RuntimeException {
private final int code;
private final String error;

public TokenException(int code, String error) {
this.code = code;
this.error = error;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.gamzabat.algohub.common.jwt.repository;

import java.util.Optional;

import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;

import com.gamzabat.algohub.common.jwt.domain.RefreshToken;
import com.gamzabat.algohub.feature.user.domain.User;

public interface RefreshTokenRepository extends JpaRepository<RefreshToken, Long> {
Optional<RefreshToken> findByLoginIdAndUser(String loginId, User user);

@Modifying
@Query("DELETE FROM RefreshToken rt WHERE rt.expirationDateTime < NOW()")
void deleteExpiredRefreshTokens();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;

import com.gamzabat.algohub.common.jwt.exception.ExpiredTokenException;
import com.gamzabat.algohub.common.jwt.exception.TokenException;
import com.gamzabat.algohub.feature.comment.exception.CommentValidationException;
import com.gamzabat.algohub.feature.group.ranking.exception.CannotFoundRankingException;
import com.gamzabat.algohub.feature.group.studygroup.exception.CannotFoundGroupException;
Expand Down Expand Up @@ -150,6 +152,18 @@ protected ResponseEntity<ErrorResponse> handler(CannotFoundUserException e) {
.body(new ErrorResponse(e.getCode(), e.getError(), null));
}

@ExceptionHandler(TokenException.class)
protected ResponseEntity<ErrorResponse> handler(TokenException e) {
return ResponseEntity.status(e.getCode())
.body(new ErrorResponse(e.getCode(), e.getError(), null));
}

@ExceptionHandler(ExpiredTokenException.class)
protected ResponseEntity<ErrorResponse> handler(ExpiredTokenException e) {
return ResponseEntity.status(e.getCode())
.body(new ErrorResponse(e.getCode(), e.getError(), null));
}

@ExceptionHandler(AwsS3Exception.class)
protected ResponseEntity<ErrorResponse> handler(AwsS3Exception e) {
return ResponseEntity.internalServerError()
Expand Down
Loading

0 comments on commit 0729081

Please sign in to comment.