Skip to content

Commit

Permalink
Simplify construction and cleanup. Address context cancellation.
Browse files Browse the repository at this point in the history
  • Loading branch information
beautifulentropy committed Sep 28, 2023
1 parent 30de5d7 commit f2e93de
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 87 deletions.
28 changes: 6 additions & 22 deletions cmd/boulder-wfe2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/jmhodges/clock"
"github.com/prometheus/client_golang/prometheus"
"github.com/redis/go-redis/v9"

"github.com/letsencrypt/boulder/cmd"
"github.com/letsencrypt/boulder/config"
Expand Down Expand Up @@ -341,20 +340,13 @@ func main() {
pendingAuthorizationLifetime := time.Duration(c.WFE.PendingAuthorizationLifetimeDays) * 24 * time.Hour

var limiter *ratelimits.Limiter
var limiterLookup *bredis.Lookup
var limiterRedis *bredis.Ring
if c.WFE.Limiter.Defaults != "" {
// Setup rate limiting.
var ring *redis.Ring
if len(c.WFE.Limiter.Redis.Lookups) > 0 {
// Configure a Redis client with periodic SRV lookups.
ring, limiterLookup, err = c.WFE.Limiter.Redis.NewRingWithPeriodicLookups(stats, logger)
cmd.FailOnError(err, "Failed to create Redis SRV Lookup for rate limiting")
} else {
// Configure a Redis client with static shard addresses.
ring, err = c.WFE.Limiter.Redis.NewRing(stats)
cmd.FailOnError(err, "Failed to create Redis client for rate limiting")
}
source := ratelimits.NewRedisSource(ring, clk, stats)
limiterRedis, err = bredis.NewRingFromConfig(*c.WFE.Limiter.Redis, stats, logger)
cmd.FailOnError(err, "Failed to create Redis ring")

source := ratelimits.NewRedisSource(limiterRedis.Ring, clk, stats)
limiter, err = ratelimits.NewLimiter(clk, source, c.WFE.Limiter.Defaults, c.WFE.Limiter.Overrides, stats)
cmd.FailOnError(err, "Failed to create rate limiter")
}
Expand Down Expand Up @@ -391,13 +383,6 @@ func main() {
)
cmd.FailOnError(err, "Unable to create WFE")

var limiterCtx context.Context
var shutdownLimiterLookup context.CancelFunc = func() {}
if limiterLookup != nil {
limiterCtx, shutdownLimiterLookup = context.WithCancel(context.Background())
limiterLookup.Start(limiterCtx)
}

wfe.SubscriberAgreementURL = c.WFE.SubscriberAgreementURL
wfe.AllowOrigins = c.WFE.AllowOrigins
wfe.DirectoryCAAIdentity = c.WFE.DirectoryCAAIdentity
Expand Down Expand Up @@ -447,12 +432,11 @@ func main() {
// ListenAndServe() and ListenAndServeTLS() to immediately return, then waits
// for any lingering connection-handling goroutines to finish their work.
defer func() {

ctx, cancel := context.WithTimeout(context.Background(), c.WFE.ShutdownStopTimeout.Duration)
defer cancel()
_ = srv.Shutdown(ctx)
_ = tlsSrv.Shutdown(ctx)
shutdownLimiterLookup()
limiterRedis.StopLookups()
oTelShutdown(ctx)
}()

Expand Down
24 changes: 20 additions & 4 deletions ratelimits/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) (
return nil, ErrInvalidCostOverLimit
}

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
tat, err := l.source.Get(ctx, bucketKey(name, id))
if err != nil {
if !errors.Is(err, ErrBucketNotFound) {
Expand Down Expand Up @@ -190,11 +193,14 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) (
}

start := l.clk.Now()
decisionStatus := Denied
status := Denied
defer func() {
l.spendLatency.WithLabelValues(nameToString[name], decisionStatus).Observe(l.clk.Since(start).Seconds())
l.spendLatency.WithLabelValues(name.String(), status).Observe(l.clk.Since(start).Seconds())
}()

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
tat, err := l.source.Get(ctx, bucketKey(name, id))
if err != nil {
if errors.Is(err, ErrBucketNotFound) {
Expand All @@ -204,7 +210,7 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) (
return nil, err
}
if d.Allowed {
decisionStatus = Allowed
status = Allowed
}
return d, nil
}
Expand All @@ -228,7 +234,7 @@ func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) (
if err != nil {
return nil, err
}
decisionStatus = Allowed
status = Allowed
return d, nil
}

Expand Down Expand Up @@ -256,6 +262,9 @@ func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64)
return nil, err
}

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
tat, err := l.source.Get(ctx, bucketKey(name, id))
if err != nil {
return nil, err
Expand All @@ -271,13 +280,20 @@ func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64)

// Reset resets the specified bucket.
func (l *Limiter) Reset(ctx context.Context, name Name, id string) error {
// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
return l.source.Delete(ctx, bucketKey(name, id))
}

// initialize creates a new bucket, specified by limit name and id, with the
// cost of the request factored into the initial state.
func (l *Limiter) initialize(ctx context.Context, rl limit, name Name, id string, cost int64) (*Decision, error) {
d := maybeSpend(l.clk, rl, l.clk.Now(), cost)

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
err := l.source.Set(ctx, bucketKey(name, id), d.newTAT)
if err != nil {
return nil, err
Expand Down
12 changes: 9 additions & 3 deletions ratelimits/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@ var ErrBucketNotFound = fmt.Errorf("bucket not found")

// source is an interface for creating and modifying TATs.
type source interface {
// Set stores the TAT at the specified bucketKey ('name:id').
// Set stores the TAT at the specified bucketKey ('name:id'). Contexts
// passed to this method should have a timeout or deadline set to prevent
// the operation from blocking indefinitely.
Set(ctx context.Context, bucketKey string, tat time.Time) error

// Get retrieves the TAT at the specified bucketKey ('name:id').
// Get retrieves the TAT at the specified bucketKey ('name:id'). Contexts
// passed to this method should have a timeout or deadline set to prevent
// the operation from blocking indefinitely.
Get(ctx context.Context, bucketKey string) (time.Time, error)

// Delete deletes the TAT at the specified bucketKey ('name:id').
// Delete deletes the TAT at the specified bucketKey ('name:id'). Contexts
// passed to this method should have a timeout or deadline set to prevent
// the operation from blocking indefinitely.
Delete(ctx context.Context, bucketKey string) error
}

Expand Down
55 changes: 35 additions & 20 deletions redis/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,18 @@ type Config struct {
IdleCheckFrequency config.Duration `validate:"-"`
}

// NewRing returns a new Redis ring client.
func (c *Config) NewRing(stats prometheus.Registerer) (*redis.Ring, error) {
// Ring is a wrapper around the go-redis/v9 Ring client that adds support for
// (optional) periodic SRV lookups.
type Ring struct {
*redis.Ring
lookup *lookup
}

// NewRingFromConfig returns a new *redis.Ring client. If periodic SRV lookups
// are supplied, a goroutine will be started to periodically perform lookups.
// Callers should defer a call to StopLookups() to ensure that this goroutine is
// gracefully shutdown.
func NewRingFromConfig(c Config, stats prometheus.Registerer, log blog.Logger) (*Ring, error) {
password, err := c.Pass()
if err != nil {
return nil, fmt.Errorf("loading password: %w", err)
Expand All @@ -120,7 +130,7 @@ func (c *Config) NewRing(stats prometheus.Registerer) (*redis.Ring, error) {
return nil, fmt.Errorf("loading TLS config: %w", err)
}

client := redis.NewRing(&redis.RingOptions{
inner := redis.NewRing(&redis.RingOptions{
Addrs: c.ShardAddrs,
Username: c.Username,
Password: password,
Expand All @@ -141,26 +151,31 @@ func (c *Config) NewRing(stats prometheus.Registerer) (*redis.Ring, error) {
})
if len(c.ShardAddrs) > 0 {
// Client was statically configured with a list of shards.
MustRegisterClientMetricsCollector(client, stats, c.ShardAddrs, c.Username)
MustRegisterClientMetricsCollector(inner, stats, c.ShardAddrs, c.Username)
}

return client, nil
}

// NewRingWithPeriodicLookups returns a new Redis ring client whose shards are
// periodically updated via SRV lookups. An initial SRV lookup is performed to
// populate the Redis ring shards. If this lookup fails or otherwise results in
// an empty set of resolved shards, an error is returned.
func (c *Config) NewRingWithPeriodicLookups(stats prometheus.Registerer, logger blog.Logger) (*redis.Ring, *Lookup, error) {
ring, err := c.NewRing(stats)
if err != nil {
return nil, nil, err
var lookup *lookup
if len(c.Lookups) != 0 {
lookup, err = newLookup(c.Lookups, c.LookupDNSAuthority, c.LookupFrequency.Duration, inner, log, stats)
if err != nil {
return nil, err
}
lookup.start()
}

lookup, err := newLookup(c.Lookups, c.LookupDNSAuthority, c.LookupFrequency.Duration, ring, logger, stats)
if err != nil {
return nil, nil, err
}
return &Ring{
Ring: inner,
lookup: lookup,
}, nil
}

return ring, lookup, nil
// StopLookups stops the goroutine responsible for keeping the shards of the
// inner *redis.Ring up-to-date. It is a no-op if the Ring was not constructed
// with periodic lookups or if the lookups have already been stopped.
func (r *Ring) StopLookups() {
if r == nil || r.lookup == nil {
// No-op.
return
}
r.lookup.stop()
}
31 changes: 18 additions & 13 deletions redis/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (

var ErrNoShardsResolved = errors.New("0 shards were resolved")

// Lookup wraps a Redis ring client by reference and keeps the Redis ring shards
// lookup wraps a Redis ring client by reference and keeps the Redis ring shards
// up to date via periodic SRV lookups.
type Lookup struct {
type lookup struct {
// srvLookups is a list of SRV records to be looked up.
srvLookups []cmd.ServiceDomain

Expand All @@ -38,16 +38,20 @@ type Lookup struct {
// will be used for resolution.
dnsAuthority string

// stop is a context.CancelFunc that can be used to stop the goroutine
// responsible for performing periodic SRV lookups.
stop context.CancelFunc

resolver *net.Resolver
ring *redis.Ring
logger blog.Logger
stats prometheus.Registerer
}

// newLookup constructs and returns a new Lookup instance. An initial SRV lookup
// newLookup constructs and returns a new lookup instance. An initial SRV lookup
// is performed to populate the Redis ring shards. If this lookup fails or
// otherwise results in an empty set of resolved shards, an error is returned.
func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency time.Duration, ring *redis.Ring, logger blog.Logger, stats prometheus.Registerer) (*Lookup, error) {
func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency time.Duration, ring *redis.Ring, logger blog.Logger, stats prometheus.Registerer) (*lookup, error) {
updateFrequency := frequency
if updateFrequency <= 0 {
// Set default frequency.
Expand All @@ -56,7 +60,7 @@ func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency ti
// Set default timeout to 90% of the update frequency.
updateTimeout := updateFrequency - updateFrequency/10

lookup := &Lookup{
lookup := &lookup{
srvLookups: srvLookups,
ring: ring,
logger: logger,
Expand Down Expand Up @@ -111,7 +115,7 @@ func newLookup(srvLookups []cmd.ServiceDomain, dnsAuthority string, frequency ti
// lookup succeeds, the Redis ring is updated, and all errors are discarded.
// Non-temporary DNS errors are always logged as they occur, as they're likely
// to be indicative of a misconfiguration.
func (look *Lookup) updateNow(ctx context.Context) (tempError, nonTempError error) {
func (look *lookup) updateNow(ctx context.Context) (tempError, nonTempError error) {
var tempErrs []error
handleDNSError := func(err error, srv cmd.ServiceDomain) {
var dnsErr *net.DNSError
Expand Down Expand Up @@ -176,20 +180,21 @@ func (look *Lookup) updateNow(ctx context.Context) (tempError, nonTempError erro
return nil, nil
}

// Start starts a goroutine that keeps the Redis ring shards up to date via
// periodic SRV lookups. The goroutine will exit when the provided context is
// cancelled.
func (look *Lookup) Start(ctx context.Context) {
// start starts a goroutine that keeps the Redis ring shards up-to-date by
// periodically performing SRV lookups.
func (look *lookup) start() {
var lookupCtx context.Context
lookupCtx, look.stop = context.WithCancel(context.Background())
go func() {
ticker := time.NewTicker(look.updateFrequency)
defer ticker.Stop()
for {
// Check for context cancellation before we do any work.
if ctx.Err() != nil {
if lookupCtx.Err() != nil {
return
}

timeoutCtx, cancel := context.WithTimeout(ctx, look.updateTimeout)
timeoutCtx, cancel := context.WithTimeout(lookupCtx, look.updateTimeout)
tempErrs, nonTempErrs := look.updateNow(timeoutCtx)
cancel()
if tempErrs != nil {
Expand All @@ -205,7 +210,7 @@ func (look *Lookup) Start(ctx context.Context) {
case <-ticker.C:
continue

case <-ctx.Done():
case <-lookupCtx.Done():
return
}
}
Expand Down
6 changes: 2 additions & 4 deletions redis/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,8 @@ func TestStart(t *testing.T) {
)
test.AssertNotError(t, err, "Expected newLookup construction to succeed")

testCtx, cancel := context.WithCancel(context.Background())
defer cancel()

lookup.Start(testCtx)
lookup.start()
lookup.stop()
}

func TestNewLookupWithOneFailingSRV(t *testing.T) {
Expand Down
Loading

0 comments on commit f2e93de

Please sign in to comment.