diff --git a/internal/tx/id.go b/internal/tx/id.go index e43406bbc..d7113978b 100644 --- a/internal/tx/id.go +++ b/internal/tx/id.go @@ -4,6 +4,7 @@ var _ Identifier = LazyID{} const ( LazyTxID = "LAZY_TX" + FakeTxID = "FAKE_TX" ) type ( diff --git a/internal/xsql/conn.go b/internal/xsql/conn.go index 769e1c95b..69490966e 100644 --- a/internal/xsql/conn.go +++ b/internal/xsql/conn.go @@ -33,7 +33,8 @@ var ( type ( connWrapper struct { - cc conn.Conn + cc conn.Conn + currentTx *txWrapper connector *Connector lastUsage xsync.LastUsage @@ -54,16 +55,22 @@ func (c *connWrapper) CheckNamedValue(value *driver.NamedValue) error { } func (c *connWrapper) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if c.currentTx != nil { + return nil, xerrors.WithStackTrace(xerrors.AlreadyHasTx(c.currentTx.ID())) + } + tx, err := c.cc.BeginTx(ctx, opts) if err != nil { return nil, xerrors.WithStackTrace(err) } - return &txWrapper{ + c.currentTx = &txWrapper{ conn: c, ctx: ctx, tx: tx, - }, nil + } + + return c.currentTx, nil } func (c *connWrapper) Close() error { @@ -179,6 +186,10 @@ func (c *connWrapper) QueryContext(ctx context.Context, sql string, args []drive return rowByAstPlan(ast, plan), nil } + if c.currentTx != nil { + return c.currentTx.tx.Query(ctx, sql, params) + } + return c.cc.Query(ctx, sql, params) } @@ -191,6 +202,10 @@ func (c *connWrapper) ExecContext(ctx context.Context, sql string, args []driver return nil, xerrors.WithStackTrace(err) } + if c.currentTx != nil { + return c.currentTx.tx.Exec(ctx, sql, params) + } + return c.cc.Exec(ctx, sql, params) } diff --git a/internal/xsql/conn/query/conn.go b/internal/xsql/conn/query/conn.go index 01edaf783..18169044d 100644 --- a/internal/xsql/conn/query/conn.go +++ b/internal/xsql/conn/query/conn.go @@ -32,6 +32,7 @@ type Conn struct { session *query.Session onClose []func() closed atomic.Bool + fakeTx bool } func (c *Conn) Exec(ctx context.Context, sql string, params *params.Params) ( @@ -114,6 +115,10 @@ func (c *Conn) isReady() bool { } func (c *Conn) beginTx(ctx context.Context, txOptions driver.TxOptions) (tx conn.Tx, finalErr error) { + if c.fakeTx { + return beginTxFake(ctx, c), nil + } + tx, err := beginTx(ctx, c, txOptions) if err != nil { return nil, xerrors.WithStackTrace(err) diff --git a/internal/xsql/conn/query/options.go b/internal/xsql/conn/query/options.go index 683f30899..2c3a1a94d 100644 --- a/internal/xsql/conn/query/options.go +++ b/internal/xsql/conn/query/options.go @@ -7,3 +7,9 @@ func WithOnClose(onClose func()) Option { c.onClose = append(c.onClose, onClose) } } + +func WithFakeTx() Option { + return func(c *Conn) { + c.fakeTx = true + } +} diff --git a/internal/xsql/conn/query/tx_fake.go b/internal/xsql/conn/query/tx_fake.go new file mode 100644 index 000000000..9f55c0f6c --- /dev/null +++ b/internal/xsql/conn/query/tx_fake.go @@ -0,0 +1,62 @@ +package query + +import ( + "context" + "database/sql/driver" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/table/badconn" +) + +type txFake struct { + conn *Conn + ctx context.Context //nolint:containedctx +} + +func (t *txFake) Exec(ctx context.Context, sql string, params *params.Params) (driver.Result, error) { + result, err := t.conn.Exec(ctx, sql, params) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return result, nil +} + +func (t *txFake) Query(ctx context.Context, sql string, params *params.Params) (driver.RowsNextResultSet, error) { + rows, err := t.conn.Query(ctx, sql, params) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return rows, nil +} + +func (t *txFake) ID() string { + return tx.FakeTxID +} + +func beginTxFake(ctx context.Context, c *Conn) conn.Tx { + return &txFake{ + conn: c, + ctx: ctx, + } +} + +func (t *txFake) Commit(ctx context.Context) (err error) { + if !t.conn.isReady() { + return badconn.Map(xerrors.WithStackTrace(errNotReadyConn)) + } + + return nil +} + +func (t *txFake) Rollback(ctx context.Context) (err error) { + if !t.conn.isReady() { + return badconn.Map(xerrors.WithStackTrace(errNotReadyConn)) + } + + return err +} diff --git a/internal/xsql/conn/table/errors.go b/internal/xsql/conn/table/errors.go index 060e95e3a..90db8c026 100644 --- a/internal/xsql/conn/table/errors.go +++ b/internal/xsql/conn/table/errors.go @@ -11,4 +11,5 @@ var ( ErrUnsupported = driver.ErrSkip errConnClosedEarly = xerrors.Retryable(errors.New("conn closed early"), xerrors.InvalidObject()) errNotReadyConn = xerrors.Retryable(errors.New("conn not ready"), xerrors.InvalidObject()) + ErrWrongQueryMode = errors.New("wrong query mode") ) diff --git a/internal/xsql/conn/table/tx.go b/internal/xsql/conn/table/tx.go index 86e8a6b0f..762dcdf4e 100644 --- a/internal/xsql/conn/table/tx.go +++ b/internal/xsql/conn/table/tx.go @@ -26,15 +26,7 @@ func (tx *transaction) ID() string { func (tx *transaction) Exec(ctx context.Context, sql string, params *params.Params) (driver.Result, error) { m := queryModeFromContext(ctx, tx.conn.defaultQueryMode) if m != DataQueryMode { - return nil, badconn.Map( - xerrors.WithStackTrace( - xerrors.Retryable( - fmt.Errorf("wrong query mode: %s", m.String()), - xerrors.InvalidObject(), - xerrors.WithName("WRONG_QUERY_MODE"), - ), - ), - ) + return nil, xerrors.WithStackTrace(fmt.Errorf("%q: %w", m.String(), ErrWrongQueryMode)) } _, err := tx.tx.Execute(ctx, sql, params, tx.conn.dataQueryOptions(ctx)...) if err != nil { diff --git a/internal/xsql/conn/table/tx_fake.go b/internal/xsql/conn/table/tx_fake.go index 277f9f95f..505ca7919 100644 --- a/internal/xsql/conn/table/tx_fake.go +++ b/internal/xsql/conn/table/tx_fake.go @@ -12,14 +12,12 @@ import ( ) type txFake struct { - tx.Identifier - conn *Conn ctx context.Context //nolint:containedctx } -func (tx *txFake) Exec(ctx context.Context, sql string, params *params.Params) (driver.Result, error) { - result, err := tx.conn.Exec(ctx, sql, params) +func (t *txFake) Exec(ctx context.Context, sql string, params *params.Params) (driver.Result, error) { + result, err := t.conn.Exec(ctx, sql, params) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -27,8 +25,8 @@ func (tx *txFake) Exec(ctx context.Context, sql string, params *params.Params) ( return result, nil } -func (tx *txFake) Query(ctx context.Context, sql string, params *params.Params) (driver.RowsNextResultSet, error) { - rows, err := tx.conn.Query(ctx, sql, params) +func (t *txFake) Query(ctx context.Context, sql string, params *params.Params) (driver.RowsNextResultSet, error) { + rows, err := t.conn.Query(ctx, sql, params) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -36,24 +34,27 @@ func (tx *txFake) Query(ctx context.Context, sql string, params *params.Params) return rows, nil } +func (t *txFake) ID() string { + return tx.FakeTxID +} + func beginTxFake(ctx context.Context, c *Conn) conn.Tx { return &txFake{ - Identifier: tx.ID("FAKE"), - conn: c, - ctx: ctx, + conn: c, + ctx: ctx, } } -func (tx *txFake) Commit(ctx context.Context) (err error) { - if !tx.conn.isReady() { +func (t *txFake) Commit(ctx context.Context) (err error) { + if !t.conn.isReady() { return badconn.Map(xerrors.WithStackTrace(errNotReadyConn)) } return nil } -func (tx *txFake) Rollback(ctx context.Context) (err error) { - if !tx.conn.isReady() { +func (t *txFake) Rollback(ctx context.Context) (err error) { + if !t.conn.isReady() { return badconn.Map(xerrors.WithStackTrace(errNotReadyConn)) } diff --git a/internal/xsql/connector.go b/internal/xsql/connector.go index 530cafe07..3f65d3588 100644 --- a/internal/xsql/connector.go +++ b/internal/xsql/connector.go @@ -15,8 +15,8 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" - query2 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/query" - table2 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/table" + connOverQueryServiceClient "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/query" + connOverTableServiceClient "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/table" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" "github.com/ydb-platform/ydb-go-sdk/v3/retry/budget" "github.com/ydb-platform/ydb-go-sdk/v3/scheme" @@ -38,8 +38,8 @@ type ( queryProcessor queryProcessor - TableOpts []table2.Option - QueryOpts []query2.Option + TableOpts []connOverTableServiceClient.Option + QueryOpts []connOverQueryServiceClient.Option disableServerBalancer bool onCLose []func(*Connector) @@ -122,9 +122,9 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { id := uuid.New() conn := &connWrapper{ - cc: query2.New(ctx, c, s, append( + cc: connOverQueryServiceClient.New(ctx, c, s, append( c.QueryOpts, - query2.WithOnClose(func() { + connOverQueryServiceClient.WithOnClose(func() { c.conns.Delete(id) }))..., ), @@ -145,8 +145,8 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { id := uuid.New() conn := &connWrapper{ - cc: table2.New(ctx, c, s, append(c.TableOpts, - table2.WithOnClose(func() { + cc: connOverTableServiceClient.New(ctx, c, s, append(c.TableOpts, + connOverTableServiceClient.WithOnClose(func() { c.conns.Delete(id) }))..., ), diff --git a/internal/xsql/errors.go b/internal/xsql/errors.go index f564b005c..252485975 100644 --- a/internal/xsql/errors.go +++ b/internal/xsql/errors.go @@ -12,5 +12,5 @@ var ( errDeprecated = driver.ErrSkip errAlreadyClosed = errors.New("already closed") errWrongQueryProcessor = errors.New("wrong query processor") - errNotReadyConn = xerrors.Retryable(errors.New("connWrapper not ready"), xerrors.InvalidObject()) + errNotReadyConn = xerrors.Retryable(errors.New("conn not ready"), xerrors.InvalidObject()) ) diff --git a/internal/xsql/options.go b/internal/xsql/options.go index 9a6b7a3c6..f15e98403 100644 --- a/internal/xsql/options.go +++ b/internal/xsql/options.go @@ -170,6 +170,22 @@ func WithIdleThreshold(idleThreshold time.Duration) Option { } } +type mergedOptions []Option + +func (opts mergedOptions) Apply(c *Connector) error { + for _, opt := range opts { + if err := opt.Apply(c); err != nil { + return err + } + } + + return nil +} + +func Merge(opts ...Option) Option { + return mergedOptions(opts) +} + func WithTableOptions(opts ...table.Option) Option { return tableQueryOptionsOption{ tableOps: opts, diff --git a/internal/xsql/tx.go b/internal/xsql/tx.go index 7dc92338a..9652da89d 100644 --- a/internal/xsql/tx.go +++ b/internal/xsql/tx.go @@ -28,6 +28,10 @@ var ( ) func (tx *txWrapper) Commit() (finalErr error) { + defer func() { + tx.conn.currentTx = nil + }() + var ( ctx = tx.ctx onDone = trace.DatabaseSQLOnTxCommit(tx.conn.connector.Trace(), &ctx, @@ -47,6 +51,10 @@ func (tx *txWrapper) Commit() (finalErr error) { } func (tx *txWrapper) Rollback() (finalErr error) { + defer func() { + tx.conn.currentTx = nil + }() + var ( ctx = tx.ctx onDone = trace.DatabaseSQLOnTxRollback(tx.conn.connector.Trace(), &ctx, diff --git a/sql.go b/sql.go index 4fbe59f04..42373018c 100644 --- a/sql.go +++ b/sql.go @@ -9,7 +9,8 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql" - table2 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/table" + connOverQueryServiceClient "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/query" + connOverTableServiceClient "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/conn/table" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" "github.com/ydb-platform/ydb-go-sdk/v3/table" "github.com/ydb-platform/ydb-go-sdk/v3/table/options" @@ -78,35 +79,40 @@ func (d *sqlDriver) detach(c *xsql.Connector) { d.connectors.Delete(c) } -type QueryMode = table2.QueryMode +type QueryMode int const ( - DataQueryMode = iota + 1 + _ = QueryMode(iota) + DataQueryMode ExplainQueryMode ScanQueryMode SchemeQueryMode ScriptingQueryMode + QueryExecuteQueryMode ) +// WithQueryMode set query mode for legacy database/sql driver +// +// For actual database/sql driver works over query service client and no needs query mode func WithQueryMode(ctx context.Context, mode QueryMode) context.Context { switch mode { case ExplainQueryMode: return xsql.WithExplain(ctx) case DataQueryMode: - return table2.WithQueryMode(ctx, table2.DataQueryMode) + return connOverTableServiceClient.WithQueryMode(ctx, connOverTableServiceClient.DataQueryMode) case ScanQueryMode: - return table2.WithQueryMode(ctx, table2.ScanQueryMode) + return connOverTableServiceClient.WithQueryMode(ctx, connOverTableServiceClient.ScanQueryMode) case SchemeQueryMode: - return table2.WithQueryMode(ctx, table2.SchemeQueryMode) + return connOverTableServiceClient.WithQueryMode(ctx, connOverTableServiceClient.SchemeQueryMode) case ScriptingQueryMode: - return table2.WithQueryMode(ctx, table2.ScriptingQueryMode) + return connOverTableServiceClient.WithQueryMode(ctx, connOverTableServiceClient.ScriptingQueryMode) default: return ctx } } func WithTxControl(ctx context.Context, txc *table.TransactionControl) context.Context { - return table2.WithTxControl(ctx, txc) + return connOverTableServiceClient.WithTxControl(ctx, txc) } type ConnectorOption = xsql.Option @@ -116,12 +122,63 @@ type QueryBindConnectorOption interface { bind.Bind } +func modeToMode(mode QueryMode) connOverTableServiceClient.QueryMode { + switch mode { + case ScanQueryMode: + return connOverTableServiceClient.ScanQueryMode + case SchemeQueryMode: + return connOverTableServiceClient.SchemeQueryMode + case ScriptingQueryMode: + return connOverTableServiceClient.ScriptingQueryMode + default: + return connOverTableServiceClient.DataQueryMode + } +} + func WithDefaultQueryMode(mode QueryMode) ConnectorOption { - return xsql.WithTableOptions(table2.WithDefaultQueryMode(mode)) + return xsql.WithTableOptions( + connOverTableServiceClient.WithDefaultQueryMode(modeToMode(mode)), + ) } -func WithFakeTx(mode QueryMode) ConnectorOption { - return xsql.WithTableOptions(table2.WithFakeTxModes(mode)) +func WithFakeTx(modes ...QueryMode) ConnectorOption { + opts := make([]ConnectorOption, 0, len(modes)) + + for _, mode := range modes { + switch mode { + case DataQueryMode: + opts = append(opts, + xsql.WithTableOptions(connOverTableServiceClient.WithFakeTxModes( + connOverTableServiceClient.DataQueryMode, + )), + ) + case ScanQueryMode: + opts = append(opts, + xsql.WithTableOptions(connOverTableServiceClient.WithFakeTxModes( + connOverTableServiceClient.ScanQueryMode, + )), + ) + case SchemeQueryMode: + opts = append(opts, + xsql.WithTableOptions(connOverTableServiceClient.WithFakeTxModes( + connOverTableServiceClient.SchemeQueryMode, + )), + ) + case ScriptingQueryMode: + opts = append(opts, + xsql.WithTableOptions(connOverTableServiceClient.WithFakeTxModes( + connOverTableServiceClient.DataQueryMode, + )), + ) + case QueryExecuteQueryMode: + opts = append(opts, + xsql.WithQueryOptions(connOverQueryServiceClient.WithFakeTx()), + ) + default: + } + } + + return xsql.Merge(opts...) } func WithTablePathPrefix(tablePathPrefix string) QueryBindConnectorOption { @@ -141,15 +198,15 @@ func WithNumericArgs() QueryBindConnectorOption { } func WithDefaultTxControl(txControl *table.TransactionControl) ConnectorOption { - return xsql.WithTableOptions(table2.WithDefaultTxControl(txControl)) + return xsql.WithTableOptions(connOverTableServiceClient.WithDefaultTxControl(txControl)) } func WithDefaultDataQueryOptions(opts ...options.ExecuteDataQueryOption) ConnectorOption { - return xsql.WithTableOptions(table2.WithDataOpts(opts...)) + return xsql.WithTableOptions(connOverTableServiceClient.WithDataOpts(opts...)) } func WithDefaultScanQueryOptions(opts ...options.ExecuteScanQueryOption) ConnectorOption { - return xsql.WithTableOptions(table2.WithScanOpts(opts...)) + return xsql.WithTableOptions(connOverTableServiceClient.WithScanOpts(opts...)) } func WithDatabaseSQLTrace( diff --git a/tests/integration/database_sql_ddl_in_transaction_test.go b/tests/integration/database_sql_ddl_in_transaction_test.go new file mode 100644 index 000000000..7c2057e12 --- /dev/null +++ b/tests/integration/database_sql_ddl_in_transaction_test.go @@ -0,0 +1,96 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/retry" +) + +func TestDatabaseSqlDDLInTransaction(t *testing.T) { + var ( + scope = newScope(t) + db = scope.SQLDriverWithFolder() + ) + + defer func() { + _ = db.Close() + }() + + f := func(ctx context.Context, tx *sql.Tx) (err error) { + _, err = tx.ExecContext( + ydb.WithQueryMode(ctx, ydb.SchemeQueryMode), + `DROP TABLE IF EXISTS test`, + ) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `CREATE TABLE test (id Uint64, PRIMARY KEY (id))`) + if err != nil { + return err + } + + return err + } + + t.Run("InTransaction", func(t *testing.T) { + t.Run("WrongQueryMode", func(t *testing.T) { + err := retry.DoTx(scope.Ctx, db, f, + retry.WithIdempotent(true), retry.WithTxOptions(&sql.TxOptions{ + Isolation: sql.LevelSnapshot, + ReadOnly: true, + }), + ) + require.Error(t, err) + }) + t.Run("SnapshotROIsolation", func(t *testing.T) { + err := retry.DoTx(scope.Ctx, db, f, + retry.WithIdempotent(true), retry.WithTxOptions(&sql.TxOptions{ + Isolation: sql.LevelSnapshot, + ReadOnly: true, + }), + ) + require.Error(t, err) + }) + t.Run("SerializableRWIsolation", func(t *testing.T) { + err := retry.DoTx(scope.Ctx, db, + f, retry.WithIdempotent(true), retry.WithTxOptions(&sql.TxOptions{ + Isolation: sql.LevelSerializable, + ReadOnly: false, + }), + ) + require.Error(t, err) + }) + t.Run("FakeTx", func(t *testing.T) { + connector, err := ydb.Connector(scope.Driver(), + ydb.WithTablePathPrefix(scope.Folder()), + ydb.WithFakeTx( + ydb.SchemeQueryMode, + ydb.QueryExecuteQueryMode, + ), + ) + require.NoError(t, err) + + db := sql.OpenDB(connector) + + err = db.PingContext(scope.Ctx) + require.NoError(t, err) + + err = retry.DoTx(ydb.WithQueryMode(scope.Ctx, ydb.SchemeQueryMode), db, + f, retry.WithIdempotent(true), retry.WithTxOptions(&sql.TxOptions{ + Isolation: sql.LevelSerializable, + ReadOnly: false, + }), + ) + require.NoError(t, err) + }) + }) +}