Skip to content

Commit

Permalink
Add RetryWithContext() and respect cancellation while sleeping
Browse files Browse the repository at this point in the history
This is a breaking change for developers using custom strategies.
However, there shouldn't be any impact on code using the
strategies included in this package.

Because the time.Sleep() call is now abstracted, strategies are
tested without actually sleeping, and the strategies don't
need to be aware of contexts.

Context is passed through to the action in case the action is defined
separately from the retry.RetryWithContext() call, is reused at
multiple points, etc.
  • Loading branch information
CodyDWJones committed Mar 12, 2021
1 parent 272ad12 commit 2c55b7d
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 94 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ logFile.Chdir() // Do something with the file
### HTTP request with strategies and backoff

```go
var response *http.Response
action := func(ctx context.Context, attempt uint) error {
var response *http.Response

action := func(attempt uint) error {
var err error

response, err = http.Get("https://api.github.com/repos/Rican7/retry")
req, err := NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/Rican7/retry", nil)
if err == nil {
response, err = c.Do(req)
}

if nil == err && nil != response && response.StatusCode > 200 {
err = fmt.Errorf("failed to fetch (attempt #%d) with status code: %d", attempt, response.StatusCode)
Expand All @@ -69,7 +70,8 @@ action := func(attempt uint) error {
return err
}

err := retry.Retry(
err := retry.RetryWithContext(
context.TODO(),
action,
strategy.Limit(5),
strategy.Backoff(backoff.Fibonacci(10*time.Millisecond)),
Expand Down
43 changes: 38 additions & 5 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,66 @@
// Copyright © 2016 Trevor N. Suarez (Rican7)
package retry

import "github.com/Rican7/retry/strategy"
import (
"context"
"time"

"github.com/Rican7/retry/strategy"
)

// Action defines a callable function that package retry can handle.
type Action func(attempt uint) error

// ActionWithContext defines a callable function that package retry can handle.
type ActionWithContext func(ctx context.Context, attempt uint) error

// Retry takes an action and performs it, repetitively, until successful.
//
// Optionally, strategies may be passed that assess whether or not an attempt
// should be made.
func Retry(action Action, strategies ...strategy.Strategy) error {
return RetryWithContext(context.Background(), func(ctx context.Context, attempt uint) error { return action(attempt) }, strategies...)
}

// RetryWithContext takes an action and performs it, repetitively, until successful
// or the context is done.
//
// Optionally, strategies may be passed that assess whether or not an attempt
// should be made.
func RetryWithContext(ctx context.Context, action ActionWithContext, strategies ...strategy.Strategy) error {
var err error

for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, strategies...); attempt++ {
err = action(attempt)
if ctx.Err() != nil {
return ctx.Err()
}

for attempt := uint(0); (0 == attempt || nil != err && nil == ctx.Err()) && shouldAttempt(attempt, sleepFunc(ctx), strategies...); attempt++ {
err = action(ctx, attempt)
}

return err
}

// shouldAttempt evaluates the provided strategies with the given attempt to
// determine if the Retry loop should make another attempt.
func shouldAttempt(attempt uint, strategies ...strategy.Strategy) bool {
func shouldAttempt(attempt uint, sleep func(time.Duration), strategies ...strategy.Strategy) bool {
shouldAttempt := true

for i := 0; shouldAttempt && i < len(strategies); i++ {
shouldAttempt = shouldAttempt && strategies[i](attempt)
shouldAttempt = shouldAttempt && strategies[i](attempt, sleep)
}

return shouldAttempt
}

// sleepFunc returns a function with the same signature as time.Sleep()
// that blocks for the given duration, but will return sooner if the context is
// cancelled or its deadline passes.
func sleepFunc(ctx context.Context) func(time.Duration) {
return func(d time.Duration) {
select {
case <-ctx.Done():
case <-time.After(d):
}
}
}
85 changes: 73 additions & 12 deletions retry_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package retry

import (
"context"
"errors"
"testing"
"time"
)

// timeMarginOfError represents the acceptable amount of time that may pass for
// a time-based (sleep) unit before considering invalid.
const timeMarginOfError = time.Millisecond

func TestRetry(t *testing.T) {
action := func(attempt uint) error {
return nil
Expand Down Expand Up @@ -47,8 +53,63 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) {
}
}

func TestRetryWithContextAlreadyCancelled(t *testing.T) {
action := func(ctx context.Context, attempt uint) error {
return errors.New("erroring")
}

ctx, cancel := context.WithCancel(context.Background())
cancel()

err := RetryWithContext(ctx, action)

if ctx.Err() != err {
t.Error("expected a context error")
}
}

func TestRetryWithContextSleepIsInterrupted(t *testing.T) {
const sleepDuration = 100 * timeMarginOfError
noSleepDeadline := time.Now().Add(sleepDuration)

strategy := func(attempt uint, sleep func(time.Duration)) bool {
sleep(sleepDuration)
return true
}

var numCalls int
expectedErr := errors.New("erroring")

action := func(ctx context.Context, attempt uint) error {
numCalls++
return expectedErr
}

stopAfter := 10 * timeMarginOfError
deadline := time.Now().Add(stopAfter)
ctx, _ := context.WithDeadline(context.Background(), deadline)

err := RetryWithContext(ctx, action, strategy)

if time.Now().Before(deadline) {
t.Errorf("expected to stop after %s", stopAfter)
}

if time.Now().After(noSleepDeadline) {
t.Errorf("expected to stop before %s", sleepDuration)
}

if 1 != numCalls {
t.Errorf("expected the action to be tried once, not %d times", numCalls)
}

if expectedErr != err {
t.Error("expected to receive the error returned by the action")
}
}

func TestShouldAttempt(t *testing.T) {
shouldAttempt := shouldAttempt(1)
shouldAttempt := shouldAttempt(1, time.Sleep)

if !shouldAttempt {
t.Error("expected to return true")
Expand All @@ -58,63 +119,63 @@ func TestShouldAttempt(t *testing.T) {
func TestShouldAttemptWithStrategy(t *testing.T) {
const attemptNumberShouldReturnFalse = 7

strategy := func(attempt uint) bool {
strategy := func(attempt uint, sleep func(time.Duration)) bool {
return (attemptNumberShouldReturnFalse != attempt)
}

should := shouldAttempt(1, strategy)
should := shouldAttempt(1, time.Sleep, strategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(1+attemptNumberShouldReturnFalse, strategy)
should = shouldAttempt(1+attemptNumberShouldReturnFalse, time.Sleep, strategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(attemptNumberShouldReturnFalse, strategy)
should = shouldAttempt(attemptNumberShouldReturnFalse, time.Sleep, strategy)

if should {
t.Error("expected to return false")
}
}

func TestShouldAttemptWithMultipleStrategies(t *testing.T) {
trueStrategy := func(attempt uint) bool {
trueStrategy := func(attempt uint, sleep func(time.Duration)) bool {
return true
}

falseStrategy := func(attempt uint) bool {
falseStrategy := func(attempt uint, sleep func(time.Duration)) bool {
return false
}

should := shouldAttempt(1, trueStrategy)
should := shouldAttempt(1, time.Sleep, trueStrategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(1, falseStrategy)
should = shouldAttempt(1, time.Sleep, falseStrategy)

if should {
t.Error("expected to return false")
}

should = shouldAttempt(1, trueStrategy, trueStrategy, trueStrategy)
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, trueStrategy)

if !should {
t.Error("expected to return true")
}

should = shouldAttempt(1, falseStrategy, falseStrategy, falseStrategy)
should = shouldAttempt(1, time.Sleep, falseStrategy, falseStrategy, falseStrategy)

if should {
t.Error("expected to return false")
}

should = shouldAttempt(1, trueStrategy, trueStrategy, falseStrategy)
should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, falseStrategy)

if should {
t.Error("expected to return false")
Expand Down
18 changes: 9 additions & 9 deletions strategy/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ import (
// The strategy will be passed an "attempt" number on each successive retry
// iteration, starting with a `0` value before the first attempt is actually
// made. This allows for a pre-action delay, etc.
type Strategy func(attempt uint) bool
type Strategy func(attempt uint, sleep func(time.Duration)) bool

// Limit creates a Strategy that limits the number of attempts that Retry will
// make.
func Limit(attemptLimit uint) Strategy {
return func(attempt uint) bool {
return (attempt <= attemptLimit)
return func(attempt uint, sleep func(time.Duration)) bool {
return attempt <= attemptLimit
}
}

// Delay creates a Strategy that waits the given duration before the first
// attempt is made.
func Delay(duration time.Duration) Strategy {
return func(attempt uint) bool {
return func(attempt uint, sleep func(time.Duration)) bool {
if 0 == attempt {
time.Sleep(duration)
sleep(duration)
}

return true
Expand All @@ -44,15 +44,15 @@ func Delay(duration time.Duration) Strategy {
// the first. If the number of attempts is greater than the number of durations
// provided, then the strategy uses the last duration provided.
func Wait(durations ...time.Duration) Strategy {
return func(attempt uint) bool {
return func(attempt uint, sleep func(time.Duration)) bool {
if 0 < attempt && 0 < len(durations) {
durationIndex := int(attempt - 1)

if len(durations) <= durationIndex {
durationIndex = len(durations) - 1
}

time.Sleep(durations[durationIndex])
sleep(durations[durationIndex])
}

return true
Expand All @@ -68,9 +68,9 @@ func Backoff(algorithm backoff.Algorithm) Strategy {
// BackoffWithJitter creates a Strategy that waits before each attempt, with a
// duration as defined by the given backoff.Algorithm and jitter.Transformation.
func BackoffWithJitter(algorithm backoff.Algorithm, transformation jitter.Transformation) Strategy {
return func(attempt uint) bool {
return func(attempt uint, sleep func(time.Duration)) bool {
if 0 < attempt {
time.Sleep(transformation(algorithm(attempt)))
sleep(transformation(algorithm(attempt)))
}

return true
Expand Down
Loading

0 comments on commit 2c55b7d

Please sign in to comment.