From 76a020acfab75bfd860c419e7e75fa7ac1bf492d Mon Sep 17 00:00:00 2001 From: Samantha Date: Fri, 13 Oct 2023 17:26:41 -0400 Subject: [PATCH 1/8] ratelimits: API improvements necessary for batches and limit fixes --- ratelimits/README.md | 15 +- ratelimits/bucket.go | 56 +++++++ ratelimits/limit.go | 16 +- ratelimits/limit_test.go | 120 ++++---------- ratelimits/limiter.go | 147 ++++++++---------- ratelimits/limiter_test.go | 105 +++++++------ ratelimits/names.go | 124 +++++++-------- ratelimits/source_redis.go | 11 +- .../working_override_regid_domain.yml | 2 +- .../working_overrides_regid_fqdnset.yml | 6 +- ratelimits/utilities.go | 33 ++++ ratelimits/utilities_test.go | 27 ++++ sa/rate_limits_test.go | 5 - wfe2/wfe.go | 40 +++-- 14 files changed, 374 insertions(+), 333 deletions(-) create mode 100644 ratelimits/bucket.go create mode 100644 ratelimits/utilities.go create mode 100644 ratelimits/utilities_test.go diff --git a/ratelimits/README.md b/ratelimits/README.md index 7e500aa4931..838edc65605 100644 --- a/ratelimits/README.md +++ b/ratelimits/README.md @@ -79,22 +79,21 @@ Example: `NewRegistrationsPerIPv6Range:2001:0db8:0000::/48` #### regId -The registration ID of the account. +An ACME account registration ID. Example: `NewOrdersPerAccount:12345678` -#### regId:domain +#### domain -A combination of registration ID and domain, formatted 'regId:domain'. +A valid eTLD+1 domain name. -Example: `CertificatesPerDomainPerAccount:12345678:example.com` +Example: `CertificatesPerDomain:example.com` -#### regId:fqdnSet +#### fqdnSet -A combination of registration ID and a comma-separated list of domain names, -formatted 'regId:fqdnSet'. +A comma-separated list of domain names. -Example: `CertificatesPerFQDNSetPerAccount:12345678:example.com,example.org` +Example: `CertificatesPerFQDNSet:example.com,example.org` ## Bucket Key Definitions diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go new file mode 100644 index 00000000000..a8901e10ed3 --- /dev/null +++ b/ratelimits/bucket.go @@ -0,0 +1,56 @@ +package ratelimits + +import ( + "fmt" + "net" +) + +// Bucket identifies a specific subscriber rate limit bucket to the Limiter. +type Bucket struct { + name Name + key string +} + +// BucketWithCost is a bucket with an associated cost. +type BucketWithCost struct { + Bucket + cost int64 +} + +// WithCost returns a BucketWithCost for the provided cost. +func (b Bucket) WithCost(cost int64) BucketWithCost { + return BucketWithCost{b, cost} +} + +// NewRegistrationsPerIPAddressBucket returns a Bucket for the provided IP +// address. +func NewRegistrationsPerIPAddressBucket(ip net.IP) (Bucket, error) { + id := ip.String() + err := validateIdForName(NewRegistrationsPerIPAddress, id) + if err != nil { + return Bucket{}, err + } + return Bucket{ + name: NewRegistrationsPerIPAddress, + key: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), + }, nil +} + +// NewRegistrationsPerIPv6RangeBucket returns a Bucket for the /48 IPv6 range +// containing the provided IPv6 address. +func NewRegistrationsPerIPv6RangeBucket(ip net.IP) (Bucket, error) { + if ip.To4() != nil { + return Bucket{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) + } + ipMask := net.CIDRMask(48, 128) + ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} + id := ipNet.String() + err := validateIdForName(NewRegistrationsPerIPv6Range, id) + if err != nil { + return Bucket{}, err + } + return Bucket{ + name: NewRegistrationsPerIPv6Range, + key: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), + }, nil +} diff --git a/ratelimits/limit.go b/ratelimits/limit.go index 7fb166bce60..261f66e67f3 100644 --- a/ratelimits/limit.go +++ b/ratelimits/limit.go @@ -121,22 +121,14 @@ func loadAndParseOverrideLimits(path string) (limits, error) { return nil, fmt.Errorf( "validating name %s and id %q for override limit %q: %w", name, id, k, err) } - if name == CertificatesPerFQDNSetPerAccount { + if name == CertificatesPerFQDNSet { // FQDNSet hashes are not a nice thing to ask for in a config file, // so we allow the user to specify a comma-separated list of FQDNs // and compute the hash here. - regIdDomains := strings.SplitN(id, ":", 2) - if len(regIdDomains) != 2 { - // Should never happen, the Id format was validated above. - return nil, fmt.Errorf("invalid override limit %q, must be formatted 'name:id'", k) - } - regId := regIdDomains[0] - domains := strings.Split(regIdDomains[1], ",") - fqdnSet := core.HashNames(domains) - id = fmt.Sprintf("%s:%s", regId, fqdnSet) + id = string(core.HashNames(strings.Split(id, ","))) } v.isOverride = true - parsed[bucketKey(name, id)] = precomputeLimit(v) + parsed[joinWithColon(name.EnumString(), id)] = precomputeLimit(v) } return parsed, nil } @@ -159,7 +151,7 @@ func loadAndParseDefaultLimits(path string) (limits, error) { if !ok { return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames) } - parsed[nameToEnumString(name)] = precomputeLimit(v) + parsed[name.EnumString()] = precomputeLimit(v) } return parsed, nil } diff --git a/ratelimits/limit_test.go b/ratelimits/limit_test.go index acabe9a51a0..50aae8c80c0 100644 --- a/ratelimits/limit_test.go +++ b/ratelimits/limit_test.go @@ -78,19 +78,19 @@ func Test_validateIdForName(t *testing.T) { err = validateIdForName(NewOrdersPerAccount, "1234567890") test.AssertNotError(t, err, "valid regId") - // 'enum:regId:domain' + // 'enum:domain' // Valid regId and domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com") + err = validateIdForName(CertificatesPerDomain, "example.com") test.AssertNotError(t, err, "valid regId and domain") - // 'enum:regId:fqdnSet' - // Valid regId and FQDN set containing a single domain. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com") + // 'enum:fqdnSet' + // Valid fqdnSet containing a single domain. + err = validateIdForName(CertificatesPerFQDNSet, "example.com") test.AssertNotError(t, err, "valid regId and FQDN set containing a single domain") - // 'enum:regId:fqdnSet' - // Valid regId and FQDN set containing multiple domains. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org") + // 'enum:fqdnSet' + // Valid fqdnSet containing multiple domains. + err = validateIdForName(CertificatesPerFQDNSet, "example.com,example.org") test.AssertNotError(t, err, "valid regId and FQDN set containing multiple domains") // Empty string. @@ -125,71 +125,20 @@ func Test_validateIdForName(t *testing.T) { err = validateIdForName(NewOrdersPerAccount, "lol") test.AssertError(t, err, "invalid regId") - // Invalid regId with good domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "lol:example.com") - test.AssertError(t, err, "invalid regId with good domain") - - // Valid regId with bad domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:lol") - test.AssertError(t, err, "valid regId with bad domain") - - // Empty regId with good domain. - err = validateIdForName(CertificatesPerDomainPerAccount, ":lol") + // Invalid domain, malformed. + err = validateIdForName(CertificatesPerDomain, "example:.com") test.AssertError(t, err, "valid regId with bad domain") - // Valid regId with empty domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:") + // Invalid domain, empty. + err = validateIdForName(CertificatesPerDomain, "") test.AssertError(t, err, "valid regId with empty domain") - - // Empty regId with empty domain, no separator. - err = validateIdForName(CertificatesPerDomainPerAccount, "") - test.AssertError(t, err, "empty regId with empty domain, no separator") - - // Instead of anything we would expect, we get lol. - err = validateIdForName(CertificatesPerDomainPerAccount, "lol") - test.AssertError(t, err, "instead of anything we would expect, just lol") - - // Valid regId with good domain and a secret third separator. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com:lol") - test.AssertError(t, err, "valid regId with good domain and a secret third separator") - - // Valid regId with bad FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:lol..99") - test.AssertError(t, err, "valid regId with bad FQDN set") - - // Bad regId with good FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol:example.com,example.org") - test.AssertError(t, err, "bad regId with good FQDN set") - - // Empty regId with good FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, ":example.com,example.org") - test.AssertError(t, err, "empty regId with good FQDN set") - - // Good regId with empty FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:") - test.AssertError(t, err, "good regId with empty FQDN set") - - // Empty regId with empty FQDN set, no separator. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "") - test.AssertError(t, err, "empty regId with empty FQDN set, no separator") - - // Instead of anything we would expect, just lol. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol") - test.AssertError(t, err, "instead of anything we would expect, just lol") - - // Valid regId with good FQDN set and a secret third separator. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org:lol") - test.AssertError(t, err, "valid regId with good FQDN set and a secret third separator") } func Test_loadAndParseOverrideLimits(t *testing.T) { - newRegistrationsPerIPAddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress) - newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range) - // Load a single valid override limit with Id formatted as 'enum:RegId'. l, err := loadAndParseOverrideLimits("testdata/working_override.yml") test.AssertNotError(t, err, "valid single override limit") - expectKey := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2" + expectKey := joinWithColon(NewRegistrationsPerIPAddress.EnumString(), "10.0.0.2") test.AssertEquals(t, l[expectKey].Burst, int64(40)) test.AssertEquals(t, l[expectKey].Count, int64(40)) test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) @@ -197,35 +146,35 @@ func Test_loadAndParseOverrideLimits(t *testing.T) { // Load single valid override limit with Id formatted as 'regId:domain'. l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domain.yml") test.AssertNotError(t, err, "valid single override limit with Id of regId:domain") - expectKey = nameToEnumString(CertificatesPerDomainPerAccount) + ":" + "12345678:example.com" + expectKey = joinWithColon(CertificatesPerDomain.EnumString(), "example.com") test.AssertEquals(t, l[expectKey].Burst, int64(40)) test.AssertEquals(t, l[expectKey].Count, int64(40)) test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) // Load multiple valid override limits with 'enum:RegId' Ids. l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml") - expectKey1 := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2" + expectKey1 := joinWithColon(NewRegistrationsPerIPAddress.EnumString(), "10.0.0.2") test.AssertNotError(t, err, "multiple valid override limits") test.AssertEquals(t, l[expectKey1].Burst, int64(40)) test.AssertEquals(t, l[expectKey1].Count, int64(40)) test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second) - expectKey2 := newRegistrationsPerIPv6RangeEnumStr + ":" + "2001:0db8:0000::/48" + expectKey2 := joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), "2001:0db8:0000::/48") test.AssertEquals(t, l[expectKey2].Burst, int64(50)) test.AssertEquals(t, l[expectKey2].Count, int64(50)) test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2) - // Load multiple valid override limits with 'regID:fqdnSet' Ids as follows: - // - CertificatesPerFQDNSetPerAccount:12345678:example.com - // - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net - // - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org + // Load multiple valid override limits with 'fqdnSet' Ids, as follows: + // - CertificatesPerFQDNSet:example.com + // - CertificatesPerFQDNSet:example.com,example.net + // - CertificatesPerFQDNSet:example.com,example.net,example.org firstEntryFQDNSetHash := string(core.HashNames([]string{"example.com"})) secondEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net"})) thirdEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net", "example.org"})) - firstEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + firstEntryFQDNSetHash - secondEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + secondEntryFQDNSetHash - thirdEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + thirdEntryFQDNSetHash + firstEntryKey := joinWithColon(CertificatesPerFQDNSet.EnumString(), firstEntryFQDNSetHash) + secondEntryKey := joinWithColon(CertificatesPerFQDNSet.EnumString(), secondEntryFQDNSetHash) + thirdEntryKey := joinWithColon(CertificatesPerFQDNSet.EnumString(), thirdEntryFQDNSetHash) l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml") - test.AssertNotError(t, err, "multiple valid override limits with Id of regId:fqdnSets") + test.AssertNotError(t, err, "multiple valid override limits with 'fqdnSet' Ids") test.AssertEquals(t, l[firstEntryKey].Burst, int64(40)) test.AssertEquals(t, l[firstEntryKey].Count, int64(40)) test.AssertEquals(t, l[firstEntryKey].Period.Duration, time.Second) @@ -278,25 +227,22 @@ func Test_loadAndParseOverrideLimits(t *testing.T) { } func Test_loadAndParseDefaultLimits(t *testing.T) { - newRestistrationsPerIPv4AddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress) - newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range) - // Load a single valid default limit. l, err := loadAndParseDefaultLimits("testdata/working_default.yml") test.AssertNotError(t, err, "valid single default limit") - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Burst, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Count, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Period.Duration, time.Second) // Load multiple valid default limits. l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml") test.AssertNotError(t, err, "multiple valid default limits") - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second) - test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Burst, int64(30)) - test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Count, int64(30)) - test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Period.Duration, time.Second*2) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Burst, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Count, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Period.Duration, time.Second) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Burst, int64(30)) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Count, int64(30)) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Period.Duration, time.Second*2) // Path is empty string. _, err = loadAndParseDefaultLimits("") diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index b518ef5ffba..6e4b074199e 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -88,8 +88,8 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat limiter.overrideUsageGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{ Name: "ratelimits_override_usage", - Help: "Proportion of override limit used, by limit name and client id.", - }, []string{"limit", "client_id"}) + Help: "Proportion of override limit used, by limit name and bucket key.", + }, []string{"limit", "bucket_key"}) stats.MustRegister(limiter.overrideUsageGauge) return limiter, nil @@ -118,23 +118,17 @@ type Decision struct { newTAT time.Time } -// Check returns a *Decision that indicates whether there's enough capacity to -// allow the request, given the cost, for the specified limit Name and client -// id. However, it DOES NOT deduct the cost of the request from the bucket's -// capacity. Hence, the returned *Decision represents the hypothetical state of -// the bucket if the cost WERE to be deducted. The returned *Decision will -// always include the number of remaining requests in the bucket, the required -// wait time before the client can make another request, and the time until the -// bucket refills to its maximum capacity (resets). If no bucket exists for the -// given limit Name and client id, a new one will be created WITHOUT the -// request's cost deducted from its initial capacity. If the specified limit is -// disabled, ErrLimitDisabled is returned. -func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) (*Decision, error) { - if cost < 0 { +// Check DOES NOT deduct the cost of the request from the provided bucket's +// capacity. The returned *Decision indicates whether the capacity exists to +// satisfy the cost and represents the hypothetical state of the bucket IF the +// cost WERE to be deducted. If no bucket exists it will NOT be created. No +// state is persisted to the underlying datastore. +func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, error) { + if bucket.cost < 0 { return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(name, id) + limit, err := l.getLimit(bucket.name, bucket.key) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -142,45 +136,34 @@ func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) ( return nil, err } - if cost > limit.Burst { - return nil, ErrInvalidCostOverLimit - } - // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucketKey(name, id)) + tat, err := l.source.Get(ctx, bucket.key) if err != nil { if !errors.Is(err, ErrBucketNotFound) { return nil, err } - // First request from this client. The cost is not deducted from the - // initial capacity because this is only a check. - d, err := l.initialize(ctx, limit, name, id, 0) - if err != nil { - return nil, err - } - return maybeSpend(l.clk, limit, d.newTAT, cost), nil + // First request from this client. No need to initialize the bucket + // because this is a check, not a spend. A TAT of "now" is equivalent to + // a full bucket. + return maybeSpend(l.clk, limit, l.clk.Now(), bucket.cost), nil } - return maybeSpend(l.clk, limit, tat, cost), nil + return maybeSpend(l.clk, limit, tat, bucket.cost), nil } -// Spend returns a *Decision that indicates if enough capacity was available to -// process the request, given the cost, for the specified limit Name and client -// id. If capacity existed, the cost of the request HAS been deducted from the -// bucket's capacity, otherwise no cost was deducted. The returned *Decision -// will always include the number of remaining requests in the bucket, the -// required wait time before the client can make another request, and the time -// until the bucket refills to its maximum capacity (resets). If no bucket -// exists for the given limit Name and client id, a new one will be created WITH -// the request's cost deducted from its initial capacity. If the specified limit -// is disabled, ErrLimitDisabled is returned. -func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) (*Decision, error) { - if cost <= 0 { +// Spend attempts to deduct the cost from the provided bucket's capacity. The +// returned *Decision The returned *Decision indicates whether the capacity +// existed to satisfy the cost and represents the current state of the bucket. +// If no bucket exists it WILL be created WITH the cost factored into its +// initial state. The new bucket state is persisted to the underlying datastore, +// if applicable, before returning. +func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, error) { + if bucket.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(name, id) + limit, err := l.getLimit(bucket.name, bucket.key) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -188,24 +171,20 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return nil, err } - if cost > limit.Burst { - return nil, ErrInvalidCostOverLimit - } - start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(name.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(bucket.name.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucketKey(name, id)) + tat, err := l.source.Get(ctx, bucket.key) if err != nil { if errors.Is(err, ErrBucketNotFound) { // First request from this client. - d, err := l.initialize(ctx, limit, name, id, cost) + d, err := l.initialize(ctx, limit, bucket) if err != nil { return nil, err } @@ -217,20 +196,19 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return nil, err } - d := maybeSpend(l.clk, limit, tat, cost) + d := maybeSpend(l.clk, limit, tat, bucket.cost) if limit.isOverride { - // Calculate the current utilization of the override limit for the - // specified client id. + // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(name.String(), id).Set(utilization) + l.overrideUsageGauge.WithLabelValues(bucket.name.String(), bucket.key).Set(utilization) } if !d.Allowed { return d, nil } - err = l.source.Set(ctx, bucketKey(name, id), d.newTAT) + err = l.source.Set(ctx, bucket.key, d.newTAT) if err != nil { return nil, err } @@ -238,23 +216,23 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return d, nil } -// Refund attempts to refund the cost to the bucket identified by limit name and -// client id. The returned *Decision indicates whether the refund was successful -// or not. If the refund was successful, the cost of the request was added back -// to the bucket's capacity. If the refund is not possible (i.e., the bucket is -// already full or the refund amount is invalid), no cost is refunded. +// Refund attempts to refund all of the cost to the capacity of the specified +// bucket. The returned *Decision indicates whether the refund was successful +// and represents the current state of the bucket. The new bucket state is +// persisted to the underlying datastore, if applicable, before returning. If no +// bucket exists it will NOT be created. // // Note: The amount refunded cannot cause the bucket to exceed its maximum -// capacity. However, partial refunds are allowed and are considered successful. -// For instance, if a bucket has a maximum capacity of 10 and currently has 5 +// capacity. Partial refunds are allowed and are considered successful. For +// instance, if a bucket has a maximum capacity of 10 and currently has 5 // requests remaining, a refund request of 7 will result in the bucket reaching // its maximum capacity of 10, not 12. -func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) (*Decision, error) { - if cost <= 0 { +func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision, error) { + if bucket.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(name, id) + limit, err := l.getLimit(bucket.name, bucket.key) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -265,36 +243,37 @@ func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucketKey(name, id)) + tat, err := l.source.Get(ctx, bucket.key) if err != nil { return nil, err } - d := maybeRefund(l.clk, limit, tat, cost) + d := maybeRefund(l.clk, limit, tat, bucket.cost) if !d.Allowed { // The bucket is already at maximum capacity. return d, nil } - return d, l.source.Set(ctx, bucketKey(name, id), d.newTAT) - + return d, l.source.Set(ctx, bucket.key, d.newTAT) } -// Reset resets the specified bucket. -func (l *Limiter) Reset(ctx context.Context, name Name, id string) error { +// Reset resets the specified bucket to its maximum capacity. The new bucket +// state is persisted to the underlying datastore before returning. +func (l *Limiter) Reset(ctx context.Context, bucket Bucket) error { // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - return l.source.Delete(ctx, bucketKey(name, id)) + return l.source.Delete(ctx, bucket.key) } -// initialize creates a new bucket, specified by limit name and id, with the -// cost of the request factored into the initial state. -func (l *Limiter) initialize(ctx context.Context, rl limit, name Name, id string, cost int64) (*Decision, error) { - d := maybeSpend(l.clk, rl, l.clk.Now(), cost) +// initialize creates a new bucket and sets its TAT to now, which is equivalent +// to a full bucket. The new bucket state is persisted to the underlying +// datastore before returning. +func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCost) (*Decision, error) { + d := maybeSpend(l.clk, rl, l.clk.Now(), bucket.cost) // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - err := l.source.Set(ctx, bucketKey(name, id), d.newTAT) + err := l.source.Set(ctx, bucket.key, d.newTAT) if err != nil { return nil, err } @@ -302,24 +281,24 @@ func (l *Limiter) initialize(ctx context.Context, rl limit, name Name, id string } -// GetLimit returns the limit for the specified by name and id, name is -// required, id is optional. If id is left unspecified, the default limit for -// the limit specified by name is returned. If no default limit exists for the -// specified name, ErrLimitDisabled is returned. -func (l *Limiter) getLimit(name Name, id string) (limit, error) { +// GetLimit returns the limit for the specified by name and bucketKey, name is +// required, bucketKey is optional. If bucketKey is left unspecified, the +// default limit for the limit specified by name is returned. If no default +// limit exists for the specified name, errLimitDisabled is returned. +func (l *Limiter) getLimit(name Name, bucketKey string) (limit, error) { if !name.isValid() { // This should never happen. Callers should only be specifying the limit // Name enums defined in this package. return limit{}, fmt.Errorf("specified name enum %q, is invalid", name) } - if id != "" { + if bucketKey != "" { // Check for override. - ol, ok := l.overrides[bucketKey(name, id)] + ol, ok := l.overrides[bucketKey] if ok { return ol, nil } } - dl, ok := l.defaults[nameToEnumString(name)] + dl, ok := l.defaults[name.EnumString()] if ok { return dl, nil } diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index bb58dfff41c..40dd7fa9e26 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -53,21 +53,13 @@ func Test_Limiter_WithBadLimitsPath(t *testing.T) { test.AssertError(t, err, "should error") } -func Test_Limiter_getLimitNoExist(t *testing.T) { - t.Parallel() - l, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/working_default.yml", "", metrics.NoopRegisterer) - test.AssertNotError(t, err, "should not error") - _, err = l.getLimit(Name(9999), "") - test.AssertError(t, err, "should error") - -} - func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { t.Parallel() testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - _, err := l.Check(testCtx, Name(9999), testIP, 1) + bucket := Bucket{name: Name(9999), key: testIP} + _, err := l.Check(testCtx, bucket.WithCost(1)) test.AssertError(t, err, "should error") }) } @@ -81,25 +73,29 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Verify our overrideUsageGauge is being set correctly. 0.0 == 0% of // the bucket has been consumed. test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{ - "limit": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 0) + "limit": NewRegistrationsPerIPAddress.String(), + "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 0) + + overridenBucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(tenZeroZeroTwo)) + test.AssertNotError(t, err, "should not error") // Attempt to check a spend of 41 requests (a cost > the limit burst // capacity), this should fail with a specific error. - _, err := l.Check(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41) + _, err = l.Check(testCtx, overridenBucket.WithCost(41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend 41 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41) + _, err = l.Spend(testCtx, overridenBucket.WithCost(41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 40 requests, this should succeed. - d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40) + d, err := l.Spend(testCtx, overridenBucket.WithCost(40)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -108,7 +104,8 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Verify our overrideUsageGauge is being set correctly. 1.0 == 100% of // the bucket has been consumed. test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{ - "limit_name": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 1.0) + "limit_name": NewRegistrationsPerIPAddress.String(), + "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 1.0) // Verify our RetryIn is correct. 1 second == 1000 milliseconds and // 1000/40 = 25 milliseconds per request. @@ -118,7 +115,7 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -129,21 +126,21 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Quickly spend 40 requests in a row. for i := 0; i < 40; i++ { - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(39-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Reset between tests. - err = l.Reset(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo) + err = l.Reset(testCtx, overridenBucket) test.AssertNotError(t, err, "should not error") }) } @@ -154,9 +151,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - // Check on an empty bucket should initialize it and return the - // theoretical next state of that bucket if the cost were spent. - d, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + + // Check on an empty bucket should return the theoretical next state + // of that bucket if the cost were spent. + d, err := l.Check(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -167,7 +167,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 20 remaining. - d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + d, err = l.Check(testCtx, bucket.WithCost(0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(20)) @@ -175,13 +175,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { test.AssertEquals(t, d.RetryIn, time.Duration(0)) // Reset our bucket. - err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP) + err = l.Reset(testCtx, bucket) test.AssertNotError(t, err, "should not error") - // Similar to above, but we'll use Spend() instead of Check() to - // initialize the bucket. Spend should return the same result as - // Check. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + // Similar to above, but we'll use Spend() to actually initialize + // the bucket. Spend should return the same result as Check. + d, err = l.Spend(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -192,7 +191,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 19 remaining. - d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + d, err = l.Check(testCtx, bucket.WithCost(0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -209,20 +208,23 @@ func Test_Limiter_RefundAndSpendCostErr(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { + bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + // Spend a cost of 0, which should fail. - _, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + _, err = l.Spend(testCtx, bucket.WithCost(0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Spend a negative cost, which should fail. - _, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, -1) + _, err = l.Spend(testCtx, bucket.WithCost(-1)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a cost of 0, which should fail. - _, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + _, err = l.Refund(testCtx, bucket.WithCost(0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a negative cost, which should fail. - _, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, -1) + _, err = l.Refund(testCtx, bucket.WithCost(-1)) test.AssertErrorIs(t, err, ErrInvalidCost) }) } @@ -233,7 +235,10 @@ func Test_Limiter_CheckWithBadCost(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - _, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, -1) + bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + + _, err = l.Check(testCtx, bucket.WithCost(-1)) test.AssertErrorIs(t, err, ErrInvalidCostForCheck) }) } @@ -244,20 +249,23 @@ func Test_Limiter_DefaultLimits(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { + bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + // Attempt to spend 21 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 21) + _, err = l.Spend(testCtx, bucket.WithCost(21)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20) + d, err := l.Spend(testCtx, bucket.WithCost(20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -271,7 +279,7 @@ func Test_Limiter_DefaultLimits(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -282,14 +290,14 @@ func Test_Limiter_DefaultLimits(t *testing.T) { // Quickly spend 20 requests in a row. for i := 0; i < 20; i++ { - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -303,30 +311,33 @@ func Test_Limiter_RefundAndReset(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { + bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20) + d, err := l.Spend(testCtx, bucket.WithCost(20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Refund 10 requests. - d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 10) + d, err = l.Refund(testCtx, bucket.WithCost(10)) test.AssertNotError(t, err, "should not error") test.AssertEquals(t, d.Remaining, int64(10)) // Spend 10 requests, this should succeed. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 10) + d, err = l.Spend(testCtx, bucket.WithCost(10)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) - err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP) + err = l.Reset(testCtx, bucket) test.AssertNotError(t, err, "should not error") // Attempt to spend 20 more requests, this should succeed. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20) + d, err = l.Spend(testCtx, bucket.WithCost(20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -336,7 +347,7 @@ func Test_Limiter_RefundAndReset(t *testing.T) { clk.Add(d.ResetIn) // Refund 1 requests above our limit, this should fail. - d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Refund(testCtx, bucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(20)) diff --git a/ratelimits/names.go b/ratelimits/names.go index b2663982b7e..28eef008f5a 100644 --- a/ratelimits/names.go +++ b/ratelimits/names.go @@ -13,8 +13,10 @@ import ( // limit names as strings and to provide a type-safe way to refer to rate // limits. // -// IMPORTANT: If you add a new limit Name, you MUST add it to the 'nameToString' -// mapping and idValidForName function below. +// IMPORTANT: If you add a new limit Name, you MUST add: +// - it to the nameToString mapping, +// - an entry for it in the validateIdForName(), and +// - provide a Bucket constructor in bucket.go. type Name int const ( @@ -39,18 +41,25 @@ const ( NewOrdersPerAccount // FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId - // is the registration id of the account. + // is the ACME registration Id of the account. FailedAuthorizationsPerAccount - // CertificatesPerDomainPerAccount uses bucket key 'enum:regId:domain', - // where name is the a name in a certificate issued to the account matching - // regId. + // CertificatesPerDomain uses bucket key 'enum:domain', where domain is a + // domain name in the issued certificate. + CertificatesPerDomain + + // CertificatesPerDomainPerAccount uses bucket key 'enum:regId', where regId + // is the ACME registration Id of the account. This limit is never checked + // or enforced by the Limiter. Instead, it is used to override the + // CertificatesPerDomain limit for the specified account. CertificatesPerDomainPerAccount - // CertificatesPerFQDNSetPerAccount uses bucket key 'enum:regId:fqdnSet', - // where nameSet is a set of names in a certificate issued to the account - // matching regId. - CertificatesPerFQDNSetPerAccount + // CertificatesPerFQDNSet uses bucket key 'enum:fqdnSet', where fqdnSet is a + // hashed set of unique eTLD+1 domain names in the issued certificate. + // + // Note: When this referenced in an overrides file, the fqdnSet MUST be + // passed as a comma-separated list of domain names. + CertificatesPerFQDNSet ) // isValid returns true if the Name is a valid rate limit name. @@ -67,15 +76,24 @@ func (n Name) String() string { return nameToString[n] } +// EnumString returns the string representation of the Name enumeration. +func (n Name) EnumString() string { + if !n.isValid() { + return nameToString[Unknown] + } + return strconv.Itoa(int(n)) +} + // nameToString is a map of Name values to string names. var nameToString = map[Name]string{ - Unknown: "Unknown", - NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress", - NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range", - NewOrdersPerAccount: "NewOrdersPerAccount", - FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount", - CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount", - CertificatesPerFQDNSetPerAccount: "CertificatesPerFQDNSetPerAccount", + Unknown: "Unknown", + NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress", + NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range", + NewOrdersPerAccount: "NewOrdersPerAccount", + FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount", + CertificatesPerDomain: "CertificatesPerDomain", + CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount", + CertificatesPerFQDNSet: "CertificatesPerFQDNSet", } // validIPAddress validates that the provided string is a valid IP address. @@ -114,48 +132,29 @@ func validateRegId(id string) error { return nil } -// validateRegIdDomain validates that the provided string is formatted -// 'regId:domain', where regId is an ACME registration Id and domain is a single -// domain name. -func validateRegIdDomain(id string) error { - parts := strings.SplitN(id, ":", 2) - if len(parts) != 2 { - return fmt.Errorf( - "invalid regId:domain, %q must be formatted 'regId:domain'", id) - } - if validateRegId(parts[0]) != nil { - return fmt.Errorf( - "invalid regId, %q must be formatted 'regId:domain'", id) - } - if policy.ValidDomain(parts[1]) != nil { - return fmt.Errorf( - "invalid domain, %q must be formatted 'regId:domain'", id) +// validateDomain validates that the provided string is formatted 'domain', +// where domain is a domain name. +func validateDomain(id string) error { + err := policy.ValidDomain(id) + if err != nil { + return fmt.Errorf("invalid domain, %q must be formatted 'domain'", id) } return nil } -// validateRegIdFQDNSet validates that the provided string is formatted -// 'regId:fqdnSet', where regId is an ACME registration Id and fqdnSet is a -// comma-separated list of domain names. -func validateRegIdFQDNSet(id string) error { - parts := strings.SplitN(id, ":", 2) - if len(parts) != 2 { - return fmt.Errorf( - "invalid regId:fqdnSet, %q must be formatted 'regId:fqdnSet'", id) - } - if validateRegId(parts[0]) != nil { - return fmt.Errorf( - "invalid regId, %q must be formatted 'regId:fqdnSet'", id) - } - domains := strings.Split(parts[1], ",") +// validateFQDNSet validates that the provided string is formatted 'fqdnSet', +// where fqdnSet is a comma-separated list of domain names. +func validateFQDNSet(id string) error { + domains := strings.Split(id, ",") if len(domains) == 0 { return fmt.Errorf( - "invalid fqdnSet, %q must be formatted 'regId:fqdnSet'", id) + "invalid fqdnSet, %q must be formatted 'fqdnSet'", id) } for _, domain := range domains { - if policy.ValidDomain(domain) != nil { + err := policy.ValidDomain(domain) + if err != nil { return fmt.Errorf( - "invalid domain, %q must be formatted 'regId:fqdnSet'", id) + "invalid domain, %q must be formatted 'fqdnSet'", id) } } return nil @@ -171,17 +170,17 @@ func validateIdForName(name Name, id string) error { // 'enum:ipv6rangeCIDR' return validIPv6RangeCIDR(id) - case NewOrdersPerAccount, FailedAuthorizationsPerAccount: + case NewOrdersPerAccount, FailedAuthorizationsPerAccount, CertificatesPerDomainPerAccount: // 'enum:regId' return validateRegId(id) - case CertificatesPerDomainPerAccount: - // 'enum:regId:domain' - return validateRegIdDomain(id) + case CertificatesPerDomain: + // 'enum:domain' + return validateDomain(id) - case CertificatesPerFQDNSetPerAccount: - // 'enum:regId:fqdnSet' - return validateRegIdFQDNSet(id) + case CertificatesPerFQDNSet: + // 'enum:fqdnSet' + return validateFQDNSet(id) case Unknown: fallthrough @@ -209,14 +208,3 @@ var limitNames = func() []string { } return names }() - -// nameToEnumString converts the integer value of the Name enumeration to its -// string representation. -func nameToEnumString(s Name) string { - return strconv.Itoa(int(s)) -} - -// bucketKey returns the key used to store a rate limit bucket. -func bucketKey(name Name, id string) string { - return nameToEnumString(name) + ":" + id -} diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index bdd7a5fb1ea..5664058fdf0 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -68,9 +68,8 @@ func resultForError(err error) string { return "failed" } -// Set stores the TAT at the specified bucketKey ('name:id'). It returns an -// error if the operation failed and nil otherwise. If the bucketKey does not -// exist, it will be created. +// Set stores the TAT at the specified bucketKey. It returns an error if the +// operation failed and nil otherwise. func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) error { start := r.clk.Now() @@ -84,9 +83,9 @@ func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) return nil } -// Get retrieves the TAT at the specified bucketKey ('name:id'). It returns the -// TAT and nil if the operation succeeded, or an error if the operation failed. -// If the bucketKey does not exist, it returns ErrBucketNotFound. +// Get retrieves the TAT at the specified bucketKey. An error is returned if the +// operation failed and nil otherwise. If the bucketKey does not exist, +// ErrBucketNotFound is returned. func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) { start := r.clk.Now() diff --git a/ratelimits/testdata/working_override_regid_domain.yml b/ratelimits/testdata/working_override_regid_domain.yml index bd4d3eb67a0..9c745290488 100644 --- a/ratelimits/testdata/working_override_regid_domain.yml +++ b/ratelimits/testdata/working_override_regid_domain.yml @@ -1,4 +1,4 @@ -CertificatesPerDomainPerAccount:12345678:example.com: +CertificatesPerDomain:example.com: burst: 40 count: 40 period: 1s diff --git a/ratelimits/testdata/working_overrides_regid_fqdnset.yml b/ratelimits/testdata/working_overrides_regid_fqdnset.yml index 093ac976e7f..352f01694f0 100644 --- a/ratelimits/testdata/working_overrides_regid_fqdnset.yml +++ b/ratelimits/testdata/working_overrides_regid_fqdnset.yml @@ -1,12 +1,12 @@ -CertificatesPerFQDNSetPerAccount:12345678:example.com: +CertificatesPerFQDNSet:example.com: burst: 40 count: 40 period: 1s -CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net: +CertificatesPerFQDNSet:example.com,example.net: burst: 50 count: 50 period: 2s -CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org: +CertificatesPerFQDNSet:example.com,example.net,example.org: burst: 60 count: 60 period: 3s diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go new file mode 100644 index 00000000000..dd5a1167eca --- /dev/null +++ b/ratelimits/utilities.go @@ -0,0 +1,33 @@ +package ratelimits + +import ( + "strings" + + "github.com/letsencrypt/boulder/core" + "github.com/weppos/publicsuffix-go/publicsuffix" +) + +// joinWithColon joins the provided args with a colon. +func joinWithColon(args ...string) string { + return strings.Join(args, ":") +} + +// DomainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's +// for the purpose of rate limiting. It also de-duplicates the output +// domains. Exact public suffix matches are included. +func DomainsForRateLimiting(names []string) []string { + var domains []string + for _, name := range names { + domain, err := publicsuffix.Domain(name) + if err != nil { + // The only possible errors are: + // (1) publicsuffix.Domain is giving garbage values + // (2) the public suffix is the domain itself + // We assume 2 and include the original name in the result. + domains = append(domains, name) + } else { + domains = append(domains, domain) + } + } + return core.UniqueLowerNames(domains) +} diff --git a/ratelimits/utilities_test.go b/ratelimits/utilities_test.go new file mode 100644 index 00000000000..9c68d3a6e89 --- /dev/null +++ b/ratelimits/utilities_test.go @@ -0,0 +1,27 @@ +package ratelimits + +import ( + "testing" + + "github.com/letsencrypt/boulder/test" +) + +func TestDomainsForRateLimiting(t *testing.T) { + domains := DomainsForRateLimiting([]string{}) + test.AssertEquals(t, len(domains), 0) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) + test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) + test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) + test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) +} diff --git a/sa/rate_limits_test.go b/sa/rate_limits_test.go index 3cce97ef4a8..0ec081b79d8 100644 --- a/sa/rate_limits_test.go +++ b/sa/rate_limits_test.go @@ -8,7 +8,6 @@ import ( sapb "github.com/letsencrypt/boulder/sa/proto" "github.com/letsencrypt/boulder/test" - "google.golang.org/protobuf/types/known/timestamppb" ) func TestCertsPerNameRateLimitTable(t *testing.T) { @@ -80,9 +79,7 @@ func TestCertsPerNameRateLimitTable(t *testing.T) { t.Run(tc.caseName, func(t *testing.T) { timeRange := &sapb.Range{ EarliestNS: aprilFirst.Add(-1 * time.Second).UnixNano(), - Earliest: timestamppb.New(aprilFirst.Add(-1 * time.Second)), LatestNS: aprilFirst.Add(aWeek).UnixNano(), - Latest: timestamppb.New(aprilFirst.Add(aWeek)), } count, earliest, err := sa.countCertificatesByName(ctx, sa.dbMap, tc.domainName, timeRange) if err != nil { @@ -111,9 +108,7 @@ func TestNewOrdersRateLimitTable(t *testing.T) { AccountID: 1, Range: &sapb.Range{ EarliestNS: start.UnixNano(), - Earliest: timestamppb.New(start), LatestNS: start.Add(time.Minute * 10).UnixNano(), - Latest: timestamppb.New(start.Add(time.Minute * 10)), }, } diff --git a/wfe2/wfe.go b/wfe2/wfe.go index f47b2f27d1b..d7ae9a29306 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -638,7 +638,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP wfe.log.Warningf("checking %s rate limit: %s", limit, err) } - decision, err := wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) + bucket, err := ratelimits.NewRegistrationsPerIPAddressBucket(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPAddress) + return + } + + decision, err := wfe.limiter.Spend(ctx, bucket.WithCost(1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -649,11 +655,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP return } - // See docs for ratelimits.NewRegistrationsPerIPv6Range for more information - // on the selection of a /48 block size for IPv6 ranges. - ipMask := net.CIDRMask(48, 128) - ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} - _, err = wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPv6Range, ipNet.String(), 1) + bucket, err = ratelimits.NewRegistrationsPerIPv6RangeBucket(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPv6Range) + return + } + + _, err = wfe.limiter.Spend(ctx, bucket.WithCost(1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } @@ -678,7 +686,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I wfe.log.Warningf("refunding %s rate limit: %s", limit, err) } - _, err := wfe.limiter.Refund(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) + bucket, err := ratelimits.NewRegistrationsPerIPAddressBucket(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPAddress) + return + } + + _, err = wfe.limiter.Refund(ctx, bucket.WithCost(1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -688,11 +702,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I return } - // See docs for ratelimits.NewRegistrationsPerIPv6Range for more information - // on the selection of a /48 block size for IPv6 ranges. - ipMask := net.CIDRMask(48, 128) - ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} - _, err = wfe.limiter.Refund(ctx, ratelimits.NewRegistrationsPerIPv6Range, ipNet.String(), 1) + bucket, err = ratelimits.NewRegistrationsPerIPv6RangeBucket(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPv6Range) + return + } + + _, err = wfe.limiter.Refund(ctx, bucket.WithCost(1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } From 3ce97166292621ab3178579606d7a69dd7456aed Mon Sep 17 00:00:00 2001 From: Samantha Date: Wed, 25 Oct 2023 12:33:29 -0400 Subject: [PATCH 2/8] Revert accidental removal. --- ratelimits/limiter.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 6e4b074199e..51f8119dcc5 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -136,6 +136,10 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } + if bucket.cost > limit.Burst { + return nil, ErrInvalidCostOverLimit + } + // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) @@ -171,6 +175,10 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } + if bucket.cost > limit.Burst { + return nil, ErrInvalidCostOverLimit + } + start := l.clk.Now() status := Denied defer func() { From af12f17ed993b3da413743901a016fd1686b02ca Mon Sep 17 00:00:00 2001 From: Samantha Date: Wed, 25 Oct 2023 17:16:07 -0400 Subject: [PATCH 3/8] Fix merge issue. --- sa/rate_limits_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sa/rate_limits_test.go b/sa/rate_limits_test.go index 0ec081b79d8..3cce97ef4a8 100644 --- a/sa/rate_limits_test.go +++ b/sa/rate_limits_test.go @@ -8,6 +8,7 @@ import ( sapb "github.com/letsencrypt/boulder/sa/proto" "github.com/letsencrypt/boulder/test" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestCertsPerNameRateLimitTable(t *testing.T) { @@ -79,7 +80,9 @@ func TestCertsPerNameRateLimitTable(t *testing.T) { t.Run(tc.caseName, func(t *testing.T) { timeRange := &sapb.Range{ EarliestNS: aprilFirst.Add(-1 * time.Second).UnixNano(), + Earliest: timestamppb.New(aprilFirst.Add(-1 * time.Second)), LatestNS: aprilFirst.Add(aWeek).UnixNano(), + Latest: timestamppb.New(aprilFirst.Add(aWeek)), } count, earliest, err := sa.countCertificatesByName(ctx, sa.dbMap, tc.domainName, timeRange) if err != nil { @@ -108,7 +111,9 @@ func TestNewOrdersRateLimitTable(t *testing.T) { AccountID: 1, Range: &sapb.Range{ EarliestNS: start.UnixNano(), + Earliest: timestamppb.New(start), LatestNS: start.Add(time.Minute * 10).UnixNano(), + Latest: timestamppb.New(start.Add(time.Minute * 10)), }, } From 376f822e4e8d7b31f0b9ef0c1b2f3cddec9279d7 Mon Sep 17 00:00:00 2001 From: Samantha Date: Thu, 26 Oct 2023 15:24:05 -0400 Subject: [PATCH 4/8] Addressing second round comments. --- ratelimits/limiter.go | 2 +- ratelimits/limiter_test.go | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 51f8119dcc5..49d76f9059d 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -289,7 +289,7 @@ func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCos } -// GetLimit returns the limit for the specified by name and bucketKey, name is +// getLimit returns the limit for the specified by name and bucketKey, name is // required, bucketKey is optional. If bucketKey is left unspecified, the // default limit for the limit specified by name is returned. If no default // limit exists for the specified name, errLimitDisabled is returned. diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index 40dd7fa9e26..3b33215db0d 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -76,26 +76,26 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { "limit": NewRegistrationsPerIPAddress.String(), "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 0) - overridenBucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(tenZeroZeroTwo)) + overriddenBucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(tenZeroZeroTwo)) test.AssertNotError(t, err, "should not error") // Attempt to check a spend of 41 requests (a cost > the limit burst // capacity), this should fail with a specific error. - _, err = l.Check(testCtx, overridenBucket.WithCost(41)) + _, err = l.Check(testCtx, overriddenBucket.WithCost(41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend 41 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, overridenBucket.WithCost(41)) + _, err = l.Spend(testCtx, overriddenBucket.WithCost(41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 40 requests, this should succeed. - d, err := l.Spend(testCtx, overridenBucket.WithCost(40)) + d, err := l.Spend(testCtx, overriddenBucket.WithCost(40)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) + d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -115,7 +115,7 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) + d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -126,21 +126,21 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Quickly spend 40 requests in a row. for i := 0; i < 40; i++ { - d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) + d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(39-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, overridenBucket.WithCost(1)) + d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Reset between tests. - err = l.Reset(testCtx, overridenBucket) + err = l.Reset(testCtx, overriddenBucket) test.AssertNotError(t, err, "should not error") }) } From 8276711f032e7f32aaa48aaf74ecec84922a24ce Mon Sep 17 00:00:00 2001 From: Samantha Date: Fri, 27 Oct 2023 12:26:21 -0400 Subject: [PATCH 5/8] Fix typo. --- ratelimits/limiter.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 49d76f9059d..4471a9ee5c7 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -157,11 +157,11 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, } // Spend attempts to deduct the cost from the provided bucket's capacity. The -// returned *Decision The returned *Decision indicates whether the capacity -// existed to satisfy the cost and represents the current state of the bucket. -// If no bucket exists it WILL be created WITH the cost factored into its -// initial state. The new bucket state is persisted to the underlying datastore, -// if applicable, before returning. +// returned *Decision indicates whether the capacity existed to satisfy the cost +// and represents the current state of the bucket. If no bucket exists it WILL +// be created WITH the cost factored into its initial state. The new bucket +// state is persisted to the underlying datastore, if applicable, before +// returning. func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, error) { if bucket.cost <= 0 { return nil, ErrInvalidCost From 81df1d7431cf36a555229eab22aead164dff6211 Mon Sep 17 00:00:00 2001 From: Samantha Date: Fri, 27 Oct 2023 14:50:44 -0400 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Aaron Gable --- ratelimits/names.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ratelimits/names.go b/ratelimits/names.go index 28eef008f5a..b0d581e7657 100644 --- a/ratelimits/names.go +++ b/ratelimits/names.go @@ -57,7 +57,7 @@ const ( // CertificatesPerFQDNSet uses bucket key 'enum:fqdnSet', where fqdnSet is a // hashed set of unique eTLD+1 domain names in the issued certificate. // - // Note: When this referenced in an overrides file, the fqdnSet MUST be + // Note: When this is referenced in an overrides file, the fqdnSet MUST be // passed as a comma-separated list of domain names. CertificatesPerFQDNSet ) From 45d18494c5d568631a29108e41661369934f1d44 Mon Sep 17 00:00:00 2001 From: Samantha Date: Mon, 6 Nov 2023 17:28:12 -0500 Subject: [PATCH 7/8] Addressing comments. --- ratelimits/bucket.go | 52 +++++++++++----------- ratelimits/limiter.go | 84 +++++++++++++++++++++--------------- ratelimits/limiter_test.go | 76 ++++++++++++++++---------------- ratelimits/utilities.go | 23 ---------- ratelimits/utilities_test.go | 27 ------------ wfe2/wfe.go | 16 +++---- 6 files changed, 119 insertions(+), 159 deletions(-) delete mode 100644 ratelimits/utilities_test.go diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go index a8901e10ed3..501d1fd2c44 100644 --- a/ratelimits/bucket.go +++ b/ratelimits/bucket.go @@ -5,52 +5,48 @@ import ( "net" ) -// Bucket identifies a specific subscriber rate limit bucket to the Limiter. -type Bucket struct { - name Name - key string -} - -// BucketWithCost is a bucket with an associated cost. -type BucketWithCost struct { - Bucket - cost int64 -} +// BucketId should only be created using the New*BucketId functions. It is used +// by the Limiter to look up the bucket and limit overrides for a specific +// subscriber and limit. +type BucketId struct { + // limit is the name of the associated rate limit. It is used for looking up + // default limits. + limit Name -// WithCost returns a BucketWithCost for the provided cost. -func (b Bucket) WithCost(cost int64) BucketWithCost { - return BucketWithCost{b, cost} + // bucketKey is the limit Name enum (e.g. "1") concatenated with the + // subscriber identifier specific to the associate limit Name type. + bucketKey string } -// NewRegistrationsPerIPAddressBucket returns a Bucket for the provided IP +// NewRegistrationsPerIPAddressBucketId returns a BucketId for the provided IP // address. -func NewRegistrationsPerIPAddressBucket(ip net.IP) (Bucket, error) { +func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) { id := ip.String() err := validateIdForName(NewRegistrationsPerIPAddress, id) if err != nil { - return Bucket{}, err + return BucketId{}, err } - return Bucket{ - name: NewRegistrationsPerIPAddress, - key: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), + return BucketId{ + limit: NewRegistrationsPerIPAddress, + bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), }, nil } -// NewRegistrationsPerIPv6RangeBucket returns a Bucket for the /48 IPv6 range -// containing the provided IPv6 address. -func NewRegistrationsPerIPv6RangeBucket(ip net.IP) (Bucket, error) { +// NewRegistrationsPerIPv6RangeBucketId returns a BucketId for the /48 IPv6 +// range containing the provided IPv6 address. +func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) { if ip.To4() != nil { - return Bucket{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) + return BucketId{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) } ipMask := net.CIDRMask(48, 128) ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} id := ipNet.String() err := validateIdForName(NewRegistrationsPerIPv6Range, id) if err != nil { - return Bucket{}, err + return BucketId{}, err } - return Bucket{ - name: NewRegistrationsPerIPv6Range, - key: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), + return BucketId{ + limit: NewRegistrationsPerIPv6Range, + bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), }, nil } diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 4471a9ee5c7..eab087bf5df 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -95,6 +95,20 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat return limiter, nil } +// Transaction is a cost to be spent or refunded from a specific BucketId. +type Transaction struct { + BucketId + cost int64 +} + +// NewTransaction creates a new Transaction for the provided BucketId and cost. +func NewTransaction(b BucketId, cost int64) Transaction { + return Transaction{ + BucketId: b, + cost: cost, + } +} + type Decision struct { // Allowed is true if the bucket possessed enough capacity to allow the // request given the cost. @@ -123,12 +137,12 @@ type Decision struct { // satisfy the cost and represents the hypothetical state of the bucket IF the // cost WERE to be deducted. If no bucket exists it will NOT be created. No // state is persisted to the underlying datastore. -func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, error) { - if bucket.cost < 0 { +func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost < 0 { return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(bucket.name, bucket.key) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -136,14 +150,14 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } - if bucket.cost > limit.Burst { + if txn.cost > limit.Burst { return nil, ErrInvalidCostOverLimit } // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucket.key) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { if !errors.Is(err, ErrBucketNotFound) { return nil, err @@ -151,9 +165,9 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, // First request from this client. No need to initialize the bucket // because this is a check, not a spend. A TAT of "now" is equivalent to // a full bucket. - return maybeSpend(l.clk, limit, l.clk.Now(), bucket.cost), nil + return maybeSpend(l.clk, limit, l.clk.Now(), txn.cost), nil } - return maybeSpend(l.clk, limit, tat, bucket.cost), nil + return maybeSpend(l.clk, limit, tat, txn.cost), nil } // Spend attempts to deduct the cost from the provided bucket's capacity. The @@ -162,12 +176,12 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, // be created WITH the cost factored into its initial state. The new bucket // state is persisted to the underlying datastore, if applicable, before // returning. -func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, error) { - if bucket.cost <= 0 { +func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(bucket.name, bucket.key) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -175,24 +189,24 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } - if bucket.cost > limit.Burst { + if txn.cost > limit.Burst { return nil, ErrInvalidCostOverLimit } start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(bucket.name.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucket.key) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { if errors.Is(err, ErrBucketNotFound) { // First request from this client. - d, err := l.initialize(ctx, limit, bucket) + d, err := l.initialize(ctx, limit, txn) if err != nil { return nil, err } @@ -204,19 +218,19 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } - d := maybeSpend(l.clk, limit, tat, bucket.cost) + d := maybeSpend(l.clk, limit, tat, txn.cost) if limit.isOverride { // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(bucket.name.String(), bucket.key).Set(utilization) + l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization) } if !d.Allowed { return d, nil } - err = l.source.Set(ctx, bucket.key, d.newTAT) + err = l.source.Set(ctx, txn.bucketKey, d.newTAT) if err != nil { return nil, err } @@ -235,12 +249,12 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, // instance, if a bucket has a maximum capacity of 10 and currently has 5 // requests remaining, a refund request of 7 will result in the bucket reaching // its maximum capacity of 10, not 12. -func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision, error) { - if bucket.cost <= 0 { +func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(bucket.name, bucket.key) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -251,37 +265,37 @@ func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision, // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucket.key) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { return nil, err } - d := maybeRefund(l.clk, limit, tat, bucket.cost) + d := maybeRefund(l.clk, limit, tat, txn.cost) if !d.Allowed { // The bucket is already at maximum capacity. return d, nil } - return d, l.source.Set(ctx, bucket.key, d.newTAT) + return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) } // Reset resets the specified bucket to its maximum capacity. The new bucket // state is persisted to the underlying datastore before returning. -func (l *Limiter) Reset(ctx context.Context, bucket Bucket) error { +func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error { // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - return l.source.Delete(ctx, bucket.key) + return l.source.Delete(ctx, bucketId.bucketKey) } // initialize creates a new bucket and sets its TAT to now, which is equivalent // to a full bucket. The new bucket state is persisted to the underlying // datastore before returning. -func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCost) (*Decision, error) { - d := maybeSpend(l.clk, rl, l.clk.Now(), bucket.cost) +func (l *Limiter) initialize(ctx context.Context, rl limit, txn Transaction) (*Decision, error) { + d := maybeSpend(l.clk, rl, l.clk.Now(), txn.cost) // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - err := l.source.Set(ctx, bucket.key, d.newTAT) + err := l.source.Set(ctx, txn.bucketKey, d.newTAT) if err != nil { return nil, err } @@ -289,19 +303,19 @@ func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCos } -// getLimit returns the limit for the specified by name and bucketKey, name is -// required, bucketKey is optional. If bucketKey is left unspecified, the -// default limit for the limit specified by name is returned. If no default -// limit exists for the specified name, errLimitDisabled is returned. -func (l *Limiter) getLimit(name Name, bucketKey string) (limit, error) { +// getLimit returns the limit for the specified by name and id, name is +// required, id is optional. If id is left unspecified, the default limit for +// the limit specified by name is returned. If no default limit exists for the +// specified name, errLimitDisabled is returned. +func (l *Limiter) getLimit(name Name, id string) (limit, error) { if !name.isValid() { // This should never happen. Callers should only be specifying the limit // Name enums defined in this package. return limit{}, fmt.Errorf("specified name enum %q, is invalid", name) } - if bucketKey != "" { + if id != "" { // Check for override. - ol, ok := l.overrides[bucketKey] + ol, ok := l.overrides[id] if ok { return ol, nil } diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index 3b33215db0d..7f108810279 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -58,8 +58,8 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket := Bucket{name: Name(9999), key: testIP} - _, err := l.Check(testCtx, bucket.WithCost(1)) + bucketId := BucketId{limit: Name(9999), bucketKey: testIP} + _, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertError(t, err, "should error") }) } @@ -76,26 +76,26 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { "limit": NewRegistrationsPerIPAddress.String(), "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 0) - overriddenBucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(tenZeroZeroTwo)) + overriddenBucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(tenZeroZeroTwo)) test.AssertNotError(t, err, "should not error") // Attempt to check a spend of 41 requests (a cost > the limit burst // capacity), this should fail with a specific error. - _, err = l.Check(testCtx, overriddenBucket.WithCost(41)) + _, err = l.Check(testCtx, NewTransaction(overriddenBucketId, 41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend 41 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, overriddenBucket.WithCost(41)) + _, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 40 requests, this should succeed. - d, err := l.Spend(testCtx, overriddenBucket.WithCost(40)) + d, err := l.Spend(testCtx, NewTransaction(overriddenBucketId, 40)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -115,7 +115,7 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -126,21 +126,21 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Quickly spend 40 requests in a row. for i := 0; i < 40; i++ { - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(39-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Reset between tests. - err = l.Reset(testCtx, overriddenBucket) + err = l.Reset(testCtx, overriddenBucketId) test.AssertNotError(t, err, "should not error") }) } @@ -151,12 +151,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Check on an empty bucket should return the theoretical next state // of that bucket if the cost were spent. - d, err := l.Check(testCtx, bucket.WithCost(1)) + d, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -167,7 +167,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 20 remaining. - d, err = l.Check(testCtx, bucket.WithCost(0)) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(20)) @@ -175,12 +175,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { test.AssertEquals(t, d.RetryIn, time.Duration(0)) // Reset our bucket. - err = l.Reset(testCtx, bucket) + err = l.Reset(testCtx, bucketId) test.AssertNotError(t, err, "should not error") // Similar to above, but we'll use Spend() to actually initialize // the bucket. Spend should return the same result as Check. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -191,7 +191,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 19 remaining. - d, err = l.Check(testCtx, bucket.WithCost(0)) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -208,23 +208,23 @@ func Test_Limiter_RefundAndSpendCostErr(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Spend a cost of 0, which should fail. - _, err = l.Spend(testCtx, bucket.WithCost(0)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Spend a negative cost, which should fail. - _, err = l.Spend(testCtx, bucket.WithCost(-1)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a cost of 0, which should fail. - _, err = l.Refund(testCtx, bucket.WithCost(0)) + _, err = l.Refund(testCtx, NewTransaction(bucketId, 0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a negative cost, which should fail. - _, err = l.Refund(testCtx, bucket.WithCost(-1)) + _, err = l.Refund(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCost) }) } @@ -235,10 +235,10 @@ func Test_Limiter_CheckWithBadCost(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") - _, err = l.Check(testCtx, bucket.WithCost(-1)) + _, err = l.Check(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCostForCheck) }) } @@ -249,23 +249,23 @@ func Test_Limiter_DefaultLimits(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Attempt to spend 21 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, bucket.WithCost(21)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 21)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, bucket.WithCost(20)) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -279,7 +279,7 @@ func Test_Limiter_DefaultLimits(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -290,14 +290,14 @@ func Test_Limiter_DefaultLimits(t *testing.T) { // Quickly spend 20 requests in a row. for i := 0; i < 20; i++ { - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -311,33 +311,33 @@ func Test_Limiter_RefundAndReset(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, bucket.WithCost(20)) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Refund 10 requests. - d, err = l.Refund(testCtx, bucket.WithCost(10)) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 10)) test.AssertNotError(t, err, "should not error") test.AssertEquals(t, d.Remaining, int64(10)) // Spend 10 requests, this should succeed. - d, err = l.Spend(testCtx, bucket.WithCost(10)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 10)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) - err = l.Reset(testCtx, bucket) + err = l.Reset(testCtx, bucketId) test.AssertNotError(t, err, "should not error") // Attempt to spend 20 more requests, this should succeed. - d, err = l.Spend(testCtx, bucket.WithCost(20)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -347,7 +347,7 @@ func Test_Limiter_RefundAndReset(t *testing.T) { clk.Add(d.ResetIn) // Refund 1 requests above our limit, this should fail. - d, err = l.Refund(testCtx, bucket.WithCost(1)) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(20)) diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go index dd5a1167eca..8a7cbca7087 100644 --- a/ratelimits/utilities.go +++ b/ratelimits/utilities.go @@ -2,32 +2,9 @@ package ratelimits import ( "strings" - - "github.com/letsencrypt/boulder/core" - "github.com/weppos/publicsuffix-go/publicsuffix" ) // joinWithColon joins the provided args with a colon. func joinWithColon(args ...string) string { return strings.Join(args, ":") } - -// DomainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's -// for the purpose of rate limiting. It also de-duplicates the output -// domains. Exact public suffix matches are included. -func DomainsForRateLimiting(names []string) []string { - var domains []string - for _, name := range names { - domain, err := publicsuffix.Domain(name) - if err != nil { - // The only possible errors are: - // (1) publicsuffix.Domain is giving garbage values - // (2) the public suffix is the domain itself - // We assume 2 and include the original name in the result. - domains = append(domains, name) - } else { - domains = append(domains, domain) - } - } - return core.UniqueLowerNames(domains) -} diff --git a/ratelimits/utilities_test.go b/ratelimits/utilities_test.go deleted file mode 100644 index 9c68d3a6e89..00000000000 --- a/ratelimits/utilities_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package ratelimits - -import ( - "testing" - - "github.com/letsencrypt/boulder/test" -) - -func TestDomainsForRateLimiting(t *testing.T) { - domains := DomainsForRateLimiting([]string{}) - test.AssertEquals(t, len(domains), 0) - - domains = DomainsForRateLimiting([]string{"www.example.com", "example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) - test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) - - domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) - test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) - - domains = DomainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = DomainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) - test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) -} diff --git a/wfe2/wfe.go b/wfe2/wfe.go index d7ae9a29306..4ff3e68cedf 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -638,13 +638,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP wfe.log.Warningf("checking %s rate limit: %s", limit, err) } - bucket, err := ratelimits.NewRegistrationsPerIPAddressBucket(ip) + bucketId, err := ratelimits.NewRegistrationsPerIPAddressBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return } - decision, err := wfe.limiter.Spend(ctx, bucket.WithCost(1)) + decision, err := wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -655,13 +655,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP return } - bucket, err = ratelimits.NewRegistrationsPerIPv6RangeBucket(ip) + bucketId, err = ratelimits.NewRegistrationsPerIPv6RangeBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) return } - _, err = wfe.limiter.Spend(ctx, bucket.WithCost(1)) + _, err = wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } @@ -686,13 +686,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I wfe.log.Warningf("refunding %s rate limit: %s", limit, err) } - bucket, err := ratelimits.NewRegistrationsPerIPAddressBucket(ip) + bucketId, err := ratelimits.NewRegistrationsPerIPAddressBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return } - _, err = wfe.limiter.Refund(ctx, bucket.WithCost(1)) + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -702,13 +702,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I return } - bucket, err = ratelimits.NewRegistrationsPerIPv6RangeBucket(ip) + bucketId, err = ratelimits.NewRegistrationsPerIPv6RangeBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) return } - _, err = wfe.limiter.Refund(ctx, bucket.WithCost(1)) + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } From 8269947b3f9022fdc53f1a515f5789fed425dddd Mon Sep 17 00:00:00 2001 From: Samantha Date: Mon, 6 Nov 2023 17:38:17 -0500 Subject: [PATCH 8/8] Remove newline that somehow got added. --- ratelimits/limiter.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index eab087bf5df..46c727b1d39 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -300,7 +300,6 @@ func (l *Limiter) initialize(ctx context.Context, rl limit, txn Transaction) (*D return nil, err } return d, nil - } // getLimit returns the limit for the specified by name and id, name is