Skip to content

Commit

Permalink
Rate limiter: add an alternative implementation based on Redis. (obse…
Browse files Browse the repository at this point in the history
…rvatorium#620)

* Add a Redis rate limiter

* Add useful comments

* Allow Redis rate limiter to be used

* Regenerate readme

* Fix warning

* Fix tests in machine in CircleCI

* Tidy up go modules

Signed-off-by: Douglas Camata <[email protected]>

* Address PR review comments

* Regenerate readme

* Fix test assertion to account for negative time delta

* Fix Redis limiter tests

* Adjust test case comment

* Fix/improve code comments

* Keep Go 1.20

* Tidy with go1.20

Signed-off-by: Douglas Camata <[email protected]>

---------

Signed-off-by: Douglas Camata <[email protected]>
  • Loading branch information
douglascamata committed Feb 29, 2024
1 parent 1132fdb commit f2995f3
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 14 deletions.
8 changes: 6 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ jobs:
make lint --always-make
test:
docker:
- image: golang:1.20
machine:
image: ubuntu-2004:2023.04.2
steps:
- checkout
- run: |
apt-get update && apt-get -y install xz-utils unzip openssl
sudo rm -rf /usr/local/go
wget -qO- https://golang.org/dl/go1.20.linux-amd64.tar.gz | sudo tar -C /usr/local -xzf -
export PATH=$PATH:/usr/local/go/bin
go version
make test --always-make
test-e2e:
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ Usage of ./observatorium-api:
The number of concurrent requests that can buffered.
-middleware.concurrent-request-limit int
The limit that controls the number of concurrently processed requests across all tenants. (default 10000)
-middleware.rate-limiter.address value
The address of the rate limiter. Only used when not using the gRPC nor "local" rate limiters. Can be repeated to specify multiple addresses (i.e. Redis Cluster).
-middleware.rate-limiter.grpc-address string
The gRPC Server Address against which to run rate limit checks when the rate limits are specified for a given tenant. If not specified, local, non-shared rate limiting will be used.
The gRPC Server Address against which to run rate limit checks when the rate limits are specified for a given tenant. If not specified, local, non-shared rate limiting will be used. Has precedence over other rate limiter options.
-middleware.rate-limiter.type string
The type of rate limiter to use when not using a gRPC rate limiter. Options: 'local' (default), 'redis' (leaky bucket algorithm). (default "local")
-rbac.config string
Path to the RBAC configuration file. (default "rbac.yaml")
-server.read-header-timeout duration
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ require (
github.com/prometheus/client_golang v1.18.0
github.com/prometheus/common v0.48.0
github.com/prometheus/prometheus v0.50.1
github.com/redis/rueidis v1.0.31
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.48.0
go.opentelemetry.io/contrib/propagators/jaeger v1.24.0
go.opentelemetry.io/otel v1.24.0
Expand Down
4 changes: 3 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DV
github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4=
github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.31.1 h1:KYppCUK+bUgAZwHOu7EXVBKyQA6ILvOESHkn/tgoqvo=
github.com/open-policy-agent/opa v0.61.0 h1:nhncQ2CAYtQTV/SMBhDDPsCpCQsUW+zO/1j+T5V7oZg=
github.com/open-policy-agent/opa v0.61.0/go.mod h1:7OUuzJnsS9yHf8lw0ApfcbrnaRG1EkN3J2fuuqi4G/E=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
Expand Down Expand Up @@ -522,6 +522,8 @@ github.com/prometheus/prometheus v0.50.1 h1:N2L+DYrxqPh4WZStU+o1p/gQlBaqFbcLBTjl
github.com/prometheus/prometheus v0.50.1/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU=
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ=
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/redis/rueidis v1.0.31 h1:S2NlrMB1N+yB+QEKD4o0lV+5GNIeLo/ZMpN42ONcwg0=
github.com/redis/rueidis v1.0.31/go.mod h1:g8nPmgR4C68N3abFiOc/gUOSEKw3Tom6/teYMehg4RE=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
Expand Down
42 changes: 35 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ type tracesConfig struct {
}

type middlewareConfig struct {
rateLimiterAddress string
grpcRateLimiterAddress string
rateLimiterType string
rateLimiterAddress multiStringFlag
concurrentRequestLimit int
backLogLimitConcurrentRequests int
backLogDurationConcurrentRequests time.Duration
Expand Down Expand Up @@ -434,14 +436,21 @@ func main() {

defer undo()

var rateLimitClient *ratelimit.Client
var rateLimitClient ratelimit.SharedRateLimiter

if cfg.middleware.rateLimiterAddress != "" {
switch {
case cfg.middleware.grpcRateLimiterAddress != "":
ctx, cancel := context.WithTimeout(context.Background(), grpcDialTimeout)
defer cancel()

rateLimitClient = ratelimit.NewClient(reg)
if err := rateLimitClient.Dial(ctx, cfg.middleware.rateLimiterAddress); err != nil {
grpcRateLimiter := ratelimit.NewClient(reg)
if err := grpcRateLimiter.Dial(ctx, cfg.middleware.grpcRateLimiterAddress); err != nil {
stdlog.Fatal(err)
}
rateLimitClient = grpcRateLimiter
case cfg.middleware.rateLimiterType == "redis":
rateLimitClient, err = ratelimit.NewRedisRateLimiter([]string(cfg.middleware.rateLimiterAddress))
if err != nil {
stdlog.Fatal(err)
}
}
Expand Down Expand Up @@ -1012,6 +1021,20 @@ func (d *duration) UnmarshalJSON(b []byte) error {
}
}

// multiStringFlag is a type that implements the flag.Value interface.
type multiStringFlag []string

// Set appends a value to the slice.
func (m *multiStringFlag) Set(value string) error {
*m = append(*m, value)
return nil
}

// String returns a string representation of the slice.
func (m *multiStringFlag) String() string {
return strings.Join(*m, ", ")
}

//nolint:funlen,gocognit
func parseFlags() (config, error) {
var (
Expand Down Expand Up @@ -1157,9 +1180,14 @@ func parseFlags() (config, error) {
"Policy for TLS client-side authentication. Values are from ClientAuthType constants in https://pkg.go.dev/crypto/tls#ClientAuthType")
flag.DurationVar(&cfg.tls.reloadInterval, "tls.reload-interval", time.Minute,
"The interval at which to watch for TLS certificate changes.")
flag.StringVar(&cfg.middleware.rateLimiterAddress, "middleware.rate-limiter.grpc-address", "",
flag.StringVar(&cfg.middleware.grpcRateLimiterAddress, "middleware.rate-limiter.grpc-address", "",
"The gRPC Server Address against which to run rate limit checks when the rate limits are specified for a given tenant."+
" If not specified, local, non-shared rate limiting will be used.")
" If not specified, local, non-shared rate limiting will be used. Has precedence over other rate limiter options.")
flag.StringVar(&cfg.middleware.rateLimiterType, "middleware.rate-limiter.type", "local",
"The type of rate limiter to use when not using a gRPC rate limiter. Options: 'local' (default), 'redis' (leaky bucket algorithm).")
flag.Var(&cfg.middleware.rateLimiterAddress, "middleware.rate-limiter.address",
"The address of the rate limiter. Only used when not using the gRPC nor \"local\" rate limiters. "+
"Can be repeated to specify multiple addresses (i.e. Redis Cluster).")
flag.IntVar(&cfg.middleware.concurrentRequestLimit, "middleware.concurrent-request-limit", 10_000,
"The limit that controls the number of concurrently processed requests across all tenants.")
flag.IntVar(&cfg.middleware.backLogLimitConcurrentRequests, "middleware.backlog-limit-concurrent-requests", 0,
Expand Down
10 changes: 7 additions & 3 deletions ratelimit/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import (
var errOverLimit = errors.New("over limit")

type request struct {
name string
key string
limit int64
name string
key string
limit int64
// duration is the duration of the rate limit window in milliseconds.
duration int64
failOpen bool
retryAfterMin time.Duration
Expand All @@ -34,6 +35,9 @@ type Client struct {
}

type SharedRateLimiter interface {
// GetRateLimits retrieves the rate limits for a given request.
// It returns the remaining requests, the reset time as Unix time (millisecond from epoch), and any error that occurred.
// When a rate limit is exceeded, the error errOverLimit is returned.
GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error)
}

Expand Down
66 changes: 66 additions & 0 deletions ratelimit/gcra_rate_limit.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
-- this script has side-effects, so it requires replicate commands mode
redis.replicate_commands()

local rate_limit_key = KEYS[1] -- The key to the rate limit bucket.
local now = ARGV[1] -- Current time (Unix time in milliseconds).
local burst = ARGV[2] -- This represents the total capacity of the bucket.
local rate = ARGV[3] -- This represents the amount that leaks from the bucket.
local period = ARGV[4] -- This represents how often the "rate" leaks from the bucket (in milliseconds).
local cost = ARGV[5] -- This represents the cost of the request. Often 1 is used per request.
-- It allows some requests to be assigned a higher cost.

local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst

local tat = redis.call("GET", rate_limit_key)

if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)

local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at

local limited
local retry_in
local reset_in

local remaining = math.floor(diff / emission_interval) -- poor man's round

if remaining < 0 then
limited = 1
-- calculate how many tokens there actually are, since
-- remaining is how many there would have been if we had been able to limit
-- and we did not limit
remaining = math.floor((now - (tat - burst_offset)) / emission_interval)
reset_in = math.ceil(tat - now)
retry_in = math.ceil(diff * -1)
elseif remaining == 0 and increment <= 0 then
-- request with cost of 0
-- cost of 0 with remaining 0 is still limited
limited = 1
remaining = 0
reset_in = math.ceil(tat - now)
retry_in = 0 -- retry in is meaningless when cost is 0
else
limited = 0
reset_in = math.ceil(new_tat - now)
retry_in = 0
if increment > 0 then
redis.call("SET", rate_limit_key, new_tat, "PX", reset_in)
end
end

-- return values (in order):
-- limited = integer-encoded boolean, 1 if limited, 0 if not
-- remaining = number of tokens remaining
-- retry_in = milliseconds until the next request will be allowed
-- reset_in = milliseconds until the rate limit window resets
-- diff = milliseconds since the last request
-- emission_interval = milliseconds between token emissions
return {limited, remaining, retry_in, reset_in, tostring(diff), tostring(emission_interval)}
82 changes: 82 additions & 0 deletions ratelimit/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package ratelimit

import (
"context"
_ "embed"
"strconv"
"time"

"github.com/redis/rueidis"
)

//go:embed gcra_rate_limit.lua
var gcraRateLimitScript string

// RedisRateLimiter is a type that represents a rate limiter that uses Redis as its backend.
// The rate limiting is a leaky bucket implementation using the generic cell rate algorithm.
// See https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm for details on how this algorithm works.
type RedisRateLimiter struct {
client rueidis.Client
}

// Ensure RedisRateLimiter implements the SharedRateLimiter interface.
var _ SharedRateLimiter = (*RedisRateLimiter)(nil)

// NewRedisRateLimiter creates a new instance of RedisRateLimiter.
func NewRedisRateLimiter(addresses []string) (*RedisRateLimiter, error) {
client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: addresses})
if err != nil {
return nil, err
}
return &RedisRateLimiter{client: client}, nil
}

// GetRateLimits retrieves the rate limits for a given request using a Redis Rate Limiter.
// It returns the amount of remaining requests, the reset time in milliseconds, and any error that occurred.
func (r *RedisRateLimiter) GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error) {
inspectScript := rueidis.NewLuaScript(gcraRateLimitScript)
rateLimitParameters := []string{
strconv.FormatInt(time.Now().UnixMilli(), 10), // now
strconv.FormatInt(req.limit, 10), // burst
strconv.FormatInt(req.limit, 10), // rate
strconv.FormatInt(req.duration, 10), // period
"1", // cost
}
result := inspectScript.Exec(ctx, r.client, []string{req.key}, rateLimitParameters)
limited, remaining, resetIn, err := r.parseRateLimitResult(&result)
if err != nil {
return 0, 0, err
}
resetTime = time.Now().Add(time.Duration(resetIn) * time.Millisecond).UnixMilli()
if limited {
return remaining, resetTime, errOverLimit
}
return remaining, resetTime, nil
}

// parseRateLimitResult parses the result of a rate limit check from Redis.
// It takes a RedisResult as input and returns the parsed rate limit values: whether the request is limited,
// the number of remaining requests, the reset time in milliseconds, and any error that occurred during parsing.
func (r *RedisRateLimiter) parseRateLimitResult(result *rueidis.RedisResult) (limited bool, remaining, resetIn int64, err error) {
values, err := result.ToArray()
if err != nil {
return false, 0, 0, err
}

limited, err = values[0].AsBool()
if err != nil {
return false, 0, 0, err
}

remaining, err = values[1].AsInt64()
if err != nil {
return false, 0, 0, err
}

resetIn, err = values[3].AsInt64()
if err != nil {
return false, 0, 0, err
}

return limited, remaining, resetIn, nil
}
Loading

0 comments on commit f2995f3

Please sign in to comment.