From 1dc9d23d8493b6966968fcad42adbb16d6c83cca Mon Sep 17 00:00:00 2001 From: Tobias Sommer Date: Mon, 30 Sep 2024 11:12:48 +0200 Subject: [PATCH 1/2] feat: add rate limit unit multiplier Signed-off-by: Tobias Sommer --- README.md | 5 + src/config/config_impl.go | 36 +++++-- src/limiter/base_limiter.go | 5 +- src/limiter/cache_key.go | 3 +- src/memcached/cache_impl.go | 2 +- src/redis/fixed_cache_impl.go | 2 +- src/service/ratelimit.go | 5 +- src/utils/utilities.go | 18 +++- test/config/basic_config.yaml | 1 + test/config/config_test.go | 39 +++++++ test/config/unit_multiplier.yaml | 14 +++ test/config/zero_unit_multiplier.yaml | 8 ++ test/limiter/base_limiter_test.go | 34 ++++-- test/limiter/cache_key_test.go | 98 +++++++++++++++++ test/memcached/cache_impl_test.go | 122 ++++++++++++++------- test/redis/bench_test.go | 2 +- test/redis/fixed_cache_impl_test.go | 150 +++++++++++++++++--------- test/service/ratelimit_test.go | 24 ++--- 18 files changed, 440 insertions(+), 128 deletions(-) create mode 100644 test/config/unit_multiplier.yaml create mode 100644 test/config/zero_unit_multiplier.yaml create mode 100644 test/limiter/cache_key_test.go diff --git a/README.md b/README.md index fff18bae..167f6e39 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,7 @@ descriptors: - name: (optional) unit: requests_per_unit: + unit_multiplier: shadow_mode: (optional) detailed_metric: (optional) descriptors: (optional block) @@ -268,11 +269,15 @@ effectively whitelisted. Otherwise, nested descriptors allow more complex matchi rate_limit: unit: requests_per_unit: + unit_multiplier: ``` The rate limit block specifies the actual rate limit that will be used when there is a match. Currently the service supports per second, minute, hour, and day limits. More types of limits may be added in the future based on user demand. +The `unit_multiplier` allows for creating custom rate limit durations in combination with `unit`. +This allows for rate limit durations such as 30 seconds or 5 minutes. +A `unit_multiplier` of 0 is invalid and leaving out the field means the duration is equal to the unit (e.g. 1 minute). ### Replaces diff --git a/src/config/config_impl.go b/src/config/config_impl.go index 45c276b4..d13bd3ec 100644 --- a/src/config/config_impl.go +++ b/src/config/config_impl.go @@ -20,7 +20,8 @@ type yamlReplaces struct { type YamlRateLimit struct { RequestsPerUnit uint32 `yaml:"requests_per_unit"` Unit string - Unlimited bool `yaml:"unlimited"` + UnitMultiplier *uint32 `yaml:"unit_multiplier"` + Unlimited bool `yaml:"unlimited"` Name string Replaces []yamlReplaces } @@ -68,16 +69,18 @@ var validKeys = map[string]bool{ "name": true, "replaces": true, "detailed_metric": true, + "unit_multiplier": true, } // Create a new rate limit config entry. // @param requestsPerUnit supplies the requests per unit of time for the entry. // @param unit supplies the unit of time for the entry. +// @param unitMultiplier supplies the multiplier for the unit of time for the entry. // @param rlStats supplies the stats structure associated with the RateLimit // @param unlimited supplies whether the rate limit is unlimited // @return the new config entry. func NewRateLimit(requestsPerUnit uint32, unit pb.RateLimitResponse_RateLimit_Unit, rlStats stats.RateLimitStats, - unlimited bool, shadowMode bool, name string, replaces []string, detailedMetric bool, + unlimited bool, shadowMode bool, name string, replaces []string, detailedMetric bool, unitMultiplier uint32, ) *RateLimit { return &RateLimit{ FullKey: rlStats.GetKey(), @@ -86,6 +89,7 @@ func NewRateLimit(requestsPerUnit uint32, unit pb.RateLimitResponse_RateLimit_Un RequestsPerUnit: requestsPerUnit, Unit: unit, Name: name, + UnitMultiplier: unitMultiplier, }, Unlimited: unlimited, ShadowMode: shadowMode, @@ -100,8 +104,8 @@ func (this *rateLimitDescriptor) dump() string { ret := "" if this.limit != nil { ret += fmt.Sprintf( - "%s: unit=%s requests_per_unit=%d, shadow_mode: %t\n", this.limit.FullKey, - this.limit.Limit.Unit.String(), this.limit.Limit.RequestsPerUnit, this.limit.ShadowMode) + "%s: unit=%s, unit_multiplier=%d, requests_per_unit=%d, shadow_mode: %t\n", this.limit.FullKey, + this.limit.Limit.Unit.String(), this.limit.Limit.UnitMultiplier, this.limit.Limit.RequestsPerUnit, this.limit.ShadowMode) } for _, descriptor := range this.descriptors { ret += descriptor.dump() @@ -159,6 +163,18 @@ func (this *rateLimitDescriptor) loadDescriptors(config RateLimitConfigToLoad, p fmt.Sprintf("invalid rate limit unit '%s'", descriptorConfig.RateLimit.Unit))) } + var unitMultiplier uint32 + if descriptorConfig.RateLimit.UnitMultiplier == nil { + unitMultiplier = 1 + } else { + unitMultiplier = *descriptorConfig.RateLimit.UnitMultiplier + if unitMultiplier == 0 { + panic(newRateLimitConfigError( + config.Name, + "invalid unit multiplier of 0")) + } + } + replaces := make([]string, len(descriptorConfig.RateLimit.Replaces)) for i, e := range descriptorConfig.RateLimit.Replaces { replaces[i] = e.Name @@ -168,10 +184,12 @@ func (this *rateLimitDescriptor) loadDescriptors(config RateLimitConfigToLoad, p descriptorConfig.RateLimit.RequestsPerUnit, pb.RateLimitResponse_RateLimit_Unit(value), statsManager.NewStats(newParentKey), unlimited, descriptorConfig.ShadowMode, descriptorConfig.RateLimit.Name, replaces, descriptorConfig.DetailedMetric, + unitMultiplier, ) + rateLimitDebugString = fmt.Sprintf( - " ratelimit={requests_per_unit=%d, unit=%s, unlimited=%t, shadow_mode=%t}", rateLimit.Limit.RequestsPerUnit, - rateLimit.Limit.Unit.String(), rateLimit.Unlimited, rateLimit.ShadowMode) + " ratelimit={requests_per_unit=%d, unit=%s, unit_multiplier=%d, unlimited=%t, shadow_mode=%t}", rateLimit.Limit.RequestsPerUnit, + rateLimit.Limit.Unit.String(), unitMultiplier, rateLimit.Unlimited, rateLimit.ShadowMode) for _, replaces := range descriptorConfig.RateLimit.Replaces { if replaces.Name == "" { @@ -302,6 +320,7 @@ func (this *rateLimitConfigImpl) GetLimit( "", []string{}, false, + 1, ) return rateLimit } @@ -354,7 +373,10 @@ func (this *rateLimitConfigImpl) GetLimit( descriptorsMap = nextDescriptor.descriptors } else { if rateLimit != nil && rateLimit.DetailedMetric { - rateLimit = NewRateLimit(rateLimit.Limit.RequestsPerUnit, rateLimit.Limit.Unit, this.statsManager.NewStats(rateLimit.FullKey), rateLimit.Unlimited, rateLimit.ShadowMode, rateLimit.Name, rateLimit.Replaces, rateLimit.DetailedMetric) + rateLimit = NewRateLimit(rateLimit.Limit.RequestsPerUnit, rateLimit.Limit.Unit, + this.statsManager.NewStats(rateLimit.FullKey), rateLimit.Unlimited, + rateLimit.ShadowMode, rateLimit.Name, rateLimit.Replaces, + rateLimit.DetailedMetric, rateLimit.Limit.UnitMultiplier) } break diff --git a/src/limiter/base_limiter.go b/src/limiter/base_limiter.go index 18f0b83d..69d66960 100644 --- a/src/limiter/base_limiter.go +++ b/src/limiter/base_limiter.go @@ -116,7 +116,8 @@ func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo * // similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start // to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m). // In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited. - err := this.localCache.Set([]byte(key), []byte{}, int(utils.UnitToDivider(limitInfo.limit.Limit.Unit))) + + err := this.localCache.Set([]byte(key), []byte{}, int(utils.UnitToDividerWithMultiplier(limitInfo.limit.Limit.Unit, limitInfo.limit.Limit.UnitMultiplier))) if err != nil { logger.Errorf("Failing to set local cache key: %s", key) } @@ -205,7 +206,7 @@ func (this *BaseRateLimiter) generateResponseDescriptorStatus(responseCode pb.Ra Code: responseCode, CurrentLimit: limit, LimitRemaining: limitRemaining, - DurationUntilReset: utils.CalculateReset(&limit.Unit, this.timeSource), + DurationUntilReset: utils.CalculateReset(&limit.Unit, this.timeSource, limit.UnitMultiplier), } } else { return &pb.RateLimitResponse_DescriptorStatus{ diff --git a/src/limiter/cache_key.go b/src/limiter/cache_key.go index 7beadd54..d3778ae4 100644 --- a/src/limiter/cache_key.go +++ b/src/limiter/cache_key.go @@ -70,7 +70,8 @@ func (this *CacheKeyGenerator) GenerateCacheKey( b.WriteByte('_') } - divider := utils.UnitToDivider(limit.Limit.Unit) + divider := utils.UnitToDividerWithMultiplier(limit.Limit.Unit, limit.Limit.UnitMultiplier) + b.WriteString(strconv.FormatInt((now/divider)*divider, 10)) return CacheKey{ diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index d4f65f75..68ccfc55 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -160,7 +160,7 @@ func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, i _, err := this.client.Increment(cacheKey.Key, hitsAddend) if err == memcache.ErrCacheMiss { - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + expirationSeconds := utils.UnitToDividerWithMultiplier(limits[i].Limit.Unit, limits[i].Limit.UnitMultiplier) if this.expirationJitterMaxSeconds > 0 { expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) } diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index 8c551e0c..1ecde682 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -152,7 +152,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit( logger.Debugf("looking up cache key: %s", cacheKey.Key) - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + expirationSeconds := utils.UnitToDividerWithMultiplier(limits[i].Limit.Unit, limits[i].Limit.UnitMultiplier) if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) } diff --git a/src/service/ratelimit.go b/src/service/ratelimit.go index ea26bf0d..997da5ae 100644 --- a/src/service/ratelimit.go +++ b/src/service/ratelimit.go @@ -138,8 +138,9 @@ func (this *service) constructLimitsToCheck(request *pb.RateLimitRequest, ctx co logger.Debugf("descriptor is unlimited, not passing to the cache") } else { logger.Debugf( - "applying limit: %d requests per %s, shadow_mode: %t", + "applying limit: %d requests per %d %s, shadow_mode: %t", limitsToCheck[i].Limit.RequestsPerUnit, + limitsToCheck[i].Limit.UnitMultiplier, limitsToCheck[i].Limit.Unit.String(), limitsToCheck[i].ShadowMode, ) @@ -262,7 +263,7 @@ func (this *service) rateLimitResetHeader( ) *core.HeaderValue { return &core.HeaderValue{ Key: this.customHeaderResetHeader, - Value: strconv.FormatInt(utils.CalculateReset(&descriptor.CurrentLimit.Unit, this.customHeaderClock).GetSeconds(), 10), + Value: strconv.FormatInt(utils.CalculateReset(&descriptor.CurrentLimit.Unit, this.customHeaderClock, descriptor.CurrentLimit.UnitMultiplier).GetSeconds(), 10), } } diff --git a/src/utils/utilities.go b/src/utils/utilities.go index abe787c0..c1c57b8f 100644 --- a/src/utils/utilities.go +++ b/src/utils/utilities.go @@ -31,8 +31,22 @@ func UnitToDivider(unit pb.RateLimitResponse_RateLimit_Unit) int64 { panic("should not get here") } -func CalculateReset(unit *pb.RateLimitResponse_RateLimit_Unit, timeSource TimeSource) *durationpb.Duration { - sec := UnitToDivider(*unit) +// Convert a rate limit into a time divider reflecting the unit multiplier. +// @param unit supplies the unit to convert. +// @param unitMultiplier supplies the unit multiplier to scale the unit with. +// @return the scaled divider to use in time computations. +func UnitToDividerWithMultiplier(unit pb.RateLimitResponse_RateLimit_Unit, unitMultiplier uint32) int64 { + // TODO: Should this stay as a safe-guard? + // Already validated in rateLimitDescriptor::loadDescriptors, here redundantly to avoid the risk of multiplying with 0 + if unitMultiplier == 0 { + unitMultiplier = 1 + } + + return UnitToDivider(unit) * int64(unitMultiplier) +} + +func CalculateReset(unit *pb.RateLimitResponse_RateLimit_Unit, timeSource TimeSource, unitMultiplier uint32) *durationpb.Duration { + sec := UnitToDividerWithMultiplier(*unit, unitMultiplier) now := timeSource.UnixNow() return &durationpb.Duration{Seconds: sec - now%sec} } diff --git a/test/config/basic_config.yaml b/test/config/basic_config.yaml index 8992966f..61690436 100644 --- a/test/config/basic_config.yaml +++ b/test/config/basic_config.yaml @@ -69,3 +69,4 @@ descriptors: rate_limit: unit: minute requests_per_unit: 70 + unit_multiplier: 5 diff --git a/test/config/config_test.go b/test/config/config_test.go index 3c06b088..d8c62c76 100644 --- a/test/config/config_test.go +++ b/test/config/config_test.go @@ -952,3 +952,42 @@ func TestDetailedMetric(t *testing.T) { }) } } + +func TestUnitMultiplier(t *testing.T) { + assert := assert.New(t) + stats := stats.NewStore(stats.NewNullSink(), false) + + rlConfig := config.NewRateLimitConfigImpl(loadFile("unit_multiplier.yaml"), mockstats.NewMockStatManager(stats), false) + rlConfig.Dump() + + rl := rlConfig.GetLimit( + context.TODO(), "test-domain", + &pb_struct.RateLimitDescriptor{ + Entries: []*pb_struct.RateLimitDescriptor_Entry{{Key: "key1"}}, + }) + + assert.EqualValues(20, rl.Limit.RequestsPerUnit) + assert.Equal(pb.RateLimitResponse_RateLimit_MINUTE, rl.Limit.Unit) + assert.EqualValues(5, rl.Limit.UnitMultiplier) + + rl = rlConfig.GetLimit( + context.TODO(), "test-domain", + &pb_struct.RateLimitDescriptor{ + Entries: []*pb_struct.RateLimitDescriptor_Entry{{Key: "key2"}}, + }) + + assert.EqualValues(25, rl.Limit.RequestsPerUnit) + assert.Equal(pb.RateLimitResponse_RateLimit_MINUTE, rl.Limit.Unit) + assert.EqualValues(1, rl.Limit.UnitMultiplier) +} + +func TestZeroUnitMultiplier(t *testing.T) { + expectConfigPanic( + t, + func() { + config.NewRateLimitConfigImpl( + loadFile("zero_unit_multiplier.yaml"), + mockstats.NewMockStatManager(stats.NewStore(stats.NewNullSink(), false)), false) + }, + "zero_unit_multiplier.yaml: invalid unit multiplier of 0") +} diff --git a/test/config/unit_multiplier.yaml b/test/config/unit_multiplier.yaml new file mode 100644 index 00000000..983199a2 --- /dev/null +++ b/test/config/unit_multiplier.yaml @@ -0,0 +1,14 @@ +domain: test-domain +descriptors: + # some unit_multiplier + - key: key1 + rate_limit: + unit: minute + requests_per_unit: 20 + unit_multiplier: 5 + + # no unit_multiplier + - key: key2 + rate_limit: + unit: minute + requests_per_unit: 25 diff --git a/test/config/zero_unit_multiplier.yaml b/test/config/zero_unit_multiplier.yaml new file mode 100644 index 00000000..091ca3f2 --- /dev/null +++ b/test/config/zero_unit_multiplier.yaml @@ -0,0 +1,8 @@ +domain: test-domain +descriptors: + # 0 unit_multiplier + - key: key1 + rate_limit: + unit: second + requests_per_unit: 5 + unit_multiplier: 0 diff --git a/test/limiter/base_limiter_test.go b/test/limiter/base_limiter_test.go index b3babcee..43164e43 100644 --- a/test/limiter/base_limiter_test.go +++ b/test/limiter/base_limiter_test.go @@ -29,7 +29,7 @@ func TestGenerateCacheKeys(t *testing.T) { timeSource.EXPECT().UnixNow().Return(int64(1234)) baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) cacheKeys := baseRateLimit.GenerateCacheKeys(request, limits, 1) assert.Equal(1, len(cacheKeys)) @@ -48,7 +48,7 @@ func TestGenerateCacheKeysPrefix(t *testing.T) { timeSource.EXPECT().UnixNow().Return(int64(1234)) baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "prefix:", sm) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) cacheKeys := baseRateLimit.GenerateCacheKeys(request, limits, 1) assert.Equal(1, len(cacheKeys)) @@ -56,6 +56,24 @@ func TestGenerateCacheKeysPrefix(t *testing.T) { assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) } +func TestGenerateCacheKeysWithUnitModifier(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + sm := mockstats.NewMockStatManager(statsStore) + timeSource.EXPECT().UnixNow().Return(int64(1234)) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 30)} + assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) + cacheKeys := baseRateLimit.GenerateCacheKeys(request, limits, 1) + assert.Equal(1, len(cacheKeys)) + assert.Equal("domain_key_value_1230", cacheKeys[0].Key) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) +} + func TestOverLimitWithLocalCache(t *testing.T) { assert := assert.New(t) controller := gomock.NewController(t) @@ -102,7 +120,7 @@ func TestGetResponseStatusOverLimitWithLocalCache(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) - limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 4, 5) // As `isOverLimitWithLocalCache` is passed as `true`, immediate response is returned with no checks of the limits. responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, true, 2) @@ -125,7 +143,7 @@ func TestGetResponseStatusOverLimitWithLocalCacheShadowMode(t *testing.T) { sm := mockstats.NewMockStatManager(statsStore) baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) // This limit is in ShadowMode - limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, "", nil, false, 1)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 4, 5) // As `isOverLimitWithLocalCache` is passed as `true`, immediate response is returned with no checks of the limits. responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, true, 2) @@ -149,7 +167,7 @@ func TestGetResponseStatusOverLimit(t *testing.T) { localCache := freecache.NewCache(100) sm := mockstats.NewMockStatManager(statsStore) baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm) - limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) assert.Equal(pb.RateLimitResponse_OVER_LIMIT, responseStatus.GetCode()) @@ -175,7 +193,7 @@ func TestGetResponseStatusOverLimitShadowMode(t *testing.T) { sm := mockstats.NewMockStatManager(statsStore) baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm) // Key is in shadow_mode: true - limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, "", nil, false, 1)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) @@ -197,7 +215,7 @@ func TestGetResponseStatusBelowLimit(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 9, 10) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) @@ -218,7 +236,7 @@ func TestGetResponseStatusBelowLimitShadowMode(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, "", nil, false, 1)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 9, 10) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) diff --git a/test/limiter/cache_key_test.go b/test/limiter/cache_key_test.go new file mode 100644 index 00000000..56eb523b --- /dev/null +++ b/test/limiter/cache_key_test.go @@ -0,0 +1,98 @@ +package limiter + +import ( + "strconv" + "strings" + "testing" + "time" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/golang/mock/gomock" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" + + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/test/common" + mockstats "github.com/envoyproxy/ratelimit/test/mocks/stats" +) + +func TestCacheKeyGenerator(t *testing.T) { + cacheKeyGenerator := limiter.NewCacheKeyGenerator("") + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + statsStore := stats.NewStore(stats.NewNullSink(), false) + sm := mockstats.NewMockStatManager(statsStore) + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + timestamp := time.Date(2024, 5, 4, 12, 30, 15, 30, time.UTC) + + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1) + cacheKey := cacheKeyGenerator.GenerateCacheKey("domain", request.Descriptors[0], limit, timestamp.Unix()) + assert.Equal("domain_key_value_1714825815", cacheKey.Key) // Rounded down to the nearest seconds (2024-05-04 12:30:15) + assert.True(cacheKey.PerSecond) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key_value"), false, false, "", nil, false, 1) + cacheKey = cacheKeyGenerator.GenerateCacheKey("domain", request.Descriptors[0], limit, timestamp.Unix()) + assert.Equal("domain_key_value_1714825800", cacheKey.Key) // Rounded down to the nearest minute (2024-05-04 12:30:00) + assert.False(cacheKey.PerSecond) + + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key_value"), false, false, "", nil, false, 1) + cacheKey = cacheKeyGenerator.GenerateCacheKey("domain", request.Descriptors[0], limit, timestamp.Unix()) + assert.Equal("domain_key_value_1714824000", cacheKey.Key) // Rounded down to the nearest hour (2024-05-04 12:00:00) + assert.False(cacheKey.PerSecond) + + timestamp = time.Date(2024, 5, 4, 12, 59, 59, 30, time.UTC) + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key_value"), false, false, "", nil, false, 1) + cacheKey = cacheKeyGenerator.GenerateCacheKey("domain", request.Descriptors[0], limit, timestamp.Unix()) + assert.Equal("domain_key_value_1714824000", cacheKey.Key) // Also rounded down to 2024-05-04 12:00:00 + assert.False(cacheKey.PerSecond) +} + +func TestCacheKeyGeneratorWithUnitMultiplier(t *testing.T) { + cacheKeyGenerator := limiter.NewCacheKeyGenerator("") + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + statsStore := stats.NewStore(stats.NewNullSink(), false) + sm := mockstats.NewMockStatManager(statsStore) + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + + // Rounding down to nearest 5 minutes (creating 5-minute long buckets) + limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key_value"), false, false, "", nil, false, 5) + testsRoundingDownToNearestFiveMinutes := []struct { + input time.Time + expectedTimestamp time.Time + }{ + {time.Date(2024, 5, 4, 12, 0, 15, 30, time.UTC), time.Date(2024, 5, 4, 12, 0, 0, 0, time.UTC)}, + {time.Date(2024, 5, 4, 12, 1, 15, 30, time.UTC), time.Date(2024, 5, 4, 12, 0, 0, 0, time.UTC)}, + {time.Date(2024, 5, 4, 12, 4, 15, 30, time.UTC), time.Date(2024, 5, 4, 12, 0, 0, 0, time.UTC)}, + {time.Date(2024, 5, 4, 12, 5, 15, 30, time.UTC), time.Date(2024, 5, 4, 12, 5, 0, 0, time.UTC)}, + } + + for _, testCase := range testsRoundingDownToNearestFiveMinutes { + cacheKey := cacheKeyGenerator.GenerateCacheKey("domain", request.Descriptors[0], limit, testCase.input.Unix()) + unixTime, _ := strconv.ParseInt(strings.Replace(cacheKey.Key, "domain_key_value_", "", 1), 10, 64) + assert.Equal(testCase.expectedTimestamp, time.Unix(unixTime, 0).UTC()) + } + + // Rounding down to nearest 2 hours (creating 2-hour long buckets) + limit = config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key_value"), false, false, "", nil, false, 2) + testsRoundingDownToNearestThreeHours := []struct { + input time.Time + expectedTimestamp time.Time + }{ + {time.Date(2024, 5, 4, 12, 0, 15, 30, time.UTC), time.Date(2024, 5, 4, 12, 0, 0, 0, time.UTC)}, + {time.Date(2024, 5, 4, 13, 0, 15, 30, time.UTC), time.Date(2024, 5, 4, 12, 0, 0, 0, time.UTC)}, + {time.Date(2024, 5, 4, 14, 4, 15, 30, time.UTC), time.Date(2024, 5, 4, 14, 0, 0, 0, time.UTC)}, + {time.Date(2024, 5, 4, 15, 5, 15, 30, time.UTC), time.Date(2024, 5, 4, 14, 0, 0, 0, time.UTC)}, + } + + for _, testCase := range testsRoundingDownToNearestThreeHours { + cacheKey := cacheKeyGenerator.GenerateCacheKey("domain", request.Descriptors[0], limit, testCase.input.Unix()) + unixTime, _ := strconv.ParseInt(strings.Replace(cacheKey.Key, "domain_key_value_", "", 1), 10, 64) + assert.Equal(testCase.expectedTimestamp, time.Unix(unixTime, 0).UTC()) + } +} diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go index 15d90fd4..8c931c98 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/cache_impl_test.go @@ -53,10 +53,10 @@ func TestMemcached(t *testing.T) { client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -77,12 +77,12 @@ func TestMemcached(t *testing.T) { }, 1) limits = []*config.RateLimit{ nil, - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2_subkey2_subvalue2"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2_subkey2_subvalue2"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ {Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) @@ -111,13 +111,13 @@ func TestMemcached(t *testing.T) { {{"key3", "value3"}, {"subkey3", "subvalue3"}}, }, 1) limits = []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key3_value3"), false, false, "", nil, false), - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, sm.NewStats("key3_value3_subkey3_subvalue3"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key3_value3"), false, false, "", nil, false, 1), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, sm.NewStats("key3_value3_subkey3_subvalue3"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -150,10 +150,10 @@ func TestMemcachedGetError(t *testing.T) { client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -168,10 +168,10 @@ func TestMemcachedGetError(t *testing.T) { client.EXPECT().Increment("domain_key_value1_1234", uint64(1)).Return(uint64(5), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value1"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value1"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value1"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -242,12 +242,12 @@ func TestOverLimitWithLocalCache(t *testing.T) { request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false), + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -268,7 +268,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) @@ -289,7 +289,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) @@ -307,7 +307,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Times(0) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) @@ -343,12 +343,12 @@ func TestNearLimit(t *testing.T) { request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false), + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -365,7 +365,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) @@ -383,7 +383,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) @@ -400,10 +400,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().Increment("domain_key5_value5_1234", uint64(3)).Return(uint64(5), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key5_value5"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key5_value5"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -418,10 +418,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().Increment("domain_key6_value6_1234", uint64(2)).Return(uint64(7), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) - limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key6_value6"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key6_value6"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -436,10 +436,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().Increment("domain_key7_value7_1234", uint64(3)).Return(uint64(19), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key7_value7"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key7_value7"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -454,10 +454,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().Increment("domain_key8_value8_1234", uint64(3)).Return(uint64(22), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key8_value8"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key8_value8"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -472,10 +472,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().Increment("domain_key9_value9_1234", uint64(7)).Return(uint64(22), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key9_value9"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key9_value9"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -490,10 +490,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().Increment("domain_key10_value10_1234", uint64(3)).Return(uint64(30), nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key10_value10"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key10_value10"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) @@ -534,10 +534,10 @@ func TestMemcacheWithJitter(t *testing.T) { ).Return(nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -577,10 +577,10 @@ func TestMemcacheAdd(t *testing.T) { uint64(2), nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -601,10 +601,10 @@ func TestMemcacheAdd(t *testing.T) { ).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -674,7 +674,7 @@ func TestMemcachedTracer(t *testing.T) { client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} cache.DoLimit(context.Background(), request, limits) @@ -695,3 +695,45 @@ func getMultiResult(vals map[string]int) map[string]*memcache.Item { } return result } + +func TestMemcacheWithUnitMultiplier(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + client := mock_memcached.NewMockClient(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + sm := mockstats.NewMockStatManager(statsStore) + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + + // Key is not found in memcache + client.EXPECT().GetMulti([]string{"domain_key_value_1230"}).Return(nil, nil) + // First increment attempt will fail + client.EXPECT().Increment("domain_key_value_1230", uint64(1)).Return( + uint64(0), memcache.ErrCacheMiss) + // Add succeeds + client.EXPECT().Add( + &memcache.Item{ + Key: "domain_key_value_1230", + Value: []byte(strconv.FormatUint(1, 10)), + // 30 seconds due to the unit multiplier + Expiration: int32(30), + }, + ).Return(nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 30)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, + cache.DoLimit(context.Background(), request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) + + cache.Flush() +} diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go index b0b2a215..a7d45e18 100644 --- a/test/redis/bench_test.go +++ b/test/redis/bench_test.go @@ -49,7 +49,7 @@ func BenchmarkParallelDoLimit(b *testing.B) { cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "", sm, true) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} // wait for the pool to fill up for { diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go index 45fe87b2..fe418ca9 100644 --- a/test/redis/fixed_cache_impl_test.go +++ b/test/redis/fixed_cache_impl_test.go @@ -69,10 +69,10 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -94,12 +94,12 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { }, 1) limits = []*config.RateLimit{ nil, - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2_subkey2_subvalue2"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2_subkey2_subvalue2"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ {Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) @@ -124,13 +124,13 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { {{"key3", "value3"}, {"subkey3", "subvalue3"}}, }, 1) limits = []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key3_value3"), false, false, "", nil, false), - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, sm.NewStats("key3_value3_subkey3_subvalue3"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key3_value3"), false, false, "", nil, false, 1), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, sm.NewStats("key3_value3_subkey3_subvalue3"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -197,20 +197,20 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_993600", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + "EXPIRE", "domain_key4_value4_993600", int64(7200)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false), + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false, 2), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -224,14 +224,14 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_993600", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + "EXPIRE", "domain_key4_value4_993600", int64(7200)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) @@ -245,14 +245,14 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_993600", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) + "EXPIRE", "domain_key4_value4_993600", int64(7200)).DoAndReturn(pipeAppend) client.EXPECT().PipeDo(gomock.Any()).Return(nil) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) @@ -271,7 +271,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { "EXPIRE", "domain_key4_value4_997200", int64(3600)).Times(0) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) @@ -305,12 +305,12 @@ func TestNearLimit(t *testing.T) { request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false), + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -327,7 +327,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) @@ -345,7 +345,7 @@ func TestNearLimit(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) @@ -361,10 +361,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key5_value5"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key5_value5"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -378,10 +378,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) - limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key6_value6"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key6_value6"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -395,10 +395,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key7_value7"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key7_value7"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -412,10 +412,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key8_value8"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key8_value8"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -429,10 +429,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key9_value9"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key9_value9"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) @@ -446,10 +446,10 @@ func TestNearLimit(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key10_value10"), false, false, "", nil, false)} + limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key10_value10"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) @@ -457,6 +457,54 @@ func TestNearLimit(t *testing.T) { assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) } +func TestRedisWithUnitMultiplier(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_redis.NewMockClient(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := gostats.NewStore(gostats.NewNullSink(), false) + sm := stats.NewMockStatManager(statsStore) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false) + + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_999990", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_999990", int64(30)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 30), + } + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, + cache.DoLimit(context.Background(), request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) + + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key2_value2_999600", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) + client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key2_value2_999600", int64(600)).DoAndReturn(pipeAppend) + client.EXPECT().PipeDo(gomock.Any()).Return(nil) + + request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) + limits = []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2"), false, false, "", nil, false, 10), + } + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, + cache.DoLimit(context.Background(), request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) +} + func TestRedisWithJitter(t *testing.T) { assert := assert.New(t) controller := gomock.NewController(t) @@ -476,10 +524,10 @@ func TestRedisWithJitter(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}}, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) @@ -511,12 +559,12 @@ func TestOverLimitWithLocalCacheShadowRule(t *testing.T) { request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, true, "", nil, false), + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, true, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -537,7 +585,7 @@ func TestOverLimitWithLocalCacheShadowRule(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) @@ -559,7 +607,7 @@ func TestOverLimitWithLocalCacheShadowRule(t *testing.T) { // The result should be OK since limit is in ShadowMode assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) @@ -581,7 +629,7 @@ func TestOverLimitWithLocalCacheShadowRule(t *testing.T) { // The result should be OK since limit is in ShadowMode assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) @@ -619,7 +667,7 @@ func TestRedisTracer(t *testing.T) { client.EXPECT().PipeDo(gomock.Any()).Return(nil) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, "", nil, false, 1)} cache.DoLimit(context.Background(), request, limits) spanStubs := testSpanExporter.GetSpans() @@ -659,14 +707,14 @@ func TestOverLimitWithStopCacheKeyIncrementWhenOverlimitConfig(t *testing.T) { request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}, {{"key5", "value5"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false), - config.NewRateLimit(14, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key5_value5"), false, false, "", nil, false), + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"), false, false, "", nil, false, 1), + config.NewRateLimit(14, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key5_value5"), false, false, "", nil, false, 1), } assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 3, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 3, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) @@ -697,8 +745,8 @@ func TestOverLimitWithStopCacheKeyIncrementWhenOverlimitConfig(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) @@ -729,8 +777,8 @@ func TestOverLimitWithStopCacheKeyIncrementWhenOverlimitConfig(t *testing.T) { assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource)}, - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(&limits[0].Limit.Unit, timeSource, limits[0].Limit.UnitMultiplier)}, + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(&limits[1].Limit.Unit, timeSource, limits[1].Limit.UnitMultiplier)}, }, cache.DoLimit(context.Background(), request, limits)) assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) diff --git a/test/service/ratelimit_test.go b/test/service/ratelimit_test.go index 18c3cafc..bf7e800e 100644 --- a/test/service/ratelimit_test.go +++ b/test/service/ratelimit_test.go @@ -153,7 +153,7 @@ func TestService(test *testing.T) { request = common.NewRateLimitRequest( "different-domain", [][][2]string{{{"foo", "bar"}}, {{"hello", "world"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "key_name", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "key_name", nil, false, 1), nil, } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) @@ -187,7 +187,7 @@ func TestService(test *testing.T) { // Config should still be valid. Also make sure order does not affect results. limits = []*config.RateLimit{ nil, - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false, 1), } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[1]).Return(limits[1]) @@ -241,7 +241,7 @@ func TestServiceGlobalShadowMode(test *testing.T) { // Global Shadow mode limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false, 1), nil, } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) @@ -281,8 +281,8 @@ func TestRuleShadowMode(test *testing.T) { request := common.NewRateLimitRequest( "different-domain", [][][2]string{{{"foo", "bar"}}, {{"hello", "world"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, true, "", nil, false), - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, true, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, true, "", nil, false, 1), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, true, "", nil, false, 1), } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[1]).Return(limits[1]) @@ -314,8 +314,8 @@ func TestMixedRuleShadowMode(test *testing.T) { request := common.NewRateLimitRequest( "different-domain", [][][2]string{{{"foo", "bar"}}, {{"hello", "world"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, true, "", nil, false), - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, true, "", nil, false, 1), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false, 1), } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[1]).Return(limits[1]) @@ -374,7 +374,7 @@ func TestServiceWithCustomRatelimitHeaders(test *testing.T) { request := common.NewRateLimitRequest( "different-domain", [][][2]string{{{"foo", "bar"}}, {{"hello", "world"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false, 1), nil, } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) @@ -427,7 +427,7 @@ func TestServiceWithDefaultRatelimitHeaders(test *testing.T) { request := common.NewRateLimitRequest( "different-domain", [][][2]string{{{"foo", "bar"}}, {{"hello", "world"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false, 1), nil, } t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) @@ -487,7 +487,7 @@ func TestCacheError(test *testing.T) { service := t.setupBasicService() request := common.NewRateLimitRequest("different-domain", [][][2]string{{{"foo", "bar"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false)} + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("key"), false, false, "", nil, false, 1)} t.config.EXPECT().GetLimit(context.Background(), "different-domain", request.Descriptors[0]).Return(limits[0]) t.cache.EXPECT().DoLimit(context.Background(), request, limits).Do( func(context.Context, *pb.RateLimitRequest, []*config.RateLimit) { @@ -529,9 +529,9 @@ func TestUnlimited(test *testing.T) { request := common.NewRateLimitRequest( "some-domain", [][][2]string{{{"foo", "bar"}}, {{"hello", "world"}}, {{"baz", "qux"}}}, 1) limits := []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("foo_bar"), false, false, "", nil, false), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, t.statsManager.NewStats("foo_bar"), false, false, "", nil, false, 1), nil, - config.NewRateLimit(55, pb.RateLimitResponse_RateLimit_SECOND, t.statsManager.NewStats("baz_qux"), true, false, "", nil, false), + config.NewRateLimit(55, pb.RateLimitResponse_RateLimit_SECOND, t.statsManager.NewStats("baz_qux"), true, false, "", nil, false, 1), } t.config.EXPECT().GetLimit(context.Background(), "some-domain", request.Descriptors[0]).Return(limits[0]) t.config.EXPECT().GetLimit(context.Background(), "some-domain", request.Descriptors[1]).Return(limits[1]) From 46aae3bb6c82577fe1c0fff199de36e23a7f3c32 Mon Sep 17 00:00:00 2001 From: Tobias Sommer Date: Mon, 30 Sep 2024 11:12:59 +0200 Subject: [PATCH 2/2] feat(xds): add rate limit unit multiplier Signed-off-by: Tobias Sommer --- .../config/ratelimit/v3/rls_conf.proto | 4 + src/config/config_xds.go | 1 + test/provider/xds_grpc_sotw_provider_test.go | 75 +++++++++++++++---- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/api/ratelimit/config/ratelimit/v3/rls_conf.proto b/api/ratelimit/config/ratelimit/v3/rls_conf.proto index a1b90fc4..02edc339 100644 --- a/api/ratelimit/config/ratelimit/v3/rls_conf.proto +++ b/api/ratelimit/config/ratelimit/v3/rls_conf.proto @@ -66,6 +66,10 @@ message RateLimitPolicy { // For more information: https://github.com/envoyproxy/ratelimit/tree/0b2f4d5fb04bf55e1873e2c5e2bb28da67c0643f#replaces // Example: https://github.com/envoyproxy/ratelimit/tree/0b2f4d5fb04bf55e1873e2c5e2bb28da67c0643f#example-7 repeated RateLimitReplace replaces = 5; + + // Multiplier for the unit of time for the rate limit. + // Used to create custom periods e.g. 10 seconds or 5 minutes. + optional uint32 unit_multiplier = 6; } // Replace specifies the rate limit policy that should be replaced (dropped evaluation). diff --git a/src/config/config_xds.go b/src/config/config_xds.go index f6c67ce2..2ec89106 100644 --- a/src/config/config_xds.go +++ b/src/config/config_xds.go @@ -35,6 +35,7 @@ func rateLimitPolicyPbToYaml(pb *rls_conf_v3.RateLimitPolicy) *YamlRateLimit { return &YamlRateLimit{ RequestsPerUnit: pb.RequestsPerUnit, Unit: pb.Unit.String(), + UnitMultiplier: pb.UnitMultiplier, Unlimited: pb.Unlimited, Name: pb.Name, Replaces: rateLimitReplacesPbToYaml(pb.Replaces), diff --git a/test/provider/xds_grpc_sotw_provider_test.go b/test/provider/xds_grpc_sotw_provider_test.go index 49a124bd..c2435259 100644 --- a/test/provider/xds_grpc_sotw_provider_test.go +++ b/test/provider/xds_grpc_sotw_provider_test.go @@ -13,6 +13,7 @@ import ( "github.com/envoyproxy/go-control-plane/pkg/cache/v3" "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/provider" "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/test/common" @@ -27,6 +28,7 @@ const ( ) func TestXdsProvider(t *testing.T) { + var unitMultiplier uint32 = 4 intSnapshot, _ := cache.NewSnapshot("1", map[resource.Type][]types.Resource{ resource.RateLimitConfigType: { @@ -40,6 +42,7 @@ func TestXdsProvider(t *testing.T) { RateLimit: &rls_config.RateLimitPolicy{ Unit: rls_config.RateLimitUnit_MINUTE, RequestsPerUnit: 3, + UnitMultiplier: &unitMultiplier, }, }, }, @@ -73,7 +76,9 @@ func TestXdsProvider(t *testing.T) { if err != nil { t.Error("Error setting 'MERGE_DOMAIN_CONFIG' environment variable", err) } - t.Run("Test same domain multiple times xDS config update", testSameDomainMultipleXdsConfigUpdate(setSnapshotFunc, providerEventChan)) + t.Run("Test same domain multiple times xDS config update", testSameDomainMultipleXdsConfigUpdate(&snapVersion, setSnapshotFunc, providerEventChan)) + + t.Run("Test setting unit multiplier to 0", testZeroUnitMultiplier(&snapVersion, setSnapshotFunc, providerEventChan)) } func testInitialXdsConfig(snapVersion *int, setSnapshotFunc common.SetSnapshotFunc, providerEventChan <-chan provider.ConfigUpdateEvent) func(t *testing.T) { @@ -86,7 +91,7 @@ func testInitialXdsConfig(snapVersion *int, setSnapshotFunc common.SetSnapshotFu config, err := configEvent.GetConfig() assert.Nil(err) - assert.Equal("foo.k1_v1: unit=MINUTE requests_per_unit=3, shadow_mode: false\n", config.Dump()) + assert.Equal("foo.k1_v1: unit=MINUTE, unit_multiplier=4, requests_per_unit=3, shadow_mode: false\n", config.Dump()) } } @@ -122,7 +127,7 @@ func testNewXdsConfigUpdate(snapVersion *int, setSnapshotFunc common.SetSnapshot config, err := configEvent.GetConfig() assert.Nil(err) - assert.Equal("foo.k2_v2: unit=MINUTE requests_per_unit=5, shadow_mode: false\n", config.Dump()) + assert.Equal("foo.k2_v2: unit=MINUTE, unit_multiplier=1, requests_per_unit=5, shadow_mode: false\n", config.Dump()) } } @@ -173,8 +178,8 @@ func testMultiDomainXdsConfigUpdate(snapVersion *int, setSnapshotFunc common.Set config, err := configEvent.GetConfig() assert.Nil(err) assert.ElementsMatch([]string{ - "foo.k1_v1: unit=MINUTE requests_per_unit=10, shadow_mode: false", - "bar.k1_v1: unit=MINUTE requests_per_unit=100, shadow_mode: false", + "foo.k1_v1: unit=MINUTE, unit_multiplier=1, requests_per_unit=10, shadow_mode: false", + "bar.k1_v1: unit=MINUTE, unit_multiplier=1, requests_per_unit=100, shadow_mode: false", }, strings.Split(strings.TrimSuffix(config.Dump(), "\n"), "\n")) } } @@ -266,22 +271,23 @@ func testDeeperLimitsXdsConfigUpdate(snapVersion *int, setSnapshotFunc common.Se config, err := configEvent.GetConfig() assert.Nil(err) assert.ElementsMatch([]string{ - "foo.k1_v1: unit=MINUTE requests_per_unit=10, shadow_mode: false", - "foo.k1_v1.k2: unit=UNKNOWN requests_per_unit=0, shadow_mode: false", - "foo.k1_v1.k2_v2: unit=HOUR requests_per_unit=15, shadow_mode: false", - "foo.j1_v2: unit=UNKNOWN requests_per_unit=0, shadow_mode: false", - "foo.j1_v2.j2: unit=UNKNOWN requests_per_unit=0, shadow_mode: false", - "foo.j1_v2.j2_v2: unit=DAY requests_per_unit=15, shadow_mode: true", - "bar.k1_v1: unit=MINUTE requests_per_unit=100, shadow_mode: false", + "foo.k1_v1: unit=MINUTE, unit_multiplier=1, requests_per_unit=10, shadow_mode: false", + "foo.k1_v1.k2: unit=UNKNOWN, unit_multiplier=1, requests_per_unit=0, shadow_mode: false", + "foo.k1_v1.k2_v2: unit=HOUR, unit_multiplier=1, requests_per_unit=15, shadow_mode: false", + "foo.j1_v2: unit=UNKNOWN, unit_multiplier=1, requests_per_unit=0, shadow_mode: false", + "foo.j1_v2.j2: unit=UNKNOWN, unit_multiplier=1, requests_per_unit=0, shadow_mode: false", + "foo.j1_v2.j2_v2: unit=DAY, unit_multiplier=1, requests_per_unit=15, shadow_mode: true", + "bar.k1_v1: unit=MINUTE, unit_multiplier=1, requests_per_unit=100, shadow_mode: false", }, strings.Split(strings.TrimSuffix(config.Dump(), "\n"), "\n")) } } -func testSameDomainMultipleXdsConfigUpdate(setSnapshotFunc common.SetSnapshotFunc, providerEventChan <-chan provider.ConfigUpdateEvent) func(t *testing.T) { +func testSameDomainMultipleXdsConfigUpdate(snapVersion *int, setSnapshotFunc common.SetSnapshotFunc, providerEventChan <-chan provider.ConfigUpdateEvent) func(t *testing.T) { + *snapVersion += 1 return func(t *testing.T) { assert := assert.New(t) - snapshot, _ := cache.NewSnapshot("3", + snapshot, _ := cache.NewSnapshot(fmt.Sprint(*snapVersion), map[resource.Type][]types.Resource{ resource.RateLimitConfigType: { &rls_config.RateLimitConfig{ @@ -323,8 +329,45 @@ func testSameDomainMultipleXdsConfigUpdate(setSnapshotFunc common.SetSnapshotFun config, err := configEvent.GetConfig() assert.Nil(err) assert.ElementsMatch([]string{ - "foo.k1_v2: unit=MINUTE requests_per_unit=100, shadow_mode: false", - "foo.k1_v1: unit=MINUTE requests_per_unit=10, shadow_mode: false", + "foo.k1_v2: unit=MINUTE, unit_multiplier=1, requests_per_unit=100, shadow_mode: false", + "foo.k1_v1: unit=MINUTE, unit_multiplier=1, requests_per_unit=10, shadow_mode: false", }, strings.Split(strings.TrimSuffix(config.Dump(), "\n"), "\n")) } } + +func testZeroUnitMultiplier(snapVersion *int, setSnapshotFunc common.SetSnapshotFunc, providerEventChan <-chan provider.ConfigUpdateEvent) func(t *testing.T) { + *snapVersion += 1 + return func(t *testing.T) { + assert := assert.New(t) + + snapshot, _ := cache.NewSnapshot(fmt.Sprint(*snapVersion), + map[resource.Type][]types.Resource{ + resource.RateLimitConfigType: { + &rls_config.RateLimitConfig{ + Name: "foo-1", + Domain: "foo", + Descriptors: []*rls_config.RateLimitDescriptor{ + { + Key: "k1", + Value: "v1", + RateLimit: &rls_config.RateLimitPolicy{ + Unit: rls_config.RateLimitUnit_MINUTE, + RequestsPerUnit: 10, + UnitMultiplier: new(uint32), + }, + }, + }, + }, + }, + }, + ) + setSnapshotFunc(snapshot) + + configEvent := <-providerEventChan + assert.NotNil(configEvent) + + _, err := configEvent.GetConfig() + configError, _ := err.(config.RateLimitConfigError) + assert.Equal("foo-1: invalid unit multiplier of 0", configError.Error()) + } +}