-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement rueidislimiter module for distributed rate limiting
Signed-off-by: Ernesto Alejandro Santana Hidalgo <[email protected]>
- Loading branch information
Showing
5 changed files
with
526 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
|
||
# rueidislimiter | ||
|
||
This module provides an interface for token bucket rate limiting with precise control over limits and time windows. Inspired by GitHub's approach to scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)). | ||
|
||
## Features | ||
|
||
- **Token Bucket Algorithm**: Implements a token bucket algorithm to control the number of actions (e.g., API requests) a user can perform within a specified time window. | ||
- **Customizable Limits**: Allows configuration of request limits and time windows to suit various application requirements. | ||
- **Distributed Rate Limiting**: Leverages Redis to maintain rate limit counters, ensuring consistency across distributed environments. | ||
- **Reset Information**: Provides `ResetAtMs` timestamps to inform clients when they can retry requests. | ||
|
||
## Installation | ||
|
||
To install the `rueidislimiter` module, run: | ||
|
||
```bash | ||
go get github.com/redis/rueidis/rueidislimiter | ||
``` | ||
|
||
## Usage | ||
|
||
### Basic Rate Limiting Example | ||
|
||
The following example demonstrates how to initialize a rate limiter with a custom request limit and time window, and how to check and allow requests based on an identifier (e.g., a user ID or IP address): | ||
|
||
```go | ||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/redis/rueidis" | ||
"github.com/redis/rueidis/rueidislimiter" | ||
) | ||
|
||
func main() { | ||
client, err := rueidis.NewClient(rueidis.ClientOption{ | ||
InitAddress: []string{"localhost:6379"}, | ||
}) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
// Initialize a new rate limiter with a limit of 5 requests per minute | ||
limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ | ||
ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, | ||
KeyPrefix: "api_rate_limit", | ||
Limit: 5, | ||
Window: time.Minute, | ||
}) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
identifier := "user_123" | ||
|
||
// Check if a request is allowed | ||
result, err := limiter.Check(context.Background(), identifier) | ||
if err != nil { | ||
panic(err) | ||
} | ||
fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) | ||
|
||
// Allow a request | ||
result, err = limiter.Allow(context.Background(), identifier) | ||
if err != nil { | ||
panic(err) | ||
} | ||
fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) | ||
|
||
// Allow multiple requests | ||
result, err = limiter.AllowN(context.Background(), identifier, 3) | ||
if err != nil { | ||
panic(err) | ||
} | ||
fmt.Printf("Allowed: %v, Remaining: %d, RetryAfter: %v\n", result.Allowed, result.Remaining, result.RetryAfter) | ||
} | ||
``` | ||
|
||
### API | ||
|
||
#### `NewRateLimiter` | ||
|
||
Creates a new rate limiter with the specified options: | ||
|
||
- `ClientOption`: Options to connect to Redis. | ||
- `KeyPrefix`: Prefix for Redis keys used by this limiter. | ||
- `Limit`: Maximum number of allowed requests per window. | ||
- `Window`: Time window duration for rate limiting. Must be greater than 1 millisecond. | ||
|
||
```go | ||
limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ | ||
ClientOption: rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, | ||
KeyPrefix: "api_rate_limit", | ||
Limit: 5, | ||
Window: time.Second, | ||
}) | ||
``` | ||
|
||
#### `Check` | ||
|
||
Checks if a request is allowed under the rate limit without incrementing the count. | ||
|
||
```go | ||
result, err := limiter.Check(ctx, "user_identifier") | ||
``` | ||
|
||
Returns a `Result` struct: | ||
|
||
- `Allowed`: Whether the request is allowed. | ||
- `Remaining`: Number of remaining requests in the current window. | ||
- `ResetAtMs`: Unix timestamp in milliseconds at which the rate limit will reset. | ||
|
||
#### `Allow` | ||
|
||
Allows a single request, incrementing the counter if allowed. | ||
|
||
```go | ||
result, err := limiter.Allow(ctx, "user_identifier") | ||
``` | ||
|
||
#### `AllowN` | ||
|
||
Allows `n` requests, incrementing the counter accordingly if allowed. | ||
|
||
```go | ||
result, err := limiter.AllowN(ctx, "user_identifier", 3) | ||
``` | ||
|
||
- `n`: The number of requests to allow. | ||
|
||
## Implementation Details | ||
|
||
The `rueidislimiter` module employs Lua scripts executed within Redis to ensure atomic operations for checking and updating rate limits. This approach minimizes race conditions and maintains consistency across distributed systems. | ||
|
||
By utilizing Redis's expiration capabilities, the module automatically resets rate limits after the specified time window, ensuring efficient memory usage and accurate rate limiting behavior. | ||
|
||
For more information on the design and implementation of Redis-based rate limiters, refer to GitHub's detailed account of scaling their API with a sharded, replicated rate limiter in Redis ([github.blog](https://github.blog/engineering/infrastructure/how-we-scaled-github-api-sharded-replicated-rate-limiter-redis/)). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module github.com/redis/rueidis/rueidislimiter | ||
|
||
go 1.21 | ||
|
||
replace github.com/redis/rueidis => ../ | ||
|
||
require github.com/redis/rueidis v1.0.48 | ||
|
||
require golang.org/x/sys v0.24.0 // indirect |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= | ||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= | ||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= | ||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= | ||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= | ||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= | ||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= | ||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= | ||
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= | ||
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | ||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | ||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
package rueidislimiter | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
"github.com/redis/rueidis" | ||
) | ||
|
||
var ( | ||
ErrInvalidTokens = errors.New("number of tokens must be non-negative") | ||
ErrInvalidResponse = errors.New("invalid response from Redis") | ||
) | ||
|
||
type Result struct { | ||
Allowed bool | ||
Remaining int64 | ||
ResetAtMs int64 | ||
} | ||
|
||
type RateLimiterClient interface { | ||
Check(ctx context.Context, identifier string) (Result, error) | ||
Allow(ctx context.Context, identifier string) (Result, error) | ||
AllowN(ctx context.Context, identifier string, n int64) (Result, error) | ||
} | ||
|
||
const PlaceholderPrefix = "rueidislimiter" | ||
|
||
type rateLimiter struct { | ||
client rueidis.Client | ||
keyPrefix string | ||
limit int | ||
window time.Duration | ||
} | ||
|
||
type RateLimiterOption struct { | ||
ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) | ||
ClientOption rueidis.ClientOption | ||
KeyPrefix string | ||
Limit int | ||
Window time.Duration | ||
} | ||
|
||
func NewRateLimiter(option RateLimiterOption) (RateLimiterClient, error) { | ||
if option.Window < time.Millisecond { | ||
option.Window = time.Millisecond | ||
} | ||
if option.Limit <= 0 { | ||
option.Limit = 1 | ||
} | ||
if option.KeyPrefix == "" { | ||
option.KeyPrefix = PlaceholderPrefix | ||
} | ||
|
||
rl := &rateLimiter{ | ||
limit: option.Limit, | ||
window: option.Window, | ||
} | ||
|
||
var err error | ||
if option.ClientBuilder != nil { | ||
rl.client, err = option.ClientBuilder(option.ClientOption) | ||
} else { | ||
rl.client, err = rueidis.NewClient(option.ClientOption) | ||
} | ||
if err != nil { | ||
return nil, err | ||
} | ||
rl.keyPrefix = option.KeyPrefix | ||
return rl, nil | ||
} | ||
|
||
func (l *rateLimiter) Limit() int { | ||
return l.limit | ||
} | ||
|
||
func (l *rateLimiter) Check(ctx context.Context, identifier string) (Result, error) { | ||
return l.AllowN(ctx, identifier, 0) | ||
} | ||
|
||
func (l *rateLimiter) Allow(ctx context.Context, identifier string) (Result, error) { | ||
return l.AllowN(ctx, identifier, 1) | ||
} | ||
|
||
func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64) (Result, error) { | ||
if n < 0 { | ||
return Result{}, ErrInvalidTokens | ||
} | ||
|
||
now := time.Now().UTC() | ||
keys := []string{l.getKey(identifier)} | ||
args := []string{ | ||
strconv.FormatInt(n, 10), | ||
strconv.FormatInt(now.Add(l.window).UnixMilli(), 10), | ||
strconv.FormatInt(now.UnixMilli(), 10), | ||
} | ||
|
||
resp := rateLimitScript.Exec(ctx, l.client, keys, args) | ||
if err := resp.Error(); err != nil { | ||
return Result{}, err | ||
} | ||
|
||
data, err := resp.AsIntSlice() | ||
if err != nil || len(data) != 2 { | ||
return Result{}, ErrInvalidResponse | ||
} | ||
|
||
current := data[0] | ||
remaining := int64(l.limit) - current | ||
if remaining < 0 { | ||
remaining = 0 | ||
} | ||
|
||
allowed := current <= int64(l.limit) | ||
if n == 0 { | ||
allowed = current < int64(l.limit) | ||
} | ||
|
||
return Result{ | ||
Allowed: allowed, | ||
Remaining: remaining, | ||
ResetAtMs: data[1], | ||
}, nil | ||
} | ||
|
||
func (l *rateLimiter) getKey(identifier string) string { | ||
sb := strings.Builder{} | ||
sb.Grow(len(l.keyPrefix) + len(identifier) + 3) | ||
sb.WriteString(l.keyPrefix) | ||
sb.WriteString(":{") | ||
sb.WriteString(identifier) | ||
sb.WriteString("}") | ||
return sb.String() | ||
} | ||
|
||
var rateLimitScript = rueidis.NewLuaScript(` | ||
local rate_limit_key = KEYS[1] | ||
local increment_amount = tonumber(ARGV[1]) | ||
local next_expires_at = tonumber(ARGV[2]) | ||
local current_time = tonumber(ARGV[3]) | ||
local expires_at_key = rate_limit_key .. ":ex" | ||
local expires_at = tonumber(redis.call("get", expires_at_key)) | ||
if not expires_at or expires_at < current_time then | ||
redis.call("set", rate_limit_key, 0, "pxat", next_expires_at + 1000) | ||
redis.call("set", expires_at_key, next_expires_at, "pxat", next_expires_at + 1000) | ||
expires_at = next_expires_at | ||
end | ||
local current = redis.call("incrby", rate_limit_key, increment_amount) | ||
return { current, expires_at } | ||
`) |
Oops, something went wrong.