From f2e93de144422ec19bbe14113becedd258b27845 Mon Sep 17 00:00:00 2001 From: Samantha Date: Thu, 28 Sep 2023 13:18:49 -0400 Subject: [PATCH] Simplify construction and cleanup. Address context cancellation. --- cmd/boulder-wfe2/main.go | 28 +++++--------------- ratelimits/limiter.go | 24 +++++++++++++++--- ratelimits/source.go | 12 ++++++--- redis/config.go | 55 +++++++++++++++++++++++++--------------- redis/lookup.go | 31 ++++++++++++---------- redis/lookup_test.go | 6 ++--- wfe2/wfe.go | 54 ++++++++++++++++++++++++++------------- wfe2/wfe_test.go | 6 ++--- 8 files changed, 129 insertions(+), 87 deletions(-) diff --git a/cmd/boulder-wfe2/main.go b/cmd/boulder-wfe2/main.go index f3d5b27db75..13e362c8836 100644 --- a/cmd/boulder-wfe2/main.go +++ b/cmd/boulder-wfe2/main.go @@ -13,7 +13,6 @@ import ( "github.com/jmhodges/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/redis/go-redis/v9" "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/config" @@ -341,20 +340,13 @@ func main() { pendingAuthorizationLifetime := time.Duration(c.WFE.PendingAuthorizationLifetimeDays) * 24 * time.Hour var limiter *ratelimits.Limiter - var limiterLookup *bredis.Lookup + var limiterRedis *bredis.Ring if c.WFE.Limiter.Defaults != "" { // Setup rate limiting. - var ring *redis.Ring - if len(c.WFE.Limiter.Redis.Lookups) > 0 { - // Configure a Redis client with periodic SRV lookups. - ring, limiterLookup, err = c.WFE.Limiter.Redis.NewRingWithPeriodicLookups(stats, logger) - cmd.FailOnError(err, "Failed to create Redis SRV Lookup for rate limiting") - } else { - // Configure a Redis client with static shard addresses. - ring, err = c.WFE.Limiter.Redis.NewRing(stats) - cmd.FailOnError(err, "Failed to create Redis client for rate limiting") - } - source := ratelimits.NewRedisSource(ring, clk, stats) + limiterRedis, err = bredis.NewRingFromConfig(*c.WFE.Limiter.Redis, stats, logger) + cmd.FailOnError(err, "Failed to create Redis ring") + + source := ratelimits.NewRedisSource(limiterRedis.Ring, clk, stats) limiter, err = ratelimits.NewLimiter(clk, source, c.WFE.Limiter.Defaults, c.WFE.Limiter.Overrides, stats) cmd.FailOnError(err, "Failed to create rate limiter") } @@ -391,13 +383,6 @@ func main() { ) cmd.FailOnError(err, "Unable to create WFE") - var limiterCtx context.Context - var shutdownLimiterLookup context.CancelFunc = func() {} - if limiterLookup != nil { - limiterCtx, shutdownLimiterLookup = context.WithCancel(context.Background()) - limiterLookup.Start(limiterCtx) - } - wfe.SubscriberAgreementURL = c.WFE.SubscriberAgreementURL wfe.AllowOrigins = c.WFE.AllowOrigins wfe.DirectoryCAAIdentity = c.WFE.DirectoryCAAIdentity @@ -447,12 +432,11 @@ func main() { // ListenAndServe() and ListenAndServeTLS() to immediately return, then waits // for any lingering connection-handling goroutines to finish their work. defer func() { - ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration) defer cancel() _ = srv.Shutdown(ctx) _ = tlsSrv.Shutdown(ctx) - shutdownLimiterLookup() + limiterRedis.StopLookups() oTelShutdown(ctx) }() diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 5204a5d5bbf..2b24ba442ec 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -146,6 +146,9 @@ func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) ( return nil, ErrInvalidCostOverLimit } + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) tat, err := l.source.Get(ctx, bucketKey(name, id)) if err != nil { if !errors.Is(err, ErrBucketNotFound) { @@ -190,11 +193,14 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( } start := l.clk.Now() - decisionStatus := Denied + status := Denied defer func() { - l.spendLatency.WithLabelValues(nameToString[name], decisionStatus).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(name.String(), status).Observe(l.clk.Since(start).Seconds()) }() + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) tat, err := l.source.Get(ctx, bucketKey(name, id)) if err != nil { if errors.Is(err, ErrBucketNotFound) { @@ -204,7 +210,7 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( return nil, err } if d.Allowed { - decisionStatus = Allowed + status = Allowed } return d, nil } @@ -228,7 +234,7 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) ( if err != nil { return nil, err } - decisionStatus = Allowed + status = Allowed return d, nil } @@ -256,6 +262,9 @@ func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) return nil, err } + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) tat, err := l.source.Get(ctx, bucketKey(name, id)) if err != nil { return nil, err @@ -271,6 +280,9 @@ func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) // Reset resets the specified bucket. func (l *Limiter) Reset(ctx context.Context, name Name, id string) error { + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) return l.source.Delete(ctx, bucketKey(name, id)) } @@ -278,6 +290,10 @@ func (l *Limiter) Reset(ctx context.Context, name Name, id string) error { // cost of the request factored into the initial state. func (l *Limiter) initialize(ctx context.Context, rl limit, name Name, id string, cost int64) (*Decision, error) { d := maybeSpend(l.clk, rl, l.clk.Now(), cost) + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) err := l.source.Set(ctx, bucketKey(name, id), d.newTAT) if err != nil { return nil, err diff --git a/ratelimits/source.go b/ratelimits/source.go index 21516e642a3..be8cef57e1c 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -12,13 +12,19 @@ var ErrBucketNotFound = fmt.Errorf("bucket not found") // source is an interface for creating and modifying TATs. type source interface { - // Set stores the TAT at the specified bucketKey ('name:id'). + // Set stores the TAT at the specified bucketKey ('name:id'). Contexts + // passed to this method should have a timeout or deadline set to prevent + // the operation from blocking indefinitely. Set(ctx context.Context, bucketKey string, tat time.Time) error - // Get retrieves the TAT at the specified bucketKey ('name:id'). + // Get retrieves the TAT at the specified bucketKey ('name:id'). Contexts + // passed to this method should have a timeout or deadline set to prevent + // the operation from blocking indefinitely. Get(ctx context.Context, bucketKey string) (time.Time, error) - // Delete deletes the TAT at the specified bucketKey ('name:id'). + // Delete deletes the TAT at the specified bucketKey ('name:id'). Contexts + // passed to this method should have a timeout or deadline set to prevent + // the operation from blocking indefinitely. Delete(ctx context.Context, bucketKey string) error } diff --git a/redis/config.go b/redis/config.go index 82257bc0e9f..997969373cd 100644 --- a/redis/config.go +++ b/redis/config.go @@ -108,8 +108,18 @@ type Config struct { IdleCheckFrequency config.Duration `validate:"-"` } -// NewRing returns a new Redis ring client. -func (c *Config) NewRing(stats prometheus.Registerer) (*redis.Ring, error) { +// Ring is a wrapper around the go-redis/v9 Ring client that adds support for +// (optional) periodic SRV lookups. +type Ring struct { + *redis.Ring + lookup *lookup +} + +// NewRingFromConfig returns a new *redis.Ring client. If periodic SRV lookups +// are supplied, a goroutine will be started to periodically perform lookups. +// Callers should defer a call to StopLookups() to ensure that this goroutine is +// gracefully shutdown. +func NewRingFromConfig(c Config, stats prometheus.Registerer, log blog.Logger) (*Ring, error) { password, err := c.Pass() if err != nil { return nil, fmt.Errorf("loading password: %w", err) @@ -120,7 +130,7 @@ func (c *Config) NewRing(stats prometheus.Registerer) (*redis.Ring, error) { return nil, fmt.Errorf("loading TLS config: %w", err) } - client := redis.NewRing(&redis.RingOptions{ + inner := redis.NewRing(&redis.RingOptions{ Addrs: c.ShardAddrs, Username: c.Username, Password: password, @@ -141,26 +151,31 @@ func (c *Config) NewRing(stats prometheus.Registerer) (*redis.Ring, error) { }) if len(c.ShardAddrs) > 0 { // Client was statically configured with a list of shards. - MustRegisterClientMetricsCollector(client, stats, c.ShardAddrs, c.Username) + MustRegisterClientMetricsCollector(inner, stats, c.ShardAddrs, c.Username) } - return client, nil -} - -// NewRingWithPeriodicLookups returns a new Redis ring client whose shards are -// periodically updated via SRV lookups. An initial SRV lookup is performed to -// populate the Redis ring shards. If this lookup fails or otherwise results in -// an empty set of resolved shards, an error is returned. -func (c *Config) NewRingWithPeriodicLookups(stats prometheus.Registerer, logger blog.Logger) (*redis.Ring, *Lookup, error) { - ring, err := c.NewRing(stats) - if err != nil { - return nil, nil, err + var lookup *lookup + if len(c.Lookups) != 0 { + lookup, err = newLookup(c.Lookups, c.LookupDNSAuthority, c.LookupFrequency.Duration, inner, log, stats) + if err != nil { + return nil, err + } + lookup.start() } - lookup, err := newLookup(c.Lookups, c.LookupDNSAuthority, c.LookupFrequency.Duration, ring, logger, stats) - if err != nil { - return nil, nil, err - } + return &Ring{ + Ring: inner, + lookup: lookup, + }, nil +} - return ring, lookup, nil +// StopLookups stops the goroutine responsible for keeping the shards of the +// inner *redis.Ring up-to-date. It is a no-op if the Ring was not constructed +// with periodic lookups or if the lookups have already been stopped. +func (r *Ring) StopLookups() { + if r == nil || r.lookup == nil { + // No-op. + return + } + r.lookup.stop() } diff --git a/redis/lookup.go b/redis/lookup.go index f4c2e899027..f66ed7450a3 100644 --- a/redis/lookup.go +++ b/redis/lookup.go @@ -17,9 +17,9 @@ import ( var ErrNoShardsResolved = errors.New("0 shards were resolved") -// Lookup wraps a Redis ring client by reference and keeps the Redis ring shards +// lookup wraps a Redis ring client by reference and keeps the Redis ring shards // up to date via periodic SRV lookups. -type Lookup struct { +type lookup struct { // srvLookups is a list of SRV records to be looked up. srvLookups []cmd.ServiceDomain @@ -38,16 +38,20 @@ type Lookup struct { // will be used for resolution. dnsAuthority string + // stop is a context.CancelFunc that can be used to stop the goroutine + // responsible for performing periodic SRV lookups. + stop context.CancelFunc + resolver *net.Resolver ring *redis.Ring logger blog.Logger stats prometheus.Registerer } -// newLookup constructs and returns a new Lookup instance. An initial SRV lookup +// newLookup constructs and returns a new lookup instance. An initial SRV lookup // is performed to populate the Redis ring shards. If this lookup fails or // otherwise results in an empty set of resolved shards, an error is returned. -func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency time.Duration, ring *redis.Ring, logger blog.Logger, stats prometheus.Registerer) (*Lookup, error) { +func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency time.Duration, ring *redis.Ring, logger blog.Logger, stats prometheus.Registerer) (*lookup, error) { updateFrequency := frequency if updateFrequency <= 0 { // Set default frequency. @@ -56,7 +60,7 @@ func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency ti // Set default timeout to 90% of the update frequency. updateTimeout := updateFrequency - updateFrequency/10 - lookup := &Lookup{ + lookup := &lookup{ srvLookups: srvLookups, ring: ring, logger: logger, @@ -111,7 +115,7 @@ func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency ti // lookup succeeds, the Redis ring is updated, and all errors are discarded. // Non-temporary DNS errors are always logged as they occur, as they're likely // to be indicative of a misconfiguration. -func (look *Lookup) updateNow(ctx context.Context) (tempError, nonTempError error) { +func (look *lookup) updateNow(ctx context.Context) (tempError, nonTempError error) { var tempErrs []error handleDNSError := func(err error, srv cmd.ServiceDomain) { var dnsErr *net.DNSError @@ -176,20 +180,21 @@ func (look *Lookup) updateNow(ctx context.Context) (tempError, nonTempError erro return nil, nil } -// Start starts a goroutine that keeps the Redis ring shards up to date via -// periodic SRV lookups. The goroutine will exit when the provided context is -// cancelled. -func (look *Lookup) Start(ctx context.Context) { +// start starts a goroutine that keeps the Redis ring shards up-to-date by +// periodically performing SRV lookups. +func (look *lookup) start() { + var lookupCtx context.Context + lookupCtx, look.stop = context.WithCancel(context.Background()) go func() { ticker := time.NewTicker(look.updateFrequency) defer ticker.Stop() for { // Check for context cancellation before we do any work. - if ctx.Err() != nil { + if lookupCtx.Err() != nil { return } - timeoutCtx, cancel := context.WithTimeout(ctx, look.updateTimeout) + timeoutCtx, cancel := context.WithTimeout(lookupCtx, look.updateTimeout) tempErrs, nonTempErrs := look.updateNow(timeoutCtx) cancel() if tempErrs != nil { @@ -205,7 +210,7 @@ func (look *Lookup) Start(ctx context.Context) { case <-ticker.C: continue - case <-ctx.Done(): + case <-lookupCtx.Done(): return } } diff --git a/redis/lookup_test.go b/redis/lookup_test.go index afe42e5f94f..da726b514ad 100644 --- a/redis/lookup_test.go +++ b/redis/lookup_test.go @@ -76,10 +76,8 @@ func TestStart(t *testing.T) { ) test.AssertNotError(t, err, "Expected newLookup construction to succeed") - testCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - - lookup.Start(testCtx) + lookup.start() + lookup.stop() } func TestNewLookupWithOneFailingSRV(t *testing.T) { diff --git a/wfe2/wfe.go b/wfe2/wfe.go index 0d7cd15f9ef..a751a1de3b9 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -616,8 +616,9 @@ func link(url, relation string) string { } // checkNewAccountLimits checks whether sufficient limit quota exists for the -// creation of a new account from the given IP address. If an error is -// encountered during the check, it is logged but not returned. +// creation of a new account from the given IP address. If so, that quota is +// spent. If an error is encountered during the check, it is logged but not +// returned. // // TODO(#5545): For now we're simply exercising the new rate limiter codepath. // This should eventually return a berrors.RateLimit error containing the retry @@ -627,11 +628,19 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP // Limiter is disabled. return } - decision, err := wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) - if err != nil { + + warn := func(err error, limit ratelimits.Name) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } // TODO(#5545): Once key-value rate limits are authoritative this log // line should be removed in favor of returning the error. - wfe.log.Warningf("checking %s rate limit: %s", ratelimits.NewRegistrationsPerIPAddress, err) + wfe.log.Warningf("checking %s rate limit: %s", limit, err) + } + + decision, err := wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) + if err != nil { + warn(err, ratelimits.NewRegistrationsPerIPAddress) return } if !decision.Allowed || ip.To4() != nil { @@ -646,8 +655,7 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} _, err = wfe.limiter.Spend(ctx, ratelimits.NewRegistrationsPerIPv6Range, ipNet.String(), 1) if err != nil { - wfe.log.Warningf("checking %s rate limit: %s", ratelimits.NewRegistrationsPerIPv6Range, err) - return + warn(err, ratelimits.NewRegistrationsPerIPv6Range) } } @@ -660,9 +668,19 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I // Limiter is disabled. return } + + warn := func(err error, limit ratelimits.Name) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + // TODO(#5545): Once key-value rate limits are authoritative this log + // line should be removed in favor of returning the error. + wfe.log.Warningf("refunding %s rate limit: %s", limit, err) + } + _, err := wfe.limiter.Refund(ctx, ratelimits.NewRegistrationsPerIPAddress, ip.String(), 1) if err != nil { - wfe.log.Warningf("refunding new account rate limit: %s", err) + warn(err, ratelimits.NewRegistrationsPerIPAddress) return } if ip.To4() != nil { @@ -676,7 +694,7 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} _, err = wfe.limiter.Refund(ctx, ratelimits.NewRegistrationsPerIPv6Range, ipNet.String(), 1) if err != nil { - wfe.log.Warningf("refunding new account rate limit: %s", err) + warn(err, ratelimits.NewRegistrationsPerIPv6Range) } } @@ -802,10 +820,15 @@ func (wfe *WebFrontEndImpl) NewAccount( InitialIP: ipBytes, } - // TODO(#5545): This can no longer by async once we start treating the new - // limiter as the source of truth for rate limits. But for now, this saves - // us from eating extra latency for each new account creation. + // TODO(#5545): Spending and Refunding can be async until these rate limits + // are authoritative. This saves us from adding latency to each request. go wfe.checkNewAccountLimits(ctx, ip) + var newRegistrationSuccessful bool + defer func() { + if !newRegistrationSuccessful { + go wfe.refundNewAccountLimits(ctx, ip) + } + }() // Send the registration to the RA via grpc acctPB, err := wfe.ra.NewRegistration(ctx, ®) @@ -814,18 +837,15 @@ func (wfe *WebFrontEndImpl) NewAccount( existingAcct, err := wfe.sa.GetRegistrationByKey(ctx, &sapb.JSONWebKey{Jwk: keyBytes}) if err == nil { returnExistingAcct(existingAcct) - wfe.refundNewAccountLimits(ctx, ip) return } // return error even if berrors.NotFound, as the duplicate key error we got from // ra.NewRegistration indicates it _does_ already exist. wfe.sendError(response, logEvent, web.ProblemDetailsForError(err, "checking for existing account"), err) - wfe.refundNewAccountLimits(ctx, ip) return } wfe.sendError(response, logEvent, web.ProblemDetailsForError(err, "Error creating new account"), err) - wfe.refundNewAccountLimits(ctx, ip) return } @@ -836,14 +856,12 @@ func (wfe *WebFrontEndImpl) NewAccount( if acctPB == nil || !registrationValid(acctPB) { wfe.sendError(response, logEvent, web.ProblemDetailsForError(err, "Error creating new account"), err) - wfe.refundNewAccountLimits(ctx, ip) return } acct, err := bgrpc.PbToRegistration(acctPB) if err != nil { wfe.sendError(response, logEvent, web.ProblemDetailsForError(err, "Error creating new account"), err) - wfe.refundNewAccountLimits(ctx, ip) return } logEvent.Requester = acct.ID @@ -863,9 +881,9 @@ func (wfe *WebFrontEndImpl) NewAccount( // ServerInternal because we just created this account, and it // should be OK. wfe.sendError(response, logEvent, probs.ServerInternal("Error marshaling account"), err) - wfe.refundNewAccountLimits(ctx, ip) return } + newRegistrationSuccessful = true } // parseRevocation accepts the payload for a revocation request and parses it diff --git a/wfe2/wfe_test.go b/wfe2/wfe_test.go index 664e023a478..def8b5b6232 100644 --- a/wfe2/wfe_test.go +++ b/wfe2/wfe_test.go @@ -380,9 +380,9 @@ func setupWFE(t *testing.T) (WebFrontEndImpl, clock.FakeClock, requestSigner) { rc.PasswordConfig = cmd.PasswordConfig{ PasswordFile: "../test/secrets/ratelimits_redis_password", } - ring, _, err := rc.NewRingWithPeriodicLookups(stats, log) - test.AssertNotError(t, err, "making redis ring and lookup") - source := ratelimits.NewRedisSource(ring, fc, stats) + ring, err := bredis.NewRingFromConfig(rc, stats, log) + test.AssertNotError(t, err, "making redis ring client") + source := ratelimits.NewRedisSource(ring.Ring, fc, stats) test.AssertNotNil(t, source, "source should not be nil") limiter, err = ratelimits.NewLimiter(fc, source, "../test/config-next/wfe2-ratelimit-defaults.yml", "", stats) test.AssertNotError(t, err, "making limiter")