Skip to content

Commit

Permalink
feature: add
Browse files Browse the repository at this point in the history
  • Loading branch information
Lvnszn committed Jul 16, 2023
1 parent c95f464 commit 708b7cf
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 4 deletions.
29 changes: 26 additions & 3 deletions pkg/runtime/optimize/dml/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
stmt := o.Stmt.(*ast.SelectStatement)
enableLocalMathComputation := ctx.Value(proto.ContextKeyEnableLocalComputation{}).(bool)
if enableLocalMathComputation && len(stmt.From) == 0 {
isLocalFlag := true
var columnList []string
var valueList []proto.Value
var (
isLocalFlag = true
isSequence = false
columnList []string
valueList []proto.Value
vts []*rule.VTable
)
for i := range stmt.Select {
switch selectItem := stmt.Select[i].(type) {
case *ast.SelectElementExpr:
Expand Down Expand Up @@ -94,8 +98,27 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
}
valueList = append(valueList, calculateRes)
columnList = append(columnList, stmt.Select[i].DisplayName())
case *ast.SelectElementColumn:
isSequence = true
if len(selectItem.Name) != 2 {
return nil, proto.ErrorNotFoundSequence
}
vt, ok := o.Rule.VTable(selectItem.Name[0])
if !ok {
return nil, proto.ErrorNotFoundSequence
}

vts = append(vts, vt)
}
}
if isSequence {
ret := &dml.LocalSequencePlan{
Stmt: stmt,
VTs: vts,
ColumnList: columnList,
}
ret.BindArgs(o.Args)
return ret, nil
}
if isLocalFlag {

Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime/plan/dml/local_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func (s *LocalSelectPlan) Type() proto.PlanType {
func (s *LocalSelectPlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) {
_, span := plan.Tracer.Start(ctx, "LocalSelectPlan.ExecIn")
defer span.End()
var theadLocalSelect thead.Thead
var (
theadLocalSelect thead.Thead
)

for i, item := range s.ColumnList {
sRes := s.Result[i].String()
Expand Down
95 changes: 95 additions & 0 deletions pkg/runtime/plan/dml/local_sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package dml

import (
"context"
"github.com/arana-db/arana/pkg/proto/rule"
"strings"
)

import (
consts "github.com/arana-db/arana/pkg/constants/mysql"
"github.com/arana-db/arana/pkg/dataset"
"github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/mysql/thead"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
"github.com/arana-db/arana/pkg/runtime/plan"
)

var _ proto.Plan = (*LocalSequencePlan)(nil)

type LocalSequencePlan struct {
plan.BasePlan
Stmt *ast.SelectStatement
VTs []*rule.VTable
ColumnList []string
}

func (s *LocalSequencePlan) Type() proto.PlanType {
return proto.PlanTypeQuery
}

func (s *LocalSequencePlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) {
_, span := plan.Tracer.Start(ctx, "LocalSequencePlan.ExecIn")

defer span.End()
var (
theadLocalSelect thead.Thead
columns []proto.Field
values []proto.Value
)

for idx := 0; s.Stmt.From == nil && idx < len(s.Stmt.Select); idx++ {
if seqColumn, ok := s.Stmt.Select[idx].(*ast.SelectElementColumn); ok && len(seqColumn.Name) == 2 {
_, seqFunc := seqColumn.Name[0], seqColumn.Name[1]
colName := seqColumn.Alias()
if colName == "" {
colName = strings.Join(seqColumn.Name, ".")
}
theadLocalSelect = append(theadLocalSelect, thead.Col{Name: colName, FieldType: consts.FieldTypeLong})

sequenceSupplier, ok := proto.GetSequenceSupplier(s.VTs[idx].GetAutoIncrement().Type)
if !ok {
return nil, proto.ErrorNotFoundSequence
}
seq := sequenceSupplier()
switch strings.ToLower(seqFunc) {
case "currval":
values = append(values, proto.NewValueInt64(seq.CurrentVal()))
case "nextval":
nextSeqVal, err := seq.Acquire(ctx)
if err != nil {
return nil, err
}
values = append(values, proto.NewValueInt64(nextSeqVal))
}
}
}

columns = theadLocalSelect.ToFields()
ds := &dataset.VirtualDataset{
Columns: columns,
}

ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(columns, values))
return resultx.New(resultx.WithDataset(ds)), nil

}
5 changes: 5 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ func (pi *defaultRuntime) Execute(ctx *proto.Context) (res proto.Result, warn ui
pi.Namespace().SlowLogger().Warnf("slow logs elapsed %v sql %s", since, ctx.GetQuery())
}
}()

if ctx.GetQuery() == "select student.currVal" {
print("test")
}

args := ctx.GetArgs()

if rcontext.IsDirect(ctx.Context) {
Expand Down
45 changes: 45 additions & 0 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,48 @@ func (s *IntegrationSuite) TestMysqlOptimizerHints() {
})
}
}

func (s *IntegrationSuite) TestSequence() {
var (
db = s.DB()
t = s.T()
)

type testCase struct {
sql string
exceptVal int64
}

for _, it := range [...]testCase{
{
"select student.nextVal",
0,
},
{
"select student.currVal",
0,
},
{
"select notexist.currVal",
-1,
},
} {
t.Run(it.sql, func(t *testing.T) {
rows, err := db.Query(it.sql)
if it.exceptVal == -1 {
assert.True(t, err != nil, err)
return
}
defer rows.Close()
assert.NoError(t, err, "should query successfully")
var val int64
records, _ := utils.PrintTable(rows)
val, err = strconv.ParseInt(records[0][0], 10, 64)
if it.sql == "select student.currVal" {
assert.Equal(t, it.exceptVal, val)
} else {
assert.Equal(t, true, val > it.exceptVal)
}
})
}
}

0 comments on commit 708b7cf

Please sign in to comment.