diff --git a/rueidislimiter/README.md b/rueidislimiter/README.md index 56643169..686c4c3d 100644 --- a/rueidislimiter/README.md +++ b/rueidislimiter/README.md @@ -112,8 +112,7 @@ Returns a `Result` struct: - `Allowed`: Whether the request is allowed. - `Remaining`: Number of remaining requests in the current window. -- `RetryAfter`: Duration until the next allowed request (0 if allowed). -- `ResetAfter`: Duration until the current rate limit window resets. +- `ResetAt`: Unix timestamp at which the rate limit will reset. #### `Allow` diff --git a/rueidislimiter/limiter.go b/rueidislimiter/limiter.go index c90cbdce..9c605a5f 100644 --- a/rueidislimiter/limiter.go +++ b/rueidislimiter/limiter.go @@ -15,10 +15,8 @@ type Result struct { Allowed bool // Remaining is the number of remaining tokens in the current window Remaining int64 - // ResetAfter is the duration until the rate limit resets - ResetAfter time.Duration - // RetryAfter is the duration after which the request may be retried (0 if allowed) - RetryAfter time.Duration + // ResetAt is the Unix timestamp at which the rate limit will reset + ResetAt int64 } type RateLimiterClient interface { @@ -149,21 +147,10 @@ func (l *rateLimiter) parseResult(array []rueidis.RedisMessage, now time.Time) ( remaining = 0 } - allowed := current < int64(l.limit) - resetAfter := time.Until(time.Unix(expiresAt, 0)) - if resetAfter <= 0 { - resetAfter = l.window - } - retryAfter := resetAfter - if allowed { - retryAfter = 0 - } - return Result{ - Allowed: allowed, - Remaining: remaining, - ResetAfter: resetAfter, - RetryAfter: retryAfter, + Allowed: current <= int64(l.limit), + Remaining: remaining, + ResetAt: expiresAt, }, nil } diff --git a/rueidislimiter/limiter_test.go b/rueidislimiter/limiter_test.go index 204f79e9..91aa47bb 100644 --- a/rueidislimiter/limiter_test.go +++ b/rueidislimiter/limiter_test.go @@ -49,8 +49,8 @@ func TestRateLimiter(t *testing.T) { if err != nil { t.Fatal(err) } - if !result.Allowed || result.Remaining != 1 || result.RetryAfter != 0 { - t.Fatalf("Expected Allowed=true, Remaining=1, RetryAfter=0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + if !result.Allowed || result.Remaining != 1 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=true, Remaining=1, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) } }) @@ -59,21 +59,21 @@ func TestRateLimiter(t *testing.T) { if err != nil { t.Fatal(err) } - if !result.Allowed || result.Remaining != 3 || result.RetryAfter != 0 { - t.Fatalf("Expected Allowed=true, Remaining=3, RetryAfter=0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + if !result.Allowed || result.Remaining != 3 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=true, Remaining=3, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) } }) t.Run("Check denied after exceeding limit", func(t *testing.T) { key := randStr() - generateLoad(t, limiter, key, 3) + generateLoad(t, limiter, key, 4) result, err := limiter.Check(context.Background(), key) if err != nil { t.Fatal(err) } - if result.Allowed || result.Remaining != 0 || result.RetryAfter <= 0 { - t.Fatalf("Expected Allowed=false, Remaining=0, RetryAfter > 0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + if result.Allowed || result.Remaining != 0 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=false, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) } }) @@ -86,8 +86,26 @@ func TestRateLimiter(t *testing.T) { if err != nil { t.Fatal(err) } - if !result.Allowed || result.Remaining != 3 || result.RetryAfter != 0 { - t.Fatalf("Expected Allowed=true, Remaining=3, RetryAfter=0 after reset; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + if !result.Allowed || result.Remaining != 3 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=true, Remaining=3, ResetAt=0 after reset; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) + } + }) + + t.Run("AllowN defaults", func(t *testing.T) { + limiter, err := rueidislimiter.NewRateLimiter(rueidislimiter.RateLimiterOption{ + ClientBuilder: func(option rueidis.ClientOption) (rueidis.Client, error) { + return client, nil + }, + }) + if err != nil { + t.Fatal(err) + } + result, err := limiter.AllowN(context.Background(), randStr(), 1) + if err != nil { + t.Fatal(err) + } + if !result.Allowed || result.Remaining != 0 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=true, Remaining=0, ResetAt >= now; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) } }) @@ -97,8 +115,8 @@ func TestRateLimiter(t *testing.T) { if err != nil { t.Fatal(err) } - if !result.Allowed || result.Remaining != 2 || result.RetryAfter != 0 { - t.Fatalf("Expected Allowed=true, Remaining=2, RetryAfter=0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + if !result.Allowed || result.Remaining != 2 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=true, Remaining=2, ResetAt=0; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) } }) @@ -110,8 +128,8 @@ func TestRateLimiter(t *testing.T) { if err != nil { t.Fatal(err) } - if result.Allowed || result.Remaining != 0 || result.RetryAfter <= 0 { - t.Fatalf("Expected Allowed=false, Remaining=0, RetryAfter > 0; got Allowed=%v, Remaining=%v, RetryAfter=%v", result.Allowed, result.Remaining, result.RetryAfter) + if result.Allowed || result.Remaining != 0 || result.ResetAt < time.Now().Unix() { + t.Fatalf("Expected Allowed=false, Remaining=0, ResetAt > 0; got Allowed=%v, Remaining=%v, ResetAt=%v", result.Allowed, result.Remaining, result.ResetAt) } })