From 042927f266a7939012b0d763a7649152669cf1dc Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 18 Feb 2025 17:02:53 +0530 Subject: [PATCH] allow dml query to return empty fields list on prepare Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/delete.go | 11 +++-------- go/vt/vtgate/engine/dml_with_input.go | 8 +------- go/vt/vtgate/engine/fk_cascade.go | 8 +------- go/vt/vtgate/engine/fk_verify.go | 6 +----- go/vt/vtgate/engine/insert_common.go | 7 +------ go/vt/vtgate/engine/primitive.go | 7 +++++++ go/vt/vtgate/engine/send.go | 3 +++ go/vt/vtgate/engine/send_test.go | 27 +++++++++++++++++++-------- go/vt/vtgate/engine/sequential.go | 9 ++------- go/vt/vtgate/engine/set.go | 6 +----- go/vt/vtgate/engine/singlerow.go | 6 +----- go/vt/vtgate/engine/update.go | 7 +------ go/vt/vtgate/engine/upsert.go | 7 ++----- go/vt/vtgate/executor.go | 5 +++-- 14 files changed, 46 insertions(+), 71 deletions(-) diff --git a/go/vt/vtgate/engine/delete.go b/go/vt/vtgate/engine/delete.go index 4dea88db81a..32f7efeaa95 100644 --- a/go/vt/vtgate/engine/delete.go +++ b/go/vt/vtgate/engine/delete.go @@ -33,10 +33,10 @@ var _ Primitive = (*Delete)(nil) // Delete represents the instructions to perform a delete. type Delete struct { - *DML - - // Delete does not take inputs noInputs + noFields + + *DML } // TryExecute performs a non-streaming exec. @@ -70,11 +70,6 @@ func (del *Delete) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVa return callback(res) } -// GetFields fetches the field info. -func (del *Delete) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, fmt.Errorf("BUG: unreachable code for %q", del.Query) -} - // deleteVindexEntries performs an delete if table owns vindex. // Note: the commit order may be different from the DML order because it's possible // for DMLs to reuse existing transactions. diff --git a/go/vt/vtgate/engine/dml_with_input.go b/go/vt/vtgate/engine/dml_with_input.go index ce8fff3f463..2ab6464792b 100644 --- a/go/vt/vtgate/engine/dml_with_input.go +++ b/go/vt/vtgate/engine/dml_with_input.go @@ -20,8 +20,6 @@ import ( "context" "fmt" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -33,6 +31,7 @@ const DmlVals = "dml_vals" // DMLWithInput represents the instructions to perform a DML operation based on the input result. type DMLWithInput struct { txNeeded + noFields Input Primitive @@ -160,11 +159,6 @@ func (dml *DMLWithInput) TryStreamExecute(ctx context.Context, vcursor VCursor, return callback(res) } -// GetFields fetches the field info. -func (dml *DMLWithInput) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.VT13001("unreachable code for DMLs") -} - func (dml *DMLWithInput) description() PrimitiveDescription { var offsets []string for idx, offset := range dml.OutputCols { diff --git a/go/vt/vtgate/engine/fk_cascade.go b/go/vt/vtgate/engine/fk_cascade.go index 35122ac9563..b73ab15546b 100644 --- a/go/vt/vtgate/engine/fk_cascade.go +++ b/go/vt/vtgate/engine/fk_cascade.go @@ -23,8 +23,6 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" ) // FkChild contains the Child Primitive to be executed collecting the values from the Selection Primitive using the column indexes. @@ -54,6 +52,7 @@ type NonLiteralUpdateInfo struct { // On success, it executes the Parent Primitive. type FkCascade struct { txNeeded + noFields // Selection is the Primitive that is used to find the rows that are going to be modified in the child tables. Selection Primitive @@ -78,11 +77,6 @@ func (fkc *FkCascade) GetTableName() string { return fkc.Parent.GetTableName() } -// GetFields implements the Primitive interface. -func (fkc *FkCascade) GetFields(_ context.Context, _ VCursor, _ map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] GetFields should not be called") -} - // TryExecute implements the Primitive interface. func (fkc *FkCascade) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { // Execute the Selection primitive to find the rows that are going to modified. diff --git a/go/vt/vtgate/engine/fk_verify.go b/go/vt/vtgate/engine/fk_verify.go index 7184e5d8381..da850f7a366 100644 --- a/go/vt/vtgate/engine/fk_verify.go +++ b/go/vt/vtgate/engine/fk_verify.go @@ -36,6 +36,7 @@ type Verify struct { // It does this by executing a select distinct query on the parent table with the values that are being inserted/updated. type FkVerify struct { txNeeded + noFields Verify []*Verify Exec Primitive @@ -62,11 +63,6 @@ func (f *FkVerify) GetTableName() string { return f.Exec.GetTableName() } -// GetFields implements the Primitive interface -func (f *FkVerify) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] GetFields should not be called") -} - // TryExecute implements the Primitive interface func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { for _, v := range f.Verify { diff --git a/go/vt/vtgate/engine/insert_common.go b/go/vt/vtgate/engine/insert_common.go index f2eff953b47..e325ab7b6cc 100644 --- a/go/vt/vtgate/engine/insert_common.go +++ b/go/vt/vtgate/engine/insert_common.go @@ -37,8 +37,8 @@ import ( type ( InsertCommon struct { - // Insert needs tx handling txNeeded + noFields // Opcode is the execution opcode. Opcode InsertOpcode @@ -144,11 +144,6 @@ func (ic *InsertCommon) GetTableName() string { return ic.TableName } -// GetFields fetches the field info. -func (ic *InsertCommon) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.VT13001("unexpected fields call for insert query") -} - func (ins *InsertCommon) executeUnshardedTableQuery(ctx context.Context, vcursor VCursor, loggingPrimitive Primitive, bindVars map[string]*querypb.BindVariable, query string, insertID uint64) (*sqltypes.Result, error) { rss, _, err := vcursor.ResolveDestinations(ctx, ins.Keyspace.Name, nil, []key.ShardDestination{key.DestinationAllShards{}}) if err != nil { diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 4abe55f50de..d1222475148 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -269,6 +269,9 @@ type ( // txNeeded is a default implementation for Primitives that need transaction handling txNeeded struct{} + + // noFields is a default implementation for Primitives that do not return fields + noFields struct{} ) // Find will return the first Primitive that matches the evaluate function. If no match is found, nil will be returned @@ -303,3 +306,7 @@ func (noTxNeeded) NeedsTransaction() bool { func (txNeeded) NeedsTransaction() bool { return true } + +func (noFields) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return &sqltypes.Result{}, nil +} diff --git a/go/vt/vtgate/engine/send.go b/go/vt/vtgate/engine/send.go index 2f24119d762..01a87bd76b9 100644 --- a/go/vt/vtgate/engine/send.go +++ b/go/vt/vtgate/engine/send.go @@ -185,6 +185,9 @@ func (s *Send) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars m // GetFields implements Primitive interface func (s *Send) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + if s.IsDML || s.IsDDL { + return &sqltypes.Result{}, nil + } qr, err := vcursor.ExecutePrimitive(ctx, s, bindVars, false) if err != nil { return nil, err diff --git a/go/vt/vtgate/engine/send_test.go b/go/vt/vtgate/engine/send_test.go index c2765ff5ce0..68bbdb10726 100644 --- a/go/vt/vtgate/engine/send_test.go +++ b/go/vt/vtgate/engine/send_test.go @@ -305,16 +305,27 @@ func TestSendGetFields(t *testing.T) { }, Query: "dummy_query", TargetDestination: key.DestinationAllShards{}, - IsDML: true, SingleShardOnly: false, } vc := &loggingVCursor{shards: []string{"-20", "20-"}, results: results} - qr, err := send.GetFields(context.Background(), vc, map[string]*querypb.BindVariable{}) - require.NoError(t, err) - vc.ExpectLog(t, []string{ - `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.-20: dummy_query {} ks.20-: dummy_query {} true false`, + + t.Run("GetFields - not a dml query", func(t *testing.T) { + qr, err := send.GetFields(context.Background(), vc, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.-20: dummy_query {} ks.20-: dummy_query {} false false`, + }) + require.Nil(t, qr.Rows) + require.Equal(t, 4, len(qr.Fields)) + }) + + vc.Rewind() + t.Run("GetFields - a dml query", func(t *testing.T) { + send.IsDML = true + qr, err := send.GetFields(context.Background(), vc, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + require.Empty(t, qr.Fields) + vc.ExpectLog(t, nil) }) - require.Nil(t, qr.Rows) - require.Equal(t, 4, len(qr.Fields)) } diff --git a/go/vt/vtgate/engine/sequential.go b/go/vt/vtgate/engine/sequential.go index ecf74d663a2..56be78c12e1 100644 --- a/go/vt/vtgate/engine/sequential.go +++ b/go/vt/vtgate/engine/sequential.go @@ -21,14 +21,14 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" ) // Sequential Primitive is used to execute DML statements in a fixed order. // Any failure, stops the execution and returns. type Sequential struct { txNeeded + noFields + Sources []Primitive } @@ -92,11 +92,6 @@ func (s *Sequential) TryStreamExecute(ctx context.Context, vcursor VCursor, bind return callback(qr) } -// GetFields fetches the field info. -func (s *Sequential) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unreachable code for Sequential engine") -} - // Inputs returns the input primitives for this func (s *Sequential) Inputs() ([]Primitive, []map[string]any) { return s.Sources, nil diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index 258f9bfe66c..32a297404f3 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -41,6 +41,7 @@ type ( // Set contains the instructions to perform set. Set struct { noTxNeeded + noFields Ops []SetOp Input Primitive @@ -142,11 +143,6 @@ func (s *Set) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars ma return callback(result) } -// GetFields implements the Primitive interface method. -func (s *Set) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return &sqltypes.Result{}, nil -} - // Inputs implements the Primitive interface func (s *Set) Inputs() ([]Primitive, []map[string]any) { return []Primitive{s.Input}, nil diff --git a/go/vt/vtgate/engine/singlerow.go b/go/vt/vtgate/engine/singlerow.go index 35ecdaff90b..8a3fff2cacd 100644 --- a/go/vt/vtgate/engine/singlerow.go +++ b/go/vt/vtgate/engine/singlerow.go @@ -29,6 +29,7 @@ var _ Primitive = (*SingleRow)(nil) type SingleRow struct { noInputs noTxNeeded + noFields } // RouteType returns a description of the query routing type used by the primitive @@ -65,11 +66,6 @@ func (s *SingleRow) TryStreamExecute(ctx context.Context, vcursor VCursor, bindV return callback(res) } -// GetFields fetches the field info. -func (s *SingleRow) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return &sqltypes.Result{}, nil -} - func (s *SingleRow) description() PrimitiveDescription { return PrimitiveDescription{ OperatorType: "SingleRow", diff --git a/go/vt/vtgate/engine/update.go b/go/vt/vtgate/engine/update.go index 2aafa1ee060..8f869b37537 100644 --- a/go/vt/vtgate/engine/update.go +++ b/go/vt/vtgate/engine/update.go @@ -40,8 +40,8 @@ type VindexValues struct { // Update represents the instructions to perform an update. type Update struct { - // Update does not take inputs noInputs + noFields *DML @@ -81,11 +81,6 @@ func (upd *Update) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVa } -// GetFields fetches the field info. -func (upd *Update) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, fmt.Errorf("BUG: unreachable code for %q", upd.Query) -} - // updateVindexEntries performs an update when a vindex is being modified // by the statement. // Note: the commit order may be different from the DML order because it's possible diff --git a/go/vt/vtgate/engine/upsert.go b/go/vt/vtgate/engine/upsert.go index 2d224e9fdf3..58d996b2b2b 100644 --- a/go/vt/vtgate/engine/upsert.go +++ b/go/vt/vtgate/engine/upsert.go @@ -32,6 +32,8 @@ var _ Primitive = (*Upsert)(nil) // if there is `Duplicate Key` error, it executes the update primitive. type Upsert struct { txNeeded + noFields + Upserts []upsert } @@ -69,11 +71,6 @@ func (u *Upsert) GetTableName() string { return "" } -// GetFields implements Primitive interface type. -func (u *Upsert) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return nil, vterrors.VT13001("unexpected to receive GetFields call for insert on duplicate key update query") -} - // TryExecute implements Primitive interface type. func (u *Upsert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { result := &sqltypes.Result{} diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index dd783399925..537ac4c09f6 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1403,9 +1403,10 @@ func (e *Executor) prepare(ctx context.Context, safeSession *econtext.SafeSessio } switch stmtType { - case sqlparser.StmtSelect, sqlparser.StmtShow: + case sqlparser.StmtSelect, sqlparser.StmtShow, + sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: return e.handlePrepare(ctx, safeSession, sql, bindVars, logStats) - case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete, + case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtAnalyze, sqlparser.StmtComment, sqlparser.StmtExplain, sqlparser.StmtFlush, sqlparser.StmtKill: return nil, nil }