diff --git a/v1/rate_limit.go b/v1/rate_limit.go index 455443c..e3f3f54 100644 --- a/v1/rate_limit.go +++ b/v1/rate_limit.go @@ -1,100 +1,135 @@ package v1 import ( + "hash/fnv" "runtime" "sync" "sync/atomic" "time" ) -// NoopLimiter implements Limiter but doesn't limit anything. +// Use double the CPU count for sharding +const shardsPerCoreMultiplier = 2 + var NoopLimiter Limiter = &noopLimiter{} type token struct { - rps atomic.Uint32 - lastUse atomic.Value + rps uint32 + lastUse int64 // Unix timestamp in nanoseconds } -// Limiter implements some form of rate limiting. +// Limiter interface for rate-limiting. type Limiter interface { - // Obtain the right to send a request. Should lock the execution if current goroutine needs to wait. - Obtain(string) + Obtain(id string) } -// TokensBucket implements basic Limiter with fixed window and fixed amount of tokens per window. +// TokensBucket implements a sharded rate limiter with fixed window and tokens. type TokensBucket struct { maxRPS uint32 - tokens sync.Map - unusedTokenTime time.Duration + unusedTokenTime int64 // in nanoseconds checkTokenTime time.Duration + shards []*tokenShard + shardCount uint32 cancel atomic.Bool sleep sleeper } -// NewTokensBucket constructs TokensBucket with provided parameters. +type tokenShard struct { + tokens map[string]*token + mu sync.Mutex +} + +// NewTokensBucket creates a sharded token bucket limiter. func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) Limiter { + shardCount := uint32(runtime.NumCPU() * shardsPerCoreMultiplier) + shards := make([]*tokenShard, shardCount) + for i := range shards { + shards[i] = &tokenShard{tokens: make(map[string]*token)} + } + bucket := &TokensBucket{ maxRPS: maxRPS, - unusedTokenTime: unusedTokenTime, + unusedTokenTime: unusedTokenTime.Nanoseconds(), checkTokenTime: checkTokenTime, + shards: shards, + shardCount: shardCount, sleep: realSleeper{}, } - go bucket.deleteUnusedToken() - runtime.SetFinalizer(bucket, destructBasket) + go bucket.cleanupRoutine() + runtime.SetFinalizer(bucket, destructBucket) return bucket } +// Obtain request hit. Will throttle RPS. func (m *TokensBucket) Obtain(id string) { - val, ok := m.tokens.Load(id) - if !ok { - token := &token{} - token.lastUse.Store(time.Now()) - token.rps.Store(1) - m.tokens.Store(id, token) + shard := m.getShard(id) + + shard.mu.Lock() + defer shard.mu.Unlock() + + item, exists := shard.tokens[id] + now := time.Now().UnixNano() + + if !exists { + shard.tokens[id] = &token{ + rps: 1, + lastUse: now, + } return } - token := val.(*token) - sleepTime := time.Second - time.Since(token.lastUse.Load().(time.Time)) + sleepTime := int64(time.Second) - (now - item.lastUse) if sleepTime <= 0 { - token.lastUse.Store(time.Now()) - token.rps.Store(0) - } else if token.rps.Load() >= m.maxRPS { - m.sleep.Sleep(sleepTime) - token.lastUse.Store(time.Now()) - token.rps.Store(0) + item.lastUse = now + atomic.StoreUint32(&item.rps, 1) + } else if atomic.LoadUint32(&item.rps) >= m.maxRPS { + m.sleep.Sleep(time.Duration(sleepTime)) + item.lastUse = time.Now().UnixNano() + atomic.StoreUint32(&item.rps, 1) + } else { + atomic.AddUint32(&item.rps, 1) } - token.rps.Add(1) } -func destructBasket(m *TokensBucket) { - m.cancel.Store(true) +func (m *TokensBucket) getShard(id string) *tokenShard { + hash := fnv.New32a() + _, _ = hash.Write([]byte(id)) + return m.shards[hash.Sum32()%m.shardCount] } -func (m *TokensBucket) deleteUnusedToken() { - for { - if m.cancel.Load() { - return - } +func (m *TokensBucket) cleanupRoutine() { + ticker := time.NewTicker(m.checkTokenTime) + defer ticker.Stop() - m.tokens.Range(func(key, value any) bool { - id, token := key.(string), value.(*token) - if time.Since(token.lastUse.Load().(time.Time)) >= m.unusedTokenTime { - m.tokens.Delete(id) + for { + select { + case <-ticker.C: + if m.cancel.Load() { + return } - return false - }) - - m.sleep.Sleep(m.checkTokenTime) + now := time.Now().UnixNano() + for _, shard := range m.shards { + shard.mu.Lock() + for id, token := range shard.tokens { + if now-token.lastUse >= m.unusedTokenTime { + delete(shard.tokens, id) + } + } + shard.mu.Unlock() + } + } } } +func destructBucket(m *TokensBucket) { + m.cancel.Store(true) +} + type noopLimiter struct{} func (l *noopLimiter) Obtain(string) {} -// sleeper sleeps. This thing is necessary for tests. type sleeper interface { Sleep(time.Duration) } diff --git a/v1/rate_limit_test.go b/v1/rate_limit_test.go index 7e9ab72..7a0710c 100644 --- a/v1/rate_limit_test.go +++ b/v1/rate_limit_test.go @@ -24,13 +24,22 @@ func (t *TokensBucketTest) Test_NewTokensBucket() { func (t *TokensBucketTest) new( maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket { + shardCount := uint32(runtime.NumCPU() * 2) // Use double the CPU count for sharding + shards := make([]*tokenShard, shardCount) + for i := range shards { + shards[i] = &tokenShard{tokens: make(map[string]*token)} + } + bucket := &TokensBucket{ maxRPS: maxRPS, - unusedTokenTime: unusedTokenTime, + unusedTokenTime: unusedTokenTime.Nanoseconds(), checkTokenTime: checkTokenTime, + shards: shards, + shardCount: shardCount, sleep: sleeper, } - runtime.SetFinalizer(bucket, destructBasket) + + runtime.SetFinalizer(bucket, destructBucket) return bucket } @@ -46,12 +55,14 @@ func (t *TokensBucketTest) Test_Obtain_NoThrottle() { func (t *TokensBucketTest) Test_Obtain_Sleep() { clock := &fakeSleeper{} tb := t.new(100, time.Hour, time.Minute, clock) + _, exists := tb.getShard("w").tokens["w"] + t.Require().False(exists) var wg sync.WaitGroup wg.Add(1) go func() { for i := 0; i < 301; i++ { - tb.Obtain("a") + tb.Obtain("w") } wg.Done() }() @@ -63,15 +74,15 @@ func (t *TokensBucketTest) Test_Obtain_Sleep() { func (t *TokensBucketTest) Test_Obtain_AddRPS() { clock := clockwork.NewFakeClock() tb := t.new(100, time.Hour, time.Minute, clock) - go tb.deleteUnusedToken() + go tb.cleanupRoutine() tb.Obtain("a") clock.Advance(time.Minute * 2) - item, found := tb.tokens.Load("a") + item, found := tb.getShard("a").tokens["a"] t.Require().True(found) - t.Assert().Equal(1, int(item.(*token).rps.Load())) + t.Assert().Equal(1, int(item.rps)) tb.Obtain("a") - t.Assert().Equal(2, int(item.(*token).rps.Load())) + t.Assert().Equal(2, int(item.rps)) } type fakeSleeper struct {