From 39cc3c193bc76b0a418407ec3793e056713de745 Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 14:28:34 +0200 Subject: [PATCH] [release-16.0] Protect `ExecuteFetchAsDBA` against multi-statements, excluding a sequence of `CREATE TABLE|VIEW`. (#14954) (#14982) Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com> Co-authored-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- .../backup/vtbackup/backup_only_test.go | 8 +- go/test/endtoend/cluster/vttablet_process.go | 28 +++++++ go/test/endtoend/clustertest/vtctld_test.go | 48 +++++++++++- go/test/endtoend/mysqlserver/main_test.go | 4 +- .../reparent/emergencyreparent/ers_test.go | 5 +- .../reparent/plannedreparent/reparent_test.go | 29 ++++--- go/test/endtoend/reparent/utils/utils.go | 9 +++ go/test/endtoend/vtorc/general/vtorc_test.go | 21 +++-- go/test/endtoend/vtorc/utils/utils.go | 20 +++++ go/vt/sqlparser/ast_test.go | 21 ++++- go/vt/vttablet/tabletmanager/rpc_query.go | 69 ++++++++++++++-- .../vttablet/tabletmanager/rpc_query_test.go | 78 +++++++++++++++++++ 12 files changed, 303 insertions(+), 37 deletions(-) diff --git a/go/test/endtoend/backup/vtbackup/backup_only_test.go b/go/test/endtoend/backup/vtbackup/backup_only_test.go index 3c0a2b6c279..782937a48d7 100644 --- a/go/test/endtoend/backup/vtbackup/backup_only_test.go +++ b/go/test/endtoend/backup/vtbackup/backup_only_test.go @@ -311,12 +311,12 @@ func tearDown(t *testing.T, initMysql bool) { } caughtUp := waitForReplicationToCatchup([]cluster.Vttablet{*replica1, *replica2}) require.True(t, caughtUp, "Timed out waiting for all replicas to catch up") - promoteCommands := "STOP SLAVE; RESET SLAVE ALL; RESET MASTER;" - disableSemiSyncCommands := "SET GLOBAL rpl_semi_sync_master_enabled = false; SET GLOBAL rpl_semi_sync_slave_enabled = false" + promoteCommands := []string{"STOP SLAVE", "RESET SLAVE ALL", "RESET MASTER"} + disableSemiSyncCommands := []string{"SET GLOBAL rpl_semi_sync_master_enabled = false", " SET GLOBAL rpl_semi_sync_slave_enabled = false"} for _, tablet := range []cluster.Vttablet{*primary, *replica1, *replica2} { - _, err := tablet.VttabletProcess.QueryTablet(promoteCommands, keyspaceName, true) + err := tablet.VttabletProcess.QueryTabletMultiple(promoteCommands, keyspaceName, true) require.Nil(t, err) - _, err = tablet.VttabletProcess.QueryTablet(disableSemiSyncCommands, keyspaceName, true) + err = tablet.VttabletProcess.QueryTabletMultiple(disableSemiSyncCommands, keyspaceName, true) require.Nil(t, err) } diff --git a/go/test/endtoend/cluster/vttablet_process.go b/go/test/endtoend/cluster/vttablet_process.go index 747bdc8fb39..13dd11819e8 100644 --- a/go/test/endtoend/cluster/vttablet_process.go +++ b/go/test/endtoend/cluster/vttablet_process.go @@ -435,6 +435,34 @@ func (vttablet *VttabletProcess) QueryTablet(query string, keyspace string, useD return executeQuery(conn, query) } +// QueryTabletMultiple lets you execute multiple queries -- without any +// results -- against the tablet. +func (vttablet *VttabletProcess) QueryTabletMultiple(queries []string, keyspace string, useDb bool) error { + conn, err := vttablet.TabletConn(keyspace, useDb) + if err != nil { + return err + } + defer conn.Close() + + for _, query := range queries { + log.Infof("Executing query %s (on %s)", query, vttablet.Name) + _, err := executeQuery(conn, query) + if err != nil { + return err + } + } + return nil +} + +// TabletConn opens a MySQL connection on this tablet +func (vttablet *VttabletProcess) TabletConn(keyspace string, useDb bool) (*mysql.Conn, error) { + if !useDb { + keyspace = "" + } + dbParams := NewConnParams(vttablet.DbPort, vttablet.DbPassword, path.Join(vttablet.Directory, "mysql.sock"), keyspace) + return vttablet.conn(&dbParams) +} + func (vttablet *VttabletProcess) defaultConn(dbname string) (*mysql.Conn, error) { dbParams := mysql.ConnParams{ Uname: "vt_dba", diff --git a/go/test/endtoend/clustertest/vtctld_test.go b/go/test/endtoend/clustertest/vtctld_test.go index 0ba4af1ee41..c741ce7fcb9 100644 --- a/go/test/endtoend/clustertest/vtctld_test.go +++ b/go/test/endtoend/clustertest/vtctld_test.go @@ -128,9 +128,51 @@ func testTabletStatus(t *testing.T) { } func testExecuteAsDba(t *testing.T) { - result, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ExecuteFetchAsDba", clusterInstance.Keyspaces[0].Shards[0].Vttablets[0].Alias, `SELECT 1 AS a`) - require.NoError(t, err) - assert.Equal(t, result, oneTableOutput) + tcases := []struct { + query string + result string + expectErr bool + }{ + { + query: "", + expectErr: true, + }, + { + query: "SELECT 1 AS a", + result: oneTableOutput, + }, + { + query: "SELECT 1 AS a; SELECT 1 AS a", + expectErr: true, + }, + { + query: "create table t(id int)", + result: "", + }, + { + query: "create table if not exists t(id int)", + result: "", + }, + { + query: "create table if not exists t(id int); create table if not exists t(id int);", + result: "", + }, + { + query: "create table if not exists t(id int); create table if not exists t(id int); SELECT 1 AS a", + expectErr: true, + }, + } + for _, tcase := range tcases { + t.Run(tcase.query, func(t *testing.T) { + result, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ExecuteFetchAsDba", clusterInstance.Keyspaces[0].Shards[0].Vttablets[0].Alias, tcase.query) + if tcase.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tcase.result, result) + } + }) + } } func testExecuteAsApp(t *testing.T) { diff --git a/go/test/endtoend/mysqlserver/main_test.go b/go/test/endtoend/mysqlserver/main_test.go index 42b4e6ea235..18b169e33d7 100644 --- a/go/test/endtoend/mysqlserver/main_test.go +++ b/go/test/endtoend/mysqlserver/main_test.go @@ -51,7 +51,7 @@ var ( PARTITION BY HASH( TO_DAYS(created) ) PARTITIONS 10; ` - createProcSQL = `use vt_test_keyspace; + createProcSQL = ` CREATE PROCEDURE testing() BEGIN delete from vt_insert_test; @@ -144,7 +144,7 @@ func TestMain(m *testing.M) { } primaryTabletProcess := clusterInstance.Keyspaces[0].Shards[0].PrimaryTablet().VttabletProcess - if _, err := primaryTabletProcess.QueryTablet(createProcSQL, keyspaceName, false); err != nil { + if _, err := primaryTabletProcess.QueryTablet(createProcSQL, keyspaceName, true); err != nil { return 1, err } diff --git a/go/test/endtoend/reparent/emergencyreparent/ers_test.go b/go/test/endtoend/reparent/emergencyreparent/ers_test.go index 8f6638ecb7e..fbd4770e15e 100644 --- a/go/test/endtoend/reparent/emergencyreparent/ers_test.go +++ b/go/test/endtoend/reparent/emergencyreparent/ers_test.go @@ -349,8 +349,11 @@ func TestNoReplicationStatusAndIOThreadStopped(t *testing.T) { tablets := clusterInstance.Keyspaces[0].Shards[0].Vttablets utils.ConfirmReplication(t, tablets[0], []*cluster.Vttablet{tablets[1], tablets[2], tablets[3]}) - err := clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[1].Alias, `STOP SLAVE; RESET SLAVE ALL`) + err := clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[1].Alias, `STOP SLAVE`) require.NoError(t, err) + err = clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[1].Alias, `RESET SLAVE ALL`) + require.NoError(t, err) + // err = clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[3].Alias, `STOP SLAVE IO_THREAD;`) require.NoError(t, err) // Run an additional command in the current primary which will only be acked by tablets[2] and be in its relay log. diff --git a/go/test/endtoend/reparent/plannedreparent/reparent_test.go b/go/test/endtoend/reparent/plannedreparent/reparent_test.go index ba8e17eb4d2..c1599cf093c 100644 --- a/go/test/endtoend/reparent/plannedreparent/reparent_test.go +++ b/go/test/endtoend/reparent/plannedreparent/reparent_test.go @@ -96,7 +96,7 @@ func TestPRSWithDrainedLaggingTablet(t *testing.T) { utils.ConfirmReplication(t, tablets[0], []*cluster.Vttablet{tablets[1], tablets[2], tablets[3]}) // make tablets[1 lag from the other tablets by setting the delay to a large number - utils.RunSQL(context.Background(), t, `stop slave;CHANGE MASTER TO MASTER_DELAY = 1999;start slave;`, tablets[1]) + utils.RunSQLs(context.Background(), t, []string{`stop slave`, `CHANGE MASTER TO MASTER_DELAY = 1999`, `start slave;`}, tablets[1]) // insert another row in tablets[1 utils.ConfirmReplication(t, tablets[0], []*cluster.Vttablet{tablets[2], tablets[3]}) @@ -213,26 +213,33 @@ func reparentFromOutside(t *testing.T, clusterInstance *cluster.LocalProcessClus } // commands to convert a replica to be writable - promoteReplicaCommands := "STOP SLAVE; RESET SLAVE ALL; SET GLOBAL read_only = OFF;" - utils.RunSQL(ctx, t, promoteReplicaCommands, tablets[1]) + promoteReplicaCommands := []string{"STOP SLAVE", "RESET SLAVE ALL", "SET GLOBAL read_only = OFF"} + utils.RunSQLs(ctx, t, promoteReplicaCommands, tablets[1]) // Get primary position _, gtID := cluster.GetPrimaryPosition(t, *tablets[1], utils.Hostname) // tablets[0] will now be a replica of tablets[1 - changeReplicationSourceCommands := fmt.Sprintf("RESET MASTER; RESET SLAVE; SET GLOBAL gtid_purged = '%s';"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1;"+ - "START SLAVE;", gtID, utils.Hostname, tablets[1].MySQLPort) - utils.RunSQL(ctx, t, changeReplicationSourceCommands, tablets[0]) + changeReplicationSourceCommands := []string{ + "RESET MASTER", + "RESET SLAVE", + fmt.Sprintf("SET GLOBAL gtid_purged = '%s'", gtID), + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", utils.Hostname, tablets[1].MySQLPort), + } + utils.RunSQLs(ctx, t, changeReplicationSourceCommands, tablets[0]) // Capture time when we made tablets[1 writable baseTime := time.Now().UnixNano() / 1000000000 // tablets[2 will be a replica of tablets[1 - changeReplicationSourceCommands = fmt.Sprintf("STOP SLAVE; RESET MASTER; SET GLOBAL gtid_purged = '%s';"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1;"+ - "START SLAVE;", gtID, utils.Hostname, tablets[1].MySQLPort) - utils.RunSQL(ctx, t, changeReplicationSourceCommands, tablets[2]) + changeReplicationSourceCommands = []string{ + "STOP SLAVE", + "RESET MASTER", + fmt.Sprintf("SET GLOBAL gtid_purged = '%s'", gtID), + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", utils.Hostname, tablets[1].MySQLPort), + "START SLAVE", + } + utils.RunSQLs(ctx, t, changeReplicationSourceCommands, tablets[2]) // To test the downPrimary, we kill the old primary first and delete its tablet record if downPrimary { diff --git a/go/test/endtoend/reparent/utils/utils.go b/go/test/endtoend/reparent/utils/utils.go index 8bab1fe4c0c..d9211d4f5f6 100644 --- a/go/test/endtoend/reparent/utils/utils.go +++ b/go/test/endtoend/reparent/utils/utils.go @@ -262,6 +262,15 @@ func getMysqlConnParam(tablet *cluster.Vttablet) mysql.ConnParams { return connParams } +// RunSQLs is used to run SQL commands directly on the MySQL instance of a vttablet +func RunSQLs(ctx context.Context, t *testing.T, sqls []string, tablet *cluster.Vttablet) (results []*sqltypes.Result) { + for _, sql := range sqls { + result := RunSQL(ctx, t, sql, tablet) + results = append(results, result) + } + return results +} + // RunSQL is used to run a SQL command directly on the MySQL instance of a vttablet func RunSQL(ctx context.Context, t *testing.T, sql string, tablet *cluster.Vttablet) *sqltypes.Result { tabletParams := getMysqlConnParam(tablet) diff --git a/go/test/endtoend/vtorc/general/vtorc_test.go b/go/test/endtoend/vtorc/general/vtorc_test.go index c0a845a5699..d03ba8af735 100644 --- a/go/test/endtoend/vtorc/general/vtorc_test.go +++ b/go/test/endtoend/vtorc/general/vtorc_test.go @@ -184,9 +184,13 @@ func TestVTOrcRepairs(t *testing.T) { t.Run("ReplicationFromOtherReplica", func(t *testing.T) { // point replica at otherReplica - changeReplicationSourceCommand := fmt.Sprintf("STOP SLAVE; RESET SLAVE ALL;"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1; START SLAVE", utils.Hostname, otherReplica.MySQLPort) - _, err := utils.RunSQL(t, changeReplicationSourceCommand, replica, "") + changeReplicationSourceCommands := []string{ + "STOP SLAVE", + "RESET SLAVE ALL", + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", utils.Hostname, otherReplica.MySQLPort), + "START SLAVE", + } + err := utils.RunSQLs(t, changeReplicationSourceCommands, replica, "") require.NoError(t, err) // wait until the source port is set back correctly by vtorc @@ -199,10 +203,13 @@ func TestVTOrcRepairs(t *testing.T) { t.Run("CircularReplication", func(t *testing.T) { // change the replication source on the primary - changeReplicationSourceCommands := fmt.Sprintf("STOP SLAVE; RESET SLAVE ALL;"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1;"+ - "START SLAVE;", replica.VttabletProcess.TabletHostname, replica.MySQLPort) - _, err := utils.RunSQL(t, changeReplicationSourceCommands, curPrimary, "") + changeReplicationSourceCommands := []string{ + "STOP SLAVE", + "RESET SLAVE ALL", + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", replica.VttabletProcess.TabletHostname, replica.MySQLPort), + "START SLAVE", + } + err := utils.RunSQLs(t, changeReplicationSourceCommands, curPrimary, "") require.NoError(t, err) // wait for curPrimary to reach stable state diff --git a/go/test/endtoend/vtorc/utils/utils.go b/go/test/endtoend/vtorc/utils/utils.go index 8d30e477e2d..4e2e87e8d6d 100644 --- a/go/test/endtoend/vtorc/utils/utils.go +++ b/go/test/endtoend/vtorc/utils/utils.go @@ -585,6 +585,26 @@ func RunSQL(t *testing.T, sql string, tablet *cluster.Vttablet, db string) (*sql return execute(t, conn, sql) } +// RunSQLs is used to run a list of SQL statements on the given tablet +func RunSQLs(t *testing.T, sqls []string, tablet *cluster.Vttablet, db string) error { + // Get Connection + tabletParams := getMysqlConnParam(tablet, db) + var timeoutDuration = time.Duration(5 * len(sqls)) + ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration*time.Second) + defer cancel() + conn, err := mysql.Connect(ctx, &tabletParams) + require.Nil(t, err) + defer conn.Close() + + // Run SQLs + for _, sql := range sqls { + if _, err := execute(t, conn, sql); err != nil { + return err + } + } + return nil +} + func execute(t *testing.T, conn *mysql.Conn, query string) (*sqltypes.Result, error) { t.Helper() return conn.ExecuteFetch(query, 1000, true) diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 88e9cedf55d..8a7a8d9d622 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -746,7 +746,26 @@ func TestSplitStatementToPieces(t *testing.T) { "`createtime` datetime NOT NULL DEFAULT NOW() COMMENT 'create time;'," + "`comment` varchar(100) NOT NULL DEFAULT '' COMMENT 'comment'," + "PRIMARY KEY (`id`))", - }} + }, { + input: "create table t1 (id int primary key); create table t2 (id int primary key);", + output: "create table t1 (id int primary key); create table t2 (id int primary key)", + }, { + input: ";;; create table t1 (id int primary key);;; ;create table t2 (id int primary key);", + output: " create table t1 (id int primary key);create table t2 (id int primary key)", + }, { + // The input doesn't have to be valid SQL statements! + input: ";create table t1 ;create table t2 (id;", + output: "create table t1 ;create table t2 (id", + }, { + // Ignore quoted semicolon + input: ";create table t1 ';';;;create table t2 (id;", + output: "create table t1 ';';create table t2 (id", + }, { + // Ignore quoted semicolon + input: "stop replica; start replica", + output: "stop replica; start replica", + }, + } for _, tcase := range testcases { t.Run(tcase.input, func(t *testing.T) { diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 835cb698d8b..254d370091f 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -23,11 +23,50 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" querypb "vitess.io/vitess/go/vt/proto/query" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" + "vitess.io/vitess/go/vt/proto/vtrpc" ) +// analyzeExecuteFetchAsDbaMultiQuery reutrns 'true' when at least one of the queries +// in the given SQL has a `/*vt+ allowZeroInDate=true */` directive. +func analyzeExecuteFetchAsDbaMultiQuery(sql string) (queries []string, parseable bool, countCreate int, allowZeroInDate bool, err error) { + queries, err = sqlparser.SplitStatementToPieces(sql) + if err != nil { + return nil, false, 0, false, err + } + if len(queries) == 0 { + return nil, false, 0, false, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "no statements found in query: %s", sql) + } + parseable = true + for _, query := range queries { + // Some of the queries we receive here are legitimately non-parseable by our + // current parser, such as `CHANGE REPLICATION SOURCE TO...`. We must allow + // them and so we skip parsing errors. + stmt, err := sqlparser.Parse(query) + if err != nil { + parseable = false + continue + } + switch stmt.(type) { + case *sqlparser.CreateTable, *sqlparser.CreateView: + countCreate++ + default: + } + + if cmnt, ok := stmt.(sqlparser.Commented); ok { + directives := cmnt.GetParsedComments().Directives() + if directives.IsSet("allowZeroInDate") { + allowZeroInDate = true + } + } + + } + return queries, parseable, countCreate, allowZeroInDate, nil +} + // ExecuteFetchAsDba will execute the given query, possibly disabling binlogs and reload schema. func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanagerdatapb.ExecuteFetchAsDbaRequest) (*querypb.QueryResult, error) { // get a connection @@ -51,20 +90,34 @@ func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanag _, _ = conn.ExecuteFetch("USE "+sqlescape.EscapeID(req.DbName), 1, false) } - // Handle special possible directives - var directives *sqlparser.CommentDirectives - if stmt, err := sqlparser.Parse(string(req.Query)); err == nil { - if cmnt, ok := stmt.(sqlparser.Commented); ok { - directives = cmnt.GetParsedComments().Directives() + statements, _, countCreate, allowZeroInDate, err := analyzeExecuteFetchAsDbaMultiQuery(string(req.Query)) + if err != nil { + return nil, err + } + if len(statements) > 1 { + // Up to v19, we allow multi-statement SQL in ExecuteFetchAsDba, but only for the specific case + // where all statements are CREATE TABLE or CREATE VIEW. This is to support `ApplySchema --batch-size`. + // In v20, we will not support multi statements whatsoever. + // v20 will throw an error by virtua of using ExecuteFetch instead of ExecuteFetchMulti. + if countCreate != len(statements) { + return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "multi statement queries are not supported in ExecuteFetchAsDba unless all are CREATE TABLE or CREATE VIEW") } } - if directives.IsSet("allowZeroInDate") { + if allowZeroInDate { if _, err := conn.ExecuteFetch("set @@session.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')", 1, false); err != nil { return nil, err } } - // run the query - result, err := conn.ExecuteFetch(string(req.Query), int(req.MaxRows), true /*wantFields*/) + // TODO(shlomi): we use ExecuteFetchMulti for backwards compatibility. In v20 we will not accept + // multi statement queries in ExecuteFetchAsDBA. This will be rewritten as: + // (in v20): result, err := ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/) + result, more, err := conn.ExecuteFetchMulti(string(req.Query), int(req.MaxRows), true /*wantFields*/) + for more { + _, more, _, err = conn.ReadQueryResult(0, false) + if err != nil { + return nil, err + } + } // re-enable binlogs if necessary if req.DisableBinlogs && !conn.IsClosed() { diff --git a/go/vt/vttablet/tabletmanager/rpc_query_test.go b/go/vt/vttablet/tabletmanager/rpc_query_test.go index f6167e24917..0490bea4798 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query_test.go +++ b/go/vt/vttablet/tabletmanager/rpc_query_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" @@ -33,6 +34,83 @@ import ( tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" ) +func TestAnalyzeExecuteFetchAsDbaMultiQuery(t *testing.T) { + tcases := []struct { + query string + count int + parseable bool + allowZeroInDate bool + allCreate bool + expectErr bool + }{ + { + query: "", + expectErr: true, + }, + { + query: "select * from t1 ; select * from t2", + count: 2, + parseable: true, + }, + { + query: "create table t(id int)", + count: 1, + allCreate: true, + parseable: true, + }, + { + query: "create table t(id int); create view v as select 1 from dual", + count: 2, + allCreate: true, + parseable: true, + }, + { + query: "create table t(id int); create view v as select 1 from dual; drop table t3", + count: 3, + allCreate: false, + parseable: true, + }, + { + query: "create /*vt+ allowZeroInDate=true */ table t (id int)", + count: 1, + allCreate: true, + allowZeroInDate: true, + parseable: true, + }, + { + query: "create table a (id int) ; create /*vt+ allowZeroInDate=true */ table b (id int)", + count: 2, + allCreate: true, + allowZeroInDate: true, + parseable: true, + }, + { + query: "stop replica; start replica", + count: 2, + parseable: false, + }, + { + query: "create table a (id int) ; --comment ; what", + count: 3, + parseable: false, + }, + } + for _, tcase := range tcases { + t.Run(tcase.query, func(t *testing.T) { + queries, parseable, countCreate, allowZeroInDate, err := analyzeExecuteFetchAsDbaMultiQuery(tcase.query) + if tcase.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tcase.count, len(queries)) + assert.Equal(t, tcase.parseable, parseable) + assert.Equal(t, tcase.allCreate, (countCreate == len(queries))) + assert.Equal(t, tcase.allowZeroInDate, allowZeroInDate) + } + }) + } +} + func TestTabletManager_ExecuteFetchAsDba(t *testing.T) { ctx := context.Background() cp := mysql.ConnParams{}