diff --git a/httprate.go b/httprate.go index 48f8fca..e83f7e8 100644 --- a/httprate.go +++ b/httprate.go @@ -1,6 +1,7 @@ package httprate import ( + "context" "net" "net/http" "strings" @@ -64,7 +65,27 @@ func WithLimitHandler(h http.HandlerFunc) Option { } } +// noContextLimitCounter wraps the context-less LimitCounter to implement the ContextLimitCounter interface. +// Exists to maintain compatiblity. +type noContextLimitCounter struct { + LimitCounter +} + +func (l *noContextLimitCounter) Increment(_ context.Context, key string, currentWindow time.Time) error { + return l.LimitCounter.Increment(key, currentWindow) +} + +func (l *noContextLimitCounter) Get(_ context.Context, key string, previousWindow, currentWindow time.Time) (int, int, error) { + return l.LimitCounter.Get(key, previousWindow, currentWindow) +} + func WithLimitCounter(c LimitCounter) Option { + return func(rl *rateLimiter) { + rl.limitCounter = &noContextLimitCounter{LimitCounter: c} + } +} + +func WithContextLimitCounter(c ContextLimitCounter) Option { return func(rl *rateLimiter) { rl.limitCounter = c } diff --git a/limiter.go b/limiter.go index a040270..bb50c26 100644 --- a/limiter.go +++ b/limiter.go @@ -1,6 +1,7 @@ package httprate import ( + "context" "fmt" "math" "net/http" @@ -15,6 +16,11 @@ type LimitCounter interface { Get(key string, previousWindow, currentWindow time.Time) (int, int, error) } +type ContextLimitCounter interface { + Increment(ctx context.Context, key string, currentWindow time.Time) error + Get(ctx context.Context, key string, previousWindow, currentWindow time.Time) (int, int, error) +} + func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter { return newRateLimiter(requestLimit, windowLength, options...) } @@ -58,24 +64,42 @@ func LimitCounterKey(key string, window time.Time) uint64 { return h.Sum64() } +// limitCounterWrap implements the LimitCounter interface without context. +// Calls ContextLimitCounter with context.Background(), exists to maintain compatibility. +type limitCounterWrap struct { + ContextLimitCounter +} + +func (l *limitCounterWrap) Increment(key string, currentWindow time.Time) error { + return l.ContextLimitCounter.Increment(context.Background(), key, currentWindow) +} + +func (l *limitCounterWrap) Get(key string, previousWindow, currentWindow time.Time) (int, int, error) { + return l.ContextLimitCounter.Get(context.Background(), key, previousWindow, currentWindow) +} + type rateLimiter struct { requestLimit int windowLength time.Duration keyFn KeyFunc - limitCounter LimitCounter + limitCounter ContextLimitCounter onRequestLimit http.HandlerFunc } func (r *rateLimiter) Counter() LimitCounter { + return &limitCounterWrap{ContextLimitCounter: r.limitCounter} +} + +func (r *rateLimiter) ContextCounter() ContextLimitCounter { return r.limitCounter } -func (r *rateLimiter) Status(key string) (bool, float64, error) { +func (r *rateLimiter) ContextStatus(ctx context.Context, key string) (bool, float64, error) { t := time.Now().UTC() currentWindow := t.Truncate(r.windowLength) previousWindow := currentWindow.Add(-r.windowLength) - currCount, prevCount, err := r.limitCounter.Get(key, currentWindow, previousWindow) + currCount, prevCount, err := r.limitCounter.Get(ctx, key, currentWindow, previousWindow) if err != nil { return false, 0, err } @@ -89,8 +113,14 @@ func (r *rateLimiter) Status(key string) (bool, float64, error) { return true, rate, nil } +func (r *rateLimiter) Status(key string) (bool, float64, error) { + return r.ContextStatus(context.Background(), key) +} + func (l *rateLimiter) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + key, err := l.keyFn(r) if err != nil { http.Error(w, err.Error(), http.StatusPreconditionRequired) @@ -103,7 +133,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", 0)) w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) - _, rate, err := l.Status(key) + _, rate, err := l.ContextStatus(ctx, key) if err != nil { http.Error(w, err.Error(), http.StatusPreconditionRequired) return @@ -120,7 +150,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { return } - err = l.limitCounter.Increment(key, currentWindow) + err = l.limitCounter.Increment(ctx, key, currentWindow) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -137,14 +167,17 @@ type localCounter struct { mu sync.Mutex } -var _ LimitCounter = &localCounter{} +var ( + _ LimitCounter = &limitCounterWrap{ContextLimitCounter: &localCounter{}} + _ ContextLimitCounter = &localCounter{} +) type count struct { value int updatedAt time.Time } -func (c *localCounter) Increment(key string, currentWindow time.Time) error { +func (c *localCounter) Increment(_ context.Context, key string, currentWindow time.Time) error { c.evict() c.mu.Lock() @@ -163,7 +196,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error { return nil } -func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) { +func (c *localCounter) Get(_ context.Context, key string, currentWindow, previousWindow time.Time) (int, int, error) { c.mu.Lock() defer c.mu.Unlock()