From ccecc0cb3700f6fef34ebab24e62c6eb85c4248a Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Fri, 23 Aug 2024 17:23:18 +0200 Subject: [PATCH] Introduce RespondOnLimit() vs. OnLimit() --- README.md | 2 +- _example/main.go | 2 +- limiter.go | 23 +++++++++++++++++------ limiter_test.go | 2 +- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1d8974f..fcb7eb3 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ r.Post("/login", func(w http.ResponseWriter, r *http.Request) { } // Rate-limit login at 5 req/min. - if loginRateLimiter.OnLimit(w, r, payload.Username) { + if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { return } diff --git a/_example/main.go b/_example/main.go index cf69e0a..8f51510 100644 --- a/_example/main.go +++ b/_example/main.go @@ -56,7 +56,7 @@ func main() { } // Rate-limit login at 5 req/min. - if loginRateLimiter.OnLimit(w, r, payload.Username) { + if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { return } diff --git a/limiter.go b/limiter.go index 0be05b8..dc69002 100644 --- a/limiter.go +++ b/limiter.go @@ -66,10 +66,10 @@ type RateLimiter struct { mu sync.Mutex } -// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true -// and automatically sends HTTP response. The caller should halt further request processing. -// If the limit is not reached, it increments the request count and returns false, allowing -// the request to proceed. +// OnLimit checks the rate limit for the given key and updates the response headers accordingly. +// If the limit is reached, it returns true, indicating that the request should be halted. Otherwise, +// it increments the request count and returns false. This method does not send an HTTP response, +// so the caller must handle the response themselves or use the RespondOnLimit() method instead. func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { currentWindow := time.Now().UTC().Truncate(l.windowLength) ctx := r.Context() @@ -100,7 +100,6 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string l.mu.Unlock() setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 - l.onRateLimited(w, r) return true } @@ -116,6 +115,18 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string return false } +// RespondOnLimit checks the rate limit for the given key and updates the response headers accordingly. +// If the limit is reached, it automatically sends an HTTP response and returns true, signaling the +// caller to halt further request processing. If the limit is not reached, it increments the request +// count and returns false, allowing the request to proceed. +func (l *RateLimiter) RespondOnLimit(w http.ResponseWriter, r *http.Request, key string) bool { + onLimit := l.OnLimit(w, r, key) + if onLimit { + l.onRateLimited(w, r) + } + return onLimit +} + func (l *RateLimiter) Counter() LimitCounter { return l.limitCounter } @@ -132,7 +143,7 @@ func (l *RateLimiter) Handler(next http.Handler) http.Handler { return } - if l.OnLimit(w, r, key) { + if l.RespondOnLimit(w, r, key) { return } diff --git a/limiter_test.go b/limiter_test.go index 689074a..5ac41c1 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -454,7 +454,7 @@ func TestRateLimitPayload(t *testing.T) { } // Rate-limit login at 5 req/min. - if loginRateLimiter.OnLimit(w, r, payload.Username) { + if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { return }