diff --git a/.circleci/config.yml b/.circleci/config.yml index 691e6ba2d..f2ee05801 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,12 +19,16 @@ jobs: make lint --always-make test: - docker: - - image: golang:1.20 + machine: + image: ubuntu-2004:2023.04.2 steps: - checkout - run: | apt-get update && apt-get -y install xz-utils unzip openssl + sudo rm -rf /usr/local/go + wget -qO- https://golang.org/dl/go1.20.linux-amd64.tar.gz | sudo tar -C /usr/local -xzf - + export PATH=$PATH:/usr/local/go/bin + go version make test --always-make test-e2e: diff --git a/README.md b/README.md index d4252b984..44bb83ed6 100644 --- a/README.md +++ b/README.md @@ -143,8 +143,12 @@ Usage of ./observatorium-api: The number of concurrent requests that can buffered. -middleware.concurrent-request-limit int The limit that controls the number of concurrently processed requests across all tenants. (default 10000) + -middleware.rate-limiter.address value + The address of the rate limiter. Only used when not using the gRPC nor "local" rate limiters. Can be repeated to specify multiple addresses (i.e. Redis Cluster). -middleware.rate-limiter.grpc-address string - The gRPC Server Address against which to run rate limit checks when the rate limits are specified for a given tenant. If not specified, local, non-shared rate limiting will be used. + The gRPC Server Address against which to run rate limit checks when the rate limits are specified for a given tenant. If not specified, local, non-shared rate limiting will be used. Has precedence over other rate limiter options. + -middleware.rate-limiter.type string + The type of rate limiter to use when not using a gRPC rate limiter. Options: 'local' (default), 'redis' (leaky bucket algorithm). (default "local") -rbac.config string Path to the RBAC configuration file. (default "rbac.yaml") -server.read-header-timeout duration diff --git a/go.mod b/go.mod index cf2778c56..713b59c37 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/prometheus/client_golang v1.18.0 github.com/prometheus/common v0.48.0 github.com/prometheus/prometheus v0.50.1 + github.com/redis/rueidis v1.0.31 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.48.0 go.opentelemetry.io/contrib/propagators/jaeger v1.24.0 go.opentelemetry.io/otel v1.24.0 diff --git a/go.sum b/go.sum index d67dc6909..e1189fa26 100644 --- a/go.sum +++ b/go.sum @@ -462,7 +462,7 @@ github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DV github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.31.1 h1:KYppCUK+bUgAZwHOu7EXVBKyQA6ILvOESHkn/tgoqvo= github.com/open-policy-agent/opa v0.61.0 h1:nhncQ2CAYtQTV/SMBhDDPsCpCQsUW+zO/1j+T5V7oZg= github.com/open-policy-agent/opa v0.61.0/go.mod h1:7OUuzJnsS9yHf8lw0ApfcbrnaRG1EkN3J2fuuqi4G/E= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -522,6 +522,8 @@ github.com/prometheus/prometheus v0.50.1 h1:N2L+DYrxqPh4WZStU+o1p/gQlBaqFbcLBTjl github.com/prometheus/prometheus v0.50.1/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/rueidis v1.0.31 h1:S2NlrMB1N+yB+QEKD4o0lV+5GNIeLo/ZMpN42ONcwg0= +github.com/redis/rueidis v1.0.31/go.mod h1:g8nPmgR4C68N3abFiOc/gUOSEKw3Tom6/teYMehg4RE= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= diff --git a/main.go b/main.go index df90652ad..1a1d63cc2 100644 --- a/main.go +++ b/main.go @@ -190,7 +190,9 @@ type tracesConfig struct { } type middlewareConfig struct { - rateLimiterAddress string + grpcRateLimiterAddress string + rateLimiterType string + rateLimiterAddress multiStringFlag concurrentRequestLimit int backLogLimitConcurrentRequests int backLogDurationConcurrentRequests time.Duration @@ -434,14 +436,21 @@ func main() { defer undo() - var rateLimitClient *ratelimit.Client + var rateLimitClient ratelimit.SharedRateLimiter - if cfg.middleware.rateLimiterAddress != "" { + switch { + case cfg.middleware.grpcRateLimiterAddress != "": ctx, cancel := context.WithTimeout(context.Background(), grpcDialTimeout) defer cancel() - rateLimitClient = ratelimit.NewClient(reg) - if err := rateLimitClient.Dial(ctx, cfg.middleware.rateLimiterAddress); err != nil { + grpcRateLimiter := ratelimit.NewClient(reg) + if err := grpcRateLimiter.Dial(ctx, cfg.middleware.grpcRateLimiterAddress); err != nil { + stdlog.Fatal(err) + } + rateLimitClient = grpcRateLimiter + case cfg.middleware.rateLimiterType == "redis": + rateLimitClient, err = ratelimit.NewRedisRateLimiter([]string(cfg.middleware.rateLimiterAddress)) + if err != nil { stdlog.Fatal(err) } } @@ -1012,6 +1021,20 @@ func (d *duration) UnmarshalJSON(b []byte) error { } } +// multiStringFlag is a type that implements the flag.Value interface. +type multiStringFlag []string + +// Set appends a value to the slice. +func (m *multiStringFlag) Set(value string) error { + *m = append(*m, value) + return nil +} + +// String returns a string representation of the slice. +func (m *multiStringFlag) String() string { + return strings.Join(*m, ", ") +} + //nolint:funlen,gocognit func parseFlags() (config, error) { var ( @@ -1157,9 +1180,14 @@ func parseFlags() (config, error) { "Policy for TLS client-side authentication. Values are from ClientAuthType constants in https://pkg.go.dev/crypto/tls#ClientAuthType") flag.DurationVar(&cfg.tls.reloadInterval, "tls.reload-interval", time.Minute, "The interval at which to watch for TLS certificate changes.") - flag.StringVar(&cfg.middleware.rateLimiterAddress, "middleware.rate-limiter.grpc-address", "", + flag.StringVar(&cfg.middleware.grpcRateLimiterAddress, "middleware.rate-limiter.grpc-address", "", "The gRPC Server Address against which to run rate limit checks when the rate limits are specified for a given tenant."+ - " If not specified, local, non-shared rate limiting will be used.") + " If not specified, local, non-shared rate limiting will be used. Has precedence over other rate limiter options.") + flag.StringVar(&cfg.middleware.rateLimiterType, "middleware.rate-limiter.type", "local", + "The type of rate limiter to use when not using a gRPC rate limiter. Options: 'local' (default), 'redis' (leaky bucket algorithm).") + flag.Var(&cfg.middleware.rateLimiterAddress, "middleware.rate-limiter.address", + "The address of the rate limiter. Only used when not using the gRPC nor \"local\" rate limiters. "+ + "Can be repeated to specify multiple addresses (i.e. Redis Cluster).") flag.IntVar(&cfg.middleware.concurrentRequestLimit, "middleware.concurrent-request-limit", 10_000, "The limit that controls the number of concurrently processed requests across all tenants.") flag.IntVar(&cfg.middleware.backLogLimitConcurrentRequests, "middleware.backlog-limit-concurrent-requests", 0, diff --git a/ratelimit/client.go b/ratelimit/client.go index dde02c6c7..37bc902ef 100644 --- a/ratelimit/client.go +++ b/ratelimit/client.go @@ -18,9 +18,10 @@ import ( var errOverLimit = errors.New("over limit") type request struct { - name string - key string - limit int64 + name string + key string + limit int64 + // duration is the duration of the rate limit window in milliseconds. duration int64 failOpen bool retryAfterMin time.Duration @@ -34,6 +35,9 @@ type Client struct { } type SharedRateLimiter interface { + // GetRateLimits retrieves the rate limits for a given request. + // It returns the remaining requests, the reset time as Unix time (millisecond from epoch), and any error that occurred. + // When a rate limit is exceeded, the error errOverLimit is returned. GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error) } diff --git a/ratelimit/gcra_rate_limit.lua b/ratelimit/gcra_rate_limit.lua new file mode 100644 index 000000000..fb4749f5c --- /dev/null +++ b/ratelimit/gcra_rate_limit.lua @@ -0,0 +1,66 @@ +-- this script has side-effects, so it requires replicate commands mode +redis.replicate_commands() + +local rate_limit_key = KEYS[1] -- The key to the rate limit bucket. +local now = ARGV[1] -- Current time (Unix time in milliseconds). +local burst = ARGV[2] -- This represents the total capacity of the bucket. +local rate = ARGV[3] -- This represents the amount that leaks from the bucket. +local period = ARGV[4] -- This represents how often the "rate" leaks from the bucket (in milliseconds). +local cost = ARGV[5] -- This represents the cost of the request. Often 1 is used per request. + -- It allows some requests to be assigned a higher cost. + +local emission_interval = period / rate +local increment = emission_interval * cost +local burst_offset = emission_interval * burst + +local tat = redis.call("GET", rate_limit_key) + +if not tat then + tat = now +else + tat = tonumber(tat) +end +tat = math.max(tat, now) + +local new_tat = tat + increment +local allow_at = new_tat - burst_offset +local diff = now - allow_at + +local limited +local retry_in +local reset_in + +local remaining = math.floor(diff / emission_interval) -- poor man's round + +if remaining < 0 then + limited = 1 + -- calculate how many tokens there actually are, since + -- remaining is how many there would have been if we had been able to limit + -- and we did not limit + remaining = math.floor((now - (tat - burst_offset)) / emission_interval) + reset_in = math.ceil(tat - now) + retry_in = math.ceil(diff * -1) +elseif remaining == 0 and increment <= 0 then + -- request with cost of 0 + -- cost of 0 with remaining 0 is still limited + limited = 1 + remaining = 0 + reset_in = math.ceil(tat - now) + retry_in = 0 -- retry in is meaningless when cost is 0 +else + limited = 0 + reset_in = math.ceil(new_tat - now) + retry_in = 0 + if increment > 0 then + redis.call("SET", rate_limit_key, new_tat, "PX", reset_in) + end +end + +-- return values (in order): +-- limited = integer-encoded boolean, 1 if limited, 0 if not +-- remaining = number of tokens remaining +-- retry_in = milliseconds until the next request will be allowed +-- reset_in = milliseconds until the rate limit window resets +-- diff = milliseconds since the last request +-- emission_interval = milliseconds between token emissions +return {limited, remaining, retry_in, reset_in, tostring(diff), tostring(emission_interval)} diff --git a/ratelimit/redis.go b/ratelimit/redis.go new file mode 100644 index 000000000..93289f3e8 --- /dev/null +++ b/ratelimit/redis.go @@ -0,0 +1,82 @@ +package ratelimit + +import ( + "context" + _ "embed" + "strconv" + "time" + + "github.com/redis/rueidis" +) + +//go:embed gcra_rate_limit.lua +var gcraRateLimitScript string + +// RedisRateLimiter is a type that represents a rate limiter that uses Redis as its backend. +// The rate limiting is a leaky bucket implementation using the generic cell rate algorithm. +// See https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm for details on how this algorithm works. +type RedisRateLimiter struct { + client rueidis.Client +} + +// Ensure RedisRateLimiter implements the SharedRateLimiter interface. +var _ SharedRateLimiter = (*RedisRateLimiter)(nil) + +// NewRedisRateLimiter creates a new instance of RedisRateLimiter. +func NewRedisRateLimiter(addresses []string) (*RedisRateLimiter, error) { + client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: addresses}) + if err != nil { + return nil, err + } + return &RedisRateLimiter{client: client}, nil +} + +// GetRateLimits retrieves the rate limits for a given request using a Redis Rate Limiter. +// It returns the amount of remaining requests, the reset time in milliseconds, and any error that occurred. +func (r *RedisRateLimiter) GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error) { + inspectScript := rueidis.NewLuaScript(gcraRateLimitScript) + rateLimitParameters := []string{ + strconv.FormatInt(time.Now().UnixMilli(), 10), // now + strconv.FormatInt(req.limit, 10), // burst + strconv.FormatInt(req.limit, 10), // rate + strconv.FormatInt(req.duration, 10), // period + "1", // cost + } + result := inspectScript.Exec(ctx, r.client, []string{req.key}, rateLimitParameters) + limited, remaining, resetIn, err := r.parseRateLimitResult(&result) + if err != nil { + return 0, 0, err + } + resetTime = time.Now().Add(time.Duration(resetIn) * time.Millisecond).UnixMilli() + if limited { + return remaining, resetTime, errOverLimit + } + return remaining, resetTime, nil +} + +// parseRateLimitResult parses the result of a rate limit check from Redis. +// It takes a RedisResult as input and returns the parsed rate limit values: whether the request is limited, +// the number of remaining requests, the reset time in milliseconds, and any error that occurred during parsing. +func (r *RedisRateLimiter) parseRateLimitResult(result *rueidis.RedisResult) (limited bool, remaining, resetIn int64, err error) { + values, err := result.ToArray() + if err != nil { + return false, 0, 0, err + } + + limited, err = values[0].AsBool() + if err != nil { + return false, 0, 0, err + } + + remaining, err = values[1].AsInt64() + if err != nil { + return false, 0, 0, err + } + + resetIn, err = values[3].AsInt64() + if err != nil { + return false, 0, 0, err + } + + return limited, remaining, resetIn, nil +} diff --git a/ratelimit/redis_test.go b/ratelimit/redis_test.go new file mode 100644 index 000000000..b7f3b6d89 --- /dev/null +++ b/ratelimit/redis_test.go @@ -0,0 +1,164 @@ +package ratelimit + +import ( + "context" + "testing" + "time" + + "github.com/efficientgo/core/backoff" + "github.com/efficientgo/core/testutil" + "github.com/efficientgo/e2e" +) + +func TestRedisRateLimiter_GetRateLimits(t *testing.T) { + t.Parallel() + // Start isolated environment with given ref. + e, err := e2e.New(e2e.WithName("redis-rate-li")) + testutil.Ok(t, err) + t.Cleanup(e.Close) + + redis := createRedisContainer(e) + t.Cleanup(func() { _ = redis.Stop() }) + err = e2e.StartAndWaitReady(redis) + testutil.Ok(t, err) + + type args struct { + ctx context.Context + req *request + } + tests := []struct { + name string + args args + totalHits int64 + wantRemaining int64 + // wantResetTimeFunc is used to calculate the expected reset time just before the hits are sent to the rate limiter. + wantResetTimeFunc func() time.Time + wantErr error + // waitBeforeLastHit is used to wait the given amount of time and then make a last hit on the rate limiter. + waitBeforeLastHit time.Duration + }{ + { + name: "Single hit, far from limit", + args: args{ + ctx: context.Background(), + req: &request{ + key: "single-hit", + limit: 10, + duration: (10 * time.Second).Milliseconds(), + }, + }, + totalHits: 1, + wantRemaining: 9, + wantResetTimeFunc: func() time.Time { + return time.Now().Add(1 * time.Second) + }, + }, + { + name: "At the edge of the limit", + args: args{ + ctx: context.Background(), + req: &request{ + key: "edge-hit", + limit: 10, + duration: (10 * time.Second).Milliseconds(), + }, + }, + totalHits: 10, + wantRemaining: 0, + wantResetTimeFunc: func() time.Time { + return time.Now().Add(10 * time.Second) + }, + }, + { + name: "Beyond the limit", + args: args{ + ctx: context.Background(), + req: &request{ + key: "beyond-limit", + limit: 10, + duration: (10 * time.Second).Milliseconds(), + }, + }, + totalHits: 11, + wantRemaining: 0, + wantErr: errOverLimit, + wantResetTimeFunc: func() time.Time { + return time.Now().Add(10 * time.Second) + }, + }, + { + // The test scenario is: + // 1. Hit the rate limiter 2 times. No big amount of time should pass between the hits. + // This ensures the bucket doesn't leak. + // 2. Wait for 2 seconds. This means the bucket will leak 2 tokens. + // 3. Hit the rate limiter 1 time. This should succeed. + // If the bucket didn't leak, this would get total remaining of 7. + // The reset time should be 3 seconds from the first hit. + name: "Wait for 1 leak", + args: args{ + ctx: context.Background(), + req: &request{ + key: "wait-for-leak", + limit: 10, + duration: (10 * time.Second).Milliseconds(), + }, + }, + totalHits: 2, + // Waits for 2 seconds instead of 1 because of rounding in the algorithm. + waitBeforeLastHit: 2 * time.Second, + wantRemaining: 9, + wantResetTimeFunc: func() time.Time { + return time.Now().Add(3 * time.Second) + }, + }, + } + + for _, tt := range tests { + tt := tt // Can be removed when Go version >= 1.22 is set in the go.mod file. + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + b := backoff.New(context.Background(), backoff.Config{ + Min: 100 * time.Millisecond, + Max: 1 * time.Second, + MaxRetries: 5, + }) + + var ( + err error + r *RedisRateLimiter + ) + for b.Reset(); b.Ongoing(); b.Wait() { + r, err = NewRedisRateLimiter([]string{redis.Endpoint("http")}) + } + testutil.Ok(t, err) + testutil.Assert(t, r != nil) + + var gotRemaining, gotResetTime int64 + wantResetTime := tt.wantResetTimeFunc() + for i := int64(0); i < tt.totalHits; i++ { + gotRemaining, gotResetTime, err = r.GetRateLimits(tt.args.ctx, tt.args.req) + } + if tt.waitBeforeLastHit > 0 { + time.Sleep(tt.waitBeforeLastHit) + gotRemaining, gotResetTime, err = r.GetRateLimits(tt.args.ctx, tt.args.req) + } + + testutil.Equals(t, tt.wantErr, err) + testutil.Equals(t, tt.wantRemaining, gotRemaining) + + parsedGotResetTime := time.UnixMilli(gotResetTime) + timeDifference := parsedGotResetTime.Sub(wantResetTime).Seconds() + + testutil.Assert(t, -1 <= timeDifference && timeDifference <= 1, "gotResetTime should be within 1 second of wantResetTime, it was %f seconds off", timeDifference) + }) + } +} + +func createRedisContainer(env e2e.Environment) e2e.Runnable { + return env.Runnable("redis").WithPorts(map[string]int{"http": 6379}).Init( + e2e.StartOptions{ + Image: "redis", + Readiness: e2e.NewCmdReadinessProbe(e2e.Command{Cmd: "redis-cli", Args: []string{"ping"}}), + }, + ) +}