-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package slidingwindowratelimiter | ||
|
||
import ( | ||
"github.com/go-redis/redis/v7" | ||
) | ||
|
||
var ( | ||
_ rediser = (*redis.Client)(nil) | ||
_ rediser = (*redis.Ring)(nil) | ||
_ rediser = (*redis.ClusterClient)(nil) | ||
) | ||
|
||
type rediser interface { | ||
EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd | ||
ScriptExists(hashes ...string) *redis.BoolSliceCmd | ||
ScriptLoad(script string) *redis.StringCmd | ||
} |
86 changes: 86 additions & 0 deletions
86
redispattern/slidingwindowratelimiter/sliding_window_rate_limiter.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package slidingwindowratelimiter | ||
|
||
import ( | ||
"crypto/sha1" | ||
"encoding/hex" | ||
"fmt" | ||
"io" | ||
"time" | ||
) | ||
|
||
const ( | ||
script = ` | ||
local key = KEYS[1] | ||
local now = tonumber(ARGV[1]) | ||
local window = tonumber(ARGV[2]) | ||
local limit = tonumber(ARGV[3]) | ||
local pivot = now - window | ||
redis.call('ZREMRANGEBYSCORE', key, 0, pivot) | ||
local count = redis.call('ZCARD', key) | ||
if count < limit then | ||
redis.call('ZADD', key, now, now) | ||
end | ||
redis.call('EXPIRE', key, window / 1000000000) | ||
return limit - count | ||
` | ||
) | ||
|
||
func scriptDigest() (string, error) { | ||
s := sha1.New() | ||
_, err := io.WriteString(s, script) | ||
if err != nil { | ||
return "", err | ||
} | ||
return hex.EncodeToString(s.Sum(nil)), nil | ||
} | ||
|
||
// SlidingWindowRateLimiter represents a sliding window rate limiter. | ||
type SlidingWindowRateLimiter struct { | ||
redis rediser | ||
key string | ||
window time.Duration | ||
limit int64 | ||
} | ||
|
||
// New generates a rate limiter. | ||
func New(redis rediser, key string, window time.Duration, limit int64) (*SlidingWindowRateLimiter, error) { | ||
sl := &SlidingWindowRateLimiter{ | ||
redis: redis, | ||
key: key, | ||
window: window, | ||
limit: limit, | ||
} | ||
return sl, nil | ||
} | ||
|
||
// Allow returns the sliding window rate limiter status. | ||
func (sl *SlidingWindowRateLimiter) Allow() (bool, error) { | ||
digest, err := scriptDigest() | ||
if err != nil { | ||
return false, err | ||
} | ||
exist, err := sl.redis.ScriptExists(digest).Result() | ||
if err != nil { | ||
return false, err | ||
} | ||
if !exist[0] { | ||
_, err := sl.redis.ScriptLoad(script).Result() | ||
if err != nil { | ||
return false, err | ||
} | ||
} | ||
ret, err := sl.redis.EvalSha(digest, []string{sl.key}, time.Now().UnixNano(), sl.window.Nanoseconds(), sl.limit).Result() | ||
if err != nil { | ||
return false, err | ||
} | ||
switch v := ret.(type) { | ||
case int64: | ||
return v > 0, nil | ||
default: | ||
return false, fmt.Errorf("sliding window rate limiter err: %#v, key = %s, window = %s, limit = %d", ret, sl.key, sl.window, sl.limit) | ||
} | ||
} |