Skip to content

Commit

Permalink
allow dml query to return empty fields list on prepare
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Feb 18, 2025
1 parent 6448789 commit 042927f
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 71 deletions.
11 changes: 3 additions & 8 deletions go/vt/vtgate/engine/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ var _ Primitive = (*Delete)(nil)

// Delete represents the instructions to perform a delete.
type Delete struct {
*DML

// Delete does not take inputs
noInputs
noFields

*DML
}

// TryExecute performs a non-streaming exec.
Expand Down Expand Up @@ -70,11 +70,6 @@ func (del *Delete) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVa
return callback(res)
}

// GetFields fetches the field info.
func (del *Delete) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, fmt.Errorf("BUG: unreachable code for %q", del.Query)
}

// deleteVindexEntries performs an delete if table owns vindex.
// Note: the commit order may be different from the DML order because it's possible
// for DMLs to reuse existing transactions.
Expand Down
8 changes: 1 addition & 7 deletions go/vt/vtgate/engine/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"context"
"fmt"

"vitess.io/vitess/go/vt/vterrors"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)
Expand All @@ -33,6 +31,7 @@ const DmlVals = "dml_vals"
// DMLWithInput represents the instructions to perform a DML operation based on the input result.
type DMLWithInput struct {
txNeeded
noFields

Input Primitive

Expand Down Expand Up @@ -160,11 +159,6 @@ func (dml *DMLWithInput) TryStreamExecute(ctx context.Context, vcursor VCursor,
return callback(res)
}

// GetFields fetches the field info.
func (dml *DMLWithInput) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.VT13001("unreachable code for DMLs")
}

func (dml *DMLWithInput) description() PrimitiveDescription {
var offsets []string
for idx, offset := range dml.OutputCols {
Expand Down
8 changes: 1 addition & 7 deletions go/vt/vtgate/engine/fk_cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// FkChild contains the Child Primitive to be executed collecting the values from the Selection Primitive using the column indexes.
Expand Down Expand Up @@ -54,6 +52,7 @@ type NonLiteralUpdateInfo struct {
// On success, it executes the Parent Primitive.
type FkCascade struct {
txNeeded
noFields

// Selection is the Primitive that is used to find the rows that are going to be modified in the child tables.
Selection Primitive
Expand All @@ -78,11 +77,6 @@ func (fkc *FkCascade) GetTableName() string {
return fkc.Parent.GetTableName()
}

// GetFields implements the Primitive interface.
func (fkc *FkCascade) GetFields(_ context.Context, _ VCursor, _ map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] GetFields should not be called")
}

// TryExecute implements the Primitive interface.
func (fkc *FkCascade) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
// Execute the Selection primitive to find the rows that are going to modified.
Expand Down
6 changes: 1 addition & 5 deletions go/vt/vtgate/engine/fk_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Verify struct {
// It does this by executing a select distinct query on the parent table with the values that are being inserted/updated.
type FkVerify struct {
txNeeded
noFields

Verify []*Verify
Exec Primitive
Expand All @@ -62,11 +63,6 @@ func (f *FkVerify) GetTableName() string {
return f.Exec.GetTableName()
}

// GetFields implements the Primitive interface
func (f *FkVerify) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] GetFields should not be called")
}

// TryExecute implements the Primitive interface
func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
for _, v := range f.Verify {
Expand Down
7 changes: 1 addition & 6 deletions go/vt/vtgate/engine/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ import (

type (
InsertCommon struct {
// Insert needs tx handling
txNeeded
noFields

// Opcode is the execution opcode.
Opcode InsertOpcode
Expand Down Expand Up @@ -144,11 +144,6 @@ func (ic *InsertCommon) GetTableName() string {
return ic.TableName
}

// GetFields fetches the field info.
func (ic *InsertCommon) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.VT13001("unexpected fields call for insert query")
}

func (ins *InsertCommon) executeUnshardedTableQuery(ctx context.Context, vcursor VCursor, loggingPrimitive Primitive, bindVars map[string]*querypb.BindVariable, query string, insertID uint64) (*sqltypes.Result, error) {
rss, _, err := vcursor.ResolveDestinations(ctx, ins.Keyspace.Name, nil, []key.ShardDestination{key.DestinationAllShards{}})
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ type (

// txNeeded is a default implementation for Primitives that need transaction handling
txNeeded struct{}

// noFields is a default implementation for Primitives that do not return fields
noFields struct{}
)

// Find will return the first Primitive that matches the evaluate function. If no match is found, nil will be returned
Expand Down Expand Up @@ -303,3 +306,7 @@ func (noTxNeeded) NeedsTransaction() bool {
func (txNeeded) NeedsTransaction() bool {
return true
}

func (noFields) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return &sqltypes.Result{}, nil
}
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ func (s *Send) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars m

// GetFields implements Primitive interface
func (s *Send) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
if s.IsDML || s.IsDDL {
return &sqltypes.Result{}, nil
}
qr, err := vcursor.ExecutePrimitive(ctx, s, bindVars, false)
if err != nil {
return nil, err
Expand Down
27 changes: 19 additions & 8 deletions go/vt/vtgate/engine/send_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,16 +305,27 @@ func TestSendGetFields(t *testing.T) {
},
Query: "dummy_query",
TargetDestination: key.DestinationAllShards{},
IsDML: true,
SingleShardOnly: false,
}
vc := &loggingVCursor{shards: []string{"-20", "20-"}, results: results}
qr, err := send.GetFields(context.Background(), vc, map[string]*querypb.BindVariable{})
require.NoError(t, err)
vc.ExpectLog(t, []string{
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ks.-20: dummy_query {} ks.20-: dummy_query {} true false`,

t.Run("GetFields - not a dml query", func(t *testing.T) {
qr, err := send.GetFields(context.Background(), vc, map[string]*querypb.BindVariable{})
require.NoError(t, err)
vc.ExpectLog(t, []string{
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ks.-20: dummy_query {} ks.20-: dummy_query {} false false`,
})
require.Nil(t, qr.Rows)
require.Equal(t, 4, len(qr.Fields))
})

vc.Rewind()
t.Run("GetFields - a dml query", func(t *testing.T) {
send.IsDML = true
qr, err := send.GetFields(context.Background(), vc, map[string]*querypb.BindVariable{})
require.NoError(t, err)
require.Empty(t, qr.Fields)
vc.ExpectLog(t, nil)
})
require.Nil(t, qr.Rows)
require.Equal(t, 4, len(qr.Fields))
}
9 changes: 2 additions & 7 deletions go/vt/vtgate/engine/sequential.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ import (

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// Sequential Primitive is used to execute DML statements in a fixed order.
// Any failure, stops the execution and returns.
type Sequential struct {
txNeeded
noFields

Sources []Primitive
}

Expand Down Expand Up @@ -92,11 +92,6 @@ func (s *Sequential) TryStreamExecute(ctx context.Context, vcursor VCursor, bind
return callback(qr)
}

// GetFields fetches the field info.
func (s *Sequential) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unreachable code for Sequential engine")
}

// Inputs returns the input primitives for this
func (s *Sequential) Inputs() ([]Primitive, []map[string]any) {
return s.Sources, nil
Expand Down
6 changes: 1 addition & 5 deletions go/vt/vtgate/engine/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type (
// Set contains the instructions to perform set.
Set struct {
noTxNeeded
noFields

Ops []SetOp
Input Primitive
Expand Down Expand Up @@ -142,11 +143,6 @@ func (s *Set) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars ma
return callback(result)
}

// GetFields implements the Primitive interface method.
func (s *Set) GetFields(context.Context, VCursor, map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return &sqltypes.Result{}, nil
}

// Inputs implements the Primitive interface
func (s *Set) Inputs() ([]Primitive, []map[string]any) {
return []Primitive{s.Input}, nil
Expand Down
6 changes: 1 addition & 5 deletions go/vt/vtgate/engine/singlerow.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var _ Primitive = (*SingleRow)(nil)
type SingleRow struct {
noInputs
noTxNeeded
noFields
}

// RouteType returns a description of the query routing type used by the primitive
Expand Down Expand Up @@ -65,11 +66,6 @@ func (s *SingleRow) TryStreamExecute(ctx context.Context, vcursor VCursor, bindV
return callback(res)
}

// GetFields fetches the field info.
func (s *SingleRow) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return &sqltypes.Result{}, nil
}

func (s *SingleRow) description() PrimitiveDescription {
return PrimitiveDescription{
OperatorType: "SingleRow",
Expand Down
7 changes: 1 addition & 6 deletions go/vt/vtgate/engine/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ type VindexValues struct {

// Update represents the instructions to perform an update.
type Update struct {
// Update does not take inputs
noInputs
noFields

*DML

Expand Down Expand Up @@ -81,11 +81,6 @@ func (upd *Update) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVa

}

// GetFields fetches the field info.
func (upd *Update) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, fmt.Errorf("BUG: unreachable code for %q", upd.Query)
}

// updateVindexEntries performs an update when a vindex is being modified
// by the statement.
// Note: the commit order may be different from the DML order because it's possible
Expand Down
7 changes: 2 additions & 5 deletions go/vt/vtgate/engine/upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ var _ Primitive = (*Upsert)(nil)
// if there is `Duplicate Key` error, it executes the update primitive.
type Upsert struct {
txNeeded
noFields

Upserts []upsert
}

Expand Down Expand Up @@ -69,11 +71,6 @@ func (u *Upsert) GetTableName() string {
return ""
}

// GetFields implements Primitive interface type.
func (u *Upsert) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return nil, vterrors.VT13001("unexpected to receive GetFields call for insert on duplicate key update query")
}

// TryExecute implements Primitive interface type.
func (u *Upsert) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
result := &sqltypes.Result{}
Expand Down
5 changes: 3 additions & 2 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1403,9 +1403,10 @@ func (e *Executor) prepare(ctx context.Context, safeSession *econtext.SafeSessio
}

switch stmtType {
case sqlparser.StmtSelect, sqlparser.StmtShow:
case sqlparser.StmtSelect, sqlparser.StmtShow,
sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete:
return e.handlePrepare(ctx, safeSession, sql, bindVars, logStats)
case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete,
case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet,
sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtAnalyze, sqlparser.StmtComment, sqlparser.StmtExplain, sqlparser.StmtFlush, sqlparser.StmtKill:
return nil, nil
}
Expand Down

0 comments on commit 042927f

Please sign in to comment.