diff --git a/rueidislimiter/README.md b/rueidislimiter/README.md new file mode 100644 index 00000000..20c80366 --- /dev/null +++ b/rueidislimiter/README.md @@ -0,0 +1,141 @@ + +# rueidislimiter + +This module provides an interface for token bucket rate limiting with precise control over limits and time windows. Inspired by GitHub's approach to scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)). + +## Features + +- **Token Bucket Algorithm**: Implements a token bucket algorithm to control the number of actions (e.g., API requests) a user can perform within a specified time window. +- **Customizable Limits**: Allows configuration of request limits and time windows to suit various application requirements. +- **Distributed Rate Limiting**: Leverages Redis to maintain rate limit counters, ensuring consistency across distributed environments. +- **Reset Information**: Provides `ResetAtMs` timestamps to inform clients when they can retry requests. + +## Installation + +To install the `rueidislimiter` module, run: + +```bash +go get github.com/redis/rueidis/rueidislimiter +``` + +## Usage + +### Basic Rate Limiting Example + +The following example demonstrates how to initialize a rate limiter with a custom request limit and time window, and how to check and allow requests based on an identifier (e.g., a user ID or IP address): + +```go +package main + +import ( + "context" + "fmt" + "time" + + "github.com/redis/rueidis" + "github.com/redis/rueidis/rueidislimiter" +) + +func main() { + client, err := rueidis.NewClient(rueidis.ClientOption{ + InitAddress: []string{"localhost:6379"}, + }) + if err != nil { + panic(err) + } + + // Initialize a new rate limiter with a limit of 5 requests per minute + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, + KeyPrefix: "api_rate_limit", + Limit: 5, + Window: time.Minute, + }) + if err != nil { + panic(err) + } + + identifier := "user_123" + + // Check if a request is allowed + result, err := limiter.Check(context.Background(), identifier) + if err != nil { + panic(err) + } + fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) + + // Allow a request + result, err = limiter.Allow(context.Background(), identifier) + if err != nil { + panic(err) + } + fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) + + // Allow multiple requests + result, err = limiter.AllowN(context.Background(), identifier, 3) + if err != nil { + panic(err) + } + fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) +} +``` + +### API + +#### `NewRateLimiter` + +Creates a new rate limiter with the specified options: + +- `ClientOption`: Options to connect to Redis. +- `KeyPrefix`: Prefix for Redis keys used by this limiter. +- `Limit`: Maximum number of allowed requests per window. +- `Window`: Time window duration for rate limiting. Must be greater than 1 millisecond. + +```go +limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, + KeyPrefix: "api_rate_limit", + Limit: 5, + Window: time.Second, +}) +``` + +#### `Check` + +Checks if a request is allowed under the rate limit without incrementing the count. + +```go +result, err := limiter.Check(ctx, "user_identifier") +``` + +Returns a `Result` struct: + +- `Allowed`: Whether the request is allowed. +- `Remaining`: Number of remaining requests in the current window. +- `ResetAtMs`: Unix timestamp in milliseconds at which the rate limit will reset. + +#### `Allow` + +Allows a single request, incrementing the counter if allowed. + +```go +result, err := limiter.Allow(ctx, "user_identifier") +``` + +#### `AllowN` + +Allows `n` requests, incrementing the counter accordingly if allowed. + +```go +result, err := limiter.AllowN(ctx, "user_identifier", 3) +``` + +- `n`: The number of requests to allow. + +## Implementation Details + +The `rueidislimiter` module employs Lua scripts executed within Redis to ensure atomic operations for checking and updating rate limits. This approach minimizes race conditions and maintains consistency across distributed systems. + +By utilizing Redis's expiration capabilities, the module automatically resets rate limits after the specified time window, ensuring efficient memory usage and accurate rate limiting behavior. + +For more information on the design and implementation of Redis-based rate limiters, refer to GitHub's detailed account of scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)). diff --git a/rueidislimiter/go.mod b/rueidislimiter/go.mod new file mode 100644 index 00000000..a4cf456f --- /dev/null +++ b/rueidislimiter/go.mod @@ -0,0 +1,9 @@ +module github.com/redis/rueidis/rueidislimiter + +go 1.21 + +replace github.com/redis/rueidis => ../ + +require github.com/redis/rueidis v1.0.48 + +require golang.org/x/sys v0.24.0 // indirect diff --git a/rueidislimiter/go.sum b/rueidislimiter/go.sum new file mode 100644 index 00000000..8c3c04f5 --- /dev/null +++ b/rueidislimiter/go.sum @@ -0,0 +1,14 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rueidislimiter/limiter.go b/rueidislimiter/limiter.go new file mode 100644 index 00000000..7417bae8 --- /dev/null +++ b/rueidislimiter/limiter.go @@ -0,0 +1,153 @@ +package rueidislimiter + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/redis/rueidis" +) + +var ( + ErrInvalidTokens = errors.New("number of tokens must be non-negative") + ErrInvalidResponse = errors.New("invalid response from Redis") +) + +type Result struct { + Allowed bool + Remaining int64 + ResetAtMs int64 +} + +type RateLimiterClient interface { + Check(ctx context.Context, identifier string) (Result, error) + Allow(ctx context.Context, identifier string) (Result, error) + AllowN(ctx context.Context, identifier string, n int64) (Result, error) +} + +const PlaceholderPrefix = "rueidislimiter" + +type rateLimiter struct { + client rueidis.Client + keyPrefix string + limit int + window time.Duration +} + +type RateLimiterOption struct { + ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) + ClientOption rueidis.ClientOption + KeyPrefix string + Limit int + Window time.Duration +} + +func NewRateLimiter(option RateLimiterOption) (RateLimiterClient, error) { + if option.Window < time.Millisecond { + option.Window = time.Millisecond + } + if option.Limit <= 0 { + option.Limit = 1 + } + if option.KeyPrefix == "" { + option.KeyPrefix = PlaceholderPrefix + } + + rl := &rateLimiter{ + limit: option.Limit, + window: option.Window, + } + + var err error + if option.ClientBuilder != nil { + rl.client, err = option.ClientBuilder(option.ClientOption) + } else { + rl.client, err = rueidis.NewClient(option.ClientOption) + } + if err != nil { + return nil, err + } + rl.keyPrefix = option.KeyPrefix + return rl, nil +} + +func (l *rateLimiter) Limit() int { + return l.limit +} + +func (l *rateLimiter) Check(ctx context.Context, identifier string) (Result, error) { + return l.AllowN(ctx, identifier, 0) +} + +func (l *rateLimiter) Allow(ctx context.Context, identifier string) (Result, error) { + return l.AllowN(ctx, identifier, 1) +} + +func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64) (Result, error) { + if n < 0 { + return Result{}, ErrInvalidTokens + } + + now := time.Now().UTC() + keys := []string{l.getKey(identifier)} + args := []string{ + strconv.FormatInt(n, 10), + strconv.FormatInt(now.Add(l.window).UnixMilli(), 10), + strconv.FormatInt(now.UnixMilli(), 10), + } + + resp := rateLimitScript.Exec(ctx, l.client, keys, args) + if err := resp.Error(); err != nil { + return Result{}, err + } + + data, err := resp.AsIntSlice() + if err != nil || len(data) != 2 { + return Result{}, ErrInvalidResponse + } + + current := data[0] + remaining := int64(l.limit) - current + if remaining < 0 { + remaining = 0 + } + + allowed := current <= int64(l.limit) + if n == 0 { + allowed = current < int64(l.limit) + } + + return Result{ + Allowed: allowed, + Remaining: remaining, + ResetAtMs: data[1], + }, nil +} + +func (l *rateLimiter) getKey(identifier string) string { + sb := strings.Builder{} + sb.Grow(len(l.keyPrefix) + len(identifier) + 3) + sb.WriteString(l.keyPrefix) + sb.WriteString(":{") + sb.WriteString(identifier) + sb.WriteString("}") + return sb.String() +} + +var rateLimitScript = rueidis.NewLuaScript(` +local rate_limit_key = KEYS[1] +local increment_amount = tonumber(ARGV[1]) +local next_expires_at = tonumber(ARGV[2]) +local current_time = tonumber(ARGV[3]) +local expires_at_key = rate_limit_key .. ":ex" +local expires_at = tonumber(redis.call("get", expires_at_key)) +if not expires_at or expires_at < current_time then + redis.call("set", rate_limit_key, 0, "pxat", next_expires_at + 1000) + redis.call("set", expires_at_key, next_expires_at, "pxat", next_expires_at + 1000) + expires_at = next_expires_at +end +local current = redis.call("incrby", rate_limit_key, increment_amount) +return { current, expires_at } +`) diff --git a/rueidislimiter/limiter_test.go b/rueidislimiter/limiter_test.go new file mode 100644 index 00000000..a22cc8ed --- /dev/null +++ b/rueidislimiter/limiter_test.go @@ -0,0 +1,209 @@ +package rueidislimiter_test + +import ( + "context" + "encoding/binary" + "encoding/hex" + "math/rand" + "testing" + "time" + "unsafe" + + "github.com/redis/rueidis" + "github.com/redis/rueidis/rueidislimiter" +) + +func setup(t testing.TB) rueidis.Client { + client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: []string{"127.0.0.1:6379"}}) + if err != nil { + t.Fatal(err) + } + return client +} + +func TestRateLimiter(t *testing.T) { + client := setup(t) + t.Cleanup(client.Close) + + now := time.Now() + window := 100 * time.Millisecond + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + Limit: 3, + Window: window, + }) + if err != nil { + t.Fatal(err) + } + + t.Run("Check defaults", func(t *testing.T) { + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + }) + if err != nil { + t.Fatal(err) + } + result, err := limiter.Check(context.Background(), randStr()) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 1 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=true, Remaining=1, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("Check allowed within limit", func(t *testing.T) { + result, err := limiter.Check(context.Background(), randStr()) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 3 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=true, Remaining=3, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("Check denied after exceeding limit", func(t *testing.T) { + key := randStr() + generateLoad(t, limiter, key, 3) + + result, err := limiter.Check(context.Background(), key) + if err != nil { + t.Fatal(err) + } + if result.Allowed || result.Remaining != 0 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=false, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("Check allowed after window reset", func(t *testing.T) { + key := randStr() + generateLoad(t, limiter, key, 3) + + // Sleep for slightly longer than window duration to ensure reset + time.Sleep(window * 2) + result, err := limiter.Check(context.Background(), key) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 3 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=true, Remaining=3, ResetAt >= now after reset; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("AllowN defaults", func(t *testing.T) { + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + }) + if err != nil { + t.Fatal(err) + } + result, err := limiter.AllowN(context.Background(), randStr(), 1) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 0 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=true, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("AllowN with tokens within limit", func(t *testing.T) { + key := randStr() + result, err := limiter.AllowN(context.Background(), key, 1) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 2 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=true, Remaining=2, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("AllowN denied after exceeding limit", func(t *testing.T) { + key := randStr() + generateLoad(t, limiter, key, 3) + + result, err := limiter.AllowN(context.Background(), key, 1) + if err != nil { + t.Fatal(err) + } + if result.Allowed || result.Remaining != 0 || result.ResetAtMs < now.UnixMilli() { + t.Fatalf("Expected Allowed=false, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs) + } + }) + + t.Run("AllowN with zero tokens", func(t *testing.T) { + key := randStr() + result, err := limiter.AllowN(context.Background(), key, 0) + if err != nil { + t.Fatal(err) + } + if !result.Allowed { + t.Fatalf("Expected Allowed=true when allowing zero tokens, but got false") + } + }) + + t.Run("AllowN with negative tokens", func(t *testing.T) { + key := randStr() + result, err := limiter.AllowN(context.Background(), key, -1) + if err == nil { + t.Fatalf("Expected error for negative tokens, but got nil") + } + if result.Allowed { + t.Fatalf("Expected Allowed=false when allowing negative tokens, but got true") + } + }) +} + +func BenchmarkRateLimiter(b *testing.B) { + client := setup(b) + defer client.Close() + + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + }) + if err != nil { + b.Fatal(err) + } + key := randStr() + + b.ResetTimer() + b.ReportAllocs() + + b.Run("Check", func(b *testing.B) { + for i := 0; i < b.N; i++ { + limiter.Check(context.Background(), key) + } + }) + + b.Run("AllowN", func(b *testing.B) { + for i := 0; i < b.N; i++ { + limiter.AllowN(context.Background(), key, 1) + } + }) +} + +func generateLoad(t *testing.T, limiter rueidislimiter.RateLimiterClient, key string, n int) { + for i := 0; i < n; i++ { + _, err := limiter.Allow(context.Background(), key) + if err != nil { + t.Fatal(err) + } + } +} + +// randStr generates a 24-byte long, random string. +func randStr() string { + b := make([]byte, 24) + binary.LittleEndian.PutUint64(b[12:], rand.Uint64()) + binary.LittleEndian.PutUint32(b[20:], rand.Uint32()) + hex.Encode(b, b[12:]) + + return unsafe.String(unsafe.SliceData(b), len(b)) +}