Skip to content

Commit

Permalink
test: Add coverage for retry package (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivov authored Nov 27, 2024
1 parent df2e9c0 commit 9eb223c
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 1 deletion.
2 changes: 1 addition & 1 deletion internal/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const (
var (
defaultMaxRetryTime = 60 * time.Second
defaultMaxRetries = 100
defaultWaitTimeBetweenRetries = 5 * time.Second
Expand Down
272 changes: 272 additions & 0 deletions internal/retry/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
package retry

import (
"errors"
"testing"
"time"
)

func setRetryTimings(t *testing.T) func() {
t.Helper()
origMaxRetryTime := defaultMaxRetryTime
origMaxRetries := defaultMaxRetries
origWaitTime := defaultWaitTimeBetweenRetries

defaultMaxRetryTime = 100 * time.Millisecond
defaultMaxRetries = 3
defaultWaitTimeBetweenRetries = 10 * time.Millisecond

return func() {
defaultMaxRetryTime = origMaxRetryTime
defaultMaxRetries = origMaxRetries
defaultWaitTimeBetweenRetries = origWaitTime
}
}

func TestUnlimitedRetry(t *testing.T) {
restoreFn := setRetryTimings(t)
defer restoreFn()

tests := []struct {
name string
operationFn func() (string, error)
expectedCalls int
expectError bool
expectedValue string
}{
{
name: "succeeds on first try",
operationFn: func() (string, error) {
return "success", nil
},
expectedCalls: 1,
expectedValue: "success",
expectError: false,
},
{
name: "succeeds after multiple retries",
operationFn: (func() func() (string, error) {
count := 0
return func() (string, error) {
count++
if count < 3 {
return "", errors.New("temporary error")
}
return "success after retries", nil
}
})(),
expectedCalls: 3,
expectedValue: "success after retries",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
callCount := 0
trackedFn := func() (string, error) {
callCount++
return tt.operationFn()
}

result, err := UnlimitedRetry("test-operation", trackedFn)

if tt.expectError && err == nil {
t.Error("expected error but got nil")
}

if !tt.expectError && err != nil {
t.Errorf("unexpected error: %v", err)
}

if result != tt.expectedValue {
t.Errorf("expected value %v but got %v", tt.expectedValue, result)
}

if callCount != tt.expectedCalls {
t.Errorf("expected %d calls but got %d", tt.expectedCalls, callCount)
}
})
}
}

func TestLimitedRetry(t *testing.T) {
restoreFn := setRetryTimings(t)
defer restoreFn()

tests := []struct {
name string
operationFn func() (string, error)
expectedCalls int
expectError bool
expectedValue string
}{
{
name: "succeeds on first try",
operationFn: func() (string, error) {
return "success", nil
},
expectedCalls: 1,
expectedValue: "success",
expectError: false,
},
{
name: "succeeds within retry limits",
operationFn: (func() func() (string, error) {
count := 0
return func() (string, error) {
count++
if count < 3 {
return "", errors.New("dummy error")
}
return "success after retries", nil
}
})(),
expectedCalls: 3,
expectedValue: "success after retries",
expectError: false,
},
{
name: "fails after max attempts",
operationFn: func() (string, error) {
return "", errors.New("persistent error")
},
expectedCalls: defaultMaxRetries,
expectError: true,
expectedValue: "",
},
{
name: "fails after max retry time",
operationFn: func() (string, error) {
time.Sleep(defaultMaxRetryTime + time.Second)
return "", errors.New("timeout error")
},
expectedCalls: 1,
expectError: true,
expectedValue: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
callCount := 0
trackedFn := func() (string, error) {
callCount++
return tt.operationFn()
}

result, err := LimitedRetry("test-operation", trackedFn)

if tt.expectError && err == nil {
t.Error("expected error but got nil")
}

if !tt.expectError && err != nil {
t.Errorf("unexpected error: %v", err)
}

if result != tt.expectedValue {
t.Errorf("expected value %v but got %v", tt.expectedValue, result)
}

if callCount != tt.expectedCalls {
t.Errorf("expected %d calls but got %d", tt.expectedCalls, callCount)
}
})
}
}

func TestRetryConfiguration(t *testing.T) {
tests := []struct {
name string
cfg retryConfig
fn func() (string, error)
want error
}{
{
name: "respects custom max retry time",
cfg: retryConfig{
MaxRetryTime: 100 * time.Millisecond,
MaxAttempts: 0,
WaitTimeBetweenRetries: time.Millisecond,
},
fn: func() (string, error) {
return "", errors.New("error")
},
want: errors.New("gave up retrying operation `test` on reaching max retry time 100ms, last error: error"),
},
{
name: "respects custom max attempts",
cfg: retryConfig{
MaxRetryTime: 0,
MaxAttempts: 2,
WaitTimeBetweenRetries: time.Millisecond,
},
fn: func() (string, error) {
return "", errors.New("error")
},
want: errors.New("gave up retrying operation `test` on reaching max retry attempts 2, last error: error"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := retry("test", tt.fn, tt.cfg)
if err == nil {
t.Fatal("expected error but got nil")
}

if err.Error() != tt.want.Error() {
t.Errorf("got error %q, want %q", err, tt.want)
}
})
}
}

func TestRetryWithDifferentTypes(t *testing.T) {
t.Run("works with string", func(t *testing.T) {
result, err := UnlimitedRetry("string-operation", func() (string, error) {
return "test", nil
})

if err != nil {
t.Errorf("unexpected error: %v", err)
}

if result != "test" {
t.Errorf("expected string \"test\" but got %s", result)
}
})

t.Run("works with int", func(t *testing.T) {
result, err := UnlimitedRetry("int-operation", func() (int, error) {
return 123, nil
})

if err != nil {
t.Errorf("unexpected error: %v", err)
}

if result != 123 {
t.Errorf("expected int 123 but got %d", result)
}
})

type testStruct struct {
value string
}

t.Run("works with struct", func(t *testing.T) {
result, err := UnlimitedRetry("struct-operation", func() (testStruct, error) {
return testStruct{value: "test"}, nil
})

if err != nil {
t.Errorf("unexpected error: %v", err)
}

if result.value != "test" {
t.Errorf("expected 'test' in struct.value but got %s", result.value)
}
})
}

0 comments on commit 9eb223c

Please sign in to comment.