diff --git a/retrier/retrier.go b/retrier/retrier.go index 04a116b..b51c1a7 100644 --- a/retrier/retrier.go +++ b/retrier/retrier.go @@ -11,12 +11,13 @@ import ( // Retrier implements the "retriable" resiliency pattern, abstracting out the process of retrying a failed action // a certain number of times with an optional back-off between each retry. type Retrier struct { - backoff []time.Duration - infiniteRetry bool - class Classifier - jitter float64 - rand *rand.Rand - randMu sync.Mutex + backoff []time.Duration + infiniteRetry bool + surfaceWorkErrors bool + class Classifier + jitter float64 + rand *rand.Rand + randMu sync.Mutex } // New constructs a Retrier with the given backoff pattern and classifier. The length of the backoff pattern @@ -43,6 +44,13 @@ func (r *Retrier) WithInfiniteRetry() *Retrier { return r } +// WithSurfaceWorkErrors configures the retrier to always return the last error received from work function +// even if a context timeout/deadline is hit. +func (r *Retrier) WithSurfaceWorkErrors() *Retrier { + r.surfaceWorkErrors = true + return r +} + // Run executes the given work function by executing RunCtx without context.Context. func (r *Retrier) Run(work func() error) error { return r.RunFn(context.Background(), func(c context.Context, r int) error { @@ -83,6 +91,9 @@ func (r *Retrier) RunFn(ctx context.Context, work func(ctx context.Context, retr timer := time.NewTimer(r.calcSleep(retries)) if err := r.sleep(ctx, timer); err != nil { + if r.surfaceWorkErrors { + return ret + } return err } diff --git a/retrier/retrier_test.go b/retrier/retrier_test.go index c23407d..b71cd26 100644 --- a/retrier/retrier_test.go +++ b/retrier/retrier_test.go @@ -153,6 +153,52 @@ func TestRetrierRunFnWithInfinite(t *testing.T) { } } +func TestRetrierRunFnWithSurfaceWorkErrors(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := New([]time.Duration{0, 10 * time.Millisecond}, nil).WithSurfaceWorkErrors() + errExpected := []error{errFoo, errBar, errBaz} + + err := r.RunFn(ctx, func(ctx context.Context, retries int) error { + if retries >= len(errExpected) { + return nil + } + if retries == 1 { + // Context canceled inside second call to work function. + cancel() + } + err := errExpected[retries] + retries++ + return err + }) + if err != errBar { + t.Error(err) + } +} + +func TestRetrierRunFnWithoutSurfaceWorkErrors(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := New([]time.Duration{0, 10 * time.Millisecond}, nil) + errExpected := []error{errFoo, errBar, errBaz} + + err := r.RunFn(ctx, func(ctx context.Context, retries int) error { + if retries >= len(errExpected) { + return nil + } + if retries == 1 { + // Context canceled inside second call to work function. + cancel() + } + err := errExpected[retries] + retries++ + return err + }) + if err != context.Canceled { + t.Error(err) + } +} + func TestRetrierNone(t *testing.T) { r := New(nil, nil)