Skip to content

Commit

Permalink
Integrated Sliding Window Rate Limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
x-sushant-x committed Nov 10, 2024
1 parent 7fd8399 commit 4ad511a
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 33 deletions.
35 changes: 27 additions & 8 deletions rate_shield/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ const (
)

type Limiter struct {
tokenBucket *TokenBucketService
fixedWindow *FixedWindowService
redisRuleSvc service.RulesService
cachedRules map[string]*models.Rule
tokenBucket *TokenBucketService
fixedWindow *FixedWindowService
slidingWindow *SlidingWindowService
redisRuleSvc service.RulesService
cachedRules map[string]*models.Rule
}

func NewRateLimiterService(tokenBucket *TokenBucketService, fixedWindow *FixedWindowService, redisRuleSvc service.RulesService) Limiter {
func NewRateLimiterService(
tokenBucket *TokenBucketService, fixedWindow *FixedWindowService, slidingWindow *SlidingWindowService, redisRuleSvc service.RulesService) Limiter {

return Limiter{
tokenBucket: tokenBucket,
fixedWindow: fixedWindow,
redisRuleSvc: redisRuleSvc,
tokenBucket: tokenBucket,
fixedWindow: fixedWindow,
redisRuleSvc: redisRuleSvc,
slidingWindow: slidingWindow,
// This is initialized later in StartRateLimiter() function
cachedRules: nil,
}
Expand All @@ -41,6 +44,8 @@ func (l *Limiter) CheckLimit(ip, endpoint string) *models.RateLimitResponse {
return l.processTokenBucketReq(key, rule)
case "FIXED WINDOW COUNTER":
return l.processFixedWindowReq(ip, endpoint, rule)
case "SLIDING WINDOW COUNTER":
return l.processSlidingWindowReq(ip, endpoint, rule)
}
}

Expand Down Expand Up @@ -79,6 +84,20 @@ func (l *Limiter) processFixedWindowReq(ip, endpoint string, rule *models.Rule)
return resp
}

func (l *Limiter) processSlidingWindowReq(ip, endpoint string, rule *models.Rule) *models.RateLimitResponse {
resp := l.slidingWindow.processRequest(ip, endpoint, rule)

if resp.Success {
return resp
}

if rule.AllowOnError {
return utils.BuildRateLimitSuccessResponse(0, 0)
}

return resp
}

func (l *Limiter) GetRule(key string) (*models.Rule, bool, error) {
return l.redisRuleSvc.GetRule(key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ import (
"github.com/x-sushant-x/RateShield/utils"
)

const (
MaxRequest = 10
WindowSize = 60 * time.Second
)

var (
ctx = context.Background()
)
Expand All @@ -29,20 +24,18 @@ func NewSlidingWindowService(redisClient *redis.Client) SlidingWindowService {
}
}

func (s *SlidingWindowService) ProcessRequest(key string) *models.RateLimitResponse {
func (s *SlidingWindowService) processRequest(ip, endpoint string, rule *models.Rule) *models.RateLimitResponse {
key := ip + ":" + endpoint

now := time.Now().Unix()
windowSize := time.Duration(rule.SlidingWindowCounterRule.WindowSize) * time.Second

pipe := s.redisClient.TxPipeline()
pipe.ZRemRangeByScore(ctx, key, "0", fmt.Sprintf("%d", now-int64(WindowSize.Seconds())))
then := fmt.Sprintf("%d", now-int64(windowSize.Seconds()))

pipe.ZAdd(ctx, key, redis.Z{
Score: float64(now),
Member: now,
})

countCmd := pipe.ZCount(ctx, key, fmt.Sprintf("%d", now-int64(WindowSize.Seconds())), fmt.Sprintf("%d", now))
pipe := s.redisClient.TxPipeline()
pipe.ZRemRangeByScore(ctx, key, "0", then)

pipe.Expire(ctx, key, WindowSize)
countCmd := pipe.ZCount(ctx, key, then, fmt.Sprintf("%d", now))

_, err := pipe.Exec(ctx)
if err != nil {
Expand All @@ -54,10 +47,23 @@ func (s *SlidingWindowService) ProcessRequest(key string) *models.RateLimitRespo
return utils.BuildRateLimitErrorResponse(500)
}

if count > MaxRequest {
if count > rule.SlidingWindowCounterRule.MaxRequests {
return utils.BuildRateLimitErrorResponse(429)
}

// TODO -> Change 999 from actual data when rule is configured.
return utils.BuildRateLimitSuccessResponse(999, 999)
pipe = s.redisClient.TxPipeline()

pipe.ZAdd(ctx, key, redis.Z{
Member: now,
Score: float64(now),
})

pipe.Expire(ctx, key, windowSize)

_, err = pipe.Exec(ctx)
if err != nil {
return utils.BuildRateLimitErrorResponse(500)
}

return utils.BuildRateLimitSuccessResponse(rule.SlidingWindowCounterRule.MaxRequests, rule.SlidingWindowCounterRule.MaxRequests-count)
}
9 changes: 8 additions & 1 deletion rate_shield/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ func main() {
log.Fatal().Err(err)
}

redisSlidingWindowClient, err := redisClient.NewSlidingWindowClient()
if err != nil {
log.Fatal().Err(err)
}

slackSvc := service.NewSlackService(slackToken, slackChannelID)

errorNotificationSvc := service.NewErrorNotificationSVC(*slackSvc)
Expand All @@ -52,7 +57,9 @@ func main() {

redisRulesSvc := service.NewRedisRulesService(redisRulesClient)

limiter := limiter.NewRateLimiterService(&tokenBucketSvc, &fixedWindowSvc, redisRulesSvc)
slidingWindowSvc := limiter.NewSlidingWindowService(redisSlidingWindowClient)

limiter := limiter.NewRateLimiterService(&tokenBucketSvc, &fixedWindowSvc, &slidingWindowSvc, redisRulesSvc)
limiter.StartRateLimiter()

server := api.NewServer(8080, limiter)
Expand Down
18 changes: 12 additions & 6 deletions rate_shield/models/rules.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package models

type Rule struct {
Strategy string `json:"strategy"`
APIEndpoint string `json:"endpoint"`
HTTPMethod string `json:"http_method"`
AllowOnError bool `json:"allow_on_error"`
TokenBucketRule *TokenBucketRule `json:"token_bucket_rule,omitempty"`
FixedWindowCounterRule *FixedWindowCounterRule `json:"fixed_window_counter_rule,omitempty"`
Strategy string `json:"strategy"`
APIEndpoint string `json:"endpoint"`
HTTPMethod string `json:"http_method"`
AllowOnError bool `json:"allow_on_error"`
TokenBucketRule *TokenBucketRule `json:"token_bucket_rule,omitempty"`
FixedWindowCounterRule *FixedWindowCounterRule `json:"fixed_window_counter_rule,omitempty"`
SlidingWindowCounterRule *SlidingWindowCounterRule `json:"sliding_window_counter_rule,omitempty"`
}

type TokenBucketRule struct {
Expand All @@ -29,3 +30,8 @@ type PaginatedRules struct {
HasNextPage bool `json:"has_next_page"`
Rules []Rule `json:"rules"`
}

type SlidingWindowCounterRule struct {
MaxRequests int64 `json:"max_requests"`
WindowSize int `json:"window"`
}
8 changes: 8 additions & 0 deletions rate_shield/redis/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ func NewFixedWindowClient() (RedisFixedWindow, error) {
}, nil
}

func NewSlidingWindowClient() (*redis.Client, error) {
client, err := createNewRedisConnection(getRedisConnectionStr(), 3)
if err != nil {
return nil, err
}
return client, nil
}

func NewRulesClient() (RedisRules, error) {

client, err := createNewRedisConnection(getRedisConnectionStr(), 0)
Expand Down

0 comments on commit 4ad511a

Please sign in to comment.