Skip to content

Commit

Permalink
chore: sliding window rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
zenghur committed Nov 25, 2020
1 parent 7176808 commit 22ada07
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
17 changes: 17 additions & 0 deletions redispattern/slidingwindowratelimiter/rediser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package slidingwindowratelimiter

import (
"github.com/go-redis/redis/v7"
)

var (
_ rediser = (*redis.Client)(nil)
_ rediser = (*redis.Ring)(nil)
_ rediser = (*redis.ClusterClient)(nil)
)

type rediser interface {
EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(hashes ...string) *redis.BoolSliceCmd
ScriptLoad(script string) *redis.StringCmd
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package slidingwindowratelimiter

import (
"crypto/sha1"
"encoding/hex"
"fmt"
"io"
"time"
)

const (
script = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local pivot = now - window
redis.call('ZREMRANGEBYSCORE', key, 0, pivot)
local count = redis.call('ZCARD', key)
if count < limit then
redis.call('ZADD', key, now, now)
end
redis.call('EXPIRE', key, window / 1000000000)
return limit - count
`
)

func scriptDigest() (string, error) {
s := sha1.New()
_, err := io.WriteString(s, script)
if err != nil {
return "", err
}
return hex.EncodeToString(s.Sum(nil)), nil
}

// SlidingWindowRateLimiter represents a sliding window rate limiter.
type SlidingWindowRateLimiter struct {
redis rediser
key string
window time.Duration
limit int64
}

// New generates a rate limiter.
func New(redis rediser, key string, window time.Duration, limit int64) (*SlidingWindowRateLimiter, error) {
sl := &SlidingWindowRateLimiter{
redis: redis,
key: key,
window: window,
limit: limit,
}
return sl, nil
}

// Allow returns the sliding window rate limiter status.
func (sl *SlidingWindowRateLimiter) Allow() (bool, error) {
digest, err := scriptDigest()
if err != nil {
return false, err
}
exist, err := sl.redis.ScriptExists(digest).Result()
if err != nil {
return false, err
}
if !exist[0] {
_, err := sl.redis.ScriptLoad(script).Result()
if err != nil {
return false, err
}
}
ret, err := sl.redis.EvalSha(digest, []string{sl.key}, time.Now().UnixNano(), sl.window.Nanoseconds(), sl.limit).Result()
if err != nil {
return false, err
}
switch v := ret.(type) {
case int64:
return v > 0, nil
default:
return false, fmt.Errorf("sliding window rate limiter err: %#v, key = %s, window = %s, limit = %d", ret, sl.key, sl.window, sl.limit)
}
}

0 comments on commit 22ada07

Please sign in to comment.