diff --git a/cmd/expiration-mailer/main.go b/cmd/expiration-mailer/main.go index 0e45f5cda5c..1d7ccde3ff1 100644 --- a/cmd/expiration-mailer/main.go +++ b/cmd/expiration-mailer/main.go @@ -304,7 +304,7 @@ func (m *mailer) updateLastNagTimestampsChunk(ctx context.Context, certs []*x509 } func (m *mailer) certIsRenewed(ctx context.Context, names []string, issued time.Time) (bool, error) { - namehash := sa.HashNames(names) + namehash := core.HashNames(names) var present bool err := m.dbMap.SelectOne( diff --git a/cmd/expiration-mailer/main_test.go b/cmd/expiration-mailer/main_test.go index 3c0ad588cf8..244eaf0b328 100644 --- a/cmd/expiration-mailer/main_test.go +++ b/cmd/expiration-mailer/main_test.go @@ -353,7 +353,7 @@ func TestNoContactCertIsRenewed(t *testing.T) { setupDBMap, err := sa.DBMapForTest(vars.DBConnSAFullPerms) test.AssertNotError(t, err, "setting up DB") err = setupDBMap.Insert(ctx, &core.FQDNSet{ - SetHash: sa.HashNames(names), + SetHash: core.HashNames(names), Serial: core.SerialToString(serial2), Issued: testCtx.fc.Now().Add(time.Hour), Expires: expires.Add(time.Hour), @@ -580,13 +580,13 @@ func addExpiringCerts(t *testing.T, ctx *testCtx) []certDERWithRegID { test.AssertNotError(t, err, "creating cert D") fqdnStatusD := &core.FQDNSet{ - SetHash: sa.HashNames(certDNames), + SetHash: core.HashNames(certDNames), Serial: serial4String, Issued: ctx.fc.Now().AddDate(0, 0, -87), Expires: ctx.fc.Now().AddDate(0, 0, 3), } fqdnStatusDRenewed := &core.FQDNSet{ - SetHash: sa.HashNames(certDNames), + SetHash: core.HashNames(certDNames), Serial: serial5String, Issued: ctx.fc.Now().AddDate(0, 0, -3), Expires: ctx.fc.Now().AddDate(0, 0, 87), @@ -747,7 +747,7 @@ func TestCertIsRenewed(t *testing.T) { t.Fatal(err) } fqdnStatus := &core.FQDNSet{ - SetHash: sa.HashNames(testData.DNS), + SetHash: core.HashNames(testData.DNS), Serial: testData.stringSerial, Issued: testData.NotBefore, Expires: testData.NotAfter, diff --git a/core/util.go b/core/util.go index 3a6f7c62b43..d7fe0266895 100644 --- a/core/util.go +++ b/core/util.go @@ -242,6 +242,14 @@ func UniqueLowerNames(names []string) (unique []string) { return } +// HashNames returns a hash of the names requested. This is intended for use +// when interacting with the orderFqdnSets table and rate limiting. +func HashNames(names []string) []byte { + names = UniqueLowerNames(names) + hash := sha256.Sum256([]byte(strings.Join(names, ","))) + return hash[:] +} + // LoadCert loads a PEM certificate specified by filename or returns an error func LoadCert(filename string) (*x509.Certificate, error) { certPEM, err := os.ReadFile(filename) diff --git a/core/util_test.go b/core/util_test.go index c1186faf467..211ee89ceeb 100644 --- a/core/util_test.go +++ b/core/util_test.go @@ -1,6 +1,7 @@ package core import ( + "bytes" "encoding/json" "fmt" "math" @@ -206,3 +207,30 @@ func TestRetryBackoff(t *testing.T) { assertBetween(float64(backoff), float64(expected)*0.8, float64(expected)*1.2) } + +func TestHashNames(t *testing.T) { + // Test that it is deterministic + h1 := HashNames([]string{"a"}) + h2 := HashNames([]string{"a"}) + test.AssertByteEquals(t, h1, h2) + + // Test that it differentiates + h1 = HashNames([]string{"a"}) + h2 = HashNames([]string{"b"}) + test.Assert(t, !bytes.Equal(h1, h2), "Should have been different") + + // Test that it is not subject to ordering + h1 = HashNames([]string{"a", "b"}) + h2 = HashNames([]string{"b", "a"}) + test.AssertByteEquals(t, h1, h2) + + // Test that it is not subject to case + h1 = HashNames([]string{"a", "b"}) + h2 = HashNames([]string{"A", "B"}) + test.AssertByteEquals(t, h1, h2) + + // Test that it is not subject to duplication + h1 = HashNames([]string{"a", "a"}) + h2 = HashNames([]string{"a"}) + test.AssertByteEquals(t, h1, h2) +} diff --git a/policy/pa.go b/policy/pa.go index 75e387a527b..ff497a24052 100644 --- a/policy/pa.go +++ b/policy/pa.go @@ -196,7 +196,7 @@ var ( errWildcardNotSupported = berrors.MalformedError("Wildcard domain names are not supported") ) -// validDomain checks that a domain isn't: +// ValidDomain checks that a domain isn't: // // * empty // * prefixed with the wildcard label `*.` @@ -210,7 +210,7 @@ var ( // * exactly equal to an IANA registered TLD // // It does _not_ check that the domain isn't on any PA blocked lists. -func validDomain(domain string) error { +func ValidDomain(domain string) error { if domain == "" { return errEmptyName } @@ -323,7 +323,7 @@ func ValidEmail(address string) error { } splitEmail := strings.SplitN(email.Address, "@", -1) domain := strings.ToLower(splitEmail[len(splitEmail)-1]) - err = validDomain(domain) + err = ValidDomain(domain) if err != nil { return berrors.InvalidEmailError( "contact email %q has invalid domain : %s", @@ -363,7 +363,7 @@ func (pa *AuthorityImpl) willingToIssue(id identifier.ACMEIdentifier) error { } domain := id.Value - err := validDomain(domain) + err := ValidDomain(domain) if err != nil { return err } diff --git a/ratelimits/README.md b/ratelimits/README.md new file mode 100644 index 00000000000..7e500aa4931 --- /dev/null +++ b/ratelimits/README.md @@ -0,0 +1,190 @@ +# Configuring and Storing Key-Value Rate Limits + +## Rate Limit Structure + +All rate limits use a token-bucket model. The metaphor is that each limit is +represented by a bucket which holds tokens. Each request removes some number of +tokens from the bucket, or is denied if there aren't enough tokens to remove. +Over time, new tokens are added to the bucket at a steady rate, until the bucket +is full. The _burst_ parameter of a rate limit indicates the maximum capacity of +a bucket: how many tokens can it hold before new ones stop being added. +Therefore, this also indicates how many requests can be made in a single burst +before a full bucket is completely emptied. The _count_ and _period_ parameters +indicate the rate at which new tokens are added to a bucket: every period, count +tokens will be added. Therefore, these also indicate the steady-state rate at +which a client which has exhausted its quota can make requests: one token every +(period / count) duration. + +## Default Limit Settings + +Each key directly corresponds to a `Name` enumeration as detailed in `//ratelimits/names.go`. +The `Name` enum is used to identify the particular limit. The parameters of a +default limit are the values that will be used for all buckets that do not have +an explicit override (see below). + +```yaml +NewRegistrationsPerIPAddress: + burst: 20 + count: 20 + period: 1s +NewOrdersPerAccount: + burst: 300 + count: 300 + period: 180m +``` + +## Override Limit Settings + +Each override key represents a specific bucket, consisting of two elements: +_name_ and _id_. The name here refers to the Name of the particular limit, while +the id is a client identifier. The format of the id is dependent on the limit. +For example, the id for 'NewRegistrationsPerIPAddress' is a subscriber IP +address, while the id for 'NewOrdersPerAccount' is the subscriber's registration +ID. + +```yaml +NewRegistrationsPerIPAddress:10.0.0.2: + burst: 20 + count: 40 + period: 1s +NewOrdersPerAccount:12345678: + burst: 300 + count: 600 + period: 180m +``` + +The above example overrides the default limits for specific subscribers. In both +cases the count of requests per period are doubled, but the burst capacity is +explicitly configured to match the default rate limit. + +### Id Formats in Limit Override Settings + +Id formats vary based on the `Name` enumeration. Below are examples for each +format: + +#### ipAddress + +A valid IPv4 or IPv6 address. + +Examples: + - `NewRegistrationsPerIPAddress:10.0.0.1` + - `NewRegistrationsPerIPAddress:2001:0db8:0000:0000:0000:ff00:0042:8329` + +#### ipv6RangeCIDR + +A valid IPv6 range in CIDR notation with a /48 mask. A /48 range is typically +assigned to a single subscriber. + +Example: `NewRegistrationsPerIPv6Range:2001:0db8:0000::/48` + +#### regId + +The registration ID of the account. + +Example: `NewOrdersPerAccount:12345678` + +#### regId:domain + +A combination of registration ID and domain, formatted 'regId:domain'. + +Example: `CertificatesPerDomainPerAccount:12345678:example.com` + +#### regId:fqdnSet + +A combination of registration ID and a comma-separated list of domain names, +formatted 'regId:fqdnSet'. + +Example: `CertificatesPerFQDNSetPerAccount:12345678:example.com,example.org` + +## Bucket Key Definitions + +A bucket key is used to lookup the bucket for a given limit and +subscriber. Bucket keys are formatted similarly to the overrides but with a +slight difference: the limit Names do not carry the string form of each limit. +Instead, they apply the `Name` enum equivalent for every limit. + +So, instead of: + +``` +NewOrdersPerAccount:12345678 +``` + +The corresponding bucket key for regId 12345678 would look like this: + +``` +6:12345678 +``` + +When loaded from a file, the keys for the default/override limits undergo the +same interning process as the aforementioned subscriber bucket keys. This +eliminates the need for redundant conversions when fetching each +default/override limit. + +## How Limits are Applied + +Although rate limit buckets are configured in terms of tokens, we do not +actually keep track of the number of tokens in each bucket. Instead, we track +the Theoretical Arrival Time (TAT) at which the bucket will be full again. If +the TAT is in the past, the bucket is full. If the TAT is in the future, some +number of tokens have been spent and the bucket is slowly refilling. If the TAT +is far enough in the future (specifically, more than `burst * (period / count)`) +in the future), then the bucket is completely empty and requests will be denied. + +Additional terminology: + + - **burst offset** is the duration of time it takes for a bucket to go from + empty to full (`burst * (period / count)`). + - **emission interval** is the interval at which tokens are added to a bucket + (`period / count`). This is also the steady-state rate at which requests can + be made without being denied even once the burst has been exhausted. + - **cost** is the number of tokens removed from a bucket for a single request. + - **cost increment** is the duration of time the TAT is advanced to account + for the cost of the request (`cost * emission interval`). + +For the purposes of this example, subscribers originating from a specific IPv4 +address are allowed 20 requests to the newFoo endpoint per second, with a +maximum burst of 20 requests at any point-in-time, or: + +```yaml +NewFoosPerIPAddress:172.23.45.22: + burst: 20 + count: 20 + period: 1s +``` + +A subscriber calls the newFoo endpoint for the first time with an IP address of +172.23.45.22. Here's what happens: + +1. The subscriber's IP address is used to generate a bucket key in the form of + 'NewFoosPerIPAddress:172.23.45.22'. + +2. The request is approved and the 'NewFoosPerIPAddress:172.23.45.22' bucket is + initialized with 19 tokens, as 1 token has been removed to account for the + cost of the current request. To accomplish this, the initial TAT is set to + the current time plus the _cost increment_ (which is 1/20th of a second if we + are limiting to 20 requests per second). + +3. Bucket 'NewFoosPerIPAddress:172.23.45.22': + - will reset to full in 50ms (1/20th of a second), + - will allow another newFoo request immediately, + - will allow between 1 and 19 more requests in the next 50ms, + - will reject the 20th request made in the next 50ms, + - and will allow 1 request every 50ms, indefinitely. + +The subscriber makes another request 5ms later: + +4. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is compared against + the current time and the _burst offset_. The current time is greater than the + TAT minus the cost increment. Therefore, the request is approved. + +5. The TAT at bucket key 'NewFoosPerIPAddress:172.23.45.22' is advanced by the + cost increment to account for the cost of the request. + +The subscriber makes a total of 18 requests over the next 44ms: + +6. The current time is less than the TAT at bucket key + 'NewFoosPerIPAddress:172.23.45.22' minus the burst offset, thus the request + is rejected. + +This mechanism allows for bursts of traffic but also ensures that the average +rate of requests stays within the prescribed limits over time. diff --git a/ratelimits/gcra.go b/ratelimits/gcra.go new file mode 100644 index 00000000000..81a5d0d05e6 --- /dev/null +++ b/ratelimits/gcra.go @@ -0,0 +1,110 @@ +package ratelimits + +import ( + "time" + + "github.com/jmhodges/clock" +) + +// maybeSpend uses the GCRA algorithm to decide whether to allow a request. It +// returns a Decision struct with the result of the decision and the updated +// TAT. The cost must be 0 or greater and <= the burst capacity of the limit. +func maybeSpend(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision { + if cost < 0 || cost > rl.Burst { + // The condition above is the union of the conditions checked in Check + // and Spend methods of Limiter. If this panic is reached, it means that + // the caller has introduced a bug. + panic("invalid cost for maybeSpend") + } + nowUnix := clk.Now().UnixNano() + tatUnix := tat.UnixNano() + + // If the TAT is in the future, use it as the starting point for the + // calculation. Otherwise, use the current time. This is to prevent the + // bucket from being filled with capacity from the past. + if nowUnix > tatUnix { + tatUnix = nowUnix + } + + // Compute the cost increment. + costIncrement := rl.emissionInterval * cost + + // Deduct the cost to find the new TAT and residual capacity. + newTAT := tatUnix + costIncrement + difference := nowUnix - (newTAT - rl.burstOffset) + + if difference < 0 { + // Too little capacity to satisfy the cost, deny the request. + residual := (nowUnix - (tatUnix - rl.burstOffset)) / rl.emissionInterval + return &Decision{ + Allowed: false, + Remaining: residual, + RetryIn: -time.Duration(difference), + ResetIn: time.Duration(tatUnix - nowUnix), + newTAT: time.Unix(0, tatUnix).UTC(), + } + } + + // There is enough capacity to satisfy the cost, allow the request. + var retryIn time.Duration + residual := difference / rl.emissionInterval + if difference < costIncrement { + retryIn = time.Duration(costIncrement - difference) + } + return &Decision{ + Allowed: true, + Remaining: residual, + RetryIn: retryIn, + ResetIn: time.Duration(newTAT - nowUnix), + newTAT: time.Unix(0, newTAT).UTC(), + } +} + +// maybeRefund uses the Generic Cell Rate Algorithm (GCRA) to attempt to refund +// the cost of a request which was previously spent. The refund cost must be 0 +// or greater. A cost will only be refunded up to the burst capacity of the +// limit. A partial refund is still considered successful. +func maybeRefund(clk clock.Clock, rl limit, tat time.Time, cost int64) *Decision { + if cost <= 0 || cost > rl.Burst { + // The condition above is checked in the Refund method of Limiter. If + // this panic is reached, it means that the caller has introduced a bug. + panic("invalid cost for maybeRefund") + } + nowUnix := clk.Now().UnixNano() + tatUnix := tat.UnixNano() + + // The TAT must be in the future to refund capacity. + if nowUnix > tatUnix { + // The TAT is in the past, therefore the bucket is full. + return &Decision{ + Allowed: false, + Remaining: rl.Burst, + RetryIn: time.Duration(0), + ResetIn: time.Duration(0), + newTAT: tat, + } + } + + // Compute the refund increment. + refundIncrement := rl.emissionInterval * cost + + // Subtract the refund increment from the TAT to find the new TAT. + newTAT := tatUnix - refundIncrement + + // Ensure the new TAT is not earlier than now. + if newTAT < nowUnix { + newTAT = nowUnix + } + + // Calculate the new capacity. + difference := nowUnix - (newTAT - rl.burstOffset) + residual := difference / rl.emissionInterval + + return &Decision{ + Allowed: (newTAT != tatUnix), + Remaining: residual, + RetryIn: time.Duration(0), + ResetIn: time.Duration(newTAT - nowUnix), + newTAT: time.Unix(0, newTAT).UTC(), + } +} diff --git a/ratelimits/gcra_test.go b/ratelimits/gcra_test.go new file mode 100644 index 00000000000..4fb4dedc41a --- /dev/null +++ b/ratelimits/gcra_test.go @@ -0,0 +1,221 @@ +package ratelimits + +import ( + "testing" + "time" + + "github.com/jmhodges/clock" + "github.com/letsencrypt/boulder/config" + "github.com/letsencrypt/boulder/test" +) + +func Test_decide(t *testing.T) { + clk := clock.NewFake() + limit := precomputeLimit( + limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}}, + ) + + // Begin by using 1 of our 10 requests. + d := maybeSpend(clk, limit, clk.Now(), 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(9)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Immediately use another 9 of our remaining requests. + d = maybeSpend(clk, limit, d.newTAT, 9) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + // We should have to wait 1 second before we can use another request but we + // used 9 so we should have to wait 9 seconds to make an identical request. + test.AssertEquals(t, d.RetryIn, time.Second*9) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // Our new TAT should be 10 seconds (limit.Burst) in the future. + test.AssertEquals(t, d.newTAT, clk.Now().Add(time.Second*10)) + + // Let's try using just 1 more request without waiting. + d = maybeSpend(clk, limit, d.newTAT, 1) + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.RetryIn, time.Second) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // Let's try being exactly as patient as we're told to be. + clk.Add(d.RetryIn) + d = maybeSpend(clk, limit, d.newTAT, 0) + test.AssertEquals(t, d.Remaining, int64(1)) + + // We are 1 second in the future, we should have 1 new request. + d = maybeSpend(clk, limit, d.newTAT, 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.RetryIn, time.Second) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // Let's try waiting (10 seconds) for our whole bucket to refill. + clk.Add(d.ResetIn) + + // We should have 10 new requests. If we use 1 we should have 9 remaining. + d = maybeSpend(clk, limit, d.newTAT, 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(9)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Wait just shy of how long we're told to wait for refilling. + clk.Add(d.ResetIn - time.Millisecond) + + // We should still have 9 remaining because we're still 1ms shy of the + // refill time. + d = maybeSpend(clk, limit, d.newTAT, 0) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(9)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Millisecond) + + // Spending 0 simply informed us that we still have 9 remaining, let's see + // what we have after waiting 20 hours. + clk.Add(20 * time.Hour) + + // C'mon, big money, no whammies, no whammies, STOP! + d = maybeSpend(clk, limit, d.newTAT, 0) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(10)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Turns out that the most we can accrue is 10 (limit.Burst). Let's empty + // this bucket out so we can try something else. + d = maybeSpend(clk, limit, d.newTAT, 10) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + // We should have to wait 1 second before we can use another request but we + // used 10 so we should have to wait 10 seconds to make an identical + // request. + test.AssertEquals(t, d.RetryIn, time.Second*10) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // If you spend 0 while you have 0 you should get 0. + d = maybeSpend(clk, limit, d.newTAT, 0) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // We don't play by the rules, we spend 1 when we have 0. + d = maybeSpend(clk, limit, d.newTAT, 1) + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.RetryIn, time.Second) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // Okay, maybe we should play by the rules if we want to get anywhere. + clk.Add(d.RetryIn) + + // Our patience pays off, we should have 1 new request. Let's use it. + d = maybeSpend(clk, limit, d.newTAT, 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.RetryIn, time.Second) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // Refill from empty to 5. + clk.Add(d.ResetIn / 2) + + // Attempt to spend 7 when we only have 5. We should be denied but the + // decision should reflect a retry of 2 seconds, the time it would take to + // refill from 5 to 7. + d = maybeSpend(clk, limit, d.newTAT, 7) + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(5)) + test.AssertEquals(t, d.RetryIn, time.Second*2) + test.AssertEquals(t, d.ResetIn, time.Second*5) +} + +func Test_maybeRefund(t *testing.T) { + clk := clock.NewFake() + limit := precomputeLimit( + limit{Burst: 10, Count: 1, Period: config.Duration{Duration: time.Second}}, + ) + + // Begin by using 1 of our 10 requests. + d := maybeSpend(clk, limit, clk.Now(), 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(9)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Refund back to 10. + d = maybeRefund(clk, limit, d.newTAT, 1) + test.AssertEquals(t, d.Remaining, int64(10)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Spend 1 more of our 10 requests. + d = maybeSpend(clk, limit, d.newTAT, 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(9)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Wait for our bucket to refill. + clk.Add(d.ResetIn) + + // Attempt to refund from 10 to 11. + d = maybeRefund(clk, limit, d.newTAT, 1) + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(10)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Spend 10 all 10 of our requests. + d = maybeSpend(clk, limit, d.newTAT, 10) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + // We should have to wait 1 second before we can use another request but we + // used 10 so we should have to wait 10 seconds to make an identical + // request. + test.AssertEquals(t, d.RetryIn, time.Second*10) + test.AssertEquals(t, d.ResetIn, time.Second*10) + + // Attempt a refund of 10. + d = maybeRefund(clk, limit, d.newTAT, 10) + test.AssertEquals(t, d.Remaining, int64(10)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Wait 11 seconds to catching up to TAT. + clk.Add(11 * time.Second) + + // Attempt to refund to 11, then ensure it's still 10. + d = maybeRefund(clk, limit, d.newTAT, 1) + test.Assert(t, !d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(10)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Spend 5 of our 10 requests, then refund 1. + d = maybeSpend(clk, limit, d.newTAT, 5) + d = maybeRefund(clk, limit, d.newTAT, 1) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(6)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + + // Wait, a 2.5 seconds to refill to 8.5 requests. + clk.Add(time.Millisecond * 2500) + + // Ensure we have 8.5 requests. + d = maybeSpend(clk, limit, d.newTAT, 0) + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(8)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + // Check that ResetIn represents the fractional earned request. + test.AssertEquals(t, d.ResetIn, time.Millisecond*1500) + + // Refund 2 requests, we should only have 10, not 10.5. + d = maybeRefund(clk, limit, d.newTAT, 2) + test.AssertEquals(t, d.Remaining, int64(10)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) +} diff --git a/ratelimits/limit.go b/ratelimits/limit.go new file mode 100644 index 00000000000..b1d685c7e81 --- /dev/null +++ b/ratelimits/limit.go @@ -0,0 +1,160 @@ +package ratelimits + +import ( + "fmt" + "os" + "strings" + + "github.com/letsencrypt/boulder/config" + "github.com/letsencrypt/boulder/core" + "github.com/letsencrypt/boulder/strictyaml" +) + +type limit struct { + // Burst specifies maximum concurrent allowed requests at any given time. It + // must be greater than zero. + Burst int64 + + // Count is the number of requests allowed per period. It must be greater + // than zero. + Count int64 + + // Period is the duration of time in which the count (of requests) is + // allowed. It must be greater than zero. + Period config.Duration + + // emissionInterval is the interval, in nanoseconds, at which tokens are + // added to a bucket (period / count). This is also the steady-state rate at + // which requests can be made without being denied even once the burst has + // been exhausted. This is precomputed to avoid doing the same calculation + // on every request. + emissionInterval int64 + + // burstOffset is the duration of time, in nanoseconds, it takes for a + // bucket to go from empty to full (burst * (period / count)). This is + // precomputed to avoid doing the same calculation on every request. + burstOffset int64 +} + +func precomputeLimit(l limit) limit { + l.emissionInterval = l.Period.Nanoseconds() / l.Count + l.burstOffset = l.emissionInterval * l.Burst + return l +} + +func validateLimit(l limit) error { + if l.Burst <= 0 { + return fmt.Errorf("invalid burst '%d', must be > 0", l.Burst) + } + if l.Count <= 0 { + return fmt.Errorf("invalid count '%d', must be > 0", l.Count) + } + if l.Period.Duration <= 0 { + return fmt.Errorf("invalid period '%s', must be > 0", l.Period) + } + return nil +} + +type limits map[string]limit + +// loadLimits marshals the YAML file at path into a map of limis. +func loadLimits(path string) (limits, error) { + lm := make(limits) + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + err = strictyaml.Unmarshal(data, &lm) + if err != nil { + return nil, err + } + return lm, nil +} + +// parseOverrideNameId is broken out for ease of testing. +func parseOverrideNameId(key string) (Name, string, error) { + if !strings.Contains(key, ":") { + // Avoids a potential panic in strings.SplitN below. + return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key) + } + nameAndId := strings.SplitN(key, ":", 2) + nameStr := nameAndId[0] + if nameStr == "" { + return Unknown, "", fmt.Errorf("empty name in override %q, must be formatted 'name:id'", key) + } + + name, ok := stringToName[nameStr] + if !ok { + return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, limitNames) + } + id := nameAndId[1] + if id == "" { + return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key) + } + return name, id, nil +} + +// loadAndParseOverrideLimits loads override limits from YAML, validates them, +// and parses them into a map of limits keyed by 'Name:id'. +func loadAndParseOverrideLimits(path string) (limits, error) { + fromFile, err := loadLimits(path) + if err != nil { + return nil, err + } + parsed := make(limits, len(fromFile)) + + for k, v := range fromFile { + err = validateLimit(v) + if err != nil { + return nil, fmt.Errorf("validating override limit %q: %w", k, err) + } + name, id, err := parseOverrideNameId(k) + if err != nil { + return nil, fmt.Errorf("parsing override limit %q: %w", k, err) + } + err = validateIdForName(name, id) + if err != nil { + return nil, fmt.Errorf( + "validating name %s and id %q for override limit %q: %w", nameToString[name], id, k, err) + } + if name == CertificatesPerFQDNSetPerAccount { + // FQDNSet hashes are not a nice thing to ask for in a config file, + // so we allow the user to specify a comma-separated list of FQDNs + // and compute the hash here. + regIdDomains := strings.SplitN(id, ":", 2) + if len(regIdDomains) != 2 { + // Should never happen, the Id format was validated above. + return nil, fmt.Errorf("invalid override limit %q, must be formatted 'name:id'", k) + } + regId := regIdDomains[0] + domains := strings.Split(regIdDomains[1], ",") + fqdnSet := core.HashNames(domains) + id = fmt.Sprintf("%s:%s", regId, fqdnSet) + } + parsed[bucketKey(name, id)] = precomputeLimit(v) + } + return parsed, nil +} + +// loadAndParseDefaultLimits loads default limits from YAML, validates them, and +// parses them into a map of limits keyed by 'Name'. +func loadAndParseDefaultLimits(path string) (limits, error) { + fromFile, err := loadLimits(path) + if err != nil { + return nil, err + } + parsed := make(limits, len(fromFile)) + + for k, v := range fromFile { + err := validateLimit(v) + if err != nil { + return nil, fmt.Errorf("parsing default limit %q: %w", k, err) + } + name, ok := stringToName[k] + if !ok { + return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, limitNames) + } + parsed[nameToEnumString(name)] = precomputeLimit(v) + } + return parsed, nil +} diff --git a/ratelimits/limit_test.go b/ratelimits/limit_test.go new file mode 100644 index 00000000000..026b888c794 --- /dev/null +++ b/ratelimits/limit_test.go @@ -0,0 +1,333 @@ +package ratelimits + +import ( + "os" + "testing" + "time" + + "github.com/letsencrypt/boulder/config" + "github.com/letsencrypt/boulder/core" + "github.com/letsencrypt/boulder/test" +) + +func Test_parseOverrideNameId(t *testing.T) { + newRegistrationsPerIPAddressStr := nameToString[NewRegistrationsPerIPAddress] + newRegistrationsPerIPv6RangeStr := nameToString[NewRegistrationsPerIPv6Range] + + // 'enum:ipv4' + // Valid IPv4 address. + name, id, err := parseOverrideNameId(newRegistrationsPerIPAddressStr + ":10.0.0.1") + test.AssertNotError(t, err, "should not error") + test.AssertEquals(t, name, NewRegistrationsPerIPAddress) + test.AssertEquals(t, id, "10.0.0.1") + + // 'enum:ipv6range' + // Valid IPv6 address range. + name, id, err = parseOverrideNameId(newRegistrationsPerIPv6RangeStr + ":2001:0db8:0000::/48") + test.AssertNotError(t, err, "should not error") + test.AssertEquals(t, name, NewRegistrationsPerIPv6Range) + test.AssertEquals(t, id, "2001:0db8:0000::/48") + + // Missing colon (this should never happen but we should avoid panicking). + _, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + "10.0.0.1") + test.AssertError(t, err, "missing colon") + + // Empty string. + _, _, err = parseOverrideNameId("") + test.AssertError(t, err, "empty string") + + // Only a colon. + _, _, err = parseOverrideNameId(newRegistrationsPerIPAddressStr + ":") + test.AssertError(t, err, "only a colon") + + // Invalid enum. + _, _, err = parseOverrideNameId("lol:noexist") + test.AssertError(t, err, "invalid enum") +} + +func Test_validateLimit(t *testing.T) { + err := validateLimit(limit{Burst: 1, Count: 1, Period: config.Duration{Duration: time.Second}}) + test.AssertNotError(t, err, "valid limit") + + // All of the following are invalid. + for _, l := range []limit{ + {Burst: 0, Count: 1, Period: config.Duration{Duration: time.Second}}, + {Burst: 1, Count: 0, Period: config.Duration{Duration: time.Second}}, + {Burst: 1, Count: 1, Period: config.Duration{Duration: 0}}, + } { + err = validateLimit(l) + test.AssertError(t, err, "limit should be invalid") + } +} + +func Test_validateIdForName(t *testing.T) { + // 'enum:ipAddress' + // Valid IPv4 address. + err := validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.1") + test.AssertNotError(t, err, "valid ipv4 address") + + // 'enum:ipAddress' + // Valid IPv6 address. + err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334") + test.AssertNotError(t, err, "valid ipv6 address") + + // 'enum:ipv6rangeCIDR' + // Valid IPv6 address range. + err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48") + test.AssertNotError(t, err, "should not error") + + // 'enum:regId' + // Valid regId. + err = validateIdForName(NewOrdersPerAccount, "1234567890") + test.AssertNotError(t, err, "valid regId") + + // 'enum:regId:domain' + // Valid regId and domain. + err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com") + test.AssertNotError(t, err, "valid regId and domain") + + // 'enum:regId:fqdnSet' + // Valid regId and FQDN set containing a single domain. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com") + test.AssertNotError(t, err, "valid regId and FQDN set containing a single domain") + + // 'enum:regId:fqdnSet' + // Valid regId and FQDN set containing multiple domains. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org") + test.AssertNotError(t, err, "valid regId and FQDN set containing multiple domains") + + // Empty string. + err = validateIdForName(NewRegistrationsPerIPAddress, "") + test.AssertError(t, err, "Id is an empty string") + + // One space. + err = validateIdForName(NewRegistrationsPerIPAddress, " ") + test.AssertError(t, err, "Id is a single space") + + // Invalid IPv4 address. + err = validateIdForName(NewRegistrationsPerIPAddress, "10.0.0.9000") + test.AssertError(t, err, "invalid IPv4 address") + + // Invalid IPv6 address. + err = validateIdForName(NewRegistrationsPerIPAddress, "2001:0db8:85a3:0000:0000:8a2e:0370:7334:9000") + test.AssertError(t, err, "invalid IPv6 address") + + // Invalid IPv6 CIDR range. + err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/128") + test.AssertError(t, err, "invalid IPv6 CIDR range") + + // Invalid IPv6 CIDR. + err = validateIdForName(NewRegistrationsPerIPv6Range, "2001:0db8:0000::/48/48") + test.AssertError(t, err, "invalid IPv6 CIDR") + + // IPv4 CIDR when we expect IPv6 CIDR range. + err = validateIdForName(NewRegistrationsPerIPv6Range, "10.0.0.0/16") + test.AssertError(t, err, "ipv4 cidr when we expect ipv6 cidr range") + + // Invalid regId. + err = validateIdForName(NewOrdersPerAccount, "lol") + test.AssertError(t, err, "invalid regId") + + // Invalid regId with good domain. + err = validateIdForName(CertificatesPerDomainPerAccount, "lol:example.com") + test.AssertError(t, err, "invalid regId with good domain") + + // Valid regId with bad domain. + err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:lol") + test.AssertError(t, err, "valid regId with bad domain") + + // Empty regId with good domain. + err = validateIdForName(CertificatesPerDomainPerAccount, ":lol") + test.AssertError(t, err, "valid regId with bad domain") + + // Valid regId with empty domain. + err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:") + test.AssertError(t, err, "valid regId with empty domain") + + // Empty regId with empty domain, no separator. + err = validateIdForName(CertificatesPerDomainPerAccount, "") + test.AssertError(t, err, "empty regId with empty domain, no separator") + + // Instead of anything we would expect, we get lol. + err = validateIdForName(CertificatesPerDomainPerAccount, "lol") + test.AssertError(t, err, "instead of anything we would expect, just lol") + + // Valid regId with good domain and a secret third separator. + err = validateIdForName(CertificatesPerDomainPerAccount, "1234567890:example.com:lol") + test.AssertError(t, err, "valid regId with good domain and a secret third separator") + + // Valid regId with bad FQDN set. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:lol..99") + test.AssertError(t, err, "valid regId with bad FQDN set") + + // Bad regId with good FQDN set. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol:example.com,example.org") + test.AssertError(t, err, "bad regId with good FQDN set") + + // Empty regId with good FQDN set. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, ":example.com,example.org") + test.AssertError(t, err, "empty regId with good FQDN set") + + // Good regId with empty FQDN set. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:") + test.AssertError(t, err, "good regId with empty FQDN set") + + // Empty regId with empty FQDN set, no separator. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "") + test.AssertError(t, err, "empty regId with empty FQDN set, no separator") + + // Instead of anything we would expect, just lol. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "lol") + test.AssertError(t, err, "instead of anything we would expect, just lol") + + // Valid regId with good FQDN set and a secret third separator. + err = validateIdForName(CertificatesPerFQDNSetPerAccount, "1234567890:example.com,example.org:lol") + test.AssertError(t, err, "valid regId with good FQDN set and a secret third separator") +} + +func Test_loadAndParseOverrideLimits(t *testing.T) { + newRegistrationsPerIPAddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress) + newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range) + + // Load a single valid override limit with Id formatted as 'enum:RegId'. + l, err := loadAndParseOverrideLimits("testdata/working_override.yml") + test.AssertNotError(t, err, "valid single override limit") + expectKey := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2" + test.AssertEquals(t, l[expectKey].Burst, int64(40)) + test.AssertEquals(t, l[expectKey].Count, int64(40)) + test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) + + // Load single valid override limit with Id formatted as 'regId:domain'. + l, err = loadAndParseOverrideLimits("testdata/working_override_regid_domain.yml") + test.AssertNotError(t, err, "valid single override limit with Id of regId:domain") + expectKey = nameToEnumString(CertificatesPerDomainPerAccount) + ":" + "12345678:example.com" + test.AssertEquals(t, l[expectKey].Burst, int64(40)) + test.AssertEquals(t, l[expectKey].Count, int64(40)) + test.AssertEquals(t, l[expectKey].Period.Duration, time.Second) + + // Load multiple valid override limits with 'enum:RegId' Ids. + l, err = loadAndParseOverrideLimits("testdata/working_overrides.yml") + expectKey1 := newRegistrationsPerIPAddressEnumStr + ":" + "10.0.0.2" + test.AssertNotError(t, err, "multiple valid override limits") + test.AssertEquals(t, l[expectKey1].Burst, int64(40)) + test.AssertEquals(t, l[expectKey1].Count, int64(40)) + test.AssertEquals(t, l[expectKey1].Period.Duration, time.Second) + expectKey2 := newRegistrationsPerIPv6RangeEnumStr + ":" + "2001:0db8:0000::/48" + test.AssertEquals(t, l[expectKey2].Burst, int64(50)) + test.AssertEquals(t, l[expectKey2].Count, int64(50)) + test.AssertEquals(t, l[expectKey2].Period.Duration, time.Second*2) + + // Load multiple valid override limits with 'regID:fqdnSet' Ids as follows: + // - CertificatesPerFQDNSetPerAccount:12345678:example.com + // - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net + // - CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org + firstEntryFQDNSetHash := string(core.HashNames([]string{"example.com"})) + secondEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net"})) + thirdEntryFQDNSetHash := string(core.HashNames([]string{"example.com", "example.net", "example.org"})) + firstEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + firstEntryFQDNSetHash + secondEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + secondEntryFQDNSetHash + thirdEntryKey := nameToEnumString(CertificatesPerFQDNSetPerAccount) + ":" + "12345678:" + thirdEntryFQDNSetHash + l, err = loadAndParseOverrideLimits("testdata/working_overrides_regid_fqdnset.yml") + test.AssertNotError(t, err, "multiple valid override limits with Id of regId:fqdnSets") + test.AssertEquals(t, l[firstEntryKey].Burst, int64(40)) + test.AssertEquals(t, l[firstEntryKey].Count, int64(40)) + test.AssertEquals(t, l[firstEntryKey].Period.Duration, time.Second) + test.AssertEquals(t, l[secondEntryKey].Burst, int64(50)) + test.AssertEquals(t, l[secondEntryKey].Count, int64(50)) + test.AssertEquals(t, l[secondEntryKey].Period.Duration, time.Second*2) + test.AssertEquals(t, l[thirdEntryKey].Burst, int64(60)) + test.AssertEquals(t, l[thirdEntryKey].Count, int64(60)) + test.AssertEquals(t, l[thirdEntryKey].Period.Duration, time.Second*3) + + // Path is empty string. + _, err = loadAndParseOverrideLimits("") + test.AssertError(t, err, "path is empty string") + test.Assert(t, os.IsNotExist(err), "path is empty string") + + // Path to file which does not exist. + _, err = loadAndParseOverrideLimits("testdata/file_does_not_exist.yml") + test.AssertError(t, err, "a file that does not exist ") + test.Assert(t, os.IsNotExist(err), "test file should not exist") + + // Burst cannot be 0. + _, err = loadAndParseOverrideLimits("testdata/busted_override_burst_0.yml") + test.AssertError(t, err, "single override limit with burst=0") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Id cannot be empty. + _, err = loadAndParseOverrideLimits("testdata/busted_override_empty_id.yml") + test.AssertError(t, err, "single override limit with empty id") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Name cannot be empty. + _, err = loadAndParseOverrideLimits("testdata/busted_override_empty_name.yml") + test.AssertError(t, err, "single override limit with empty name") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Name must be a string representation of a valid Name enumeration. + _, err = loadAndParseOverrideLimits("testdata/busted_override_invalid_name.yml") + test.AssertError(t, err, "single override limit with invalid name") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Multiple entries, second entry has a bad name. + _, err = loadAndParseOverrideLimits("testdata/busted_overrides_second_entry_bad_name.yml") + test.AssertError(t, err, "multiple override limits, second entry is bad") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Multiple entries, third entry has id of "lol", instead of an IPv4 address. + _, err = loadAndParseOverrideLimits("testdata/busted_overrides_third_entry_bad_id.yml") + test.AssertError(t, err, "multiple override limits, third entry has bad Id value") + test.Assert(t, !os.IsNotExist(err), "test file should exist") +} + +func Test_loadAndParseDefaultLimits(t *testing.T) { + newRestistrationsPerIPv4AddressEnumStr := nameToEnumString(NewRegistrationsPerIPAddress) + newRegistrationsPerIPv6RangeEnumStr := nameToEnumString(NewRegistrationsPerIPv6Range) + + // Load a single valid default limit. + l, err := loadAndParseDefaultLimits("testdata/working_default.yml") + test.AssertNotError(t, err, "valid single default limit") + test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20)) + test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20)) + test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second) + + // Load multiple valid default limits. + l, err = loadAndParseDefaultLimits("testdata/working_defaults.yml") + test.AssertNotError(t, err, "multiple valid default limits") + test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Burst, int64(20)) + test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Count, int64(20)) + test.AssertEquals(t, l[newRestistrationsPerIPv4AddressEnumStr].Period.Duration, time.Second) + test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Burst, int64(30)) + test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Count, int64(30)) + test.AssertEquals(t, l[newRegistrationsPerIPv6RangeEnumStr].Period.Duration, time.Second*2) + + // Path is empty string. + _, err = loadAndParseDefaultLimits("") + test.AssertError(t, err, "path is empty string") + test.Assert(t, os.IsNotExist(err), "path is empty string") + + // Path to file which does not exist. + _, err = loadAndParseDefaultLimits("testdata/file_does_not_exist.yml") + test.AssertError(t, err, "a file that does not exist") + test.Assert(t, os.IsNotExist(err), "test file should not exist") + + // Burst cannot be 0. + _, err = loadAndParseDefaultLimits("testdata/busted_default_burst_0.yml") + test.AssertError(t, err, "single default limit with burst=0") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Name cannot be empty. + _, err = loadAndParseDefaultLimits("testdata/busted_default_empty_name.yml") + test.AssertError(t, err, "single default limit with empty name") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Name must be a string representation of a valid Name enumeration. + _, err = loadAndParseDefaultLimits("testdata/busted_default_invalid_name.yml") + test.AssertError(t, err, "single default limit with invalid name") + test.Assert(t, !os.IsNotExist(err), "test file should exist") + + // Multiple entries, second entry has a bad name. + _, err = loadAndParseDefaultLimits("testdata/busted_defaults_second_entry_bad_name.yml") + test.AssertError(t, err, "multiple default limits, one is bad") + test.Assert(t, !os.IsNotExist(err), "test file should exist") +} diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go new file mode 100644 index 00000000000..c255df13a78 --- /dev/null +++ b/ratelimits/limiter.go @@ -0,0 +1,234 @@ +package ratelimits + +import ( + "errors" + "fmt" + "time" + + "github.com/jmhodges/clock" +) + +// ErrInvalidCost indicates that the cost specified was <= 0. +var ErrInvalidCost = fmt.Errorf("invalid cost, must be > 0") + +// ErrInvalidCostForCheck indicates that the check cost specified was < 0. +var ErrInvalidCostForCheck = fmt.Errorf("invalid check cost, must be >= 0") + +// ErrInvalidCostOverLimit indicates that the cost specified was > limit.Burst. +var ErrInvalidCostOverLimit = fmt.Errorf("invalid cost, must be <= limit.Burst") + +// ErrBucketAlreadyFull indicates that the bucket already has reached its +// maximum capacity. +var ErrBucketAlreadyFull = fmt.Errorf("bucket already full") + +// Limiter provides a high-level interface for rate limiting requests by +// utilizing a leaky bucket-style approach. +type Limiter struct { + // defaults stores default limits by 'name'. + defaults limits + + // overrides stores override limits by 'name:id'. + overrides limits + + // source is used to store buckets. It must be safe for concurrent use. + source source + clk clock.Clock +} + +// NewLimiter returns a new *Limiter. The provided source must be safe for +// concurrent use. The defaults and overrides paths are expected to be paths to +// YAML files that contain the default and override limits, respectively. The +// overrides file is optional, all other arguments are required. +func NewLimiter(clk clock.Clock, source source, defaults, overrides string) (*Limiter, error) { + limiter := &Limiter{source: source, clk: clk} + + var err error + limiter.defaults, err = loadAndParseDefaultLimits(defaults) + if err != nil { + return nil, err + } + + if overrides == "" { + // No overrides specified, initialize an empty map. + limiter.overrides = make(limits) + return limiter, nil + } + + limiter.overrides, err = loadAndParseOverrideLimits(overrides) + if err != nil { + return nil, err + } + + return limiter, nil +} + +type Decision struct { + // Allowed is true if the bucket possessed enough capacity to allow the + // request given the cost. + Allowed bool + + // Remaining is the number of requests the client is allowed to make before + // they're rate limited. + Remaining int64 + + // RetryIn is the duration the client MUST wait before they're allowed to + // make a request. + RetryIn time.Duration + + // ResetIn is the duration the bucket will take to refill to its maximum + // capacity, assuming no further requests are made. + ResetIn time.Duration + + // newTAT indicates the time at which the bucket will be full. It is the + // theoretical arrival time (TAT) of next request. It must be no more than + // (burst * (period / count)) in the future at any single point in time. + newTAT time.Time +} + +// Check returns a *Decision that indicates whether there's enough capacity to +// allow the request, given the cost, for the specified limit Name and client +// id. However, it DOES NOT deduct the cost of the request from the bucket's +// capacity. Hence, the returned *Decision represents the hypothetical state of +// the bucket if the cost WERE to be deducted. The returned *Decision will +// always include the number of remaining requests in the bucket, the required +// wait time before the client can make another request, and the time until the +// bucket refills to its maximum capacity (resets). If no bucket exists for the +// given limit Name and client id, a new one will be created WITHOUT the +// request's cost deducted from its initial capacity. +func (l *Limiter) Check(name Name, id string, cost int64) (*Decision, error) { + if cost < 0 { + return nil, ErrInvalidCostForCheck + } + + limit, err := l.getLimit(name, id) + if err != nil { + return nil, err + } + + if cost > limit.Burst { + return nil, ErrInvalidCostOverLimit + } + + tat, err := l.source.Get(bucketKey(name, id)) + if err != nil { + if !errors.Is(err, ErrBucketNotFound) { + return nil, err + } + // First request from this client. The cost is not deducted from the + // initial capacity because this is only a check. + d, err := l.initialize(limit, name, id, 0) + if err != nil { + return nil, err + } + return maybeSpend(l.clk, limit, d.newTAT, cost), nil + } + return maybeSpend(l.clk, limit, tat, cost), nil +} + +// Spend returns a *Decision that indicates if enough capacity was available to +// process the request, given the cost, for the specified limit Name and client +// id. If capacity existed, the cost of the request HAS been deducted from the +// bucket's capacity, otherwise no cost was deducted. The returned *Decision +// will always include the number of remaining requests in the bucket, the +// required wait time before the client can make another request, and the time +// until the bucket refills to its maximum capacity (resets). If no bucket +// exists for the given limit Name and client id, a new one will be created WITH +// the request's cost deducted from its initial capacity. +func (l *Limiter) Spend(name Name, id string, cost int64) (*Decision, error) { + if cost <= 0 { + return nil, ErrInvalidCost + } + + limit, err := l.getLimit(name, id) + if err != nil { + return nil, err + } + + if cost > limit.Burst { + return nil, ErrInvalidCostOverLimit + } + + tat, err := l.source.Get(bucketKey(name, id)) + if err != nil { + if errors.Is(err, ErrBucketNotFound) { + // First request from this client. + return l.initialize(limit, name, id, cost) + } + return nil, err + } + + d := maybeSpend(l.clk, limit, tat, cost) + + if !d.Allowed { + return d, nil + } + return d, l.source.Set(bucketKey(name, id), d.newTAT) +} + +// Refund attempts to refund the cost to the bucket identified by limit name and +// client id. The returned *Decision indicates whether the refund was successful +// or not. If the refund was successful, the cost of the request was added back +// to the bucket's capacity. If the refund is not possible (i.e., the bucket is +// already full or the refund amount is invalid), no cost is refunded. +// +// Note: The amount refunded cannot cause the bucket to exceed its maximum +// capacity. However, partial refunds are allowed and are considered successful. +// For instance, if a bucket has a maximum capacity of 10 and currently has 5 +// requests remaining, a refund request of 7 will result in the bucket reaching +// its maximum capacity of 10, not 12. +func (l *Limiter) Refund(name Name, id string, cost int64) (*Decision, error) { + if cost <= 0 { + return nil, ErrInvalidCost + } + + limit, err := l.getLimit(name, id) + if err != nil { + return nil, err + } + + tat, err := l.source.Get(bucketKey(name, id)) + if err != nil { + return nil, err + } + d := maybeRefund(l.clk, limit, tat, cost) + if !d.Allowed { + return d, ErrBucketAlreadyFull + } + return d, l.source.Set(bucketKey(name, id), d.newTAT) + +} + +// Reset resets the specified bucket. +func (l *Limiter) Reset(name Name, id string) error { + return l.source.Delete(bucketKey(name, id)) +} + +// initialize creates a new bucket, specified by limit name and id, with the +// cost of the request factored into the initial state. +func (l *Limiter) initialize(rl limit, name Name, id string, cost int64) (*Decision, error) { + d := maybeSpend(l.clk, rl, l.clk.Now(), cost) + err := l.source.Set(bucketKey(name, id), d.newTAT) + if err != nil { + return nil, err + } + return d, nil + +} + +// GetLimit returns the limit for the specified by name and id, name is +// required, id is optional. If id is left unspecified, the default limit for +// the limit specified by name is returned. +func (l *Limiter) getLimit(name Name, id string) (limit, error) { + if id != "" { + // Check for override. + ol, ok := l.overrides[bucketKey(name, id)] + if ok { + return ol, nil + } + } + dl, ok := l.defaults[nameToEnumString(name)] + if ok { + return dl, nil + } + return limit{}, fmt.Errorf("limit %q does not exist", name) +} diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go new file mode 100644 index 00000000000..2b791f5938d --- /dev/null +++ b/ratelimits/limiter_test.go @@ -0,0 +1,315 @@ +package ratelimits + +import ( + "testing" + "time" + + "github.com/jmhodges/clock" + "github.com/letsencrypt/boulder/test" +) + +const ( + tenZeroZeroOne = "10.0.0.1" + tenZeroZeroTwo = "10.0.0.2" +) + +// newTestLimiter makes a new limiter with the following configuration: +// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s +func newTestLimiter(t *testing.T) (*Limiter, clock.FakeClock) { + clk := clock.NewFake() + l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "") + test.AssertNotError(t, err, "should not error") + return l, clk +} + +// newTestLimiterWithOverrides makes a new limiter with the following +// configuration: +// - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s +// - 'NewRegistrationsPerIPAddress:10.0.0.2' burst: 40 count: 40 period: 1s +func newTestLimiterWithOverrides(t *testing.T) (*Limiter, clock.FakeClock) { + clk := clock.NewFake() + l, err := NewLimiter(clk, newInmem(), "testdata/working_default.yml", "testdata/working_override.yml") + test.AssertNotError(t, err, "should not error") + return l, clk +} + +func Test_Limiter_initialization_via_Check_and_Spend(t *testing.T) { + l, _ := newTestLimiter(t) + + // Check on an empty bucket should initialize it and return the theoretical + // next state of that bucket if the cost were spent. + d, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + // Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and + // 1000/20 = 50 milliseconds per request. + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + + // However, that cost should not be spent yet, a 0 cost check should tell us + // that we actually have 20 remaining. + d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(20)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + + // Reset our bucket. + err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne) + test.AssertNotError(t, err, "should not error") + + // Similar to above, but we'll use Spend() instead of Check() to initialize + // the bucket. Spend should return the same result as Check. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + // Verify our ResetIn timing is correct. 1 second == 1000 milliseconds and + // 1000/20 = 50 milliseconds per request. + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + + // However, that cost should not be spent yet, a 0 cost check should tell us + // that we actually have 19 remaining. + d, err = l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + // Verify our ResetIn is correct. 1 second == 1000 milliseconds and + // 1000/20 = 50 milliseconds per request. + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) +} + +func Test_Limiter_Refund_and_Spend_cost_err(t *testing.T) { + l, _ := newTestLimiter(t) + + // Spend a cost of 0, which should fail. + _, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0) + test.AssertErrorIs(t, err, ErrInvalidCost) + + // Spend a negative cost, which should fail. + _, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1) + test.AssertErrorIs(t, err, ErrInvalidCost) + + // Refund a cost of 0, which should fail. + _, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 0) + test.AssertErrorIs(t, err, ErrInvalidCost) + + // Refund a negative cost, which should fail. + _, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1) + test.AssertErrorIs(t, err, ErrInvalidCost) +} + +func Test_Limiter_with_bad_limits_path(t *testing.T) { + _, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/does-not-exist.yml", "") + test.AssertError(t, err, "should error") + + _, err = NewLimiter(clock.NewFake(), newInmem(), "testdata/defaults.yml", "testdata/does-not-exist.yml") + test.AssertError(t, err, "should error") +} + +func Test_Limiter_Check_bad_cost(t *testing.T) { + l, _ := newTestLimiter(t) + _, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, -1) + test.AssertErrorIs(t, err, ErrInvalidCostForCheck) +} + +func Test_Limiter_Check_limit_no_exist(t *testing.T) { + l, _ := newTestLimiter(t) + _, err := l.Check(Name(9999), tenZeroZeroOne, 1) + test.AssertError(t, err, "should error") +} + +func Test_Limiter_getLimit_no_exist(t *testing.T) { + l, _ := newTestLimiter(t) + _, err := l.getLimit(Name(9999), "") + test.AssertError(t, err, "should error") +} + +func Test_Limiter_with_defaults(t *testing.T) { + l, clk := newTestLimiter(t) + + // Attempt to spend 21 requests (a cost > the limit burst capacity), this + // should fail with a specific error. + _, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 21) + test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) + + // Attempt to spend all 20 requests, this should succeed. + d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Attempting to spend 1 more, this should fail. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Verify our ResetIn is correct. 1 second == 1000 milliseconds and + // 1000/20 = 50 milliseconds per request. + test.AssertEquals(t, d.RetryIn, time.Millisecond*50) + + // Wait 50 milliseconds and try again. + clk.Add(d.RetryIn) + + // We should be allowed to spend 1 more request. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Wait 1 second for a full bucket reset. + clk.Add(d.ResetIn) + + // Quickly spend 20 requests in a row. + for i := 0; i < 20; i++ { + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19-i)) + } + + // Attempting to spend 1 more, this should fail. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) +} + +func Test_Limiter_with_limit_overrides(t *testing.T) { + l, clk := newTestLimiterWithOverrides(t) + + // Attempt to check a spend of 41 requests (a cost > the limit burst + // capacity), this should fail with a specific error. + _, err := l.Check(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41) + test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) + + // Attempt to spend 41 requests (a cost > the limit burst capacity), this + // should fail with a specific error. + _, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41) + test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) + + // Attempt to spend all 40 requests, this should succeed. + d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + + // Attempting to spend 1 more, this should fail. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Verify our ResetIn is correct. 1 second == 1000 milliseconds and + // 1000/40 = 25 milliseconds per request. + test.AssertEquals(t, d.RetryIn, time.Millisecond*25) + + // Wait 50 milliseconds and try again. + clk.Add(d.RetryIn) + + // We should be allowed to spend 1 more request. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Wait 1 second for a full bucket reset. + clk.Add(d.ResetIn) + + // Quickly spend 40 requests in a row. + for i := 0; i < 40; i++ { + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(39-i)) + } + + // Attempting to spend 1 more, this should fail. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, !d.Allowed, "should not be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) +} + +func Test_Limiter_with_new_clients(t *testing.T) { + l, _ := newTestLimiter(t) + + // Attempt to spend all 20 requests, this should succeed. + d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Another new client, spend 1 and check our remaining. + d, err = l.Spend(NewRegistrationsPerIPAddress, "10.0.0.100", 1) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + + // 1 second == 1000 milliseconds and 1000/20 = 50 milliseconds per request. + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) +} + +func Test_Limiter_Refund_and_Reset(t *testing.T) { + l, clk := newTestLimiter(t) + + // Attempt to spend all 20 requests, this should succeed. + d, err := l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Refund 10 requests. + d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10) + test.AssertNotError(t, err, "should not error") + test.AssertEquals(t, d.Remaining, int64(10)) + + // Spend 10 requests, this should succeed. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 10) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + err = l.Reset(NewRegistrationsPerIPAddress, tenZeroZeroOne) + test.AssertNotError(t, err, "should not error") + + // Attempt to spend 20 more requests, this should succeed. + d, err = l.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 20) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(0)) + test.AssertEquals(t, d.ResetIn, time.Second) + + // Reset to full. + clk.Add(d.ResetIn) + + // Refund 1 requests above our limit, this should fail. + d, err = l.Refund(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertErrorIs(t, err, ErrBucketAlreadyFull) + test.AssertEquals(t, d.Remaining, int64(20)) +} + +func Test_Limiter_Check_Spend_parity(t *testing.T) { + il, _ := newTestLimiter(t) + jl, _ := newTestLimiter(t) + i, err := il.Check(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + j, err := jl.Spend(NewRegistrationsPerIPAddress, tenZeroZeroOne, 1) + test.AssertNotError(t, err, "should not error") + test.AssertDeepEquals(t, i.Remaining, j.Remaining) +} diff --git a/ratelimits/names.go b/ratelimits/names.go new file mode 100644 index 00000000000..bb827e5e8e6 --- /dev/null +++ b/ratelimits/names.go @@ -0,0 +1,202 @@ +package ratelimits + +import ( + "fmt" + "net" + "strconv" + "strings" + + "github.com/letsencrypt/boulder/policy" +) + +// Name is an enumeration of all rate limit names. It is used to intern rate +// limit names as strings and to provide a type-safe way to refer to rate +// limits. +// +// IMPORTANT: If you add a new limit Name, you MUST add it to the 'nameToString' +// mapping and idValidForName function below. +type Name int + +const ( + // Unknown is the zero value of Name and is used to indicate an unknown + // limit name. + Unknown Name = iota + + // NewRegistrationsPerIPAddress uses bucket key 'enum:ipAddress'. + NewRegistrationsPerIPAddress + + // NewRegistrationsPerIPv6Range uses bucket key 'enum:ipv6rangeCIDR'. The + // address range must be a /48. + NewRegistrationsPerIPv6Range + + // NewOrdersPerAccount uses bucket key 'enum:regId'. + NewOrdersPerAccount + + // FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId + // is the registration id of the account. + FailedAuthorizationsPerAccount + + // CertificatesPerDomainPerAccount uses bucket key 'enum:regId:domain', + // where name is the a name in a certificate issued to the account matching + // regId. + CertificatesPerDomainPerAccount + + // CertificatesPerFQDNSetPerAccount uses bucket key 'enum:regId:fqdnSet', + // where nameSet is a set of names in a certificate issued to the account + // matching regId. + CertificatesPerFQDNSetPerAccount +) + +// nameToString is a map of Name values to string names. +var nameToString = map[Name]string{ + Unknown: "Unknown", + NewRegistrationsPerIPAddress: "NewRegistrationsPerIPAddress", + NewRegistrationsPerIPv6Range: "NewRegistrationsPerIPv6Range", + NewOrdersPerAccount: "NewOrdersPerAccount", + FailedAuthorizationsPerAccount: "FailedAuthorizationsPerAccount", + CertificatesPerDomainPerAccount: "CertificatesPerDomainPerAccount", + CertificatesPerFQDNSetPerAccount: "CertificatesPerFQDNSetPerAccount", +} + +// validIPAddress validates that the provided string is a valid IP address. +func validIPAddress(id string) error { + ip := net.ParseIP(id) + if ip == nil { + return fmt.Errorf("invalid IP address, %q must be an IP address", id) + } + return nil +} + +// validIPv6RangeCIDR validates that the provided string is formatted is an IPv6 +// CIDR range with a /48 mask. +func validIPv6RangeCIDR(id string) error { + _, ipNet, err := net.ParseCIDR(id) + if err != nil { + return fmt.Errorf( + "invalid CIDR, %q must be an IPv6 CIDR range", id) + } + ones, _ := ipNet.Mask.Size() + if ones != 48 { + // This also catches the case where the range is an IPv4 CIDR, since an + // IPv4 CIDR can't have a /48 subnet mask - the maximum is /32. + return fmt.Errorf( + "invalid CIDR, %q must be /48", id) + } + return nil +} + +// validateRegId validates that the provided string is a valid ACME regId. +func validateRegId(id string) error { + _, err := strconv.ParseUint(id, 10, 64) + if err != nil { + return fmt.Errorf("invalid regId, %q must be an ACME registration Id", id) + } + return nil +} + +// validateRegIdDomain validates that the provided string is formatted +// 'regId:domain', where regId is an ACME registration Id and domain is a single +// domain name. +func validateRegIdDomain(id string) error { + parts := strings.SplitN(id, ":", 2) + if len(parts) != 2 { + return fmt.Errorf( + "invalid regId:domain, %q must be formatted 'regId:domain'", id) + } + if validateRegId(parts[0]) != nil { + return fmt.Errorf( + "invalid regId, %q must be formatted 'regId:domain'", id) + } + if policy.ValidDomain(parts[1]) != nil { + return fmt.Errorf( + "invalid domain, %q must be formatted 'regId:domain'", id) + } + return nil +} + +// validateRegIdFQDNSet validates that the provided string is formatted +// 'regId:fqdnSet', where regId is an ACME registration Id and fqdnSet is a +// comma-separated list of domain names. +func validateRegIdFQDNSet(id string) error { + parts := strings.SplitN(id, ":", 2) + if len(parts) != 2 { + return fmt.Errorf( + "invalid regId:fqdnSet, %q must be formatted 'regId:fqdnSet'", id) + } + if validateRegId(parts[0]) != nil { + return fmt.Errorf( + "invalid regId, %q must be formatted 'regId:fqdnSet'", id) + } + domains := strings.Split(parts[1], ",") + if len(domains) == 0 { + return fmt.Errorf( + "invalid fqdnSet, %q must be formatted 'regId:fqdnSet'", id) + } + for _, domain := range domains { + if policy.ValidDomain(domain) != nil { + return fmt.Errorf( + "invalid domain, %q must be formatted 'regId:fqdnSet'", id) + } + } + return nil +} + +func validateIdForName(name Name, id string) error { + switch name { + case NewRegistrationsPerIPAddress: + // 'enum:ipaddress' + return validIPAddress(id) + + case NewRegistrationsPerIPv6Range: + // 'enum:ipv6rangeCIDR' + return validIPv6RangeCIDR(id) + + case NewOrdersPerAccount, FailedAuthorizationsPerAccount: + // 'enum:regId' + return validateRegId(id) + + case CertificatesPerDomainPerAccount: + // 'enum:regId:domain' + return validateRegIdDomain(id) + + case CertificatesPerFQDNSetPerAccount: + // 'enum:regId:fqdnSet' + return validateRegIdFQDNSet(id) + + case Unknown: + fallthrough + + default: + // This should never happen. + return fmt.Errorf("unknown limit enum %q", name) + } +} + +// stringToName is a map of string names to Name values. +var stringToName = func() map[string]Name { + m := make(map[string]Name, len(nameToString)) + for k, v := range nameToString { + m[v] = k + } + return m +}() + +// limitNames is a slice of all rate limit names. +var limitNames = func() []string { + names := make([]string, len(nameToString)) + for _, v := range nameToString { + names = append(names, v) + } + return names +}() + +// nameToEnumString converts the integer value of the Name enumeration to its +// string representation. +func nameToEnumString(s Name) string { + return strconv.Itoa(int(s)) +} + +// bucketKey returns the key used to store a rate limit bucket. +func bucketKey(name Name, id string) string { + return nameToEnumString(name) + ":" + id +} diff --git a/ratelimits/source.go b/ratelimits/source.go new file mode 100644 index 00000000000..0ad6e0f0f2a --- /dev/null +++ b/ratelimits/source.go @@ -0,0 +1,57 @@ +package ratelimits + +import ( + "fmt" + "sync" + "time" +) + +// ErrBucketNotFound indicates that the bucket was not found. +var ErrBucketNotFound = fmt.Errorf("bucket not found") + +// source is an interface for creating and modifying TATs. +type source interface { + // Set stores the TAT at the specified bucketKey ('name:id'). + Set(bucketKey string, tat time.Time) error + + // Get retrieves the TAT at the specified bucketKey ('name:id'). + Get(bucketKey string) (time.Time, error) + + // Delete deletes the TAT at the specified bucketKey ('name:id'). + Delete(bucketKey string) error +} + +// inmem is an in-memory implementation of the source interface used for +// testing. +type inmem struct { + sync.RWMutex + m map[string]time.Time +} + +func newInmem() *inmem { + return &inmem{m: make(map[string]time.Time)} +} + +func (in *inmem) Set(bucketKey string, tat time.Time) error { + in.Lock() + defer in.Unlock() + in.m[bucketKey] = tat + return nil +} + +func (in *inmem) Get(bucketKey string) (time.Time, error) { + in.RLock() + defer in.RUnlock() + tat, ok := in.m[bucketKey] + if !ok { + return time.Time{}, ErrBucketNotFound + } + return tat, nil +} + +func (in *inmem) Delete(bucketKey string) error { + in.Lock() + defer in.Unlock() + delete(in.m, bucketKey) + return nil +} diff --git a/ratelimits/testdata/busted_default_burst_0.yml b/ratelimits/testdata/busted_default_burst_0.yml new file mode 100644 index 00000000000..26a2466ad02 --- /dev/null +++ b/ratelimits/testdata/busted_default_burst_0.yml @@ -0,0 +1,4 @@ +NewRegistrationsPerIPAddress: + burst: 0 + count: 20 + period: 1s diff --git a/ratelimits/testdata/busted_default_empty_name.yml b/ratelimits/testdata/busted_default_empty_name.yml new file mode 100644 index 00000000000..981c58536f0 --- /dev/null +++ b/ratelimits/testdata/busted_default_empty_name.yml @@ -0,0 +1,4 @@ +"": + burst: 20 + count: 20 + period: 1s diff --git a/ratelimits/testdata/busted_default_invalid_name.yml b/ratelimits/testdata/busted_default_invalid_name.yml new file mode 100644 index 00000000000..bf41b326d7e --- /dev/null +++ b/ratelimits/testdata/busted_default_invalid_name.yml @@ -0,0 +1,4 @@ +UsageRequestsPerIPv10Address: + burst: 20 + count: 20 + period: 1s diff --git a/ratelimits/testdata/busted_defaults_second_entry_bad_name.yml b/ratelimits/testdata/busted_defaults_second_entry_bad_name.yml new file mode 100644 index 00000000000..cc276a869b9 --- /dev/null +++ b/ratelimits/testdata/busted_defaults_second_entry_bad_name.yml @@ -0,0 +1,8 @@ +NewRegistrationsPerIPAddress: + burst: 20 + count: 20 + period: 1s +UsageRequestsPerIPv10Address: + burst: 20 + count: 20 + period: 1s diff --git a/ratelimits/testdata/busted_override_burst_0.yml b/ratelimits/testdata/busted_override_burst_0.yml new file mode 100644 index 00000000000..66ec594c337 --- /dev/null +++ b/ratelimits/testdata/busted_override_burst_0.yml @@ -0,0 +1,4 @@ +NewRegistrationsPerIPAddress:10.0.0.2: + burst: 0 + count: 40 + period: 1s diff --git a/ratelimits/testdata/busted_override_empty_id.yml b/ratelimits/testdata/busted_override_empty_id.yml new file mode 100644 index 00000000000..bbaa5b8cd3e --- /dev/null +++ b/ratelimits/testdata/busted_override_empty_id.yml @@ -0,0 +1,4 @@ +"UsageRequestsPerIPv10Address:": + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/busted_override_empty_name.yml b/ratelimits/testdata/busted_override_empty_name.yml new file mode 100644 index 00000000000..fe31bf1a912 --- /dev/null +++ b/ratelimits/testdata/busted_override_empty_name.yml @@ -0,0 +1,4 @@ +":10.0.0.2": + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/busted_override_invalid_name.yml b/ratelimits/testdata/busted_override_invalid_name.yml new file mode 100644 index 00000000000..473ddeb8fae --- /dev/null +++ b/ratelimits/testdata/busted_override_invalid_name.yml @@ -0,0 +1,4 @@ +UsageRequestsPerIPv10Address:10.0.0.2: + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/busted_overrides_second_entry_bad_name.yml b/ratelimits/testdata/busted_overrides_second_entry_bad_name.yml new file mode 100644 index 00000000000..8c96d8acc87 --- /dev/null +++ b/ratelimits/testdata/busted_overrides_second_entry_bad_name.yml @@ -0,0 +1,8 @@ +NewRegistrationsPerIPAddress:10.0.0.2: + burst: 40 + count: 40 + period: 1s +UsageRequestsPerIPv10Address:10.0.0.5: + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/busted_overrides_third_entry_bad_id.yml b/ratelimits/testdata/busted_overrides_third_entry_bad_id.yml new file mode 100644 index 00000000000..645dc397bf9 --- /dev/null +++ b/ratelimits/testdata/busted_overrides_third_entry_bad_id.yml @@ -0,0 +1,12 @@ +NewRegistrationsPerIPAddress:10.0.0.2: + burst: 40 + count: 40 + period: 1s +NewRegistrationsPerIPAddress:10.0.0.5: + burst: 40 + count: 40 + period: 1s +NewRegistrationsPerIPAddress:lol: + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/working_default.yml b/ratelimits/testdata/working_default.yml new file mode 100644 index 00000000000..1c0c63bce5e --- /dev/null +++ b/ratelimits/testdata/working_default.yml @@ -0,0 +1,4 @@ +NewRegistrationsPerIPAddress: + burst: 20 + count: 20 + period: 1s diff --git a/ratelimits/testdata/working_defaults.yml b/ratelimits/testdata/working_defaults.yml new file mode 100644 index 00000000000..be5988b7a2c --- /dev/null +++ b/ratelimits/testdata/working_defaults.yml @@ -0,0 +1,8 @@ +NewRegistrationsPerIPAddress: + burst: 20 + count: 20 + period: 1s +NewRegistrationsPerIPv6Range: + burst: 30 + count: 30 + period: 2s diff --git a/ratelimits/testdata/working_override.yml b/ratelimits/testdata/working_override.yml new file mode 100644 index 00000000000..ea13553a87d --- /dev/null +++ b/ratelimits/testdata/working_override.yml @@ -0,0 +1,4 @@ +NewRegistrationsPerIPAddress:10.0.0.2: + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/working_override_regid_domain.yml b/ratelimits/testdata/working_override_regid_domain.yml new file mode 100644 index 00000000000..bd4d3eb67a0 --- /dev/null +++ b/ratelimits/testdata/working_override_regid_domain.yml @@ -0,0 +1,4 @@ +CertificatesPerDomainPerAccount:12345678:example.com: + burst: 40 + count: 40 + period: 1s diff --git a/ratelimits/testdata/working_overrides.yml b/ratelimits/testdata/working_overrides.yml new file mode 100644 index 00000000000..d13704ce3bc --- /dev/null +++ b/ratelimits/testdata/working_overrides.yml @@ -0,0 +1,8 @@ +NewRegistrationsPerIPAddress:10.0.0.2: + burst: 40 + count: 40 + period: 1s +NewRegistrationsPerIPv6Range:2001:0db8:0000::/48: + burst: 50 + count: 50 + period: 2s diff --git a/ratelimits/testdata/working_overrides_regid_fqdnset.yml b/ratelimits/testdata/working_overrides_regid_fqdnset.yml new file mode 100644 index 00000000000..093ac976e7f --- /dev/null +++ b/ratelimits/testdata/working_overrides_regid_fqdnset.yml @@ -0,0 +1,12 @@ +CertificatesPerFQDNSetPerAccount:12345678:example.com: + burst: 40 + count: 40 + period: 1s +CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net: + burst: 50 + count: 50 + period: 2s +CertificatesPerFQDNSetPerAccount:12345678:example.com,example.net,example.org: + burst: 60 + count: 60 + period: 3s diff --git a/sa/model.go b/sa/model.go index 695539cb748..712691b7e10 100644 --- a/sa/model.go +++ b/sa/model.go @@ -12,7 +12,6 @@ import ( "net" "net/url" "strconv" - "strings" "time" "golang.org/x/exp/slices" @@ -873,14 +872,6 @@ type crlEntryModel struct { RevokedDate time.Time `db:"revokedDate"` } -// HashNames returns a hash of the names requested. This is intended for use -// when interacting with the orderFqdnSets table. -func HashNames(names []string) []byte { - names = core.UniqueLowerNames(names) - hash := sha256.Sum256([]byte(strings.Join(names, ","))) - return hash[:] -} - // orderFQDNSet contains the SHA256 hash of the lowercased, comma joined names // from a new-order request, along with the corresponding orderID, the // registration ID, and the order expiry. This is used to find @@ -895,7 +886,7 @@ type orderFQDNSet struct { func addFQDNSet(ctx context.Context, db db.Inserter, names []string, serial string, issued time.Time, expires time.Time) error { return db.Insert(ctx, &core.FQDNSet{ - SetHash: HashNames(names), + SetHash: core.HashNames(names), Serial: serial, Issued: issued, Expires: expires, @@ -914,7 +905,7 @@ func addOrderFQDNSet( regID int64, expires time.Time) error { return db.Insert(ctx, &orderFQDNSet{ - SetHash: HashNames(names), + SetHash: core.HashNames(names), OrderID: orderID, RegistrationID: regID, Expires: expires, diff --git a/sa/sa_test.go b/sa/sa_test.go index 8553c3947be..0361ce6ed6e 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -2893,33 +2893,6 @@ func TestBlockedKeyRevokedBy(t *testing.T) { test.AssertNotError(t, err, "AddBlockedKey failed") } -func TestHashNames(t *testing.T) { - // Test that it is deterministic - h1 := HashNames([]string{"a"}) - h2 := HashNames([]string{"a"}) - test.AssertByteEquals(t, h1, h2) - - // Test that it differentiates - h1 = HashNames([]string{"a"}) - h2 = HashNames([]string{"b"}) - test.Assert(t, !bytes.Equal(h1, h2), "Should have been different") - - // Test that it is not subject to ordering - h1 = HashNames([]string{"a", "b"}) - h2 = HashNames([]string{"b", "a"}) - test.AssertByteEquals(t, h1, h2) - - // Test that it is not subject to case - h1 = HashNames([]string{"a", "b"}) - h2 = HashNames([]string{"A", "B"}) - test.AssertByteEquals(t, h1, h2) - - // Test that it is not subject to duplication - h1 = HashNames([]string{"a", "a"}) - h2 = HashNames([]string{"a"}) - test.AssertByteEquals(t, h1, h2) -} - func TestIncidentsForSerial(t *testing.T) { sa, _, cleanUp := initSA(t) defer cleanUp() diff --git a/sa/saro.go b/sa/saro.go index 436c1b13a06..d3f77ae6ba1 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -516,7 +516,7 @@ func (ssa *SQLStorageAuthorityRO) CountFQDNSets(ctx context.Context, req *sapb.C `SELECT COUNT(*) FROM fqdnSets WHERE setHash = ? AND issued > ?`, - HashNames(req.Domains), + core.HashNames(req.Domains), ssa.clk.Now().Add(-time.Duration(req.Window)), ) return &sapb.Count{Count: count}, err @@ -544,7 +544,7 @@ func (ssa *SQLStorageAuthorityRO) FQDNSetTimestampsForWindow(ctx context.Context WHERE setHash = ? AND issued > ? ORDER BY issued DESC`, - HashNames(req.Domains), + core.HashNames(req.Domains), ssa.clk.Now().Add(-time.Duration(req.Window)), ) if err != nil { @@ -586,7 +586,7 @@ type oneSelectorFunc func(ctx context.Context, holder interface{}, query string, // checkFQDNSetExists uses the given oneSelectorFunc to check whether an fqdnSet // for the given names exists. func (ssa *SQLStorageAuthorityRO) checkFQDNSetExists(ctx context.Context, selector oneSelectorFunc, names []string) (bool, error) { - namehash := HashNames(names) + namehash := core.HashNames(names) var exists bool err := selector( ctx, @@ -761,7 +761,7 @@ func (ssa *SQLStorageAuthorityRO) GetOrderForNames(ctx context.Context, req *sap } // Hash the names requested for lookup in the orderFqdnSets table - fqdnHash := HashNames(req.Names) + fqdnHash := core.HashNames(req.Names) // Find a possibly-suitable order. We don't include the account ID or order // status in this query because there's no index that includes those, so