From ee24ba9ea4580e0473b2540ff0367dfbd675b303 Mon Sep 17 00:00:00 2001 From: marioooo0 <35252889+marioooo0@users.noreply.github.com> Date: Wed, 19 Jul 2023 11:17:32 +0800 Subject: [PATCH] feat: support explain statement by loop shards (#695) * support explain by loop all shards and ordered by table name * fix * Add license header. * fix explain select 1 && import format --------- Co-authored-by: maorui --- pkg/dataset/parallel.go | 22 +++ pkg/runtime/ast/ast.go | 14 ++ pkg/runtime/ast/describe.go | 3 +- pkg/runtime/ast/proto.go | 2 + pkg/runtime/optimize/utility/explain.go | 54 +++++++ pkg/runtime/plan/utility/explain.go | 194 ++++++++++++++++++++++++ test/integration_test.go | 19 +++ 7 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 pkg/runtime/optimize/utility/explain.go create mode 100644 pkg/runtime/plan/utility/explain.go diff --git a/pkg/dataset/parallel.go b/pkg/dataset/parallel.go index c3202c19..e85c7abc 100644 --- a/pkg/dataset/parallel.go +++ b/pkg/dataset/parallel.go @@ -229,6 +229,28 @@ func Parallel(first GenerateFunc, others ...GenerateFunc) (RandomAccessDataset, }, nil } +type parallelBuilder struct { + genFuns []GenerateFunc +} + +func NewParallelBuilder() parallelBuilder { + return parallelBuilder{} +} + +func (pb *parallelBuilder) Add(genFunc GenerateFunc) { + pb.genFuns = append(pb.genFuns, genFunc) +} + +func (pb *parallelBuilder) Build() (RandomAccessDataset, error) { + if len(pb.genFuns) == 0 { + return nil, errors.New("failed to create parallel datasets") + } + if len(pb.genFuns) == 1 { + return Parallel(pb.genFuns[0], nil) + } + return Parallel(pb.genFuns[0], pb.genFuns[1:]...) +} + // Peekable converts a dataset to a peekable one. func Peekable(origin proto.Dataset) PeekableDataset { return &peekableDataset{ diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index 443581ac..e8333c62 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -115,6 +115,20 @@ func FromStmtNode(node ast.StmtNode) (Statement, error) { switch tgt := result.(type) { case *ShowColumns: return &DescribeStatement{Table: tgt.TableName, Column: tgt.Column}, nil + case *SelectStatement: + if len(tgt.From) != 0 { + return &ExplainStatement{Target: tgt, Table: tgt.From[0].Source.(TableName)}, nil + } else { + return &ExplainStatement{Target: tgt}, nil + } + case *DeleteStatement: + return &ExplainStatement{Target: tgt, Table: tgt.Table}, nil + case *InsertStatement: + return &ExplainStatement{Target: tgt, Table: tgt.Table}, nil + case *InsertSelectStatement: + return &ExplainStatement{Target: tgt, Table: tgt.Table}, nil + case *UpdateStatement: + return &ExplainStatement{Target: tgt, Table: tgt.Table}, nil default: return &ExplainStatement{Target: tgt}, nil } diff --git a/pkg/runtime/ast/describe.go b/pkg/runtime/ast/describe.go index 4ee4c1e9..e9b12137 100644 --- a/pkg/runtime/ast/describe.go +++ b/pkg/runtime/ast/describe.go @@ -57,6 +57,7 @@ func (d *DescribeStatement) Mode() SQLType { // ExplainStatement represents mysql explain statement. see https://dev.mysql.com/doc/refman/8.0/en/explain.html type ExplainStatement struct { Target Statement + Table TableName } func (e *ExplainStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { @@ -68,5 +69,5 @@ func (e *ExplainStatement) Restore(flag RestoreFlag, sb *strings.Builder, args * } func (e *ExplainStatement) Mode() SQLType { - return SQLTypeSelect + return SQLTypeExplain } diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go index e75b7aca..73c55c86 100644 --- a/pkg/runtime/ast/proto.go +++ b/pkg/runtime/ast/proto.go @@ -47,6 +47,7 @@ const ( SQLTypeShowShardingTable // SHOW SHARDING TABLE SQLTypeShowCreateSequence // SHOW CREATE SEQUENCE SQLTypeDescribe // DESCRIBE + SQLTypeExplain // EXPLAIN SQLTypeUnion // UNION SQLTypeDropTrigger // DROP TRIGGER SQLTypeCreateIndex // CREATE INDEX @@ -89,6 +90,7 @@ var _sqlTypeNames = [...]string{ SQLTypeShowVariables: "SHOW VARIABLES", SQLTypeShowCreateSequence: "SHOW CREATE SEQUENCE", SQLTypeDescribe: "DESCRIBE", + SQLTypeExplain: "EXPLAIN", SQLTypeUnion: "UNION", SQLTypeDropTrigger: "DROP TRIGGER", SQLTypeCreateIndex: "CREATE INDEX", diff --git a/pkg/runtime/optimize/utility/explain.go b/pkg/runtime/optimize/utility/explain.go new file mode 100644 index 00000000..cabb841a --- /dev/null +++ b/pkg/runtime/optimize/utility/explain.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utility + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/utility" +) + +func init() { + optimize.Register(ast.SQLTypeExplain, optimzeExplainStatement) +} + +func optimzeExplainStatement(ctx context.Context, o *optimize.Optimizer) (proto.Plan, error) { + stmt := o.Stmt.(*ast.ExplainStatement) + + ret := utility.NewExplainPlan(stmt) + + var ( + shards rule.DatabaseTables + err error + ) + + shards, err = o.ComputeShards(ctx, stmt.Table, nil, o.Args) + if err != nil { + return nil, err + } + + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/plan/utility/explain.go b/pkg/runtime/plan/utility/explain.go new file mode 100644 index 00000000..7cd05a57 --- /dev/null +++ b/pkg/runtime/plan/utility/explain.go @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utility + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + constant "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +type ExplainPlan struct { + plan.BasePlan + stmt *ast.ExplainStatement + dataBase string + shards rule.DatabaseTables +} + +func NewExplainPlan(stmt *ast.ExplainStatement) *ExplainPlan { + return &ExplainPlan{stmt: stmt} +} + +func (e *ExplainPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (e *ExplainPlan) ExecIn(ctx context.Context, vConn proto.VConn) (proto.Result, error) { + if e.shards == nil || e.shards.IsEmpty() { + return resultx.New(), nil + } + + var ( + sb strings.Builder + stmt = new(ast.ExplainStatement) + args []int + ) + + // prepare + sb.Grow(256) + *stmt = *e.stmt + + //build parallel dataset + pBuilder := dataset.NewParallelBuilder() + + for db, tables := range e.shards { + for _, table := range tables { + if err := e.resetTargetTable(table); err != nil { + return nil, err + } + if err := stmt.Restore(ast.RestoreDefault, &sb, &args); err != nil { + return nil, errors.Wrap(err, "failed to restore EXPLAIN statement") + } + + res, err := e.execOne(ctx, vConn, db, sb.String(), e.ToArgs(args)) + if err != nil { + return nil, errors.WithStack(err) + } + var ( + rr = res.(*mysql.RawResult) + fields []proto.Field + ) + + ds, err := rr.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + if fields, err = ds.Fields(); err != nil { + return nil, errors.WithStack(err) + } + + // add column table_name to result + newField := append([]proto.Field{mysql.NewField("table_name", constant.FieldTypeVarString)}, fields...) + ds = dataset.Pipe(ds, + dataset.Map( + func(oriField []proto.Field) []proto.Field { + return newField + }, + func(oriRow proto.Row) (proto.Row, error) { + oriVal := make([]proto.Value, len(fields)) + err = oriRow.Scan(oriVal) + if err != nil { + return nil, err + } + newVal := append([]proto.Value{proto.NewValueString(table)}, oriVal...) + if oriRow.IsBinary() { + return rows.NewBinaryVirtualRow(newField, newVal), nil + } + return rows.NewTextVirtualRow(newField, newVal), nil + })) + + // add single result to parallel ds + vcol, _ := ds.Fields() + vrow, err := ds.Next() + if err == nil { + pBuilder.Add(func() (proto.Dataset, error) { + return &dataset.VirtualDataset{Columns: vcol, Rows: []proto.Row{vrow}}, nil + }) + } + + // cleanup + if len(args) > 0 { + args = args[:0] + } + sb.Reset() + rr.Discard() + } + } + + // parallel ds + pDs, err := pBuilder.Build() + if err != nil { + return nil, errors.WithStack(err) + } + + // order by table_name + return resultx.New(resultx.WithDataset(dataset.NewOrderedDataset(pDs, []dataset.OrderByItem{{ + Column: "table_name", + Desc: false, + }}))), nil +} + +func (e *ExplainPlan) SetShards(shards rule.DatabaseTables) { + e.shards = shards +} + +func (e *ExplainPlan) resetTargetTable(table string) error { + switch e.stmt.Target.Mode() { + case ast.SQLTypeSelect: + targetStmt, ok := e.stmt.Target.(*ast.SelectStatement) + if !ok { + return errors.New("fail to get explain target statement") + } + // reset table for select stmt is complicated + targetTable := targetStmt.From[0].Source.(ast.TableName) + targetStmt.From[0].Source = targetTable.ResetSuffix(table) + return nil + case ast.SQLTypeDelete: + targetStmt, ok := e.stmt.Target.(*ast.DeleteStatement) + if !ok { + return errors.New("fail to get explain target statement") + } + targetStmt.Table = targetStmt.Table.ResetSuffix(table) + return nil + case ast.SQLTypeInsert: + targetStmt, ok := e.stmt.Target.(*ast.InsertStatement) + if !ok { + return errors.New("fail to get explain target statement") + } + targetStmt.Table = targetStmt.Table.ResetSuffix(table) + return nil + case ast.SQLTypeUpdate: + targetStmt, ok := e.stmt.Target.(*ast.UpdateStatement) + if !ok { + return errors.New("fail to get explain target statement") + } + targetStmt.Table = targetStmt.Table.ResetSuffix(table) + return nil + } + return errors.New("no target statement found for explain statement") +} + +func (e *ExplainPlan) execOne(ctx context.Context, conn proto.VConn, db, query string, args []proto.Value) (proto.Result, error) { + return conn.Query(ctx, db, query, args...) +} diff --git a/test/integration_test.go b/test/integration_test.go index 1364a9c1..4584ad89 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -1412,3 +1412,22 @@ func (s *IntegrationSuite) TestSequence() { }) } } + +func (s *IntegrationSuite) TestExplain() { + type tt struct { + sql string + } + + for _, it := range [...]tt{ + {"explain select * from student where uid = 1"}, + {"explain delete from student where uid = 1"}, + {"explain INSERT INTO student(uid,name) values(1,'name_1'),(2,'name_2'), (9,'name_3')"}, + {"explain update student set score=100.0 where uid = 1"}, + } { + t.Run(it.sql, func(t *testing.T) { + rows, err := db.Query(it.sql) + assert.NoError(t, err) + defer rows.Close() + }) + } +}