Skip to content

Commit

Permalink
feat: Introduce RateLimitOption to RateLimiterClient of rueidislimit…
Browse files Browse the repository at this point in the history
…er (#681)

RateLimitOption is currently a struct but in the future it can be converted to another type.
  • Loading branch information
altanozlu authored Nov 28, 2024
1 parent 06ed4fa commit c256ef2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
15 changes: 15 additions & 0 deletions rueidislimiter/limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package rueidislimiter

import "time"

type RateLimitOption struct {
limit int64
window time.Duration
}

func WithCustomRateLimit(limit int, window time.Duration) RateLimitOption {
return RateLimitOption{
limit: int64(limit),
window: window,
}
}
43 changes: 24 additions & 19 deletions rueidislimiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,17 @@ type Result struct {
}

type RateLimiterClient interface {
Check(ctx context.Context, identifier string) (Result, error)
Allow(ctx context.Context, identifier string) (Result, error)
AllowN(ctx context.Context, identifier string, n int64) (Result, error)
Check(ctx context.Context, identifier string, options ...RateLimitOption) (Result, error)
Allow(ctx context.Context, identifier string, options ...RateLimitOption) (Result, error)
AllowN(ctx context.Context, identifier string, n int64, options ...RateLimitOption) (Result, error)
}

const PlaceholderPrefix = "rueidislimiter"

type rateLimiter struct {
client rueidis.Client
keyPrefix string
limit int
window time.Duration
client rueidis.Client
keyPrefix string
defaultRateLimit RateLimitOption
}

type RateLimiterOption struct {
Expand All @@ -56,8 +55,10 @@ func NewRateLimiter(option RateLimiterOption) (RateLimiterClient, error) {
}

rl := &rateLimiter{
limit: option.Limit,
window: option.Window,
defaultRateLimit: RateLimitOption{
limit: int64(option.Limit),
window: option.Window,
},
}

var err error
Expand All @@ -74,27 +75,31 @@ func NewRateLimiter(option RateLimiterOption) (RateLimiterClient, error) {
}

func (l *rateLimiter) Limit() int {
return l.limit
return int(l.defaultRateLimit.limit)
}

func (l *rateLimiter) Check(ctx context.Context, identifier string) (Result, error) {
return l.AllowN(ctx, identifier, 0)
func (l *rateLimiter) Check(ctx context.Context, identifier string, options ...RateLimitOption) (Result, error) {
return l.AllowN(ctx, identifier, 0, options...)
}

func (l *rateLimiter) Allow(ctx context.Context, identifier string) (Result, error) {
return l.AllowN(ctx, identifier, 1)
func (l *rateLimiter) Allow(ctx context.Context, identifier string, options ...RateLimitOption) (Result, error) {
return l.AllowN(ctx, identifier, 1, options...)
}

func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64) (Result, error) {
func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64, options ...RateLimitOption) (Result, error) {
if n < 0 {
return Result{}, ErrInvalidTokens
}
rl := l.defaultRateLimit
if len(options) > 0 {
rl = options[len(options)-1]
}

now := time.Now().UTC()
keys := []string{l.getKey(identifier)}
args := []string{
strconv.FormatInt(n, 10),
strconv.FormatInt(now.Add(l.window).UnixMilli(), 10),
strconv.FormatInt(now.Add(rl.window).UnixMilli(), 10),
strconv.FormatInt(now.UnixMilli(), 10),
}

Expand All @@ -109,14 +114,14 @@ func (l *rateLimiter) AllowN(ctx context.Context, identifier string, n int64) (R
}

current := data[0]
remaining := int64(l.limit) - current
remaining := rl.limit - current
if remaining < 0 {
remaining = 0
}

allowed := current <= int64(l.limit)
allowed := current <= rl.limit
if n == 0 {
allowed = current < int64(l.limit)
allowed = current < rl.limit
}

return Result{
Expand Down
21 changes: 21 additions & 0 deletions rueidislimiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ func TestRateLimiter(t *testing.T) {
}
})

t.Run("Check allowed with limit option", func(t *testing.T) {
key := randStr()
generateLoad(t, limiter, key, 3)

result, err := limiter.Check(context.Background(), key)
if err != nil {
t.Fatal(err)
}
if result.Allowed {
t.Fatalf("Expected Allowed=false; got Allowed=%v", result.Allowed)
}

result, err = limiter.Check(context.Background(), key, rueidislimiter.WithCustomRateLimit(10, time.Millisecond*100))
if err != nil {
t.Fatal(err)
}
if !result.Allowed || result.Remaining != 7 || result.ResetAtMs < now.UnixMilli() {
t.Fatalf("Expected Allowed=true, Remaining=7, ResetAt >= now after reset; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAtMs)
}
})

t.Run("AllowN defaults", func(t *testing.T) {
limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{
ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) {
Expand Down

0 comments on commit c256ef2

Please sign in to comment.