diff --git a/rueidislimiter/README.md b/rueidislimiter/README.md new file mode 100644 index 00000000..a9d0a770 --- /dev/null +++ b/rueidislimiter/README.md @@ -0,0 +1,142 @@ + +# rueidislimiter + +A Redis-based rate limiter for Go, implemented using the `rueidis` client library. This module provides an interface for token bucket rate limiting with precise control over limits and time windows. Inspired by GitHub's approach to scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)), this module offers a robust solution for managing request rates across distributed systems. + +## Features + +- **Token Bucket Algorithm**: Implements a token bucket algorithm to control the number of actions (e.g., API requests) a user can perform within a specified time window. +- **Customizable Limits**: Allows configuration of request limits and time windows to suit various application requirements. +- **Distributed Rate Limiting**: Leverages Redis to maintain rate limit counters, ensuring consistency across distributed environments. +- **Retry and Reset Information**: Provides `RetryAfter` and `ResetAfter` durations to inform clients when they can retry requests. + +## Installation + +To install the `rueidislimiter` module, run: + +```bash +go get github.com/redis/rueidis/rueidislimiter +``` + +## Usage + +### Basic Rate Limiting Example + +The following example demonstrates how to initialize a rate limiter with a custom request limit and time window, and how to check and allow requests based on an identifier (e.g., a user ID or IP address): + +```go +package main + +import ( + "context" + "fmt" + "time" + + "github.com/redis/rueidis" + "github.com/redis/rueidis/rueidislimiter" +) + +func main() { + client, err := rueidis.NewClient(rueidis.ClientOption{ + InitAddress: []string{"localhost:6379"}, + }) + if err != nil { + panic(err) + } + + // Initialize a new rate limiter with a limit of 5 requests per minute + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, + Key: "api_rate_limit", + Limit: 5, + Window: time.Minute, + }) + if err != nil { + panic(err) + } + + identifier := "user_123" + + // Check if a request is allowed + result, err := limiter.Check(context.Background(), identifier) + if err != nil { + panic(err) + } + fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) + + // Allow a request + result, err = limiter.Allow(context.Background(), identifier) + if err != nil { + panic(err) + } + fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) + + // Allow multiple requests + result, err = limiter.AllowN(context.Background(), identifier, 3) + if err != nil { + panic(err) + } + fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) +} +``` + +### API + +#### `NewRateLimiter` + +Creates a new rate limiter with the specified options: + +- `ClientOption`: Options to connect to Redis. +- `Key`: Prefix for Redis keys used by this limiter. +- `Limit`: Maximum number of allowed requests per window. +- `Window`: Time window duration for rate limiting. + +```go +limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, + Key: "api_rate_limit", + Limit: 5, + Window: time.Minute, +}) +``` + +#### `Check` + +Checks if a request is allowed under the rate limit without incrementing the count. + +```go +result, err := limiter.Check(ctx, "user_identifier") +``` + +Returns a `Result` struct: + +- `Allowed`: Whether the request is allowed. +- `Remaining`: Number of remaining requests in the current window. +- `RetryAfter`: Duration until the next allowed request (0 if allowed). +- `ResetAfter`: Duration until the current rate limit window resets. + +#### `Allow` + +Allows a single request, incrementing the counter if allowed. + +```go +result, err := limiter.Allow(ctx, "user_identifier") +``` + +#### `AllowN` + +Allows `n` requests, incrementing the counter accordingly if allowed. + +```go +result, err := limiter.AllowN(ctx, "user_identifier", 3) +``` + +- `n`: The number of requests to allow. + +## Implementation Details + +The `rueidislimiter` module employs Lua scripts executed within Redis to ensure atomic operations for checking and updating rate limits. This approach minimizes race conditions and maintains consistency across distributed systems. + +By utilizing Redis's expiration capabilities, the module automatically resets rate limits after the specified time window, ensuring efficient memory usage and accurate rate limiting behavior. + +For more information on the design and implementation of Redis-based rate limiters, refer to GitHub's detailed account of scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)). diff --git a/rueidislimiter/go.mod b/rueidislimiter/go.mod new file mode 100644 index 00000000..18bbfd22 --- /dev/null +++ b/rueidislimiter/go.mod @@ -0,0 +1,9 @@ +module github.com/redis/rueidis/rueidislimiter + +go 1.22 + +replace github.com/redis/rueidis => ../ + +require github.com/redis/rueidis v1.0.48 + +require golang.org/x/sys v0.24.0 // indirect diff --git a/rueidislimiter/go.sum b/rueidislimiter/go.sum new file mode 100644 index 00000000..8c3c04f5 --- /dev/null +++ b/rueidislimiter/go.sum @@ -0,0 +1,14 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rueidislimiter/limiter.go b/rueidislimiter/limiter.go new file mode 100644 index 00000000..9fddf910 --- /dev/null +++ b/rueidislimiter/limiter.go @@ -0,0 +1,243 @@ +package rueidislimiter + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/redis/rueidis" +) + +type Result struct { + // Allowed indicates if the request is allowed under the rate limit + Allowed bool + // Remaining is the number of remaining tokens in the current window + Remaining int64 + // ResetAfter is the duration until the rate limit resets + ResetAfter time.Duration + // RetryAfter is the duration after which the request may be retried (0 if allowed) + RetryAfter time.Duration +} + +type RateLimiterClient interface { + Check(ctx context.Context, identifier string) (Result, error) + Allow(ctx context.Context, identifier string) (Result, error) + AllowN(ctx context.Context, identifier string, n int64) (Result, error) +} + +const PlaceholderPrefix = "rueidislimiter" + +type rateLimiter struct { + client rueidis.Client + key string + limit int + window time.Duration +} + +type RateLimiterOption struct { + // ClientBuilder can be used to modify rueidis.Client used by RateLimiter + ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) + ClientOption rueidis.ClientOption + Key string + Limit int + Window time.Duration +} + +func NewRateLimiter(option RateLimiterOption) (rlc RateLimiterClient, err error) { + if option.Window <= 0 { + option.Window = time.Second + } + if option.Limit <= 0 { + option.Limit = 1 + } + if option.Key == "" { + option.Key = PlaceholderPrefix + } + rl := &rateLimiter{ + limit: option.Limit, + window: option.Window, + } + if option.ClientBuilder != nil { + rl.client, err = option.ClientBuilder(option.ClientOption) + } else { + rl.client, err = rueidis.NewClient(option.ClientOption) + } + if err != nil { + return nil, err + } + rl.key = option.Key + return rl, nil +} + +func (l *rateLimiter) Check(ctx context.Context, identifier string) (Result, error) { + now := time.Now().UTC() + keys := []string{l.getKey(identifier)} + args := []string{strconv.FormatInt(now.Unix(), 10)} + + resp := checkLimitScript.Exec(ctx, l.client, keys, args) + if resp.Error() != nil { + return Result{}, fmt.Errorf("failed to execute check limit script: %w", resp.Error()) + } + + array, err := resp.ToArray() + if err != nil { + return Result{}, fmt.Errorf("failed to parse response array: %w", err) + } + + return l.parseResult(array, now) +} + +func (l *rateLimiter) Allow(ctx context.Context, identifier string) (Result, error) { + return l.AllowN(ctx, identifier, 1) +} + +func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64) (Result, error) { + if n <= 0 { + return Result{}, fmt.Errorf("number of tokens must be positive, got %d", n) + } + + now := time.Now().UTC() + keys := []string{l.getKey(identifier)} + args := []string{ + strconv.FormatInt(n, 10), + strconv.FormatInt(now.Add(l.window).Unix(), 10), + strconv.FormatInt(now.Unix(), 10), + } + + resp := rateLimitScript.Exec(ctx, l.client, keys, args) + if resp.Error() != nil { + return Result{}, fmt.Errorf("failed to execute rate limit script: %w", resp.Error()) + } + + array, err := resp.ToArray() + if err != nil { + return Result{}, fmt.Errorf("failed to parse response array: %w", err) + } + + return l.parseResult(array, now) +} + +func (l *rateLimiter) parseResult(array []rueidis.RedisMessage, now time.Time) (r Result, err error) { + var current int64 + var expiresAt int64 + + if len(array) == 2 { + current, err = array[0].ToInt64() + if err != nil { + if !rueidis.IsRedisNil(err) { + return Result{}, err + } + } + + expiresAt, err = array[1].ToInt64() + if err != nil { + if !rueidis.IsRedisNil(err) { + return Result{}, err + } + } + } + + if expiresAt == 0 { + current = 0 + expiresAt = now.Unix() + int64(l.window.Seconds()) + } + + remaining := int64(l.limit) - current + if remaining < 0 { + remaining = 0 + } + + allowed := current < int64(l.limit) + resetAfter := time.Until(time.Unix(expiresAt, 0)) + if resetAfter <= 0 { + resetAfter = l.window + } + retryAfter := resetAfter + if allowed { + retryAfter = 0 + } + + return Result{ + Allowed: allowed, + Remaining: remaining, + ResetAfter: resetAfter, + RetryAfter: retryAfter, + }, nil +} + +func (l *rateLimiter) getKey(identifier string) string { + // NOTE: https://redis.io/docs/reference/cluster-spec/#hash-tags + sb := strings.Builder{} + sb.Grow(len(l.key) + len(identifier) + 3) // +3 for ":", "{", "}" + sb.WriteString(l.key) + sb.WriteString(":{") + sb.WriteString(identifier) + sb.WriteString("}") + return sb.String() +} + +var ( + rateLimitScript = rueidis.NewLuaScript(` +-- count a request for a client +-- and return the current state for the client +-- rename the inputs for clarity below +local rate_limit_key = KEYS[1] +local increment_amount = tonumber(ARGV[1]) +local next_expires_at = tonumber(ARGV[2]) +local current_time = tonumber(ARGV[3]) +local expires_at_key = rate_limit_key .. ":ex" +local expires_at = tonumber(redis.call("get", expires_at_key)) +if not expires_at or expires_at < current_time then + -- this is either a brand new window, + -- or this window has closed, but redis hasn't cleaned up the key yet + -- (redis will clean it up in one more second) + -- initialize a new rate limit window + redis.call("set", rate_limit_key, 0) + redis.call("set", expires_at_key, next_expires_at) + -- tell Redis to clean this up _one second after_ the expires-at time. + -- that way, clock differences between the application and Redis won't cause data to disappear. + -- (Redis will only clean up these keys "long after" the window has passed) + redis.call("expireat", rate_limit_key, next_expires_at + 1) + redis.call("expireat", expires_at_key, next_expires_at + 1) + -- since the database was updated, return the new value + expires_at = next_expires_at +end +-- Now that the window is either known to already exist _or_ be freshly initialized, +-- increment the counter ('incrby' returns a number) +local current = redis.call("incrby", rate_limit_key, increment_amount) +return { current, expires_at } +`) + checkLimitScript = rueidis.NewLuaScriptReadOnly(` +-- Getting both the value and the expiration +-- of key as needed by our algorithm needs to be ran +-- in an atomic way, hence the script. + +-- rename the inputs for clarity below +local rate_limit_key = KEYS[1] +local expires_at_key = rate_limit_key .. ":ex" +local current_time = tonumber(ARGV[1]) +local tries = tonumber(redis.call("get", rate_limit_key)) +local expires_at = nil -- maybe overridden below +if not tries then + -- this client hasn't initialized a window yet + -- let this fall through to returning {nil, nil}, + -- where the application will provide defaults + tries = nil +else + -- we found a number of tries, now check + -- if this window is actually expired + expires_at = tonumber(redis.call("get", expires_at_key)) + if not expires_at or expires_at < current_time then + -- this window hasn't been cleaned up by Redis yet, but it has closed. + -- (maybe it was _partly_ cleaned up, if we found 'tries' but not 'expires_at') + -- ignore the data in the database; return a fresh window instead + tries = nil + expires_at = nil + end +end +-- Return {nil, nil} if the window is brand new (or expired) +return { tries, expires_at } +`) +) diff --git a/rueidislimiter/limiter_test.go b/rueidislimiter/limiter_test.go new file mode 100644 index 00000000..efb5e678 --- /dev/null +++ b/rueidislimiter/limiter_test.go @@ -0,0 +1,188 @@ +package rueidislimiter_test + +import ( + "context" + "encoding/binary" + "encoding/hex" + "math/rand/v2" + "testing" + "time" + "unsafe" + + "github.com/redis/rueidis" + "github.com/redis/rueidis/rueidislimiter" +) + +func setup(t testing.TB) rueidis.Client { + client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: []string{"127.0.0.1:6379"}}) + if err != nil { + t.Fatal(err) + } + return client +} + +func TestRateLimiter(t *testing.T) { + client := setup(t) + defer client.Close() + + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + Limit: 3, + Window: time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + t.Run("Check defaults", func(t *testing.T) { + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + }) + if err != nil { + t.Fatal(err) + } + result, err := limiter.Check(context.Background(), randStr()) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 1 || result.RetryAfter != 0 { + t.Fatalf("Expected Allowed=true, Remaining=1, RetryAfter=0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + } + }) + + t.Run("Check allowed within limit", func(t *testing.T) { + result, err := limiter.Check(context.Background(), randStr()) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 3 || result.RetryAfter != 0 { + t.Fatalf("Expected Allowed=true, Remaining=3, RetryAfter=0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + } + }) + + t.Run("Check denied after exceeding limit", func(t *testing.T) { + key := randStr() + generateLoad(t, limiter, key, 3) + + result, err := limiter.Check(context.Background(), key) + if err != nil { + t.Fatal(err) + } + if result.Allowed || result.Remaining != 0 || result.RetryAfter <= 0 { + t.Fatalf("Expected Allowed=false, Remaining=0, RetryAfter > 0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + } + }) + + t.Run("Check allowed after window reset", func(t *testing.T) { + key := randStr() + generateLoad(t, limiter, key, 3) + + time.Sleep(time.Second) // Wait for window to reset + result, err := limiter.Check(context.Background(), key) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 3 || result.RetryAfter != 0 { + t.Fatalf("Expected Allowed=true, Remaining=3, RetryAfter=0 after reset; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + } + }) + + t.Run("AllowN with tokens within limit", func(t *testing.T) { + key := randStr() + result, err := limiter.AllowN(context.Background(), key, 1) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 2 || result.RetryAfter != 0 { + t.Fatalf("Expected Allowed=true, Remaining=2, RetryAfter=0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + } + }) + + t.Run("AllowN denied after exceeding limit", func(t *testing.T) { + key := randStr() + generateLoad(t, limiter, key, 3) + + result, err := limiter.AllowN(context.Background(), key, 1) + if err != nil { + t.Fatal(err) + } + if result.Allowed || result.Remaining != 0 || result.RetryAfter <= 0 { + t.Fatalf("Expected Allowed=false, Remaining=0, RetryAfter > 0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + } + }) + + t.Run("AllowN with zero tokens", func(t *testing.T) { + key := randStr() + result, err := limiter.AllowN(context.Background(), key, 0) + if err == nil { + t.Fatalf("Expected error for zero tokens, but got nil") + } + if result.Allowed { + t.Fatalf("Expected Allowed=false when allowing zero tokens, but got true") + } + }) + + t.Run("AllowN with negative tokens", func(t *testing.T) { + key := randStr() + result, err := limiter.AllowN(context.Background(), key, -1) + if err == nil { + t.Fatalf("Expected error for negative tokens, but got nil") + } + if result.Allowed { + t.Fatalf("Expected Allowed=false when allowing negative tokens, but got true") + } + }) +} + +func BenchmarkRateLimiter(b *testing.B) { + client := setup(b) + defer client.Close() + + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + }) + if err != nil { + b.Fatal(err) + } + key := randStr() + + b.ResetTimer() + b.ReportAllocs() + + b.Run("Check", func(b *testing.B) { + for i := 0; i < b.N; i++ { + limiter.Check(context.Background(), key) + } + }) + + b.Run("AllowN", func(b *testing.B) { + for i := 0; i < b.N; i++ { + limiter.AllowN(context.Background(), key, 1) + } + }) +} + +func generateLoad(t *testing.T, limiter rueidislimiter.RateLimiterClient, key string, n int) { + for i := 0; i < n; i++ { + _, err := limiter.Allow(context.Background(), key) + if err != nil { + t.Fatal(err) + } + } +} + +// randStr generates a 24-byte long, random string. +func randStr() string { + b := make([]byte, 24) + binary.LittleEndian.PutUint64(b[12:], rand.Uint64()) + binary.LittleEndian.PutUint32(b[20:], rand.Uint32()) + hex.Encode(b, b[12:]) + + return unsafe.String(unsafe.SliceData(b), len(b)) +}