Skip to content

Commit

Permalink
wip - add reference directive
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Oct 8, 2024
1 parent f683c16 commit 2333fef
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 35 deletions.
92 changes: 63 additions & 29 deletions go/tester/comparing_query_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package tester
import (
"fmt"
log "github.com/sirupsen/logrus"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"
"vitess.io/vitess/go/vt/sqlparser"

Expand All @@ -31,37 +32,37 @@ type (
reporter Reporter
handleCreateTable CreateTableHandler
comparer utils.MySQLCompare
cluster *cluster.LocalProcessCluster
}
CreateTableHandler func(create *sqlparser.CreateTable) func()
ComparingQueryRunnerFactory struct{}
)

func (f ComparingQueryRunnerFactory) Close() {}

func (f ComparingQueryRunnerFactory) NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare) QueryRunner {
return newComparingQueryRunner(reporter, handleCreateTable, comparer)
func (f ComparingQueryRunnerFactory) NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare, cluster *cluster.LocalProcessCluster, table func(name string) (ks string, err error)) QueryRunner {
return newComparingQueryRunner(reporter, handleCreateTable, comparer, cluster)
}

func newComparingQueryRunner(
reporter Reporter,
handleCreateTable CreateTableHandler,
comparer utils.MySQLCompare,
cluster *cluster.LocalProcessCluster,
) *ComparingQueryRunner {
return &ComparingQueryRunner{
reporter: reporter,
handleCreateTable: handleCreateTable,
comparer: comparer,
cluster: cluster,
}
}

func (nqr ComparingQueryRunner) runQuery(q data.Query, expectedErrs bool, cfg QueryRunConfig) error {
if !cfg.vitess && !cfg.mysql {
return fmt.Errorf("both vitess and mysql are false")
}
return nqr.execute(q, expectedErrs, cfg.ast, cfg.vitess, cfg.mysql)
return nqr.execute(q, expectedErrs, cfg)
}

func (nqr *ComparingQueryRunner) execute(query data.Query, expectedErrs bool, ast sqlparser.Statement, vitess bool, mysql bool) error {
func (nqr *ComparingQueryRunner) execute(query data.Query, expectedErrs bool, cfg QueryRunConfig) error {
if len(query.Query) == 0 {
return nil
}
Expand All @@ -71,22 +72,22 @@ func (nqr *ComparingQueryRunner) execute(query data.Query, expectedErrs bool, as
expectedErrs = false
}()

if err := nqr.executeStmt(query.Query, ast, expectedErrs, vitess, mysql); err != nil {
if err := nqr.executeStmt(query.Query, cfg, expectedErrs); err != nil {
return fmt.Errorf("run \"%v\" at line %d err %v", query.Query, query.Line, err)
}

return nil
}

func (nqr *ComparingQueryRunner) executeStmt(query string, ast sqlparser.Statement, expectedErrs bool, vitess bool, mysql bool) (err error) {
_, commentOnly := ast.(*sqlparser.CommentOnly)
func (nqr *ComparingQueryRunner) executeStmt(query string, cfg QueryRunConfig, expectedErrs bool) (err error) {
_, commentOnly := cfg.ast.(*sqlparser.CommentOnly)
if commentOnly {
return nil
}

log.Debugf("executeStmt: %s", query)
create, isCreateStatement := ast.(*sqlparser.CreateTable)
if isCreateStatement && !expectedErrs && vitess {
create, isCreateStatement := cfg.ast.(*sqlparser.CreateTable)
if isCreateStatement && !expectedErrs && cfg.vitess {
closer := nqr.handleCreateTable(create)
defer func() {
if err == nil {
Expand All @@ -97,19 +98,21 @@ func (nqr *ComparingQueryRunner) executeStmt(query string, ast sqlparser.Stateme

switch {
case expectedErrs:
err := nqr.execAndExpectErr(query, vitess, mysql)
err := nqr.execAndExpectErr(query)
if err != nil {
nqr.reporter.AddFailure(err)
}
default:
var err error
switch {
case vitess && !mysql:
case cfg.reference:
return nqr.executeReference(query, cfg.ast)
case cfg.mysql && cfg.vitess:
nqr.comparer.Exec(query)
case cfg.vitess:
_, err = nqr.comparer.VtConn.ExecuteFetch(query, 1000, true)
case mysql && !vitess:
case cfg.mysql:
_, err = nqr.comparer.MySQLConn.ExecuteFetch(query, 1000, true)
case mysql:
nqr.comparer.Exec(query)
}
if err != nil {
nqr.reporter.AddFailure(err)
Expand All @@ -118,21 +121,52 @@ func (nqr *ComparingQueryRunner) executeStmt(query string, ast sqlparser.Stateme
return nil
}

func (nqr *ComparingQueryRunner) execAndExpectErr(query string, vitess bool, mysql bool) error {
var err error
switch {
case vitess && !mysql:
_, err = nqr.comparer.VtConn.ExecuteFetch(query, 1000, true)
case mysql && !vitess:
_, err = nqr.comparer.MySQLConn.ExecuteFetch(query, 1000, true)
case mysql:
_, err = nqr.comparer.ExecAllowAndCompareError(query, utils.CompareOptions{CompareColumnNames: true})
return err
}

func (nqr *ComparingQueryRunner) execAndExpectErr(query string) error {
_, err := nqr.comparer.ExecAllowAndCompareError(query, utils.CompareOptions{CompareColumnNames: true})
if err == nil {
// If we expected an error, but didn't get one, return an error
return fmt.Errorf("expected error, but got none")
}
return nil
}

func (nqr *ComparingQueryRunner) executeReference(query string, ast sqlparser.Statement) error {
_, err := nqr.comparer.MySQLConn.ExecuteFetch(query, 1000, true)
if err != nil {
return err
}

tables := sqlparser.ExtractAllTables(ast)
if len(tables) != 1 {
return fmt.Errorf("expected exactly one table in the query, got %d", len(tables))
}

tableName := tables[0]

tbl, err := vschema.FindTable("" /*empty means global search*/, tableName)
if err != nil {
return err
}

for _, ks := range nqr.cluster.Keyspaces {
if ks.Name == tbl.Keyspace.Name {
for _, shard := range ks.Shards {
_, err := nqr.comparer.VtConn.ExecuteFetch(fmt.Sprintf("use `%s/%s`", ks.Name, shard.Name), 1000, true)
if err != nil {
return fmt.Errorf("error setting keyspace/shard: %w", err)
}
_, err = nqr.comparer.VtConn.ExecuteFetch(query, 1000, true)
if err != nil {
return fmt.Errorf("error executing query on vtgate: %w", err)
}
}
q := fmt.Sprintf("use %s", ks.Name)
_, err = nqr.comparer.VtConn.ExecuteFetch(q, 1000, true)
if err != nil {
return fmt.Errorf("error setting keyspace: %s %w", q, err)
}
}
}

return nil
}
8 changes: 4 additions & 4 deletions go/tester/tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ type (
}

QueryRunConfig struct {
ast sqlparser.Statement
vitess, mysql bool
ast sqlparser.Statement
vitess, mysql, reference bool
}

QueryRunner interface {
runQuery(q data.Query, expectedErrs bool, cfg QueryRunConfig) error
}

QueryRunnerFactory interface {
NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare) QueryRunner
NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare, cluster *cluster.LocalProcessCluster, table func(name string) (ks string, err error)) QueryRunner
Close()
}
)
Expand Down Expand Up @@ -117,7 +117,7 @@ func NewTester(
if !t.autoVSchema() {
createTableHandler = func(*sqlparser.CreateTable) func() { return func() {} }
}
t.qr = factory.NewQueryRunner(reporter, createTableHandler, mcmp)
t.qr = factory.NewQueryRunner(reporter, createTableHandler, mcmp, clusterInstance, t.findTable)

return t
}
Expand Down
5 changes: 3 additions & 2 deletions go/tester/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"vitess.io/vitess/go/test/endtoend/cluster"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/utils"
Expand Down Expand Up @@ -38,8 +39,8 @@ func NewTracerFactory(traceFile *os.File, inner QueryRunnerFactory) *TracerFacto
}
}

func (t *TracerFactory) NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare) QueryRunner {
inner := t.inner.NewQueryRunner(reporter, handleCreateTable, comparer)
func (t *TracerFactory) NewQueryRunner(reporter Reporter, handleCreateTable CreateTableHandler, comparer utils.MySQLCompare, cluster *cluster.LocalProcessCluster, table func(name string) (ks string, err error)) QueryRunner {
inner := t.inner.NewQueryRunner(reporter, handleCreateTable, comparer, cluster, APA)

Check failure on line 43 in go/tester/tracer.go

View workflow job for this annotation

GitHub Actions / test

undefined: APA

Check failure on line 43 in go/tester/tracer.go

View workflow job for this annotation

GitHub Actions / test

undefined: APA
return newTracer(t.traceFile, comparer.MySQLConn, comparer.VtConn, reporter, inner)
}

Expand Down
1 change: 1 addition & 0 deletions t/demo.test
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ create table name_idx
primary key (name, customer_id)
);

--reference
insert into pincode_areas(pincode, area_name)
values (110001, 'Connaught Place'),
(110002, 'Lodhi Road'),
Expand Down

0 comments on commit 2333fef

Please sign in to comment.