diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 3608dafe..f471bcb6 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -56,9 +56,13 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err stmt := o.Stmt.(*ast.SelectStatement) enableLocalMathComputation := ctx.Value(proto.ContextKeyEnableLocalComputation{}).(bool) if enableLocalMathComputation && len(stmt.From) == 0 { - isLocalFlag := true - var columnList []string - var valueList []proto.Value + var ( + isLocalFlag = true + isSequence = false + columnList []string + valueList []proto.Value + vts []*rule.VTable + ) for i := range stmt.Select { switch selectItem := stmt.Select[i].(type) { case *ast.SelectElementExpr: @@ -94,9 +98,26 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err } valueList = append(valueList, calculateRes) columnList = append(columnList, stmt.Select[i].DisplayName()) - + case *ast.SelectElementColumn: + if len(selectItem.Name) == 2 && (strings.ToLower(selectItem.Name[1]) == "currval" || strings.ToLower(selectItem.Name[1]) == "nextval") { + isSequence = true + vt, ok := o.Rule.VTable(selectItem.Name[0]) + if !ok { + return nil, proto.ErrorNotFoundSequence + } + vts = append(vts, vt) + } } } + if isSequence { + ret := &dml.LocalSequencePlan{ + Stmt: stmt, + VTs: vts, + ColumnList: columnList, + } + ret.BindArgs(o.Args) + return ret, nil + } if isLocalFlag { ret := &dml.LocalSelectPlan{ diff --git a/pkg/runtime/plan/dml/local_select.go b/pkg/runtime/plan/dml/local_select.go index 18779097..022694ed 100644 --- a/pkg/runtime/plan/dml/local_select.go +++ b/pkg/runtime/plan/dml/local_select.go @@ -49,7 +49,9 @@ func (s *LocalSelectPlan) Type() proto.PlanType { func (s *LocalSelectPlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) { _, span := plan.Tracer.Start(ctx, "LocalSelectPlan.ExecIn") defer span.End() - var theadLocalSelect thead.Thead + var ( + theadLocalSelect thead.Thead + ) for i, item := range s.ColumnList { sRes := s.Result[i].String() diff --git a/pkg/runtime/plan/dml/local_sequence.go b/pkg/runtime/plan/dml/local_sequence.go new file mode 100644 index 00000000..62eba4d6 --- /dev/null +++ b/pkg/runtime/plan/dml/local_sequence.go @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/pkg/errors" + "strings" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/mysql/thead" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +var _ proto.Plan = (*LocalSequencePlan)(nil) + +type LocalSequencePlan struct { + plan.BasePlan + Stmt *ast.SelectStatement + VTs []*rule.VTable + ColumnList []string +} + +func (s *LocalSequencePlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (s *LocalSequencePlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) { + _, span := plan.Tracer.Start(ctx, "LocalSequencePlan.ExecIn") + + defer span.End() + var ( + theadLocalSelect thead.Thead + columns []proto.Field + values []proto.Value + ) + + for idx := 0; s.Stmt.From == nil && idx < len(s.Stmt.Select); idx++ { + if seqColumn, ok := s.Stmt.Select[idx].(*ast.SelectElementColumn); ok && len(seqColumn.Name) == 2 { + seqName, seqFunc := seqColumn.Name[0], seqColumn.Name[1] + colName := seqColumn.Alias() + if colName == "" { + colName = strings.Join(seqColumn.Name, ".") + } + theadLocalSelect = append(theadLocalSelect, thead.Col{Name: colName, FieldType: consts.FieldTypeLong}) + seq, err := proto.LoadSequenceManager().GetSequence(ctx, rcontext.Tenant(ctx), rcontext.Schema(ctx), seqName) + if err != nil { + return nil, errors.WithStack(err) + } + + switch strings.ToLower(seqFunc) { + case "currval": + values = append(values, proto.NewValueInt64(seq.(proto.EnhancedSequence).CurrentVal())) + case "nextval": + nextSeqVal, err := seq.Acquire(ctx) + if err != nil { + return nil, err + } + values = append(values, proto.NewValueInt64(nextSeqVal)) + } + } + } + + columns = theadLocalSelect.ToFields() + ds := &dataset.VirtualDataset{ + Columns: columns, + } + + ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(columns, values)) + return resultx.New(resultx.WithDataset(ds)), nil + +} diff --git a/test/integration_test.go b/test/integration_test.go index a95a1bb8..33f6a747 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -21,8 +21,11 @@ package test import ( + "context" "database/sql" "fmt" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime" "sort" "strconv" "strings" @@ -1354,11 +1357,6 @@ func (s *IntegrationSuite) TestMysqlOptimizerHints() { } func (s *IntegrationSuite) TestExplain() { - var ( - db = s.DB() - t = s.T() - ) - type tt struct { sql string } @@ -1376,3 +1374,60 @@ func (s *IntegrationSuite) TestExplain() { }) } } + +func (s *IntegrationSuite) TestSequence() { + var ( + db = s.DB() + t = s.T() + ) + + rt, err := runtime.Load("arana", "employees") + if err != nil { + panic(err) + } + ctx := context.WithValue(context.Background(), proto.RuntimeCtxKey{}, rt) + ctx = context.WithValue(ctx, proto.ContextKeyTenant{}, "arana") + ctx = context.WithValue(ctx, proto.ContextKeySchema{}, "employees") + _, err = proto.LoadSequenceManager().CreateSequence(ctx, "arana", "employees", proto.SequenceConfig{Name: "student", Type: "group"}) + if err != nil { + panic(err) + } + + type testCase struct { + sql string + exceptVal int64 + } + + for _, it := range [...]testCase{ + { + "select student.nextVal", + 1, + }, + { + "select student.currVal", + 1, + }, + { + "select student.nextVal", + 2, + }, + { + "select notexist.currVal", + -1, + }, + } { + t.Run(it.sql, func(t *testing.T) { + rows, err := db.Query(it.sql) + if it.exceptVal == -1 { + assert.True(t, err != nil, err) + return + } + defer rows.Close() + assert.NoError(t, err, "should query successfully") + var val int64 + records, _ := utils.PrintTable(rows) + val, err = strconv.ParseInt(records[0][0], 10, 64) + assert.Equal(t, it.exceptVal, val) + }) + } +}