diff --git a/internal/table/client.go b/internal/table/client.go index 57e20a827..90bedc239 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -666,9 +666,30 @@ func (c *Client) Do(ctx context.Context, op table.Operation, opts ...table.Optio onDone(attempts, finalErr) }() - err := do(ctx, c, c.config, op, func(err error) { - attempts++ - }, config.RetryOptions...) + err := retryBackoff(ctx, c, + func(ctx context.Context, s table.Session) (err error) { + attempts++ + + err = func() error { + if panicCallback := c.config.PanicCallback(); panicCallback != nil { + defer func() { + if e := recover(); e != nil { + panicCallback(e) + } + }() + } + + return op(xcontext.MarkRetryCall(ctx), s) + }() + + if err != nil { + return xerrors.WithStackTrace(err) + } + + return nil + }, + config.RetryOptions..., + ) if err != nil { return xerrors.WithStackTrace(err) } @@ -695,14 +716,19 @@ func (c *Client) DoTx(ctx context.Context, op table.TxOperation, opts ...table.O onDone(attempts, finalErr) }() - return retryBackoff(ctx, c, - func(ctx context.Context, s table.Session) (err error) { + err := retryBackoff(ctx, c, + func(ctx context.Context, s table.Session) (finalErr error) { attempts++ tx, err := s.BeginTransaction(ctx, config.TxSettings) if err != nil { return xerrors.WithStackTrace(err) } + defer func() { + if finalErr != nil { + _ = tx.Rollback(ctx) + } + }() err = func() error { if panicCallback := c.config.PanicCallback(); panicCallback != nil { @@ -729,6 +755,11 @@ func (c *Client) DoTx(ctx context.Context, op table.TxOperation, opts ...table.O }, config.RetryOptions..., ) + if err != nil { + return xerrors.WithStackTrace(err) + } + + return nil } func (c *Client) internalPoolGCTick(ctx context.Context, idleThreshold time.Duration) { diff --git a/internal/table/retry.go b/internal/table/retry.go index e2b522b45..92d95bacf 100644 --- a/internal/table/retry.go +++ b/internal/table/retry.go @@ -3,8 +3,6 @@ package table import ( "context" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/retry" "github.com/ydb-platform/ydb-go-sdk/v3/table" @@ -22,44 +20,6 @@ type SessionProvider interface { Put(ctx context.Context, s *session) (err error) } -func do( - ctx context.Context, - c SessionProvider, - config *config.Config, - op table.Operation, - onAttempt func(err error), - opts ...retry.Option, -) (err error) { - return retryBackoff(ctx, c, - func(ctx context.Context, s table.Session) (err error) { - defer func() { - if onAttempt != nil { - onAttempt(err) - } - }() - - err = func() error { - if panicCallback := config.PanicCallback(); panicCallback != nil { - defer func() { - if e := recover(); e != nil { - panicCallback(e) - } - }() - } - - return op(xcontext.MarkRetryCall(ctx), s) - }() - - if err != nil { - return xerrors.WithStackTrace(err) - } - - return nil - }, - opts..., - ) -} - func retryBackoff( ctx context.Context, p SessionProvider, diff --git a/internal/table/retry_test.go b/internal/table/retry_test.go index 008633786..332126360 100644 --- a/internal/table/retry_test.go +++ b/internal/table/retry_test.go @@ -11,7 +11,6 @@ import ( grpcCodes "google.golang.org/grpc/codes" grpcStatus "google.golang.org/grpc/status" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" @@ -41,12 +40,10 @@ func TestRetryerBackoffRetryCancelation(t *testing.T) { ctx, cancel := xcontext.WithCancel(context.Background()) results := make(chan error) go func() { - err := do(ctx, p, - config.New(), + err := retryBackoff(ctx, p, func(ctx context.Context, _ table.Session) error { return testErr }, - nil, retry.WithFastBackoff( testutil.BackoffFunc(func(n int) <-chan time.Time { ch := make(chan time.Time) @@ -103,7 +100,7 @@ func TestRetryerBadSession(t *testing.T) { sessions []table.Session ) ctx, cancel := xcontext.WithCancel(context.Background()) - err := do(ctx, p, config.New(), + err := retryBackoff(ctx, p, func(ctx context.Context, s table.Session) error { sessions = append(sessions, s) i++ @@ -115,7 +112,6 @@ func TestRetryerBadSession(t *testing.T) { xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION), ) }, - func(err error) {}, ) if !xerrors.Is(err, context.Canceled) { t.Errorf("unexpected error: %v", err) @@ -154,17 +150,13 @@ func TestRetryerSessionClosing(t *testing.T) { } var sessions []table.Session for i := 0; i < 1000; i++ { - err := do( - context.Background(), - p, - config.New(), + err := retryBackoff(context.Background(), p, func(ctx context.Context, s table.Session) error { sessions = append(sessions, s) s.(*session).SetStatus(table.SessionClosing) return nil }, - nil, ) if err != nil { t.Errorf("unexpected error: %v", err) @@ -208,14 +200,10 @@ func TestRetryerImmediateReturn(t *testing.T) { p := SingleSession( simpleSession(t), ) - err := do( - context.Background(), - p, - config.New(), + err := retryBackoff(context.Background(), p, func(ctx context.Context, _ table.Session) error { return testErr }, - nil, retry.WithFastBackoff( testutil.BackoffFunc(func(n int) <-chan time.Time { panic("this code will not be called") @@ -341,10 +329,7 @@ func TestRetryContextDeadline(t *testing.T) { t.Run(fmt.Sprintf("Timeout=%v,Sleep=%v", timeout, sleep), func(t *testing.T) { ctx, cancel := xcontext.WithTimeout(context.Background(), timeout) defer cancel() - _ = do( - ctx, - p, - config.New(), + _ = retryBackoff(ctx, p, func(ctx context.Context, _ table.Session) error { select { case <-ctx.Done(): @@ -353,7 +338,6 @@ func TestRetryContextDeadline(t *testing.T) { return errs[r.Int(len(errs))] } }, - nil, ) }) } @@ -442,10 +426,7 @@ func TestRetryWithCustomErrors(t *testing.T) { i = 0 sessions = make(map[table.Session]int) ) - err := do( - ctx, - p, - config.New(), + err := retryBackoff(ctx, p, func(ctx context.Context, s table.Session) (err error) { sessions[s]++ i++ @@ -455,7 +436,6 @@ func TestRetryWithCustomErrors(t *testing.T) { return nil }, - nil, ) //nolint:nestif if test.retriable { diff --git a/retry/sql.go b/retry/sql.go index fca156431..aa263cee6 100644 --- a/retry/sql.go +++ b/retry/sql.go @@ -166,6 +166,11 @@ func DoTx(ctx context.Context, db *sql.DB, op func(context.Context, *sql.Tx) err if err != nil { return unwrapErrBadConn(xerrors.WithStackTrace(err)) } + defer func() { + if finalErr != nil { + _ = tx.Rollback() + } + }() if err = op(xcontext.MarkRetryCall(ctx), tx); err != nil { return unwrapErrBadConn(xerrors.WithStackTrace(err))