Skip to content

Commit

Permalink
Merge pull request #171 from kookmin-sw/feature/be/#159-ApiRateLimiter
Browse files Browse the repository at this point in the history
Feature/be/#159 API Rate Limiter 적용
  • Loading branch information
BlueBerrySoda authored May 10, 2024
2 parents 93f68f9 + 6b8553e commit d838cbe
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 26 deletions.
4 changes: 3 additions & 1 deletion back-gateway/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ ext {
}

dependencies {
implementation 'org.springframework.cloud:spring-cloud-starter-gateway'
compileOnly 'org.projectlombok:lombok'
annotationProcessor 'org.projectlombok:lombok'
testImplementation 'org.springframework.boot:spring-boot-starter-test'

implementation 'io.jsonwebtoken:jjwt-api:0.11.5'
implementation 'io.jsonwebtoken:jjwt-impl:0.11.5'
implementation 'io.jsonwebtoken:jjwt-jackson:0.11.5'

implementation 'org.springframework.boot:spring-boot-starter-data-redis-reactive'
implementation 'org.springframework.cloud:spring-cloud-starter-gateway'
}

dependencyManagement {
Expand Down
19 changes: 19 additions & 0 deletions back-gateway/redis-server.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## Test Version

version: '3'

services:
redis:
image: redis:7.2.0-alpine
container_name: redis
hostname: redis
restart: unless-stopped
environment:
TZ: "Asia/Seoul"
ports:
- 6379:6379
healthcheck:
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
interval: 5s
timeout: 3s
retries: 10
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.gateway.backgateway.config;

import com.gateway.backgateway.filter.AuthorizationHeaderFilter;
import com.gateway.backgateway.filter.RequestRateLimitFilter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
Expand All @@ -14,23 +16,52 @@ public class GatewayConfig {

@Bean
public RouteLocator gatewayRoutes(RouteLocatorBuilder builder,
AuthorizationHeaderFilter authFilter) {
AuthorizationHeaderFilter authFilter,
RequestRateLimitFilter limitFilter) {
/**
* 여기에 JWT토큰 및 API Rate Limiter가 필요한 Api Routing을 작성해주세요.
* Spring Security와 비슷한 느낌으로 해주시면 됩니다.
*/
return builder.routes()
.route("chatbot",r -> r.path("/docs", "/openapi.json")
.route("chatbot-docs",r -> r.path("/docs", "/openapi.json")
.uri(chatbotUrl))
.route("chatbot",r -> r.path("/api/chatbot/**")
.filters(f->f.filter(authFilter.apply(config -> {config.setRequiredRole("role_user");})))
.filters(f->f
.filter(authFilter.apply(config -> {config.setRequiredRole("role_user");}))
.filter(limitFilter.apply(config -> {
config.setRateLimiter(redisRateLimiter());
config.setRouteId("chatbot");
}))
)
.uri(chatbotUrl))
.route("chat", r -> r.path("/api/chat/**")
.filters(f->f.filter(authFilter.apply(config -> {config.setRequiredRole("role_user");})))
.route(r -> r.path("/api/chat/**")
.filters(f->f
.filter(authFilter.apply(config -> {config.setRequiredRole("role_user");}))
.filter(limitFilter.apply(config -> {
config.setRateLimiter(redisRateLimiter());
config.setRouteId("chat");
})))
.uri("http://ruby:3000"))
.route("business", r -> r.path("/api/user/signin", "/api/user/test", "/api/user/signup",
.route("nonJwt-spring", r -> r.path("/api/user/signin", "/api/user/test", "/api/user/signup",
"/api/announcement/**", "/api/menu/**", "/api/speech/**", "/api/question/read", "/api/question/list",
"/api/answer/list", "/api/faq/**", "/api/help/read", "/api/help/list", "/api/auth/**", "/api/swagger-ui/**", "/api/api-docs/**")
.uri("http://spring:8080"))
.route("business", r -> r.path("/api/**")
.filters(f->f.filter(authFilter.apply(config -> {config.setRequiredRole("role_user");})))
.route("spring", r -> r.path("/api/**")
.filters(f->f
.filter(authFilter.apply(config -> {config.setRequiredRole("role_user");}))
.filter(limitFilter.apply(config -> {
config.setRateLimiter(redisRateLimiter());
config.setRouteId("spring");
}))
)
.uri("http://spring:8080"))
.build();
}

//TODO: Custom RedisRateLimiter로 변경 예정
@Bean
public RedisRateLimiter redisRateLimiter() {
// 기본 replenishRate 및 burstCapacity 값을 지정합니다.
return new RedisRateLimiter(20, 60, 3);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.gateway.backgateway.config;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.connection.ReactiveRedisConnectionFactory;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.data.redis.serializer.StringRedisSerializer;

@Configuration
public class RedisConfig {

@Value("${spring.data.redis.host}")
String redishost;

@Bean
@Primary
public ReactiveRedisConnectionFactory reactiveRedisConnectionFactory() {
return new LettuceConnectionFactory(redishost, 6379);
}

@Bean
@Primary
public ReactiveRedisTemplate<String, Object> reactiveRedisTemplate(ReactiveRedisConnectionFactory factory) {
RedisSerializationContext<String, Object> serializationContext = RedisSerializationContext
.<String, Object>newSerializationContext(new StringRedisSerializer())
.hashKey(new StringRedisSerializer())
.hashValue(new StringRedisSerializer())
.build();

return new ReactiveRedisTemplate<>(factory, serializationContext);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.gateway.backgateway.config;

import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

@Configuration("user-key-resolver")
@Primary
@Slf4j
public class UserIdKeyResolver implements KeyResolver {
@Override
public Mono<String> resolve(ServerWebExchange exchange) {
String userId = exchange.getRequest().getHeaders().getFirst("X-USER-ID");
log.info("The user id is {}", userId);
return Mono.just(userId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public enum ErrorCode {
INTERNAL_SERVER_ERROR(500, "C004", "Server Error"),
INVALID_TYPE_VALUE(400, "C005", "Invalid Type Value"),
HANDLE_ACCESS_DENIED(403, "C006", "Access is Denied"),
TOO_MANY_REQUESTS(429, "C007", "Too Many Requests"),

// JWT Error
INVALID_JWT_TOKEN(401, "J001", "Invalid JWT Token");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
package com.gateway.backgateway.exception;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.gateway.backgateway.dto.ErrorResponse;
import io.jsonwebtoken.JwtException;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.web.reactive.error.ErrorWebExceptionHandler;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import static com.gateway.backgateway.exception.ErrorCode.INTERNAL_SERVER_ERROR;
import static com.gateway.backgateway.exception.ErrorCode.INVALID_JWT_TOKEN;
import static com.gateway.backgateway.exception.ErrorCode.*;

@Slf4j
@Order(-1)
Expand All @@ -41,11 +37,19 @@ public Mono<Void> handle(ServerWebExchange exchange, Throwable ex) {
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
if (ex instanceof JwtTokenInvalidException) {
errorCode = INVALID_JWT_TOKEN;
response.setStatusCode(HttpStatusCode.valueOf(errorCode.getStatus()));
response.setStatusCode(HttpStatus.valueOf(errorCode.getStatus()));
}
else if (ex instanceof TooManyRequestException) {
errorCode = TOO_MANY_REQUESTS;
response.setStatusCode(HttpStatus.valueOf(errorCode.getStatus()));
}
else if (ex instanceof BusinessException) {
errorCode = ((BusinessException) ex).getErrorCode();
response.setStatusCode(HttpStatus.valueOf(String.valueOf(((BusinessException) ex).getErrorCode())));
}
else{
errorCode = INTERNAL_SERVER_ERROR;
response.setStatusCode(HttpStatusCode.valueOf(errorCode.getStatus()));
response.setStatusCode(HttpStatus.valueOf(errorCode.getStatus()));
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.gateway.backgateway.exception;

public class TooManyRequestException extends BusinessException {

// 자주 발생할 수 있는 Exception이라 싱글톤화 하는게 좋다고 합니다.
public static final TooManyRequestException INSTANCE = new TooManyRequestException();

private TooManyRequestException() {
super(ErrorCode.TOO_MANY_REQUESTS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.security.Key;
import java.util.Map;
import java.util.function.Function;

@Component
Expand All @@ -46,6 +41,8 @@ public GatewayFilter apply(Config config) {
String token = request.getHeaders()
.getFirst(HttpHeaders.AUTHORIZATION).replace("Bearer ", "");

System.out.println(token);

if (!validateToken(token)) {
throw JwtTokenInvalidException.INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.gateway.backgateway.filter;

import com.gateway.backgateway.exception.TooManyRequestException;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter;
import org.springframework.cloud.gateway.support.HasRouteId;
import org.springframework.stereotype.Component;

@Component
@Slf4j
public class RequestRateLimitFilter extends AbstractGatewayFilterFactory<RequestRateLimitFilter.Config> {

private final KeyResolver defaultKeyResolver;
private final RedisRateLimiter defaultRateLimiter;

public RequestRateLimitFilter(
KeyResolver defaultKeyResolver,
RedisRateLimiter redisRateLimiter) {
super(Config.class);
this.defaultKeyResolver = defaultKeyResolver;
this.defaultRateLimiter = redisRateLimiter;
}

@Override
public GatewayFilter apply(Config config) {
log.info("여기 필터 지나는지 확인 1111");
GatewayFilter filter = (exchange, chain) -> {
KeyResolver keyResolver = getOrDefault(config.keyResolver, defaultKeyResolver);
RedisRateLimiter rateLimiter = getOrDefault(config.rateLimiter, defaultRateLimiter);
String routeId = config.getRouteId();
log.info("여기 필터 지나는지 확인 2222222");

return keyResolver.resolve(exchange)
.flatMap(key -> rateLimiter.isAllowed(routeId, key))
.flatMap(rateLimitResponse -> {
if (rateLimitResponse.isAllowed()) {
return chain.filter(exchange);
} else {
throw TooManyRequestException.INSTANCE;
}
});
};

return filter;
}

private <T> T getOrDefault(T configValue, T defaultValue) {
if (configValue != null) return configValue;
else return defaultValue;
}

@Getter
@Setter
public static class Config implements HasRouteId {
private KeyResolver keyResolver;
private RedisRateLimiter rateLimiter;
private String routeId;
}
}
13 changes: 12 additions & 1 deletion back-gateway/src/main/resources/application.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
spring:
application:
name: back-gateway
data:
redis:
host: ${REDIS_HOST}
port: ${REDIS_PORT}
data: 0


jwt:
secret:
key: ${JWT_SECRET}

server:
port: 8081
chatbot-url: ${CHATBOT_URL}
chatbot-url: ${CHATBOT_URL}

logging:
level:
root: debug

Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ ResponseEntity<ApiResult<AnnouncementListWrapper>> getAnnouncementSearchList(
@RequestParam(defaultValue = "KO", value = "language") String language,
@Parameter(description = "어디까지 로드됐는지 가르키는 커서입니다. 입력하지 않으면 처음부터 10개 받아옵니다.")
@RequestParam(defaultValue = "0", value = "cursor") long cursor,
@RequestBody AnnouncementSearchListRequest request
@Parameter(description = "검색어입니다. 문자열을 인코딩해서 보내주셔야됩니다.")
@RequestParam(value = "word") String word
) {

if (request.word().length() < 2) throw new BusinessException(SEARCH_TOO_SHORT);
if (word.length() < 2) throw new BusinessException(SEARCH_TOO_SHORT);


Slice<AnnouncementListResponse> slice = announcementSearchService.getAnnouncementSearchList(cursor, type,
language, request.word());
language, word);

List<AnnouncementListResponse> announcements = slice.getContent();
boolean hasNext = slice.hasNext();
Expand Down

0 comments on commit d838cbe

Please sign in to comment.