diff --git a/main.go b/main.go index de61b91..89ad296 100644 --- a/main.go +++ b/main.go @@ -16,9 +16,13 @@ package main import ( "flag" "fmt" + "io" + "net/http" "os" + "time" log "github.com/sirupsen/logrus" + "vitess.io/vitess/go/test/endtoend/cluster" "github.com/vitessio/vitess-tester/src/cmd" vitess_tester "github.com/vitessio/vitess-tester/src/vitess-tester" @@ -92,9 +96,9 @@ func main() { if xunit { reporterSuite = vitess_tester.NewXMLTestSuite() } else { - reporterSuite = vitess_tester.NewFileReporterSuite() + reporterSuite = vitess_tester.NewFileReporterSuite(getVschema(clusterInstance)) } - failed := cmd.ExecuteTests(clusterInstance, vtParams, mysqlParams, tests, reporterSuite, ksNames, vschemaFile, vtexplainVschemaFile, olap, traceFile) + failed := cmd.ExecuteTests(clusterInstance, vtParams, mysqlParams, tests, reporterSuite, ksNames, vschemaFile, vtexplainVschemaFile, olap, getQueryRunnerFactory()) outputFile := reporterSuite.Close() if failed { log.Errorf("some tests failed 😭\nsee errors in %v", outputFile) @@ -102,3 +106,40 @@ func main() { } println("Great, All tests passed") } + +func getQueryRunnerFactory() vitess_tester.QueryRunnerFactory { + inner := vitess_tester.ComparingQueryRunnerFactory{} + if traceFile == "" { + return inner + } + + var err error + writer, err := os.Create(traceFile) + if err != nil { + panic(err) + } + _, err = writer.Write([]byte("[")) + if err != nil { + panic(err.Error()) + } + return vitess_tester.NewTracerFactory(writer, inner) +} + +func getVschema(clusterInstance *cluster.LocalProcessCluster) func() []byte { + return func() []byte { + httpClient := &http.Client{Timeout: 5 * time.Second} + resp, err := httpClient.Get(clusterInstance.VtgateProcess.VSchemaURL) + if err != nil { + log.Errorf(err.Error()) + return nil + } + defer resp.Body.Close() + res, err := io.ReadAll(resp.Body) + if err != nil { + log.Errorf(err.Error()) + return nil + } + + return res + } +} diff --git a/src/cmd/cmd.go b/src/cmd/cmd.go index cd894e3..5ce4659 100644 --- a/src/cmd/cmd.go +++ b/src/cmd/cmd.go @@ -53,28 +53,16 @@ func ExecuteTests( ksNames []string, vschemaFile, vtexplainVschemaFile string, olap bool, - traceFile string, + factory vitess_tester.QueryRunnerFactory, ) (failed bool) { vschemaF := vschemaFile if vschemaF == "" { vschemaF = vtexplainVschemaFile } - var writer *os.File - if traceFile != "" { - // create the file and store the writer in the Tester struct - var err error - writer, err = os.Create(traceFile) - if err != nil { - panic(err) - } - _, err = writer.Write([]byte("[")) - if err != nil { - panic(err.Error()) - } - } + for _, name := range fileNames { errReporter := s.NewReporterForFile(name) - vTester := vitess_tester.NewTester(name, errReporter, clusterInstance, vtParams, mysqlParams, olap, ksNames, vschema, vschemaF, writer) + vTester := vitess_tester.NewTester(name, errReporter, clusterInstance, vtParams, mysqlParams, olap, ksNames, vschema, vschemaF, factory) err := vTester.Run() if err != nil { failed = true @@ -83,17 +71,10 @@ func ExecuteTests( failed = failed || errReporter.Failed() s.CloseReportForFile() } - if writer != nil { - _, err := writer.Write([]byte("]")) - if err != nil { - panic(err.Error()) - } - err = writer.Close() - if err != nil { - panic(err.Error()) - } - } - return + + factory.Close() + + return failed } func SetupCluster( diff --git a/src/vitess-tester/comparing_query_runner.go b/src/vitess-tester/comparing_query_runner.go new file mode 100644 index 0000000..88e0205 --- /dev/null +++ b/src/vitess-tester/comparing_query_runner.go @@ -0,0 +1,103 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vitess_tester + +import ( + "fmt" + + "github.com/pingcap/errors" + log "github.com/sirupsen/logrus" + "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/vt/sqlparser" +) + +type ( + // ComparingQueryRunner is a QueryRunner that compares the results of the queries between MySQL and Vitess + ComparingQueryRunner struct { + reporter Reporter + handleCreateTable CreateTableHandler + comparer utils.MySQLCompare + } + CreateTableHandler func(create *sqlparser.CreateTable) func() + ComparingQueryRunnerFactory struct{} +) + +func (f ComparingQueryRunnerFactory) Close() {} + +func (f ComparingQueryRunnerFactory) NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare) QueryRunner { + return newComparingQueryRunner(reporter, handleCreateTable, comparer) +} + +func newComparingQueryRunner( + reporter Reporter, + handleCreateTable CreateTableHandler, + comparer utils.MySQLCompare, +) *ComparingQueryRunner { + return &ComparingQueryRunner{ + reporter: reporter, + handleCreateTable: handleCreateTable, + comparer: comparer, + } +} + +func (nqr ComparingQueryRunner) runQuery(q query, expectedErrs bool, ast sqlparser.Statement) error { + return nqr.execute(q, expectedErrs, ast) +} + +func (nqr *ComparingQueryRunner) execute(query query, expectedErrs bool, ast sqlparser.Statement) error { + if len(query.Query) == 0 { + return nil + } + + if err := nqr.executeStmt(query.Query, ast, expectedErrs); err != nil { + return errors.Trace(errors.Errorf("run \"%v\" at line %d err %v", query.Query, query.Line, err)) + } + // clear expected errors after we execute + expectedErrs = false + + return nil +} + +func (nqr *ComparingQueryRunner) executeStmt(query string, ast sqlparser.Statement, expectedErrs bool) (err error) { + _, commentOnly := ast.(*sqlparser.CommentOnly) + if commentOnly { + return nil + } + + log.Debugf("executeStmt: %s", query) + create, isCreateStatement := ast.(*sqlparser.CreateTable) + if isCreateStatement && !expectedErrs { + closer := nqr.handleCreateTable(create) + defer func() { + if err == nil { + closer() + } + }() + } + + switch { + case expectedErrs: + _, err := nqr.comparer.ExecAllowAndCompareError(query, utils.CompareOptions{CompareColumnNames: true}) + if err == nil { + // If we expected an error, but didn't get one, return an error + return fmt.Errorf("expected error, but got none") + } + default: + _ = nqr.comparer.Exec(query) + } + return nil +} diff --git a/src/vitess-tester/reporter.go b/src/vitess-tester/reporter.go index 059daae..cb2cc31 100644 --- a/src/vitess-tester/reporter.go +++ b/src/vitess-tester/reporter.go @@ -23,27 +23,32 @@ import ( "path" "strings" "time" + + "vitess.io/vitess/go/test/endtoend/utils" ) type Suite interface { NewReporterForFile(name string) Reporter CloseReportForFile() - Close() string + Close() string // returns the path to the file or directory with files } type Reporter interface { + utils.TestingT AddTestCase(query string, lineNo int) EndTestCase() - AddFailure(vschema []byte, err error) + AddFailure(err error) AddInfo(info string) Report() string Failed() bool } -type FileReporterSuite struct{} +type FileReporterSuite struct { + getVschema func() []byte +} func (frs *FileReporterSuite) NewReporterForFile(name string) Reporter { - return newFileReporter(name) + return newFileReporter(name, frs.getVschema) } func (frs *FileReporterSuite) CloseReportForFile() {} @@ -52,8 +57,10 @@ func (frs *FileReporterSuite) Close() string { return "errors" } -func NewFileReporterSuite() *FileReporterSuite { - return &FileReporterSuite{} +func NewFileReporterSuite(getVschema func() []byte) *FileReporterSuite { + return &FileReporterSuite{ + getVschema: getVschema, + } } type FileReporter struct { @@ -69,12 +76,15 @@ type FileReporter struct { failureCount int queryCount int successCount int + + getVschema func() []byte } -func newFileReporter(name string) *FileReporter { +func newFileReporter(name string, getVschema func() []byte) *FileReporter { return &FileReporter{ - name: name, - startTime: time.Now(), + name: name, + startTime: time.Now(), + getVschema: getVschema, } } @@ -115,7 +125,7 @@ func (e *FileReporter) EndTestCase() { } } -func (e *FileReporter) AddFailure(vschema []byte, err error) { +func (e *FileReporter) AddFailure(err error) { e.failureCount++ e.currentQueryFailed = true if e.currentQuery == "" { @@ -130,7 +140,7 @@ func (e *FileReporter) AddFailure(vschema []byte, err error) { panic("failed to write error file\n" + err.Error()) } - e.createVSchemaDump(vschema) + e.createVSchemaDump() } func (e *FileReporter) AddInfo(info string) { @@ -162,14 +172,14 @@ func (e *FileReporter) createErrorFileFor() *os.File { return file } -func (e *FileReporter) createVSchemaDump(vschema []byte) { +func (e *FileReporter) createVSchemaDump() { errorDir := e.errorDir() err := os.MkdirAll(errorDir, PERM) if err != nil { panic("failed to create vschema directory\n" + err.Error()) } - err = os.WriteFile(path.Join(errorDir, "vschema.json"), vschema, PERM) + err = os.WriteFile(path.Join(errorDir, "vschema.json"), e.getVschema(), PERM) if err != nil { panic("failed to write vschema\n" + err.Error()) } @@ -189,4 +199,14 @@ func (e *FileReporter) errorDir() string { return path.Join("errors", errFileName) } +func (e *FileReporter) Errorf(format string, args ...interface{}) { + e.AddFailure(fmt.Errorf(format, args...)) +} + +func (e *FileReporter) FailNow() { + // we don't need to do anything here +} + +func (e *FileReporter) Helper() {} + var _ Reporter = (*FileReporter)(nil) diff --git a/src/vitess-tester/tester.go b/src/vitess-tester/tester.go index 9f99046..8b38626 100644 --- a/src/vitess-tester/tester.go +++ b/src/vitess-tester/tester.go @@ -20,13 +20,13 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "os" "strconv" "strings" "time" - "vitess.io/vitess/go/vt/vterrors" "github.com/pingcap/errors" log "github.com/sirupsen/logrus" @@ -37,30 +37,42 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -type Tester struct { - name string - - clusterInstance *cluster.LocalProcessCluster - vtParams, mysqlParams mysql.ConnParams - curr utils.MySQLCompare - - skipBinary string - skipVersion int - skipNext bool - olap bool - ksNames []string - vschema vindexes.VSchema - vschemaFile string - vexplain string - - // check expected error, use --error before the statement - // we only care if an error is returned, not the exact error message. - expectedErrs bool - - reporter Reporter - traceFile io.Writer - alreadyWrittenTraces bool // we need to keep track of it is the first trace or not, to add commas in between traces -} +type ( + Tester struct { + name string + + clusterInstance *cluster.LocalProcessCluster + vtParams, mysqlParams mysql.ConnParams + curr utils.MySQLCompare + + skipBinary string + skipVersion int + skipNext bool + olap bool + ksNames []string + vschema vindexes.VSchema + vschemaFile string + vexplain string + + // check expected error, use --error before the statement + // we only care if an error is returned, not the exact error message. + expectedErrs bool + + reporter Reporter + alreadyWrittenTraces bool // we need to keep track of it is the first trace or not, to add commas in between traces + + qr QueryRunner + } + + QueryRunner interface { + runQuery(q query, expectedErrs bool, ast sqlparser.Statement) error + } + + QueryRunnerFactory interface { + NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare) QueryRunner + Close() + } +) func NewTester( name string, @@ -71,7 +83,7 @@ func NewTester( ksNames []string, vschema vindexes.VSchema, vschemaFile string, - traceFile *os.File, + factory QueryRunnerFactory, ) *Tester { t := &Tester{ name: name, @@ -84,18 +96,18 @@ func NewTester( vschemaFile: vschemaFile, olap: olap, } - if traceFile != nil { - t.traceFile = traceFile - } - return t -} -func (t *Tester) preProcess() { - mcmp, err := utils.NewMySQLCompare(t, t.vtParams, t.mysqlParams) + mcmp, err := utils.NewMySQLCompare(t.reporter, t.vtParams, t.mysqlParams) if err != nil { panic(err.Error()) } t.curr = mcmp + t.qr = factory.NewQueryRunner(reporter, t.handleCreateTable, mcmp) + + return t +} + +func (t *Tester) preProcess() { if t.olap { _, err := t.curr.VtConn.ExecuteFetch("set workload = 'olap'", 0, false) if err != nil { @@ -117,25 +129,23 @@ func (t *Tester) postProcess() { var PERM os.FileMode = 0755 -func (t *Tester) addSuccess() { - -} +func (t *Tester) getVschema() func() []byte { + return func() []byte { + httpClient := &http.Client{Timeout: 5 * time.Second} + resp, err := httpClient.Get(t.clusterInstance.VtgateProcess.VSchemaURL) + if err != nil { + log.Errorf(err.Error()) + return nil + } + defer resp.Body.Close() + res, err := io.ReadAll(resp.Body) + if err != nil { + log.Errorf(err.Error()) + return nil + } -func (t *Tester) getVschema() []byte { - httpClient := &http.Client{Timeout: 5 * time.Second} - resp, err := httpClient.Get(t.clusterInstance.VtgateProcess.VSchemaURL) - if err != nil { - log.Errorf(err.Error()) - return nil - } - defer resp.Body.Close() - res, err := io.ReadAll(resp.Body) - if err != nil { - log.Errorf(err.Error()) - return nil + return res } - - return res } func (t *Tester) Run() error { @@ -145,7 +155,7 @@ func (t *Tester) Run() error { } queries, err := t.loadQueries() if err != nil { - t.reporter.AddFailure(t.getVschema(), err) + t.reporter.AddFailure(err) return err } @@ -167,18 +177,18 @@ func (t *Tester) Run() error { case Q_SKIP: t.skipNext = true case Q_BEGIN_CONCURRENT, Q_END_CONCURRENT, Q_CONNECT, Q_CONNECTION, Q_DISCONNECT, Q_LET, Q_REPLACE_COLUMN: - t.reporter.AddFailure(t.getVschema(), fmt.Errorf("%s not supported", String(q.tp))) + t.reporter.AddFailure(fmt.Errorf("%s not supported", String(q.tp))) case Q_SKIP_IF_BELOW_VERSION: strs := strings.Split(q.Query, " ") if len(strs) != 3 { - t.reporter.AddFailure(t.getVschema(), fmt.Errorf("incorrect syntax for Q_SKIP_IF_BELOW_VERSION in: %v", q.Query)) + t.reporter.AddFailure(fmt.Errorf("incorrect syntax for Q_SKIP_IF_BELOW_VERSION in: %v", q.Query)) continue } t.skipBinary = strs[1] var err error t.skipVersion, err = strconv.Atoi(strs[2]) if err != nil { - t.reporter.AddFailure(t.getVschema(), err) + t.reporter.AddFailure(err) continue } case Q_ERROR: @@ -186,7 +196,7 @@ func (t *Tester) Run() error { case Q_VEXPLAIN: strs := strings.Split(q.Query, " ") if len(strs) != 2 { - t.reporter.AddFailure(t.getVschema(), fmt.Errorf("incorrect syntax for Q_VEXPLAIN in: %v", q.Query)) + t.reporter.AddFailure(fmt.Errorf("incorrect syntax for Q_VEXPLAIN in: %v", q.Query)) continue } @@ -194,6 +204,16 @@ func (t *Tester) Run() error { case Q_WAIT_FOR_AUTHORITATIVE: t.waitAuthoritative(q.Query) case Q_QUERY: + if t.vexplain != "" { + result, err := t.curr.VtConn.ExecuteFetch(fmt.Sprintf("vexplain %s %s", t.vexplain, q.Query), -1, false) + t.vexplain = "" + if err != nil { + t.reporter.AddFailure(err) + } + + t.reporter.AddInfo(fmt.Sprintf("VExplain Output:\n %s\n", result.Rows[0][0].ToString())) + } + t.runQuery(q) case Q_REMOVE_FILE: err = os.Remove(strings.TrimSpace(q.Query)) @@ -201,7 +221,7 @@ func (t *Tester) Run() error { return errors.Annotate(err, "failed to remove file") } default: - t.reporter.AddFailure(t.getVschema(), fmt.Errorf("%s not supported", String(q.tp))) + t.reporter.AddFailure(fmt.Errorf("%s not supported", String(q.tp))) } } fmt.Printf("%s\n", t.reporter.Report()) @@ -222,18 +242,16 @@ func (t *Tester) runQuery(q query) { } } t.reporter.AddTestCase(q.Query, q.Line) - if t.vexplain != "" { - result, err := t.curr.VtConn.ExecuteFetch("vexplain "+t.vexplain+" "+q.Query, -1, false) - t.vexplain = "" - if err != nil { - t.reporter.AddFailure(t.getVschema(), err) - return - } - - t.reporter.AddInfo(fmt.Sprintf("VExplain Output:\n %s\n", result.Rows[0][0].ToString())) + parser := sqlparser.NewTestParser() + ast, err := parser.Parse(q.Query) + if err != nil { + t.reporter.AddFailure(err) + return } - if err := t.execute(q); err != nil && !t.expectedErrs { - t.reporter.AddFailure(t.getVschema(), err) + + err = t.qr.runQuery(q, t.expectedErrs, ast) + if err != nil { + t.reporter.AddFailure(err) } t.reporter.EndTestCase() // clear expected errors and current query after we execute any query @@ -266,7 +284,7 @@ func (t *Tester) waitAuthoritative(query string) { var err error ksName, err = t.findTable(tblName) if err != nil { - t.reporter.AddFailure(t.getVschema(), err) + t.reporter.AddFailure(err) return } case 3: @@ -274,13 +292,13 @@ func (t *Tester) waitAuthoritative(query string) { ksName = strs[2] default: - t.reporter.AddFailure(t.getVschema(), fmt.Errorf("expected table name and keyspace for wait_authoritative in: %v", query)) + t.reporter.AddFailure(fmt.Errorf("expected table name and keyspace for wait_authoritative in: %v", query)) } log.Infof("Waiting for authoritative schema for table %s", tblName) - err := utils.WaitForAuthoritative(t, ksName, tblName, t.clusterInstance.VtgateProcess.ReadVSchema) + err := utils.WaitForAuthoritative(t.reporter, ksName, tblName, t.clusterInstance.VtgateProcess.ReadVSchema) if err != nil { - t.reporter.AddFailure(t.getVschema(), fmt.Errorf("failed to wait for authoritative schema for table %s: %v", tblName, err)) + t.reporter.AddFailure(fmt.Errorf("failed to wait for authoritative schema for table %s: %v", tblName, err)) } } @@ -339,91 +357,6 @@ func (t *Tester) readData() ([]byte, error) { return os.ReadFile(t.name) } -func (t *Tester) execute(query query) error { - if len(query.Query) == 0 { - return nil - } - - parser := sqlparser.NewTestParser() - ast, err := parser.Parse(query.Query) - if err != nil { - return err - } - - if sqlparser.IsDMLStatement(ast) && t.traceFile != nil && !t.expectedErrs { - // we don't want to run DMLs twice, so we just run them once while tracing - var errs []error - err := t.trace(query) - if err != nil { - errs = append(errs, err) - } - - // we need to run the DMLs on mysql as well - _, err = t.curr.MySQLConn.ExecuteFetch(query.Query, 10000, false) - if err != nil { - errs = append(errs, err) - } - return vterrors.Aggregate(errs) - } - - if err = t.executeStmt(query.Query, ast); err != nil { - return errors.Trace(errors.Errorf("run \"%v\" at line %d err %v", query.Query, query.Line, err)) - } - // clear expected errors after we execute - t.expectedErrs = false - - if t.traceFile == nil { - return nil - } - - _, isDDL := ast.(sqlparser.DDLStatement) - if isDDL { - return nil - } - - return t.trace(query) -} - -// trace writes the query and its trace (fetched from VtConn) as a JSON object into traceFile -func (t *Tester) trace(query query) error { - // Marshal the query into JSON format for safe embedding - queryJSON, err := json.Marshal(query.Query) - if err != nil { - return err - } - - // Fetch the trace for the query using "vexplain trace" - rs, err := t.curr.VtConn.ExecuteFetch(fmt.Sprintf("vexplain trace %s", query.Query), 10000, false) - if err != nil { - return err - } - - // Extract the trace result and format it with indentation for pretty printing - var prettyTrace bytes.Buffer - if err := json.Indent(&prettyTrace, []byte(rs.Rows[0][0].ToString()), "", " "); err != nil { - return err - } - - // Construct the entire JSON entry in memory - var traceEntry bytes.Buffer - if t.alreadyWrittenTraces { - traceEntry.WriteString(",") // Prepend a comma if there are already written traces - } - traceEntry.WriteString(fmt.Sprintf(`{"Query": %s, "LineNumber": "%d", "Trace": `, queryJSON, query.Line)) - traceEntry.Write(prettyTrace.Bytes()) // Add the formatted trace - traceEntry.WriteString("}") // Close the JSON object - - // Mark that at least one trace has been written - t.alreadyWrittenTraces = true - - // Write the fully constructed JSON entry to the file - if _, err := t.traceFile.Write(traceEntry.Bytes()); err != nil { - return err - } - - return nil -} - func newPrimaryKeyIndexDefinitionSingleColumn(name sqlparser.IdentifierCI) *sqlparser.IndexDefinition { index := &sqlparser.IndexDefinition{ Info: &sqlparser.IndexInfo{ @@ -435,39 +368,6 @@ func newPrimaryKeyIndexDefinitionSingleColumn(name sqlparser.IdentifierCI) *sqlp return index } -func (t *Tester) executeStmt(query string, ast sqlparser.Statement) error { - _, commentOnly := ast.(*sqlparser.CommentOnly) - if commentOnly { - return nil - } - - log.Debugf("executeStmt: %s", query) - create, isCreateStatement := ast.(*sqlparser.CreateTable) - handleVSchema := isCreateStatement && !t.expectedErrs && t.autoVSchema() - if handleVSchema { - t.handleCreateTable(create) - } - - switch { - case t.expectedErrs: - _, err := t.curr.ExecAllowAndCompareError(query, utils.CompareOptions{CompareColumnNames: true}) - if err == nil { - // If we expected an error, but didn't get one, return an error - return fmt.Errorf("expected error, but got none") - } - default: - _ = t.curr.Exec(query) - } - - if handleVSchema { - err := utils.WaitForAuthoritative(t, t.ksNames[0], create.Table.Name.String(), t.clusterInstance.VtgateProcess.ReadVSchema) - if err != nil { - panic(err) - } - } - return nil -} - func (t *Tester) autoVSchema() bool { return t.vschemaFile == "" } @@ -499,7 +399,7 @@ func getShardingKeysForTable(create *sqlparser.CreateTable) (sks []sqlparser.Ide return } -func (t *Tester) handleCreateTable(create *sqlparser.CreateTable) { +func (t *Tester) handleCreateTable(create *sqlparser.CreateTable) func() { sks := getShardingKeysForTable(create) shardingKeys := &vindexes.ColumnVindex{ @@ -525,14 +425,11 @@ func (t *Tester) handleCreateTable(create *sqlparser.CreateTable) { if err != nil { panic(err) } -} -func (t *Tester) Errorf(format string, args ...interface{}) { - t.reporter.AddFailure(t.getVschema(), errors.Errorf(format, args...)) -} - -func (t *Tester) FailNow() { - // we don't need to do anything here + return func() { + err := utils.WaitForAuthoritative(t.reporter, t.ksNames[0], create.Table.Name.String(), t.clusterInstance.VtgateProcess.ReadVSchema) + if err != nil { + panic(err) + } + } } - -func (t *Tester) Helper() {} diff --git a/src/vitess-tester/tracer.go b/src/vitess-tester/tracer.go new file mode 100644 index 0000000..d6d7f38 --- /dev/null +++ b/src/vitess-tester/tracer.go @@ -0,0 +1,138 @@ +package vitess_tester + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" +) + +var _ QueryRunner = (*Tracer)(nil) +var _ QueryRunnerFactory = (*TracerFactory)(nil) + +type ( + Tracer struct { + traceFile *os.File + MySQLConn, VtConn *mysql.Conn + reporter Reporter + inner QueryRunner + alreadyWrittenTraces bool + } + TracerFactory struct { + traceFile *os.File + inner QueryRunnerFactory + } +) + +func NewTracerFactory(traceFile *os.File, inner QueryRunnerFactory) *TracerFactory { + return &TracerFactory{ + traceFile: traceFile, + inner: inner, + } +} + +func (t *TracerFactory) NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare) QueryRunner { + inner := t.inner.NewQueryRunner(reporter, handleCreateTable, comparer) + return newTracer(t.traceFile, comparer.MySQLConn, comparer.VtConn, reporter, inner) +} + +func (t *TracerFactory) Close() { + _, err := t.traceFile.Write([]byte("]")) + if err != nil { + panic(err.Error()) + } + err = t.traceFile.Close() + if err != nil { + panic(err.Error()) + } +} + +func newTracer(traceFile *os.File, + mySQLConn, vtConn *mysql.Conn, + reporter Reporter, + inner QueryRunner, +) QueryRunner { + return &Tracer{ + traceFile: traceFile, + MySQLConn: mySQLConn, + VtConn: vtConn, + reporter: reporter, + inner: inner, + } +} + +func (t *Tracer) runQuery(q query, expectErr bool, ast sqlparser.Statement) error { + if sqlparser.IsDMLStatement(ast) && t.traceFile != nil && !expectErr { + // we don't want to run DMLs twice, so we just run them once while tracing + var errs []error + err := t.trace(q) + if err != nil { + errs = append(errs, err) + } + + // we need to run the DMLs on mysql as well + _, err = t.MySQLConn.ExecuteFetch(q.Query, 10000, false) + if err != nil { + errs = append(errs, err) + } + + return vterrors.Aggregate(errs) + } + + err := t.inner.runQuery(q, expectErr, ast) + if err != nil { + return err + } + + _, isDDL := ast.(sqlparser.DDLStatement) + if isDDL { + // we don't want to trace DDLs + return nil + } + + return t.trace(q) +} + +// trace writes the query and its trace (fetched from VtConn) as a JSON object into traceFile +func (t *Tracer) trace(query query) error { + // Marshal the query into JSON format for safe embedding + queryJSON, err := json.Marshal(query.Query) + if err != nil { + return err + } + + // Fetch the trace for the query using "vexplain trace" + rs, err := t.VtConn.ExecuteFetch(fmt.Sprintf("vexplain trace %s", query.Query), 10000, false) + if err != nil { + return err + } + + // Extract the trace result and format it with indentation for pretty printing + var prettyTrace bytes.Buffer + if err := json.Indent(&prettyTrace, []byte(rs.Rows[0][0].ToString()), "", " "); err != nil { + return err + } + + // Construct the entire JSON entry in memory + var traceEntry bytes.Buffer + if t.alreadyWrittenTraces { + traceEntry.WriteString(",") // Prepend a comma if there are already written traces + } + traceEntry.WriteString(fmt.Sprintf(`{"Query": %s, "LineNumber": "%d", "Trace": `, queryJSON, query.Line)) + traceEntry.Write(prettyTrace.Bytes()) // Add the formatted trace + traceEntry.WriteString("}") // Close the JSON object + + // Mark that at least one trace has been written + t.alreadyWrittenTraces = true + + // Write the fully constructed JSON entry to the file + if _, err := t.traceFile.Write(traceEntry.Bytes()); err != nil { + return err + } + + return nil +} diff --git a/src/vitess-tester/xunit.go b/src/vitess-tester/xunit.go index 1706d24..ce83773 100644 --- a/src/vitess-tester/xunit.go +++ b/src/vitess-tester/xunit.go @@ -77,10 +77,10 @@ func (xml *XMLTestSuite) EndTestCase() { xml.currTestCase = nil } -func (xml *XMLTestSuite) AddFailure(vschema []byte, err error) { +func (xml *XMLTestSuite) AddFailure(err error) { if xml.currTestCase == nil { xml.AddTestCase("SETUP", 0) - xml.AddFailure(vschema, err) + xml.AddFailure(err) xml.EndTestCase() return } @@ -118,3 +118,13 @@ func (xml *XMLTestSuite) AddInfo(info string) { xml.currTestCase.SystemOut.Data += info + "\n" } } + +func (xml *XMLTestSuite) Errorf(format string, args ...interface{}) { + xml.AddFailure(fmt.Errorf(format, args...)) +} + +func (xml *XMLTestSuite) FailNow() { + // we don't need to do anything here +} + +func (xml *XMLTestSuite) Helper() {}