Skip to content

Commit

Permalink
feat: add rate limit unit multiplier
Browse files Browse the repository at this point in the history
Signed-off-by: Tobias Sommer <[email protected]>
  • Loading branch information
Tobias Sommer authored and harpunius committed Apr 6, 2024
1 parent 4537d29 commit 80786c0
Show file tree
Hide file tree
Showing 18 changed files with 484 additions and 162 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ descriptors:
- name: (optional)
unit: <see below: required>
requests_per_unit: <see below: required>
unit_multiplier: <see below: optional>
shadow_mode: (optional)
detailed_metric: (optional)
descriptors: (optional block)
Expand All @@ -262,11 +263,15 @@ effectively whitelisted. Otherwise, nested descriptors allow more complex matchi
rate_limit:
unit: <second, minute, hour, day>
requests_per_unit: <uint>
unit_multiplier: <uint>
```
The rate limit block specifies the actual rate limit that will be used when there is a match.
Currently the service supports per second, minute, hour, and day limits. More types of limits may be added in the
future based on user demand.
The `unit_multiplier` allows for creating custom rate limit durations in combination with `unit`.
This allows for rate limit durations such as 30 seconds or 5 minutes.
A `unit_multiplier` of 0 is invalid and leaving out the field means the duration is equal to the unit (e.g. 1 minute).
### Replaces
Expand Down
53 changes: 37 additions & 16 deletions src/config/config_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ type yamlReplaces struct {
type YamlRateLimit struct {
RequestsPerUnit uint32 `yaml:"requests_per_unit"`
Unit string
Unlimited bool `yaml:"unlimited"`
UnitMultiplier *uint32 `yaml:"unit_multiplier"`
Unlimited bool `yaml:"unlimited"`
Name string
Replaces []yamlReplaces
}
Expand Down Expand Up @@ -68,23 +69,26 @@ var validKeys = map[string]bool{
"name": true,
"replaces": true,
"detailed_metric": true,
"unit_multiplier": true,
}

// Create a new rate limit config entry.
// @param requestsPerUnit supplies the requests per unit of time for the entry.
// @param unit supplies the unit of time for the entry.
// @param unitMultiplier supplies the multiplier for the unit of time for the entry.
// @param rlStats supplies the stats structure associated with the RateLimit
// @param unlimited supplies whether the rate limit is unlimited
// @return the new config entry.
func NewRateLimit(requestsPerUnit uint32, unit pb.RateLimitResponse_RateLimit_Unit, rlStats stats.RateLimitStats,
unlimited bool, shadowMode bool, name string, replaces []string, detailedMetric bool) *RateLimit {

unlimited bool, shadowMode bool, name string, replaces []string, detailedMetric bool, unitMultiplier uint32,
) *RateLimit {
return &RateLimit{
FullKey: rlStats.GetKey(),
Stats: rlStats,
Limit: &pb.RateLimitResponse_RateLimit{
RequestsPerUnit: requestsPerUnit,
Unit: unit,
UnitMultiplier: unitMultiplier,
},
Unlimited: unlimited,
ShadowMode: shadowMode,
Expand All @@ -99,8 +103,8 @@ func (this *rateLimitDescriptor) dump() string {
ret := ""
if this.limit != nil {
ret += fmt.Sprintf(
"%s: unit=%s requests_per_unit=%d, shadow_mode: %t\n", this.limit.FullKey,
this.limit.Limit.Unit.String(), this.limit.Limit.RequestsPerUnit, this.limit.ShadowMode)
"%s: unit=%s, unit_multiplier=%d, requests_per_unit=%d, shadow_mode: %t\n", this.limit.FullKey,
this.limit.Limit.Unit.String(), this.limit.Limit.UnitMultiplier, this.limit.Limit.RequestsPerUnit, this.limit.ShadowMode)
}
for _, descriptor := range this.descriptors {
ret += descriptor.dump()
Expand Down Expand Up @@ -143,8 +147,7 @@ func (this *rateLimitDescriptor) loadDescriptors(config RateLimitConfigToLoad, p
if descriptorConfig.RateLimit != nil {
unlimited := descriptorConfig.RateLimit.Unlimited

value, present :=
pb.RateLimitResponse_RateLimit_Unit_value[strings.ToUpper(descriptorConfig.RateLimit.Unit)]
value, present := pb.RateLimitResponse_RateLimit_Unit_value[strings.ToUpper(descriptorConfig.RateLimit.Unit)]
validUnit := present && value != int32(pb.RateLimitResponse_RateLimit_UNKNOWN)

if unlimited {
Expand All @@ -159,6 +162,18 @@ func (this *rateLimitDescriptor) loadDescriptors(config RateLimitConfigToLoad, p
fmt.Sprintf("invalid rate limit unit '%s'", descriptorConfig.RateLimit.Unit)))
}

var unitMultiplier uint32
if descriptorConfig.RateLimit.UnitMultiplier == nil {
unitMultiplier = 1
} else {
unitMultiplier = *descriptorConfig.RateLimit.UnitMultiplier
if unitMultiplier == 0 {
panic(newRateLimitConfigError(
config.Name,
"invalid unit multiplier of 0"))
}
}

replaces := make([]string, len(descriptorConfig.RateLimit.Replaces))
for i, e := range descriptorConfig.RateLimit.Replaces {
replaces[i] = e.Name
Expand All @@ -168,10 +183,12 @@ func (this *rateLimitDescriptor) loadDescriptors(config RateLimitConfigToLoad, p
descriptorConfig.RateLimit.RequestsPerUnit, pb.RateLimitResponse_RateLimit_Unit(value),
statsManager.NewStats(newParentKey), unlimited, descriptorConfig.ShadowMode,
descriptorConfig.RateLimit.Name, replaces, descriptorConfig.DetailedMetric,
unitMultiplier,
)

rateLimitDebugString = fmt.Sprintf(
" ratelimit={requests_per_unit=%d, unit=%s, unlimited=%t, shadow_mode=%t}", rateLimit.Limit.RequestsPerUnit,
rateLimit.Limit.Unit.String(), rateLimit.Unlimited, rateLimit.ShadowMode)
" ratelimit={requests_per_unit=%d, unit=%s, unit_multiplier=%d, unlimited=%t, shadow_mode=%t}", rateLimit.Limit.RequestsPerUnit,
rateLimit.Limit.Unit.String(), unitMultiplier, rateLimit.Unlimited, rateLimit.ShadowMode)

for _, replaces := range descriptorConfig.RateLimit.Replaces {
if replaces.Name == "" {
Expand Down Expand Up @@ -277,8 +294,8 @@ func (this *rateLimitConfigImpl) Dump() string {
}

func (this *rateLimitConfigImpl) GetLimit(
ctx context.Context, domain string, descriptor *pb_struct.RateLimitDescriptor) *RateLimit {

ctx context.Context, domain string, descriptor *pb_struct.RateLimitDescriptor,
) *RateLimit {
logger.Debugf("starting get limit lookup")
var rateLimit *RateLimit = nil
value := this.domains[domain]
Expand All @@ -300,6 +317,7 @@ func (this *rateLimitConfigImpl) GetLimit(
"",
[]string{},
false,
1,
)
return rateLimit
}
Expand Down Expand Up @@ -352,7 +370,10 @@ func (this *rateLimitConfigImpl) GetLimit(
descriptorsMap = nextDescriptor.descriptors
} else {
if rateLimit != nil && rateLimit.DetailedMetric {
rateLimit = NewRateLimit(rateLimit.Limit.RequestsPerUnit, rateLimit.Limit.Unit, this.statsManager.NewStats(rateLimit.FullKey), rateLimit.Unlimited, rateLimit.ShadowMode, rateLimit.Name, rateLimit.Replaces, rateLimit.DetailedMetric)
rateLimit = NewRateLimit(rateLimit.Limit.RequestsPerUnit, rateLimit.Limit.Unit,
this.statsManager.NewStats(rateLimit.FullKey), rateLimit.Unlimited,
rateLimit.ShadowMode, rateLimit.Name, rateLimit.Replaces,
rateLimit.DetailedMetric, rateLimit.Limit.UnitMultiplier)
}

break
Expand Down Expand Up @@ -417,8 +438,8 @@ func ConfigFileContentToYaml(fileName, content string) *YamlRoot {
// @param mergeDomainConfigs defines whether multiple configurations referencing the same domain will be merged or rejected throwing an error.
// @return a new config.
func NewRateLimitConfigImpl(
configs []RateLimitConfigToLoad, statsManager stats.Manager, mergeDomainConfigs bool) RateLimitConfig {

configs []RateLimitConfigToLoad, statsManager stats.Manager, mergeDomainConfigs bool,
) RateLimitConfig {
ret := &rateLimitConfigImpl{map[string]*rateLimitDomain{}, statsManager, mergeDomainConfigs}
for _, config := range configs {
ret.loadConfig(config)
Expand All @@ -430,8 +451,8 @@ func NewRateLimitConfigImpl(
type rateLimitConfigLoaderImpl struct{}

func (this *rateLimitConfigLoaderImpl) Load(
configs []RateLimitConfigToLoad, statsManager stats.Manager, mergeDomainConfigs bool) RateLimitConfig {

configs []RateLimitConfigToLoad, statsManager stats.Manager, mergeDomainConfigs bool,
) RateLimitConfig {
return NewRateLimitConfigImpl(configs, statsManager, mergeDomainConfigs)
}

Expand Down
20 changes: 13 additions & 7 deletions src/limiter/base_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ type LimitInfo struct {
}

func NewRateLimitInfo(limit *config.RateLimit, limitBeforeIncrease uint32, limitAfterIncrease uint32,
nearLimitThreshold uint32, overLimitThreshold uint32) *LimitInfo {
nearLimitThreshold uint32, overLimitThreshold uint32,
) *LimitInfo {
return &LimitInfo{
limit: limit, limitBeforeIncrease: limitBeforeIncrease, limitAfterIncrease: limitAfterIncrease,
nearLimitThreshold: nearLimitThreshold, overLimitThreshold: overLimitThreshold,
Expand All @@ -43,7 +44,8 @@ func NewRateLimitInfo(limit *config.RateLimit, limitBeforeIncrease uint32, limit
// Generates cache keys for given rate limit request. Each cache key is represented by a concatenation of
// domain, descriptor and current timestamp.
func (this *BaseRateLimiter) GenerateCacheKeys(request *pb.RateLimitRequest,
limits []*config.RateLimit, hitsAddend uint32) []CacheKey {
limits []*config.RateLimit, hitsAddend uint32,
) []CacheKey {
assert.Assert(len(request.Descriptors) == len(limits))
cacheKeys := make([]CacheKey, len(request.Descriptors))
now := this.timeSource.UnixNow()
Expand Down Expand Up @@ -79,7 +81,8 @@ func (this *BaseRateLimiter) IsOverLimitThresholdReached(limitInfo *LimitInfo) b
// Generates response descriptor status based on cache key, over the limit with local cache, over the limit and
// near the limit thresholds. Thresholds are checked in order and are mutually exclusive.
func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo *LimitInfo,
isOverLimitWithLocalCache bool, hitsAddend uint32) *pb.RateLimitResponse_DescriptorStatus {
isOverLimitWithLocalCache bool, hitsAddend uint32,
) *pb.RateLimitResponse_DescriptorStatus {
if key == "" {
return this.generateResponseDescriptorStatus(pb.RateLimitResponse_OK,
nil, 0)
Expand Down Expand Up @@ -113,7 +116,8 @@ func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo *
// similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start
// to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m).
// In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited.
err := this.localCache.Set([]byte(key), []byte{}, int(utils.UnitToDivider(limitInfo.limit.Limit.Unit)))

err := this.localCache.Set([]byte(key), []byte{}, int(utils.UnitToDividerWithMultiplier(limitInfo.limit.Limit.Unit, limitInfo.limit.Limit.UnitMultiplier)))
if err != nil {
logger.Errorf("Failing to set local cache key: %s", key)
}
Expand All @@ -140,7 +144,8 @@ func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo *
}

func NewBaseRateLimit(timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64,
localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager) *BaseRateLimiter {
localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager,
) *BaseRateLimiter {
return &BaseRateLimiter{
timeSource: timeSource,
JitterRand: jitterRand,
Expand Down Expand Up @@ -194,13 +199,14 @@ func (this *BaseRateLimiter) increaseShadowModeStats(isOverLimitWithLocalCache b
}

func (this *BaseRateLimiter) generateResponseDescriptorStatus(responseCode pb.RateLimitResponse_Code,
limit *pb.RateLimitResponse_RateLimit, limitRemaining uint32) *pb.RateLimitResponse_DescriptorStatus {
limit *pb.RateLimitResponse_RateLimit, limitRemaining uint32,
) *pb.RateLimitResponse_DescriptorStatus {
if limit != nil {
return &pb.RateLimitResponse_DescriptorStatus{
Code: responseCode,
CurrentLimit: limit,
LimitRemaining: limitRemaining,
DurationUntilReset: utils.CalculateReset(&limit.Unit, this.timeSource),
DurationUntilReset: utils.CalculateReset(&limit.Unit, this.timeSource, limit.UnitMultiplier),
}
} else {
return &pb.RateLimitResponse_DescriptorStatus{
Expand Down
7 changes: 4 additions & 3 deletions src/limiter/cache_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func isPerSecondLimit(unit pb.RateLimitResponse_RateLimit_Unit) bool {
// @param now supplies the current unix time.
// @return CacheKey struct.
func (this *CacheKeyGenerator) GenerateCacheKey(
domain string, descriptor *pb_struct.RateLimitDescriptor, limit *config.RateLimit, now int64) CacheKey {

domain string, descriptor *pb_struct.RateLimitDescriptor, limit *config.RateLimit, now int64,
) CacheKey {
if limit == nil {
return CacheKey{
Key: "",
Expand All @@ -70,7 +70,8 @@ func (this *CacheKeyGenerator) GenerateCacheKey(
b.WriteByte('_')
}

divider := utils.UnitToDivider(limit.Limit.Unit)
divider := utils.UnitToDividerWithMultiplier(limit.Limit.Unit, limit.Limit.UnitMultiplier)

b.WriteString(strconv.FormatInt((now/divider)*divider, 10))

return CacheKey{
Expand Down
15 changes: 9 additions & 6 deletions src/memcached/cache_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil)
func (this *rateLimitMemcacheImpl) DoLimit(
ctx context.Context,
request *pb.RateLimitRequest,
limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus {

limits []*config.RateLimit,
) []*pb.RateLimitResponse_DescriptorStatus {
logger.Debugf("starting cache lookup")

// request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request.
Expand Down Expand Up @@ -148,7 +148,8 @@ func (this *rateLimitMemcacheImpl) DoLimit(
}

func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool,
limits []*config.RateLimit, hitsAddend uint64) {
limits []*config.RateLimit, hitsAddend uint64,
) {
defer this.waitGroup.Done()
for i, cacheKey := range cacheKeys {
if cacheKey.Key == "" || isOverLimitWithLocalCache[i] {
Expand All @@ -157,7 +158,7 @@ func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, i

_, err := this.client.Increment(cacheKey.Key, hitsAddend)
if err == memcache.ErrCacheMiss {
expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit)
expirationSeconds := utils.UnitToDividerWithMultiplier(limits[i].Limit.Unit, limits[i].Limit.UnitMultiplier)
if this.expirationJitterMaxSeconds > 0 {
expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds)
}
Expand Down Expand Up @@ -290,7 +291,8 @@ func runAsync(task func()) {
}

func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand,
expirationJitterMaxSeconds int64, localCache *freecache.Cache, statsManager stats.Manager, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache {
expirationJitterMaxSeconds int64, localCache *freecache.Cache, statsManager stats.Manager, nearLimitRatio float32, cacheKeyPrefix string,
) limiter.RateLimitCache {
return &rateLimitMemcacheImpl{
client: client,
timeSource: timeSource,
Expand All @@ -303,7 +305,8 @@ func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRan
}

func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand,
localCache *freecache.Cache, scope gostats.Scope, statsManager stats.Manager) limiter.RateLimitCache {
localCache *freecache.Cache, scope gostats.Scope, statsManager stats.Manager,
) limiter.RateLimitCache {
return NewRateLimitCacheImpl(
CollectStats(newMemcacheFromSettings(s), scope.Scope("memcache")),
timeSource,
Expand Down
9 changes: 5 additions & 4 deletions src/redis/fixed_cache_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func pipelineAppendtoGet(client Client, pipeline *Pipeline, key string, result *
func (this *fixedRateLimitCacheImpl) DoLimit(
ctx context.Context,
request *pb.RateLimitRequest,
limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus {

limits []*config.RateLimit,
) []*pb.RateLimitResponse_DescriptorStatus {
logger.Debugf("starting cache lookup")

// request.HitsAddend could be 0 (default value) if not specified by the caller in the RateLimit request.
Expand Down Expand Up @@ -152,7 +152,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit(

logger.Debugf("looking up cache key: %s", cacheKey.Key)

expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit)
expirationSeconds := utils.UnitToDividerWithMultiplier(limits[i].Limit.Unit, limits[i].Limit.UnitMultiplier)
if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 {
expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds)
}
Expand Down Expand Up @@ -218,7 +218,8 @@ func (this *fixedRateLimitCacheImpl) Flush() {}

func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource,
jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager,
stopCacheKeyIncrementWhenOverlimit bool) limiter.RateLimitCache {
stopCacheKeyIncrementWhenOverlimit bool,
) limiter.RateLimitCache {
return &fixedRateLimitCacheImpl{
client: client,
perSecondClient: perSecondClient,
Expand Down
21 changes: 11 additions & 10 deletions src/service/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ func (this *service) constructLimitsToCheck(request *pb.RateLimitRequest, ctx co
logger.Debugf("descriptor is unlimited, not passing to the cache")
} else {
logger.Debugf(
"applying limit: %d requests per %s, shadow_mode: %t",
"applying limit: %d requests per %d %s, shadow_mode: %t",
limitsToCheck[i].Limit.RequestsPerUnit,
limitsToCheck[i].Limit.UnitMultiplier,
limitsToCheck[i].Limit.Unit.String(),
limitsToCheck[i].ShadowMode,
)
Expand Down Expand Up @@ -177,8 +178,8 @@ func (this *service) constructLimitsToCheck(request *pb.RateLimitRequest, ctx co
const MaxUint32 = uint32(1<<32 - 1)

func (this *service) shouldRateLimitWorker(
ctx context.Context, request *pb.RateLimitRequest) *pb.RateLimitResponse {

ctx context.Context, request *pb.RateLimitRequest,
) *pb.RateLimitResponse {
checkServiceErr(request.Domain != "", "rate limit domain must not be empty")
checkServiceErr(len(request.Descriptors) != 0, "rate limit descriptor list must not be empty")

Expand Down Expand Up @@ -258,18 +259,18 @@ func (this *service) rateLimitRemainingHeader(descriptor *pb.RateLimitResponse_D
}

func (this *service) rateLimitResetHeader(
descriptor *pb.RateLimitResponse_DescriptorStatus) *core.HeaderValue {

descriptor *pb.RateLimitResponse_DescriptorStatus,
) *core.HeaderValue {
return &core.HeaderValue{
Key: this.customHeaderResetHeader,
Value: strconv.FormatInt(utils.CalculateReset(&descriptor.CurrentLimit.Unit, this.customHeaderClock).GetSeconds(), 10),
Value: strconv.FormatInt(utils.CalculateReset(&descriptor.CurrentLimit.Unit, this.customHeaderClock, descriptor.CurrentLimit.UnitMultiplier).GetSeconds(), 10),
}
}

func (this *service) ShouldRateLimit(
ctx context.Context,
request *pb.RateLimitRequest) (finalResponse *pb.RateLimitResponse, finalError error) {

request *pb.RateLimitRequest,
) (finalResponse *pb.RateLimitResponse, finalError error) {
// Generate trace
_, span := tracer.Start(ctx, "ShouldRateLimit Execution",
trace.WithAttributes(
Expand Down Expand Up @@ -316,8 +317,8 @@ func (this *service) GetCurrentConfig() (config.RateLimitConfig, bool) {
}

func NewService(cache limiter.RateLimitCache, configProvider provider.RateLimitConfigProvider, statsManager stats.Manager,
health *server.HealthChecker, clock utils.TimeSource, shadowMode, forceStart bool, healthyWithAtLeastOneConfigLoad bool) RateLimitServiceServer {

health *server.HealthChecker, clock utils.TimeSource, shadowMode, forceStart bool, healthyWithAtLeastOneConfigLoad bool,
) RateLimitServiceServer {
newService := &service{
configLock: sync.RWMutex{},
configUpdateEvent: configProvider.ConfigUpdateEvent(),
Expand Down
Loading

0 comments on commit 80786c0

Please sign in to comment.