diff --git a/internal/allocator/allocator.go b/internal/allocator/allocator.go index 666eb7371..9b0219916 100644 --- a/internal/allocator/allocator.go +++ b/internal/allocator/allocator.go @@ -67,6 +67,7 @@ func New() (v *Allocator) { return allocatorPool.Get() } +//nolint:funlen func (a *Allocator) Free() { a.valueAllocator.free() a.typeAllocator.free() diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index 4fdbd32e6..2795592e8 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -3,8 +3,6 @@ package decimal import ( "math/big" "math/bits" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) const ( @@ -99,27 +97,56 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } + s, neg, specialValue := setSpecialValue(s, v) + if specialValue != nil { + return specialValue, nil + } + var err error + v, err = parseNumber(s, v, precision, scale, neg) + if err != nil { + return nil, err + } + + return v, nil +} + +func setSpecialValue(s string, v *big.Int) (string, bool, *big.Int) { + s, neg := parseSign(s) + + return parseSpecialValue(s, neg, v) +} + +func parseSign(s string) (string, bool) { neg := s[0] == '-' if neg || s[0] == '+' { s = s[1:] } + + return s, neg +} + +func parseSpecialValue(s string, neg bool, v *big.Int) (string, bool, *big.Int) { if isInf(s) { if neg { - return v.Set(neginf), nil + return s, neg, v.Set(neginf) } - return v.Set(inf), nil + return s, neg, v.Set(inf) } if isNaN(s) { if neg { - return v.Set(negnan), nil + return s, neg, v.Set(negnan) } - return v.Set(nan), nil + return s, neg, v.Set(nan) } - integral := precision - scale + return s, neg, nil +} +func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big.Int, error) { + var err error + integral := precision - scale var dot bool for ; len(s) > 0; s = s[1:] { c := s[0] @@ -131,12 +158,10 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { continue } - if dot { - if scale > 0 { - scale-- - } else { - break - } + if dot && scale > 0 { + scale-- + } else if dot { + break } if !isDigit(c) { @@ -155,30 +180,10 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { } integral-- } - //nolint:nestif if len(s) > 0 { // Characters remaining. - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus := c > '5' - if !plus && c == '5' { - var x big.Int - plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. - for !plus && len(s) > 1 { - s = s[1:] - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus = c != '0' - } - } - if plus { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) - } + v, err = handleRemainingDigits(s, v, precision) + if err != nil { + return nil, err } } v.Mul(v, pow(ten, scale)) @@ -189,26 +194,54 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } +func handleRemainingDigits(s string, v *big.Int, precision uint32) (*big.Int, error) { + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus := c > '5' + if !plus && c == '5' { + var x big.Int + plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus = c != '0' + } + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) + } + } + + return v, nil +} + // Format returns the string representation of x with the given precision and // scale. func Format(x *big.Int, precision, scale uint32) string { - switch { - case x.CmpAbs(inf) == 0: + // Check for special values and nil pointer upfront. + if x == nil { + return "0" + } + if x.CmpAbs(inf) == 0 { if x.Sign() < 0 { return "-inf" } return "inf" - - case x.CmpAbs(nan) == 0: + } + if x.CmpAbs(nan) == 0 { if x.Sign() < 0 { return "-nan" } return "nan" - - case x == nil: - return "0" } v := big.NewInt(0).Set(x) @@ -232,42 +265,59 @@ func Format(x *big.Int, precision, scale uint32) string { digit.Mod(v, ten) d := int(digit.Int64()) - if d != 0 || scale == 0 || pos > 0 { - const numbers = "0123456789" - pos-- - bts[pos] = numbers[d] + + pos-- + if d != 0 || scale == 0 || pos >= 0 { + setDigitAtPosition(bts, pos, d) } + if scale > 0 { scale-- if scale == 0 && pos > 0 { + bts[pos-1] = '.' pos-- - bts[pos] = '.' } } } - if scale > 0 { - for ; scale > 0; scale-- { - if precision == 0 { - return errorTag - } - precision-- - pos-- - bts[pos] = '0' - } + for ; scale > 0; scale-- { + if precision == 0 { + pos = 0 + + break + } + precision-- pos-- - bts[pos] = '.' + bts[pos] = '0' } + if bts[pos] == '.' { pos-- bts[pos] = '0' } + if neg { pos-- bts[pos] = '-' } - return xstring.FromBytes(bts[pos:]) + return string(bts[pos:]) +} + +func abs(x *big.Int) (*big.Int, bool) { + v := big.NewInt(0).Set(x) + neg := x.Sign() < 0 + if neg { + // Convert negative to positive. + v.Neg(x) + } + + return v, neg +} + +func setDigitAtPosition(bts []byte, pos, digit int) { + const numbers = "0123456789" + bts[pos] = numbers[digit] } // BigIntToByte returns the 16-byte array representation of x. diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index fd7391da1..a4fe5fcaf 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -2,7 +2,10 @@ package decimal import ( "encoding/binary" + "math/big" "testing" + + "github.com/stretchr/testify/require" ) func TestFromBytes(t *testing.T) { @@ -57,6 +60,181 @@ func TestFromBytes(t *testing.T) { } } +func TestSetSpecialValue(t *testing.T) { + tests := []struct { + name string + input string + expectedS string + expectedNeg bool + expectedV *big.Int + }{ + { + name: "Positive infinity", + input: "inf", + expectedS: "inf", + expectedNeg: false, + expectedV: inf, + }, + { + name: "Negative infinity", + input: "-inf", + expectedS: "inf", + expectedNeg: true, + expectedV: neginf, + }, + { + name: "Positive NaN", + input: "nan", + expectedS: "nan", + expectedNeg: false, + expectedV: nan, + }, + { + name: "Negative NaN", + input: "-nan", + expectedS: "nan", + expectedNeg: true, + expectedV: negnan, + }, + { + name: "Regular number", + input: "123", + expectedS: "123", + expectedNeg: false, + expectedV: nil, + }, + { + name: "Negative regular number", + input: "-123", + expectedS: "123", + expectedNeg: true, + expectedV: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := big.NewInt(0) + gotS, gotNeg, gotV := setSpecialValue(tt.input, v) + require.Equal(t, tt.expectedS, gotS) + require.Equal(t, tt.expectedNeg, gotNeg) + if tt.expectedV != nil { + require.Equal(t, 0, tt.expectedV.Cmp(gotV)) + } else { + require.Nil(t, gotV) + } + }) + } +} + +func TestPrepareValue(t *testing.T) { + tests := []struct { + name string + input *big.Int + expectedValue *big.Int + expectedNeg bool + }{ + { + name: "Positive value", + input: big.NewInt(123), + expectedValue: big.NewInt(123), + expectedNeg: false, + }, + { + name: "Negative value", + input: big.NewInt(-123), + expectedValue: big.NewInt(123), + expectedNeg: true, + }, + { + name: "Zero value", + input: big.NewInt(0), + expectedValue: big.NewInt(0), + expectedNeg: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value, neg := abs(tt.input) + require.Equal(t, tt.expectedValue, value) + require.Equal(t, tt.expectedNeg, neg) + }) + } +} + +func TestParseNumber(t *testing.T) { + // Mock or define these as per your actual implementation. + tests := []struct { + name string + s string + wantValue *big.Int + precision uint32 + scale uint32 + neg bool + wantErr bool + }{ + { + name: "Valid number without decimal", + s: "123", + precision: 3, + scale: 0, + neg: false, + wantValue: big.NewInt(123), + wantErr: false, + }, + { + name: "Valid number with decimal", + s: "123.45", + precision: 5, + scale: 2, + neg: false, + wantValue: big.NewInt(12345), + wantErr: false, + }, + { + name: "Valid negative number", + s: "123", + precision: 3, + scale: 0, + neg: true, + wantValue: big.NewInt(-123), + wantErr: false, + }, + { + name: "Syntax error with non-digit", + s: "123a", + precision: 4, + scale: 0, + neg: false, + wantValue: nil, + wantErr: true, + }, + { + name: "Multiple decimal points", + s: "12.3.4", + precision: 5, + scale: 2, + neg: false, + wantValue: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := big.NewInt(0) + gotValue, gotErr := parseNumber(tt.s, v, tt.precision, tt.scale, tt.neg) + if tt.wantErr { + require.Error(t, gotErr) + } else { + require.NoError(t, gotErr) + require.Equal(t, 0, tt.wantValue.Cmp(gotValue)) + } + }) + } +} + func uint128(hi, lo uint64) []byte { p := make([]byte, 16) binary.BigEndian.PutUint64(p[:8], hi) @@ -68,3 +246,143 @@ func uint128(hi, lo uint64) []byte { func uint128s(lo uint64) []byte { return uint128(0, lo) } + +func TestParse(t *testing.T) { + tests := []struct { + name string + s string + precision uint32 + scale uint32 + }{ + { + name: "Specific Parse test", + s: "100", + precision: 0, + scale: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expectedRes, expectedErr := oldParse(tt.s, tt.precision, tt.scale) + res, err := Parse(tt.s, tt.precision, tt.scale) + if expectedErr == nil { + require.Equal(t, expectedRes, res) + } else { + require.Error(t, err) + } + }) + } +} + +func FuzzParse(f *testing.F) { + f.Fuzz(func(t *testing.T, s string, precision, scale uint32) { + expectedRes, expectedErr := oldParse(s, precision, scale) + res, err := Parse(s, precision, scale) + if expectedErr == nil { + require.Equal(t, expectedRes, res) + } else { + require.Error(t, err) + } + }) +} + +func oldParse(s string, precision, scale uint32) (*big.Int, error) { + if scale > precision { + return nil, precisionError(s, precision, scale) + } + + v := big.NewInt(0) + if s == "" { + return v, nil + } + + neg := s[0] == '-' + if neg || s[0] == '+' { + s = s[1:] + } + if isInf(s) { + if neg { + return v.Set(neginf), nil + } + + return v.Set(inf), nil + } + if isNaN(s) { + if neg { + return v.Set(negnan), nil + } + + return v.Set(nan), nil + } + + integral := precision - scale + + var dot bool + for ; len(s) > 0; s = s[1:] { + c := s[0] + if c == '.' { + if dot { + return nil, syntaxError(s) + } + dot = true + + continue + } + if dot { + if scale > 0 { + scale-- + } else { + break + } + } + + if !isDigit(c) { + return nil, syntaxError(s) + } + + v.Mul(v, ten) + v.Add(v, big.NewInt(int64(c-'0'))) + + if !dot && v.Cmp(zero) > 0 && integral == 0 { + if neg { + return neginf, nil + } + + return inf, nil + } + integral-- + } + //nolint:nestif + if len(s) > 0 { // Characters remaining. + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus := c > '5' + if !plus && c == '5' { + var x big.Int + plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus = c != '0' + } + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) + } + } + } + v.Mul(v, pow(ten, scale)) + if neg { + v.Neg(v) + } + + return v, nil +} diff --git a/internal/stack/record.go b/internal/stack/record.go index 2098c1ec8..17c156736 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -2,6 +2,7 @@ package stack import ( "fmt" + "path" "runtime" "strings" @@ -18,6 +19,14 @@ type recordOptions struct { lambdas bool } +type functionDetails struct { + pkgPath string + pkgName string + structName string + funcName string + lambdas []string +} + type recordOption func(opts *recordOptions) func PackageName(b bool) recordOption { @@ -91,67 +100,81 @@ func (c call) Record(opts ...recordOption) string { opt(&optionsHolder) } } - name := runtime.FuncForPC(c.function).Name() - var ( - pkgPath string - pkgName string - structName string - funcName string - file = c.file - ) - if i := strings.LastIndex(file, "/"); i > -1 { - file = file[i+1:] - } + + name, file := extractName(c.function, c.file) + fnDetails := parseFunctionName(name) + + return buildRecordString(optionsHolder, &fnDetails, file, c.line) +} + +func extractName(function uintptr, file string) (name, fileName string) { + name = runtime.FuncForPC(function).Name() + _, fileName = path.Split(file) + name = strings.ReplaceAll(name, "[...]", "") + + return name, fileName +} + +func parseFunctionName(name string) functionDetails { + var details functionDetails if i := strings.LastIndex(name, "/"); i > -1 { - pkgPath, name = name[:i], name[i+1:] + details.pkgPath, name = name[:i], name[i+1:] } - name = strings.ReplaceAll(name, "[...]", "") split := strings.Split(name, ".") - lambdas := make([]string, 0, len(split)) + details.lambdas = make([]string, 0, len(split)) for i := range split { elem := split[len(split)-i-1] if !strings.HasPrefix(elem, "func") { break } - lambdas = append(lambdas, elem) + details.lambdas = append(details.lambdas, elem) } - split = split[:len(split)-len(lambdas)] + split = split[:len(split)-len(details.lambdas)] if len(split) > 0 { - pkgName = split[0] + details.pkgName = split[0] } if len(split) > 1 { - funcName = split[len(split)-1] + details.funcName = split[len(split)-1] } if len(split) > 2 { //nolint:gomnd - structName = split[1] + details.structName = split[1] } + return details +} + +func buildRecordString( + optionsHolder recordOptions, + fnDetails *functionDetails, + file string, + line int, +) string { buffer := xstring.Buffer() defer buffer.Free() if optionsHolder.packagePath { - buffer.WriteString(pkgPath) + buffer.WriteString(fnDetails.pkgPath) } if optionsHolder.packageName { if buffer.Len() > 0 { buffer.WriteByte('/') } - buffer.WriteString(pkgName) + buffer.WriteString(fnDetails.pkgName) } - if optionsHolder.structName && len(structName) > 0 { + if optionsHolder.structName && len(fnDetails.structName) > 0 { if buffer.Len() > 0 { buffer.WriteByte('.') } - buffer.WriteString(structName) + buffer.WriteString(fnDetails.structName) } if optionsHolder.functionName { if buffer.Len() > 0 { buffer.WriteByte('.') } - buffer.WriteString(funcName) + buffer.WriteString(fnDetails.funcName) if optionsHolder.lambdas { - for i := range lambdas { + for i := range fnDetails.lambdas { buffer.WriteByte('.') - buffer.WriteString(lambdas[len(lambdas)-i-1]) + buffer.WriteString(fnDetails.lambdas[len(fnDetails.lambdas)-i-1]) } } } @@ -164,7 +187,7 @@ func (c call) Record(opts ...recordOption) string { buffer.WriteString(file) if optionsHolder.line { buffer.WriteByte(':') - fmt.Fprintf(buffer, "%d", c.line) + fmt.Fprintf(buffer, "%d", line) } if closeBrace { buffer.WriteByte(')') diff --git a/internal/stack/record_test.go b/internal/stack/record_test.go index d0cdec245..9f10be303 100644 --- a/internal/stack/record_test.go +++ b/internal/stack/record_test.go @@ -1,6 +1,9 @@ package stack import ( + "reflect" + "runtime" + "strings" "testing" "github.com/stretchr/testify/require" @@ -30,13 +33,13 @@ func TestRecord(t *testing.T) { }{ { act: Record(0), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:32)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:35)", }, { act: func() string { return Record(1) }(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:38)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:41)", }, { act: func() string { @@ -44,7 +47,7 @@ func TestRecord(t *testing.T) { return Record(2) }() }(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:46)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:49)", }, { act: testStruct{depth: 0, opts: []recordOption{ @@ -164,7 +167,7 @@ func TestRecord(t *testing.T) { // FileName(false), // Line(false), }}.TestFunc(), - exp: "record_test.go:16", + exp: "record_test.go:19", }, { act: testStruct{depth: 0, opts: []recordOption{ @@ -236,7 +239,7 @@ func TestRecord(t *testing.T) { // FileName(false), // Line(false), }}.TestFunc(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.testStruct.TestFunc.func1(record_test.go:16)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.testStruct.TestFunc.func1(record_test.go:19)", }, { act: (&testStruct{depth: 0, opts: []recordOption{ @@ -248,7 +251,7 @@ func TestRecord(t *testing.T) { // FileName(false), // Line(false), }}).TestPointerFunc(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.(*testStruct).TestPointerFunc.func1(record_test.go:22)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.(*testStruct).TestPointerFunc.func1(record_test.go:25)", }, } { t.Run("", func(t *testing.T) { @@ -257,6 +260,61 @@ func TestRecord(t *testing.T) { } } +func TestExtractNames(t *testing.T) { + testFunc := func() {} + funcPtr := reflect.ValueOf(testFunc).Pointer() + + funcNameExpected := runtime.FuncForPC(funcPtr).Name() + + _, file, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller should return true indicating success") + + fileParts := strings.Split(file, "/") + fileNameExpected := fileParts[len(fileParts)-1] + + name, fileName := extractName(funcPtr, file) + + require.Equal(t, funcNameExpected, name, "Function name should match expected value") + require.Equal(t, fileNameExpected, fileName, "File name should match expected value") +} + +func TestParseFunctionName(t *testing.T) { + name := "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestParseFunctionName.func1" + fnDetails := parseFunctionName(name) + + require.Equal(t, "github.com/ydb-platform/ydb-go-sdk/v3/internal", fnDetails.pkgPath) + require.Equal(t, "stack", fnDetails.pkgName) + require.Empty(t, fnDetails.structName, "Struct name should be empty for standalone functions") + require.Equal(t, "TestParseFunctionName", fnDetails.funcName) + require.Contains(t, fnDetails.lambdas, "func1", "Lambdas should include 'func1'") +} + +func TestBuildRecordString(t *testing.T) { + optionsHolder := recordOptions{ + packagePath: true, + packageName: false, + structName: true, + functionName: true, + fileName: true, + line: true, + lambdas: true, + } + fnDetails := functionDetails{ + pkgPath: "github.com/ydb-platform/ydb-go-sdk/v3/internal", + pkgName: "", + structName: "testStruct", + funcName: "TestFunc", + + lambdas: []string{"func1"}, + } + file := "record_test.go" + line := 319 + + result := buildRecordString(optionsHolder, &fnDetails, file, line) + expected := "github.com/ydb-platform/ydb-go-sdk/v3/internal.testStruct.TestFunc.func1(record_test.go:319)" + require.Equal(t, expected, result) +} + func BenchmarkCall(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { diff --git a/internal/types/types.go b/internal/types/types.go index 4942cd787..b40a04003 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -82,6 +82,7 @@ func TypeFromYDB(x *Ydb.Type) Type { } } +//nolint:funlen func primitiveTypeFromYDB(t Ydb.Type_PrimitiveTypeId) Type { switch t { case Ydb.Type_BOOL: diff --git a/retry/retry.go b/retry/retry.go index 8b4e4d05d..c2e601603 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -257,6 +257,8 @@ func WithPanicCallback(panicCallback func(e interface{})) panicCallbackOption { // Warning: if context without deadline or cancellation func was passed, Retry will work infinitely. // // If you need to retry your op func on some logic errors - you must return RetryableError() from retryOperation +// +//nolint:funlen func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr error) { options := &retryOptions{ call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/retry.Retry"), @@ -300,26 +302,11 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err select { case <-ctx.Done(): return xerrors.WithStackTrace( - fmt.Errorf("retry failed on attempt No.%d: %w", - attempts, ctx.Err(), - ), + fmt.Errorf("retry failed on attempt No.%d: %w", attempts, ctx.Err()), ) default: - err := func() (err error) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - err = xerrors.WithStackTrace( - fmt.Errorf("panic recovered: %v", e), - ) - } - }() - } - - return op(ctx) - }() + err := opWithRecover(ctx, options, op) if err == nil { return nil @@ -336,8 +323,7 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err if !m.MustRetry(options.idempotent) { return xerrors.WithStackTrace( fmt.Errorf("non-retryable error occurred on attempt No.%d (idempotent=%v): %w", - attempts, options.idempotent, err, - ), + attempts, options.idempotent, err), ) } @@ -373,6 +359,21 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err } } +func opWithRecover(ctx context.Context, options *retryOptions, op retryOperation) (err error) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + err = xerrors.WithStackTrace( + fmt.Errorf("panic recovered: %v", e), + ) + } + }() + } + + return op(ctx) +} + // Check returns retry mode for queryErr. func Check(err error) (m retryMode) { code, errType, backoffType, deleteSession := xerrors.Check(err) diff --git a/retry/retry_test.go b/retry/retry_test.go index 29afcf6fa..40e2f0294 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -209,3 +209,45 @@ func TestRetryWithBudget(t *testing.T) { require.ErrorIs(t, err, errNoQuota) }) } + +type MockPanicCallback struct { + called bool + received interface{} +} + +func (m *MockPanicCallback) Call(e interface{}) { + m.called = true + m.received = e +} + +func TestOpWithRecover_NoPanic(t *testing.T) { + ctx := context.Background() + options := &retryOptions{ + panicCallback: nil, + } + op := func(ctx context.Context) error { + return nil + } + + err := opWithRecover(ctx, options, op) + + require.NoError(t, err) +} + +func TestOpWithRecover_WithPanic(t *testing.T) { + ctx := context.Background() + mockCallback := new(MockPanicCallback) + options := &retryOptions{ + panicCallback: mockCallback.Call, + } + op := func(ctx context.Context) error { + panic("test panic") + } + + err := opWithRecover(ctx, options, op) + + require.Error(t, err) + require.Contains(t, err.Error(), "panic recovered: test panic") + require.True(t, mockCallback.called) + require.Equal(t, "test panic", mockCallback.received) +}