diff --git a/ratelimits/README.md b/ratelimits/README.md index 7e500aa4931..838edc65605 100644 --- a/ratelimits/README.md +++ b/ratelimits/README.md @@ -79,22 +79,21 @@ Example: `NewRegistrationsPerIPv6Range:2001:0db8:0000::/48` #### regId -The registration ID of the account. +An ACME account registration ID. Example: `NewOrdersPerAccount:12345678` -#### regId:domain +#### domain -A combination of registration ID and domain, formatted 'regId:domain'. +A valid eTLD+1 domain name. -Example: `CertificatesPerDomainPerAccount:12345678:example.com` +Example: `CertificatesPerDomain:example.com` -#### regId:fqdnSet +#### fqdnSet -A combination of registration ID and a comma-separated list of domain names, -formatted 'regId:fqdnSet'. +A comma-separated list of domain names. -Example: `CertificatesPerFQDNSetPerAccount:12345678:example.com,example.org` +Example: `CertificatesPerFQDNSet:example.com,example.org` ## Bucket Key Definitions diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go new file mode 100644 index 00000000000..501d1fd2c44 --- /dev/null +++ b/ratelimits/bucket.go @@ -0,0 +1,52 @@ +package ratelimits + +import ( + "fmt" + "net" +) + +// BucketId should only be created using the New*BucketId functions. It is used +// by the Limiter to look up the bucket and limit overrides for a specific +// subscriber and limit. +type BucketId struct { + // limit is the name of the associated rate limit. It is used for looking up + // default limits. + limit Name + + // bucketKey is the limit Name enum (e.g. "1") concatenated with the + // subscriber identifier specific to the associate limit Name type. + bucketKey string +} + +// NewRegistrationsPerIPAddressBucketId returns a BucketId for the provided IP +// address. +func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) { + id := ip.String() + err := validateIdForName(NewRegistrationsPerIPAddress, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limit: NewRegistrationsPerIPAddress, + bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), + }, nil +} + +// NewRegistrationsPerIPv6RangeBucketId returns a BucketId for the /48 IPv6 +// range containing the provided IPv6 address. +func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) { + if ip.To4() != nil { + return BucketId{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) + } + ipMask := net.CIDRMask(48, 128) + ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} + id := ipNet.String() + err := validateIdForName(NewRegistrationsPerIPv6Range, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limit: NewRegistrationsPerIPv6Range, + bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), + }, nil +} diff --git a/ratelimits/limit.go b/ratelimits/limit.go index 7fb166bce60..261f66e67f3 100644 --- a/ratelimits/limit.go +++ b/ratelimits/limit.go @@ -121,22 +121,14 @@ func loadAndParseOverrideLimits(path string) (limits, error) { return nil, fmt.Errorf( "validating name %s and id %q for override limit %q: %w", name, id, k, err) } - if name == CertificatesPerFQDNSetPerAccount { + if name == CertificatesPerFQDNSet { // FQDNSet hashes are not a nice thing to ask for in a config file, // so we allow the user to specify a comma-separated list of FQDNs // and compute the hash here. - regIdDomains := strings.SplitN(id, ":", 2) - if len(regIdDomains) != 2 { - // Should never happen, the Id format was validated above. - return nil, fmt.Errorf("invalid override limit %q, must be formatted 'name:id'", k) - } - regId := regIdDomains[0] - domains := strings.Split(regIdDomains[1], ",") - fqdnSet := core.HashNames(domains) - id = fmt.Sprintf("%s:%s", regId, fqdnSet) + id = string(core.HashNames(strings.Split(id, ","))) } v.isOverride = true - parsed[bucketKey(name, id)] = precomputeLimit(v) + parsed[joinWithColon(name.EnumString(), id)] = precomputeLimit(v) } return parsed, nil } @@ -159,7 +151,7 @@ func loadAndParseDefaultLimits(path string) (limits, error) { if !ok { return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames) } - parsed[nameToEnumString(name)] = precomputeLimit(v) + parsed[name.EnumString()] = precomputeLimit(v) } return parsed, nil } diff --git a/ratelimits/limit_test.go b/ratelimits/limit_test.go index acabe9a51a0..50aae8c80c0 100644 --- a/ratelimits/limit_test.go +++ b/ratelimits/limit_test.go @@ -78,19 +78,19 @@ func Test_validateIdForName(t *testing.T) { err = validateIdForName(NewOrdersPerAccount, "1234567890") test.AssertNotError(t, err, "valid regId") - // 'enum:regId:domain' + // 'enum:domain' // Valid regId and domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com") + err = validateIdForName(CertificatesPerDomain, "example.com") test.AssertNotError(t, err, "valid regId and domain") - // 'enum:regId:fqdnSet' - // Valid regId and FQDN set containing a single domain. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com") + // 'enum:fqdnSet' + // Valid fqdnSet containing a single domain. + err = validateIdForName(CertificatesPerFQDNSet, "example.com") test.AssertNotError(t, err, "valid regId and FQDN set containing a single domain") - // 'enum:regId:fqdnSet' - // Valid regId and FQDN set containing multiple domains. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org") + // 'enum:fqdnSet' + // Valid fqdnSet containing multiple domains. + err = validateIdForName(CertificatesPerFQDNSet, "example.com,example.org") test.AssertNotError(t, err, "valid regId and FQDN set containing multiple domains") // Empty string. @@ -125,71 +125,20 @@ func Test_validateIdForName(t *testing.T) { err = validateIdForName(NewOrdersPerAccount, "lol") test.AssertError(t, err, "invalid regId") - // Invalid regId with good domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "lol:example.com") - test.AssertError(t, err, "invalid regId with good domain") - - // Valid regId with bad domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:lol") - test.AssertError(t, err, "valid regId with bad domain") - - // Empty regId with good domain. - err = validateIdForName(CertificatesPerDomainPerAccount, ":lol") + // Invalid domain, malformed. + err = validateIdForName(CertificatesPerDomain, "example:.com") test.AssertError(t, err, "valid regId with bad domain") - // Valid regId with empty domain. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:") + // Invalid domain, empty. + err = validateIdForName(CertificatesPerDomain, "") test.AssertError(t, err, "valid regId with empty domain") - - // Empty regId with empty domain, no separator. - err = validateIdForName(CertificatesPerDomainPerAccount, "") - test.AssertError(t, err, "empty regId with empty domain, no separator") - - // Instead of anything we would expect, we get lol. - err = validateIdForName(CertificatesPerDomainPerAccount, "lol") - test.AssertError(t, err, "instead of anything we would expect, just lol") - - // Valid regId with good domain and a secret third separator. - err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com:lol") - test.AssertError(t, err, "valid regId with good domain and a secret third separator") - - // Valid regId with bad FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:lol..99") - test.AssertError(t, err, "valid regId with bad FQDN set") - - // Bad regId with good FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol:example.com,example.org") - test.AssertError(t, err, "bad regId with good FQDN set") - - // Empty regId with good FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, ":example.com,example.org") - test.AssertError(t, err, "empty regId with good FQDN set") - - // Good regId with empty FQDN set. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:") - test.AssertError(t, err, "good regId with empty FQDN set") - - // Empty regId with empty FQDN set, no separator. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "") - test.AssertError(t, err, "empty regId with empty FQDN set, no separator") - - // Instead of anything we would expect, just lol. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol") - test.AssertError(t, err, "instead of anything we would expect, just lol") - - // Valid regId with good FQDN set and a secret third separator. - err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org:lol") - test.AssertError(t, err, "valid regId with good FQDN set and a secret third separator") } func Test_loadAndParseOverrideLimits(t *testing.T) { - newRegistrationsPerIPAddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress) - newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range) - // Load a single valid override limit with Id formatted as 'enum:RegId'. l, err := loadAndParseOverrideLimits("testdata/working_override.yml") test.AssertNotError(t, err, "valid single override limit") - expectKey := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2" + expectKey := joinWithColon(NewRegistrationsPerIPAddress.EnumString(), "10.0.0.2") test.AssertEquals(t, l[expectKey].Burst, int64(40)) test.AssertEquals(t, l[expectKey].Count, int64(40)) test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) @@ -197,35 +146,35 @@ func Test_loadAndParseOverrideLimits(t *testing.T) { // Load single valid override limit with Id formatted as 'regId:domain'. l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domain.yml") test.AssertNotError(t, err, "valid single override limit with Id of regId:domain") - expectKey = nameToEnumString(CertificatesPerDomainPerAccount) + ":" + "12345678:example.com" + expectKey = joinWithColon(CertificatesPerDomain.EnumString(), "example.com") test.AssertEquals(t, l[expectKey].Burst, int64(40)) test.AssertEquals(t, l[expectKey].Count, int64(40)) test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) // Load multiple valid override limits with 'enum:RegId' Ids. l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml") - expectKey1 := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2" + expectKey1 := joinWithColon(NewRegistrationsPerIPAddress.EnumString(), "10.0.0.2") test.AssertNotError(t, err, "multiple valid override limits") test.AssertEquals(t, l[expectKey1].Burst, int64(40)) test.AssertEquals(t, l[expectKey1].Count, int64(40)) test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second) - expectKey2 := newRegistrationsPerIPv6RangeEnumStr + ":" + "2001:0db8:0000::/48" + expectKey2 := joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), "2001:0db8:0000::/48") test.AssertEquals(t, l[expectKey2].Burst, int64(50)) test.AssertEquals(t, l[expectKey2].Count, int64(50)) test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2) - // Load multiple valid override limits with 'regID:fqdnSet' Ids as follows: - // - CertificatesPerFQDNSetPerAccount:12345678:example.com - // - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net - // - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org + // Load multiple valid override limits with 'fqdnSet' Ids, as follows: + // - CertificatesPerFQDNSet:example.com + // - CertificatesPerFQDNSet:example.com,example.net + // - CertificatesPerFQDNSet:example.com,example.net,example.org firstEntryFQDNSetHash := string(core.HashNames([]string{"example.com"})) secondEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net"})) thirdEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net", "example.org"})) - firstEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + firstEntryFQDNSetHash - secondEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + secondEntryFQDNSetHash - thirdEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + thirdEntryFQDNSetHash + firstEntryKey := joinWithColon(CertificatesPerFQDNSet.EnumString(), firstEntryFQDNSetHash) + secondEntryKey := joinWithColon(CertificatesPerFQDNSet.EnumString(), secondEntryFQDNSetHash) + thirdEntryKey := joinWithColon(CertificatesPerFQDNSet.EnumString(), thirdEntryFQDNSetHash) l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml") - test.AssertNotError(t, err, "multiple valid override limits with Id of regId:fqdnSets") + test.AssertNotError(t, err, "multiple valid override limits with 'fqdnSet' Ids") test.AssertEquals(t, l[firstEntryKey].Burst, int64(40)) test.AssertEquals(t, l[firstEntryKey].Count, int64(40)) test.AssertEquals(t, l[firstEntryKey].Period.Duration, time.Second) @@ -278,25 +227,22 @@ func Test_loadAndParseOverrideLimits(t *testing.T) { } func Test_loadAndParseDefaultLimits(t *testing.T) { - newRestistrationsPerIPv4AddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress) - newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range) - // Load a single valid default limit. l, err := loadAndParseDefaultLimits("testdata/working_default.yml") test.AssertNotError(t, err, "valid single default limit") - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Burst, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Count, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Period.Duration, time.Second) // Load multiple valid default limits. l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml") test.AssertNotError(t, err, "multiple valid default limits") - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20)) - test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second) - test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Burst, int64(30)) - test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Count, int64(30)) - test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Period.Duration, time.Second*2) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Burst, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Count, int64(20)) + test.AssertEquals(t, l[NewRegistrationsPerIPAddress.EnumString()].Period.Duration, time.Second) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Burst, int64(30)) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Count, int64(30)) + test.AssertEquals(t, l[NewRegistrationsPerIPv6Range.EnumString()].Period.Duration, time.Second*2) // Path is empty string. _, err = loadAndParseDefaultLimits("") diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index b518ef5ffba..46c727b1d39 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -88,13 +88,27 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat limiter.overrideUsageGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{ Name: "ratelimits_override_usage", - Help: "Proportion of override limit used, by limit name and client id.", - }, []string{"limit", "client_id"}) + Help: "Proportion of override limit used, by limit name and bucket key.", + }, []string{"limit", "bucket_key"}) stats.MustRegister(limiter.overrideUsageGauge) return limiter, nil } +// Transaction is a cost to be spent or refunded from a specific BucketId. +type Transaction struct { + BucketId + cost int64 +} + +// NewTransaction creates a new Transaction for the provided BucketId and cost. +func NewTransaction(b BucketId, cost int64) Transaction { + return Transaction{ + BucketId: b, + cost: cost, + } +} + type Decision struct { // Allowed is true if the bucket possessed enough capacity to allow the // request given the cost. @@ -118,23 +132,17 @@ type Decision struct { newTAT time.Time } -// Check returns a *Decision that indicates whether there's enough capacity to -// allow the request, given the cost, for the specified limit Name and client -// id. However, it DOES NOT deduct the cost of the request from the bucket's -// capacity. Hence, the returned *Decision represents the hypothetical state of -// the bucket if the cost WERE to be deducted. The returned *Decision will -// always include the number of remaining requests in the bucket, the required -// wait time before the client can make another request, and the time until the -// bucket refills to its maximum capacity (resets). If no bucket exists for the -// given limit Name and client id, a new one will be created WITHOUT the -// request's cost deducted from its initial capacity. If the specified limit is -// disabled, ErrLimitDisabled is returned. -func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) (*Decision, error) { - if cost < 0 { +// Check DOES NOT deduct the cost of the request from the provided bucket's +// capacity. The returned *Decision indicates whether the capacity exists to +// satisfy the cost and represents the hypothetical state of the bucket IF the +// cost WERE to be deducted. If no bucket exists it will NOT be created. No +// state is persisted to the underlying datastore. +func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost < 0 { return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(name, id) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -142,45 +150,38 @@ func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) ( return nil, err } - if cost > limit.Burst { + if txn.cost > limit.Burst { return nil, ErrInvalidCostOverLimit } // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucketKey(name, id)) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { if !errors.Is(err, ErrBucketNotFound) { return nil, err } - // First request from this client. The cost is not deducted from the - // initial capacity because this is only a check. - d, err := l.initialize(ctx, limit, name, id, 0) - if err != nil { - return nil, err - } - return maybeSpend(l.clk, limit, d.newTAT, cost), nil + // First request from this client. No need to initialize the bucket + // because this is a check, not a spend. A TAT of "now" is equivalent to + // a full bucket. + return maybeSpend(l.clk, limit, l.clk.Now(), txn.cost), nil } - return maybeSpend(l.clk, limit, tat, cost), nil + return maybeSpend(l.clk, limit, tat, txn.cost), nil } -// Spend returns a *Decision that indicates if enough capacity was available to -// process the request, given the cost, for the specified limit Name and client -// id. If capacity existed, the cost of the request HAS been deducted from the -// bucket's capacity, otherwise no cost was deducted. The returned *Decision -// will always include the number of remaining requests in the bucket, the -// required wait time before the client can make another request, and the time -// until the bucket refills to its maximum capacity (resets). If no bucket -// exists for the given limit Name and client id, a new one will be created WITH -// the request's cost deducted from its initial capacity. If the specified limit -// is disabled, ErrLimitDisabled is returned. -func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) (*Decision, error) { - if cost <= 0 { +// Spend attempts to deduct the cost from the provided bucket's capacity. The +// returned *Decision indicates whether the capacity existed to satisfy the cost +// and represents the current state of the bucket. If no bucket exists it WILL +// be created WITH the cost factored into its initial state. The new bucket +// state is persisted to the underlying datastore, if applicable, before +// returning. +func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(name, id) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -188,24 +189,24 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return nil, err } - if cost > limit.Burst { + if txn.cost > limit.Burst { return nil, ErrInvalidCostOverLimit } start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(name.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucketKey(name, id)) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { if errors.Is(err, ErrBucketNotFound) { // First request from this client. - d, err := l.initialize(ctx, limit, name, id, cost) + d, err := l.initialize(ctx, limit, txn) if err != nil { return nil, err } @@ -217,20 +218,19 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return nil, err } - d := maybeSpend(l.clk, limit, tat, cost) + d := maybeSpend(l.clk, limit, tat, txn.cost) if limit.isOverride { - // Calculate the current utilization of the override limit for the - // specified client id. + // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(name.String(), id).Set(utilization) + l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization) } if !d.Allowed { return d, nil } - err = l.source.Set(ctx, bucketKey(name, id), d.newTAT) + err = l.source.Set(ctx, txn.bucketKey, d.newTAT) if err != nil { return nil, err } @@ -238,23 +238,23 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return d, nil } -// Refund attempts to refund the cost to the bucket identified by limit name and -// client id. The returned *Decision indicates whether the refund was successful -// or not. If the refund was successful, the cost of the request was added back -// to the bucket's capacity. If the refund is not possible (i.e., the bucket is -// already full or the refund amount is invalid), no cost is refunded. +// Refund attempts to refund all of the cost to the capacity of the specified +// bucket. The returned *Decision indicates whether the refund was successful +// and represents the current state of the bucket. The new bucket state is +// persisted to the underlying datastore, if applicable, before returning. If no +// bucket exists it will NOT be created. // // Note: The amount refunded cannot cause the bucket to exceed its maximum -// capacity. However, partial refunds are allowed and are considered successful. -// For instance, if a bucket has a maximum capacity of 10 and currently has 5 +// capacity. Partial refunds are allowed and are considered successful. For +// instance, if a bucket has a maximum capacity of 10 and currently has 5 // requests remaining, a refund request of 7 will result in the bucket reaching // its maximum capacity of 10, not 12. -func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) (*Decision, error) { - if cost <= 0 { +func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(name, id) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -265,47 +265,47 @@ func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucketKey(name, id)) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { return nil, err } - d := maybeRefund(l.clk, limit, tat, cost) + d := maybeRefund(l.clk, limit, tat, txn.cost) if !d.Allowed { // The bucket is already at maximum capacity. return d, nil } - return d, l.source.Set(ctx, bucketKey(name, id), d.newTAT) - + return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) } -// Reset resets the specified bucket. -func (l *Limiter) Reset(ctx context.Context, name Name, id string) error { +// Reset resets the specified bucket to its maximum capacity. The new bucket +// state is persisted to the underlying datastore before returning. +func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error { // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - return l.source.Delete(ctx, bucketKey(name, id)) + return l.source.Delete(ctx, bucketId.bucketKey) } -// initialize creates a new bucket, specified by limit name and id, with the -// cost of the request factored into the initial state. -func (l *Limiter) initialize(ctx context.Context, rl limit, name Name, id string, cost int64) (*Decision, error) { - d := maybeSpend(l.clk, rl, l.clk.Now(), cost) +// initialize creates a new bucket and sets its TAT to now, which is equivalent +// to a full bucket. The new bucket state is persisted to the underlying +// datastore before returning. +func (l *Limiter) initialize(ctx context.Context, rl limit, txn Transaction) (*Decision, error) { + d := maybeSpend(l.clk, rl, l.clk.Now(), txn.cost) // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - err := l.source.Set(ctx, bucketKey(name, id), d.newTAT) + err := l.source.Set(ctx, txn.bucketKey, d.newTAT) if err != nil { return nil, err } return d, nil - } -// GetLimit returns the limit for the specified by name and id, name is +// getLimit returns the limit for the specified by name and id, name is // required, id is optional. If id is left unspecified, the default limit for // the limit specified by name is returned. If no default limit exists for the -// specified name, ErrLimitDisabled is returned. +// specified name, errLimitDisabled is returned. func (l *Limiter) getLimit(name Name, id string) (limit, error) { if !name.isValid() { // This should never happen. Callers should only be specifying the limit @@ -314,12 +314,12 @@ func (l *Limiter) getLimit(name Name, id string) (limit, error) { } if id != "" { // Check for override. - ol, ok := l.overrides[bucketKey(name, id)] + ol, ok := l.overrides[id] if ok { return ol, nil } } - dl, ok := l.defaults[nameToEnumString(name)] + dl, ok := l.defaults[name.EnumString()] if ok { return dl, nil } diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index bb58dfff41c..7f108810279 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -53,21 +53,13 @@ func Test_Limiter_WithBadLimitsPath(t *testing.T) { test.AssertError(t, err, "should error") } -func Test_Limiter_getLimitNoExist(t *testing.T) { - t.Parallel() - l, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/working_default.yml", "", metrics.NoopRegisterer) - test.AssertNotError(t, err, "should not error") - _, err = l.getLimit(Name(9999), "") - test.AssertError(t, err, "should error") - -} - func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { t.Parallel() testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - _, err := l.Check(testCtx, Name(9999), testIP, 1) + bucketId := BucketId{limit: Name(9999), bucketKey: testIP} + _, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertError(t, err, "should error") }) } @@ -81,25 +73,29 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Verify our overrideUsageGauge is being set correctly. 0.0 == 0% of // the bucket has been consumed. test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{ - "limit": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 0) + "limit": NewRegistrationsPerIPAddress.String(), + "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 0) + + overriddenBucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(tenZeroZeroTwo)) + test.AssertNotError(t, err, "should not error") // Attempt to check a spend of 41 requests (a cost > the limit burst // capacity), this should fail with a specific error. - _, err := l.Check(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41) + _, err = l.Check(testCtx, NewTransaction(overriddenBucketId, 41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend 41 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41) + _, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 40 requests, this should succeed. - d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40) + d, err := l.Spend(testCtx, NewTransaction(overriddenBucketId, 40)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -108,7 +104,8 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Verify our overrideUsageGauge is being set correctly. 1.0 == 100% of // the bucket has been consumed. test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{ - "limit_name": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 1.0) + "limit_name": NewRegistrationsPerIPAddress.String(), + "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 1.0) // Verify our RetryIn is correct. 1 second == 1000 milliseconds and // 1000/40 = 25 milliseconds per request. @@ -118,7 +115,7 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -129,21 +126,21 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Quickly spend 40 requests in a row. for i := 0; i < 40; i++ { - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(39-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Reset between tests. - err = l.Reset(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo) + err = l.Reset(testCtx, overriddenBucketId) test.AssertNotError(t, err, "should not error") }) } @@ -154,9 +151,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - // Check on an empty bucket should initialize it and return the - // theoretical next state of that bucket if the cost were spent. - d, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + + // Check on an empty bucket should return the theoretical next state + // of that bucket if the cost were spent. + d, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -167,7 +167,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 20 remaining. - d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(20)) @@ -175,13 +175,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { test.AssertEquals(t, d.RetryIn, time.Duration(0)) // Reset our bucket. - err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP) + err = l.Reset(testCtx, bucketId) test.AssertNotError(t, err, "should not error") - // Similar to above, but we'll use Spend() instead of Check() to - // initialize the bucket. Spend should return the same result as - // Check. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + // Similar to above, but we'll use Spend() to actually initialize + // the bucket. Spend should return the same result as Check. + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -192,7 +191,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 19 remaining. - d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -209,20 +208,23 @@ func Test_Limiter_RefundAndSpendCostErr(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + // Spend a cost of 0, which should fail. - _, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Spend a negative cost, which should fail. - _, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, -1) + _, err = l.Spend(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a cost of 0, which should fail. - _, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 0) + _, err = l.Refund(testCtx, NewTransaction(bucketId, 0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a negative cost, which should fail. - _, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, -1) + _, err = l.Refund(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCost) }) } @@ -233,7 +235,10 @@ func Test_Limiter_CheckWithBadCost(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - _, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, -1) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + + _, err = l.Check(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCostForCheck) }) } @@ -244,20 +249,23 @@ func Test_Limiter_DefaultLimits(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + // Attempt to spend 21 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 21) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 21)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -271,7 +279,7 @@ func Test_Limiter_DefaultLimits(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -282,14 +290,14 @@ func Test_Limiter_DefaultLimits(t *testing.T) { // Quickly spend 20 requests in a row. for i := 0; i < 20; i++ { - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -303,30 +311,33 @@ func Test_Limiter_RefundAndReset(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) + test.AssertNotError(t, err, "should not error") + // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Refund 10 requests. - d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 10) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 10)) test.AssertNotError(t, err, "should not error") test.AssertEquals(t, d.Remaining, int64(10)) // Spend 10 requests, this should succeed. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 10) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 10)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) - err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP) + err = l.Reset(testCtx, bucketId) test.AssertNotError(t, err, "should not error") // Attempt to spend 20 more requests, this should succeed. - d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -336,7 +347,7 @@ func Test_Limiter_RefundAndReset(t *testing.T) { clk.Add(d.ResetIn) // Refund 1 requests above our limit, this should fail. - d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 1) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(20)) diff --git a/ratelimits/names.go b/ratelimits/names.go index b2663982b7e..b0d581e7657 100644 --- a/ratelimits/names.go +++ b/ratelimits/names.go @@ -13,8 +13,10 @@ import ( // limit names as strings and to provide a type-safe way to refer to rate // limits. // -// IMPORTANT: If you add a new limit Name, you MUST add it to the 'nameToString' -// mapping and idValidForName function below. +// IMPORTANT: If you add a new limit Name, you MUST add: +// - it to the nameToString mapping, +// - an entry for it in the validateIdForName(), and +// - provide a Bucket constructor in bucket.go. type Name int const ( @@ -39,18 +41,25 @@ const ( NewOrdersPerAccount // FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId - // is the registration id of the account. + // is the ACME registration Id of the account. FailedAuthorizationsPerAccount - // CertificatesPerDomainPerAccount uses bucket key 'enum:regId:domain', - // where name is the a name in a certificate issued to the account matching - // regId. + // CertificatesPerDomain uses bucket key 'enum:domain', where domain is a + // domain name in the issued certificate. + CertificatesPerDomain + + // CertificatesPerDomainPerAccount uses bucket key 'enum:regId', where regId + // is the ACME registration Id of the account. This limit is never checked + // or enforced by the Limiter. Instead, it is used to override the + // CertificatesPerDomain limit for the specified account. CertificatesPerDomainPerAccount - // CertificatesPerFQDNSetPerAccount uses bucket key 'enum:regId:fqdnSet', - // where nameSet is a set of names in a certificate issued to the account - // matching regId. - CertificatesPerFQDNSetPerAccount + // CertificatesPerFQDNSet uses bucket key 'enum:fqdnSet', where fqdnSet is a + // hashed set of unique eTLD+1 domain names in the issued certificate. + // + // Note: When this is referenced in an overrides file, the fqdnSet MUST be + // passed as a comma-separated list of domain names. + CertificatesPerFQDNSet ) // isValid returns true if the Name is a valid rate limit name. @@ -67,15 +76,24 @@ func (n Name) String() string { return nameToString[n] } +// EnumString returns the string representation of the Name enumeration. +func (n Name) EnumString() string { + if !n.isValid() { + return nameToString[Unknown] + } + return strconv.Itoa(int(n)) +} + // nameToString is a map of Name values to string names. var nameToString = map[Name]string{ - Unknown: "Unknown", - NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress", - NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range", - NewOrdersPerAccount: "NewOrdersPerAccount", - FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount", - CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount", - CertificatesPerFQDNSetPerAccount: "CertificatesPerFQDNSetPerAccount", + Unknown: "Unknown", + NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress", + NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range", + NewOrdersPerAccount: "NewOrdersPerAccount", + FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount", + CertificatesPerDomain: "CertificatesPerDomain", + CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount", + CertificatesPerFQDNSet: "CertificatesPerFQDNSet", } // validIPAddress validates that the provided string is a valid IP address. @@ -114,48 +132,29 @@ func validateRegId(id string) error { return nil } -// validateRegIdDomain validates that the provided string is formatted -// 'regId:domain', where regId is an ACME registration Id and domain is a single -// domain name. -func validateRegIdDomain(id string) error { - parts := strings.SplitN(id, ":", 2) - if len(parts) != 2 { - return fmt.Errorf( - "invalid regId:domain, %q must be formatted 'regId:domain'", id) - } - if validateRegId(parts[0]) != nil { - return fmt.Errorf( - "invalid regId, %q must be formatted 'regId:domain'", id) - } - if policy.ValidDomain(parts[1]) != nil { - return fmt.Errorf( - "invalid domain, %q must be formatted 'regId:domain'", id) +// validateDomain validates that the provided string is formatted 'domain', +// where domain is a domain name. +func validateDomain(id string) error { + err := policy.ValidDomain(id) + if err != nil { + return fmt.Errorf("invalid domain, %q must be formatted 'domain'", id) } return nil } -// validateRegIdFQDNSet validates that the provided string is formatted -// 'regId:fqdnSet', where regId is an ACME registration Id and fqdnSet is a -// comma-separated list of domain names. -func validateRegIdFQDNSet(id string) error { - parts := strings.SplitN(id, ":", 2) - if len(parts) != 2 { - return fmt.Errorf( - "invalid regId:fqdnSet, %q must be formatted 'regId:fqdnSet'", id) - } - if validateRegId(parts[0]) != nil { - return fmt.Errorf( - "invalid regId, %q must be formatted 'regId:fqdnSet'", id) - } - domains := strings.Split(parts[1], ",") +// validateFQDNSet validates that the provided string is formatted 'fqdnSet', +// where fqdnSet is a comma-separated list of domain names. +func validateFQDNSet(id string) error { + domains := strings.Split(id, ",") if len(domains) == 0 { return fmt.Errorf( - "invalid fqdnSet, %q must be formatted 'regId:fqdnSet'", id) + "invalid fqdnSet, %q must be formatted 'fqdnSet'", id) } for _, domain := range domains { - if policy.ValidDomain(domain) != nil { + err := policy.ValidDomain(domain) + if err != nil { return fmt.Errorf( - "invalid domain, %q must be formatted 'regId:fqdnSet'", id) + "invalid domain, %q must be formatted 'fqdnSet'", id) } } return nil @@ -171,17 +170,17 @@ func validateIdForName(name Name, id string) error { // 'enum:ipv6rangeCIDR' return validIPv6RangeCIDR(id) - case NewOrdersPerAccount, FailedAuthorizationsPerAccount: + case NewOrdersPerAccount, FailedAuthorizationsPerAccount, CertificatesPerDomainPerAccount: // 'enum:regId' return validateRegId(id) - case CertificatesPerDomainPerAccount: - // 'enum:regId:domain' - return validateRegIdDomain(id) + case CertificatesPerDomain: + // 'enum:domain' + return validateDomain(id) - case CertificatesPerFQDNSetPerAccount: - // 'enum:regId:fqdnSet' - return validateRegIdFQDNSet(id) + case CertificatesPerFQDNSet: + // 'enum:fqdnSet' + return validateFQDNSet(id) case Unknown: fallthrough @@ -209,14 +208,3 @@ var limitNames = func() []string { } return names }() - -// nameToEnumString converts the integer value of the Name enumeration to its -// string representation. -func nameToEnumString(s Name) string { - return strconv.Itoa(int(s)) -} - -// bucketKey returns the key used to store a rate limit bucket. -func bucketKey(name Name, id string) string { - return nameToEnumString(name) + ":" + id -} diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index bdd7a5fb1ea..5664058fdf0 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -68,9 +68,8 @@ func resultForError(err error) string { return "failed" } -// Set stores the TAT at the specified bucketKey ('name:id'). It returns an -// error if the operation failed and nil otherwise. If the bucketKey does not -// exist, it will be created. +// Set stores the TAT at the specified bucketKey. It returns an error if the +// operation failed and nil otherwise. func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) error { start := r.clk.Now() @@ -84,9 +83,9 @@ func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) return nil } -// Get retrieves the TAT at the specified bucketKey ('name:id'). It returns the -// TAT and nil if the operation succeeded, or an error if the operation failed. -// If the bucketKey does not exist, it returns ErrBucketNotFound. +// Get retrieves the TAT at the specified bucketKey. An error is returned if the +// operation failed and nil otherwise. If the bucketKey does not exist, +// ErrBucketNotFound is returned. func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) { start := r.clk.Now() diff --git a/ratelimits/testdata/working_override_regid_domain.yml b/ratelimits/testdata/working_override_regid_domain.yml index bd4d3eb67a0..9c745290488 100644 --- a/ratelimits/testdata/working_override_regid_domain.yml +++ b/ratelimits/testdata/working_override_regid_domain.yml @@ -1,4 +1,4 @@ -CertificatesPerDomainPerAccount:12345678:example.com: +CertificatesPerDomain:example.com: burst: 40 count: 40 period: 1s diff --git a/ratelimits/testdata/working_overrides_regid_fqdnset.yml b/ratelimits/testdata/working_overrides_regid_fqdnset.yml index 093ac976e7f..352f01694f0 100644 --- a/ratelimits/testdata/working_overrides_regid_fqdnset.yml +++ b/ratelimits/testdata/working_overrides_regid_fqdnset.yml @@ -1,12 +1,12 @@ -CertificatesPerFQDNSetPerAccount:12345678:example.com: +CertificatesPerFQDNSet:example.com: burst: 40 count: 40 period: 1s -CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net: +CertificatesPerFQDNSet:example.com,example.net: burst: 50 count: 50 period: 2s -CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org: +CertificatesPerFQDNSet:example.com,example.net,example.org: burst: 60 count: 60 period: 3s diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go new file mode 100644 index 00000000000..8a7cbca7087 --- /dev/null +++ b/ratelimits/utilities.go @@ -0,0 +1,10 @@ +package ratelimits + +import ( + "strings" +) + +// joinWithColon joins the provided args with a colon. +func joinWithColon(args ...string) string { + return strings.Join(args, ":") +} diff --git a/wfe2/wfe.go b/wfe2/wfe.go index f47b2f27d1b..4ff3e68cedf 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -638,7 +638,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP wfe.log.Warningf("checking %s rate limit: %s", limit, err) } - decision, err := wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) + bucketId, err := ratelimits.NewRegistrationsPerIPAddressBucketId(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPAddress) + return + } + + decision, err := wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -649,11 +655,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP return } - // See docs for ratelimits.NewRegistrationsPerIPv6Range for more information - // on the selection of a /48 block size for IPv6 ranges. - ipMask := net.CIDRMask(48, 128) - ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} - _, err = wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPv6Range, ipNet.String(), 1) + bucketId, err = ratelimits.NewRegistrationsPerIPv6RangeBucketId(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPv6Range) + return + } + + _, err = wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } @@ -678,7 +686,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I wfe.log.Warningf("refunding %s rate limit: %s", limit, err) } - _, err := wfe.limiter.Refund(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) + bucketId, err := ratelimits.NewRegistrationsPerIPAddressBucketId(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPAddress) + return + } + + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -688,11 +702,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I return } - // See docs for ratelimits.NewRegistrationsPerIPv6Range for more information - // on the selection of a /48 block size for IPv6 ranges. - ipMask := net.CIDRMask(48, 128) - ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} - _, err = wfe.limiter.Refund(ctx, ratelimits.NewRegistrationsPerIPv6Range, ipNet.String(), 1) + bucketId, err = ratelimits.NewRegistrationsPerIPv6RangeBucketId(ip) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPv6Range) + return + } + + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) }