Skip to content

Commit

Permalink
feat: implement rueidislimiter module for distributed rate limiting
Browse files Browse the repository at this point in the history
Signed-off-by: Ernesto Alejandro Santana Hidalgo <[email protected]>
  • Loading branch information
nesty92 committed Nov 5, 2024
1 parent bf8de44 commit a0cb730
Show file tree
Hide file tree
Showing 5 changed files with 569 additions and 0 deletions.
142 changes: 142 additions & 0 deletions rueidislimiter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@

# rueidislimiter

A Redis-based rate limiter for Go, implemented using the `rueidis` client library. This module provides an interface for token bucket rate limiting with precise control over limits and time windows. Inspired by GitHub's approach to scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)), this module offers a robust solution for managing request rates across distributed systems.

## Features

- **Token Bucket Algorithm**: Implements a token bucket algorithm to control the number of actions (e.g., API requests) a user can perform within a specified time window.
- **Customizable Limits**: Allows configuration of request limits and time windows to suit various application requirements.
- **Distributed Rate Limiting**: Leverages Redis to maintain rate limit counters, ensuring consistency across distributed environments.
- **Retry and Reset Information**: Provides `RetryAfter` and `ResetAfter` durations to inform clients when they can retry requests.

## Installation

To install the `rueidislimiter` module, run:

```bash
go get github.com/redis/rueidis/rueidislimiter
```

## Usage

### Basic Rate Limiting Example

The following example demonstrates how to initialize a rate limiter with a custom request limit and time window, and how to check and allow requests based on an identifier (e.g., a user ID or IP address):

```go
package main

import (
"context"
"fmt"
"time"

"github.com/redis/rueidis"
"github.com/redis/rueidis/rueidislimiter"
)

func main() {
client, err := rueidis.NewClient(rueidis.ClientOption{
InitAddress: []string{"localhost:6379"},
})
if err != nil {
panic(err)
}

// Initialize a new rate limiter with a limit of 5 requests per minute
limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{
ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}},
KeyPrefix: "api_rate_limit",
Limit: 5,
Window: time.Minute,
})
if err != nil {
panic(err)
}

identifier := "user_123"

// Check if a request is allowed
result, err := limiter.Check(context.Background(), identifier)
if err != nil {
panic(err)
}
fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter)

// Allow a request
result, err = limiter.Allow(context.Background(), identifier)
if err != nil {
panic(err)
}
fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter)

// Allow multiple requests
result, err = limiter.AllowN(context.Background(), identifier, 3)
if err != nil {
panic(err)
}
fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter)
}
```

### API

#### `NewRateLimiter`

Creates a new rate limiter with the specified options:

- `ClientOption`: Options to connect to Redis.
- `KeyPrefix`: Prefix for Redis keys used by this limiter.
- `Limit`: Maximum number of allowed requests per window.
- `Window`: Time window duration for rate limiting.

```go
limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{
ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}},
KeyPrefix: "api_rate_limit",
Limit: 5,
Window: time.Minute,
})
```

#### `Check`

Checks if a request is allowed under the rate limit without incrementing the count.

```go
result, err := limiter.Check(ctx, "user_identifier")
```

Returns a `Result` struct:

- `Allowed`: Whether the request is allowed.
- `Remaining`: Number of remaining requests in the current window.
- `RetryAfter`: Duration until the next allowed request (0 if allowed).
- `ResetAfter`: Duration until the current rate limit window resets.

#### `Allow`

Allows a single request, incrementing the counter if allowed.

```go
result, err := limiter.Allow(ctx, "user_identifier")
```

#### `AllowN`

Allows `n` requests, incrementing the counter accordingly if allowed.

```go
result, err := limiter.AllowN(ctx, "user_identifier", 3)
```

- `n`: The number of requests to allow.

## Implementation Details

The `rueidislimiter` module employs Lua scripts executed within Redis to ensure atomic operations for checking and updating rate limits. This approach minimizes race conditions and maintains consistency across distributed systems.

By utilizing Redis's expiration capabilities, the module automatically resets rate limits after the specified time window, ensuring efficient memory usage and accurate rate limiting behavior.

For more information on the design and implementation of Redis-based rate limiters, refer to GitHub's detailed account of scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)).
9 changes: 9 additions & 0 deletions rueidislimiter/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module github.com/redis/rueidis/rueidislimiter

go 1.21

replace github.com/redis/rueidis => ../

require github.com/redis/rueidis v1.0.48

require golang.org/x/sys v0.24.0 // indirect
14 changes: 14 additions & 0 deletions rueidislimiter/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
216 changes: 216 additions & 0 deletions rueidislimiter/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package rueidislimiter

import (
"context"
"fmt"
"strconv"
"strings"
"time"

"github.com/redis/rueidis"
)

type Result struct {
// Allowed indicates if the request is allowed under the rate limit
Allowed bool
// Remaining is the number of remaining tokens in the current window
Remaining int64
// ResetAfter is the duration until the rate limit resets
ResetAfter time.Duration
// RetryAfter is the duration after which the request may be retried (0 if allowed)
RetryAfter time.Duration
}

type RateLimiterClient interface {
Check(ctx context.Context, identifier string) (Result, error)
Allow(ctx context.Context, identifier string) (Result, error)
AllowN(ctx context.Context, identifier string, n int64) (Result, error)
}

const PlaceholderPrefix = "rueidislimiter"

type rateLimiter struct {
client rueidis.Client
key string
limit int
window time.Duration
}

type RateLimiterOption struct {
// ClientBuilder can be used to modify rueidis.Client used by RateLimiter
ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error)
ClientOption rueidis.ClientOption
KeyPrefix string
Limit int
Window time.Duration
}

func NewRateLimiter(option RateLimiterOption) (rlc RateLimiterClient, err error) {
if option.Window <= 0 {
option.Window = time.Second
}
if option.Limit <= 0 {
option.Limit = 1
}
if option.KeyPrefix == "" {
option.KeyPrefix = PlaceholderPrefix
}
rl := &rateLimiter{
limit: option.Limit,
window: option.Window,
}
if option.ClientBuilder != nil {
rl.client, err = option.ClientBuilder(option.ClientOption)
} else {
rl.client, err = rueidis.NewClient(option.ClientOption)
}
if err != nil {
return nil, err
}
rl.key = option.KeyPrefix
return rl, nil
}

func (l *rateLimiter) Check(ctx context.Context, identifier string) (Result, error) {
now := time.Now().UTC()
keys := []string{l.getKey(identifier)}
args := []string{strconv.FormatInt(now.Unix(), 10)}

resp := checkLimitScript.Exec(ctx, l.client, keys, args)
if resp.Error() != nil {
return Result{}, fmt.Errorf("failed to execute check limit script: %w", resp.Error())
}

array, err := resp.ToArray()
if err != nil {
return Result{}, fmt.Errorf("failed to parse response array: %w", err)
}

return l.parseResult(array, now)
}

func (l *rateLimiter) Allow(ctx context.Context, identifier string) (Result, error) {
return l.AllowN(ctx, identifier, 1)
}

func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64) (Result, error) {
if n <= 0 {
return Result{}, fmt.Errorf("number of tokens must be positive, got %d", n)
}

now := time.Now().UTC()
keys := []string{l.getKey(identifier)}
args := []string{
strconv.FormatInt(n, 10),
strconv.FormatInt(now.Add(l.window).Unix(), 10),
strconv.FormatInt(now.Unix(), 10),
}

resp := rateLimitScript.Exec(ctx, l.client, keys, args)
if resp.Error() != nil {
return Result{}, fmt.Errorf("failed to execute rate limit script: %w", resp.Error())
}

array, err := resp.ToArray()
if err != nil {
return Result{}, fmt.Errorf("failed to parse response array: %w", err)
}

return l.parseResult(array, now)
}

func (l *rateLimiter) parseResult(array []rueidis.RedisMessage, now time.Time) (r Result, err error) {
var current int64
var expiresAt int64

if len(array) == 2 {
current, err = array[0].ToInt64()
if err != nil {
if !rueidis.IsRedisNil(err) {
return Result{}, err
}
}

expiresAt, err = array[1].ToInt64()
if err != nil {
if !rueidis.IsRedisNil(err) {
return Result{}, err
}
}
}

if expiresAt == 0 {
current = 0
expiresAt = now.Unix() + int64(l.window.Seconds())
}

remaining := int64(l.limit) - current
if remaining < 0 {
remaining = 0
}

allowed := current < int64(l.limit)
resetAfter := time.Until(time.Unix(expiresAt, 0))
if resetAfter <= 0 {
resetAfter = l.window
}
retryAfter := resetAfter
if allowed {
retryAfter = 0
}

return Result{
Allowed: allowed,
Remaining: remaining,
ResetAfter: resetAfter,
RetryAfter: retryAfter,
}, nil
}

func (l *rateLimiter) getKey(identifier string) string {
// NOTE: https://redis.io/docs/reference/cluster-spec/#hash-tags
sb := strings.Builder{}
sb.Grow(len(l.key) + len(identifier) + 3) // +3 for ":", "{", "}"
sb.WriteString(l.key)
sb.WriteString(":{")
sb.WriteString(identifier)
sb.WriteString("}")
return sb.String()
}

var (
rateLimitScript = rueidis.NewLuaScript(`
local rate_limit_key = KEYS[1]
local increment_amount = tonumber(ARGV[1])
local next_expires_at = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
local expires_at_key = rate_limit_key .. ":ex"
local expires_at = tonumber(redis.call("get", expires_at_key))
if not expires_at or expires_at < current_time then
redis.call("set", rate_limit_key, 0)
redis.call("set", expires_at_key, next_expires_at)
redis.call("expireat", rate_limit_key, next_expires_at + 1)
redis.call("expireat", expires_at_key, next_expires_at + 1)
expires_at = next_expires_at
end
local current = redis.call("incrby", rate_limit_key, increment_amount)
return { current, expires_at }
`)
checkLimitScript = rueidis.NewLuaScriptReadOnly(`
local rate_limit_key = KEYS[1]
local expires_at_key = rate_limit_key .. ":ex"
local current_time = tonumber(ARGV[1])
local tries = tonumber(redis.call("get", rate_limit_key))
local expires_at = nil -- maybe overridden below
if not tries then
tries = nil
else
expires_at = tonumber(redis.call("get", expires_at_key))
if not expires_at or expires_at < current_time then
tries = nil
expires_at = nil
end
end
return { tries, expires_at }
`)
)
Loading

0 comments on commit a0cb730

Please sign in to comment.