From a7a44c6108918f99c0c25b1c7d250452bb01cc44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vicent=20Mart=C3=AD?= <42793+vmg@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:28:20 +0100 Subject: [PATCH] evalengine: fix numeric coercibility (#14473) Signed-off-by: Vicent Marti Signed-off-by: Andres Taylor Co-authored-by: Andres Taylor --- go/sqltypes/type.go | 2 +- .../vtgate/queries/random/random_test.go | 1 + go/vt/vtgate/engine/opcode/constants.go | 6 + go/vt/vtgate/evalengine/api_literal.go | 2 +- go/vt/vtgate/evalengine/collation.go | 115 +++++++++++++++++- go/vt/vtgate/evalengine/compiler_asm.go | 4 +- go/vt/vtgate/evalengine/compiler_test.go | 13 ++ go/vt/vtgate/evalengine/eval.go | 20 +-- go/vt/vtgate/evalengine/eval_temporal.go | 10 +- go/vt/vtgate/evalengine/expr_bvar.go | 8 +- go/vt/vtgate/evalengine/expr_collate.go | 91 -------------- go/vt/vtgate/evalengine/expr_column.go | 2 +- go/vt/vtgate/evalengine/fn_base64.go | 6 +- go/vt/vtgate/evalengine/fn_crypto.go | 12 +- go/vt/vtgate/evalengine/fn_hex.go | 6 +- go/vt/vtgate/evalengine/fn_misc.go | 19 +-- go/vt/vtgate/evalengine/fn_numeric.go | 4 +- go/vt/vtgate/evalengine/fn_string.go | 14 +-- go/vt/vtgate/evalengine/fn_time.go | 20 +-- go/vt/vtgate/evalengine/translate.go | 2 +- .../planbuilder/operators/dml_planning.go | 9 +- go/vt/vtgate/planbuilder/operators/insert.go | 18 ++- .../planbuilder/operators/projection.go | 5 +- .../planbuilder/operators/route_planning.go | 3 +- go/vt/vtgate/planbuilder/operators/update.go | 2 +- .../planbuilder/testdata/tpch_cases.json | 8 +- go/vt/vtgate/semantics/dependencies.go | 1 + go/vt/vtgate/semantics/typer.go | 7 +- 28 files changed, 229 insertions(+), 181 deletions(-) diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index eeaa4e9ddf6..9157db685e9 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -136,7 +136,7 @@ func IsNull(t querypb.Type) bool { // switch statements for those who want to cover types // by their category. const ( - Unknown = -1 + Unknown = querypb.Type(-1) Null = querypb.Type_NULL_TYPE Int8 = querypb.Type_INT8 Uint8 = querypb.Type_UINT8 diff --git a/go/test/endtoend/vtgate/queries/random/random_test.go b/go/test/endtoend/vtgate/queries/random/random_test.go index 31b48b4ee52..aea43c2f929 100644 --- a/go/test/endtoend/vtgate/queries/random/random_test.go +++ b/go/test/endtoend/vtgate/queries/random/random_test.go @@ -339,4 +339,5 @@ func TestBuggyQueries(t *testing.T) { mcmp.Exec("select count(tbl1.dname) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.dname > tbl1.loc where tbl1.loc <=> tbl1.dname group by tbl1.dname order by tbl1.dname asc") mcmp.Exec("select count(*) from (select count(*) from dept as tbl0) as tbl0") mcmp.Exec("select count(*), count(*) from (select count(*) from dept as tbl0) as tbl0, dept as tbl1") + mcmp.Exec(`select distinct case max(tbl0.ename) when min(tbl0.job) then 'sole' else count(case when false then -27 when 'gazelle' then tbl0.deptno end) end as caggr0 from emp as tbl0`) } diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 07a39020f8b..dd73a78974d 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -139,6 +139,9 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { case AggregateUnassigned: return sqltypes.Null case AggregateGroupConcat: + if typ == sqltypes.Unknown { + return sqltypes.Unknown + } if sqltypes.IsBinary(typ) { return sqltypes.Blob } @@ -146,6 +149,9 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { case AggregateMax, AggregateMin, AggregateAnyValue: return typ case AggregateSumDistinct, AggregateSum: + if typ == sqltypes.Unknown { + return sqltypes.Unknown + } if sqltypes.IsIntegral(typ) || sqltypes.IsDecimal(typ) { return sqltypes.Decimal } diff --git a/go/vt/vtgate/evalengine/api_literal.go b/go/vt/vtgate/evalengine/api_literal.go index 5711c325fc6..f12988233e8 100644 --- a/go/vt/vtgate/evalengine/api_literal.go +++ b/go/vt/vtgate/evalengine/api_literal.go @@ -223,7 +223,7 @@ func NewColumn(offset int, typ Type, original sqlparser.Expr) *Column { return &Column{ Offset: offset, Type: typ.Type, - Collation: defaultCoercionCollation(typ.Coll), + Collation: typedCoercionCollation(typ.Type, typ.Coll), Original: original, dynamicTypeOffset: -1, } diff --git a/go/vt/vtgate/evalengine/collation.go b/go/vt/vtgate/evalengine/collation.go index 9d53a9d8ea9..7cb341f52b0 100644 --- a/go/vt/vtgate/evalengine/collation.go +++ b/go/vt/vtgate/evalengine/collation.go @@ -16,12 +16,115 @@ limitations under the License. package evalengine -import "vitess.io/vitess/go/mysql/collations" +import ( + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/colldata" + "vitess.io/vitess/go/sqltypes" +) -func defaultCoercionCollation(id collations.ID) collations.TypedCollation { - return collations.TypedCollation{ - Collation: id, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, +func typedCoercionCollation(typ sqltypes.Type, id collations.ID) collations.TypedCollation { + switch { + case sqltypes.IsNull(typ): + return collationNull + case sqltypes.IsNumber(typ) || sqltypes.IsDateOrTime(typ): + return collationNumeric + case typ == sqltypes.TypeJSON: + return collationJSON + default: + return collations.TypedCollation{ + Collation: id, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + } } } + +func evalCollation(e eval) collations.TypedCollation { + switch e := e.(type) { + case nil: + return collationNull + case evalNumeric, *evalTemporal: + return collationNumeric + case *evalJSON: + return collationJSON + case *evalBytes: + return e.col + default: + return collationBinary + } +} + +func mergeCollations(c1, c2 collations.TypedCollation, t1, t2 sqltypes.Type) (collations.TypedCollation, colldata.Coercion, colldata.Coercion, error) { + if c1.Collation == c2.Collation { + return c1, nil, nil, nil + } + + lt := sqltypes.IsText(t1) || sqltypes.IsBinary(t1) + rt := sqltypes.IsText(t2) || sqltypes.IsBinary(t2) + if !lt || !rt { + if lt { + return c1, nil, nil, nil + } + if rt { + return c2, nil, nil, nil + } + return collationBinary, nil, nil, nil + } + + env := collations.Local() + return colldata.Merge(env, c1, c2, colldata.CoercionOptions{ + ConvertToSuperset: true, + ConvertWithCoercion: true, + }) +} + +func mergeAndCoerceCollations(left, right eval) (eval, eval, collations.TypedCollation, error) { + lt := left.SQLType() + rt := right.SQLType() + + mc, coerceLeft, coerceRight, err := mergeCollations(evalCollation(left), evalCollation(right), lt, rt) + if err != nil { + return nil, nil, collations.TypedCollation{}, err + } + if coerceLeft == nil && coerceRight == nil { + return left, right, mc, nil + } + + left1 := newEvalRaw(lt, left.(*evalBytes).bytes, mc) + right1 := newEvalRaw(rt, right.(*evalBytes).bytes, mc) + + if coerceLeft != nil { + left1.bytes, err = coerceLeft(nil, left1.bytes) + if err != nil { + return nil, nil, collations.TypedCollation{}, err + } + } + if coerceRight != nil { + right1.bytes, err = coerceRight(nil, right1.bytes) + if err != nil { + return nil, nil, collations.TypedCollation{}, err + } + } + return left1, right1, mc, nil +} + +type collationAggregation struct { + cur collations.TypedCollation +} + +func (ca *collationAggregation) add(env *collations.Environment, tc collations.TypedCollation) error { + if ca.cur.Collation == collations.Unknown { + ca.cur = tc + } else { + var err error + ca.cur, _, _, err = colldata.Merge(env, ca.cur, tc, colldata.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true}) + if err != nil { + return err + } + } + return nil +} + +func (ca *collationAggregation) result() collations.TypedCollation { + return ca.cur +} diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 58ea99fcbef..6230627c26a 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -4165,13 +4165,13 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) { } tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal) - env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now) + env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.Unknown, env.now) env.vm.sp-- return 1 }, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)") } -func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.TypedCollation) { +func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.ID) { asm.adjustStack(-1) asm.emit(func(env *ExpressionEnv) int { var interval *datetime.Interval diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 9111d1f7090..4fba65aeb75 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -164,6 +164,7 @@ func TestCompilerSingle(t *testing.T) { expression string values []sqltypes.Value result string + collation collations.ID }{ { expression: "1 + column0", @@ -489,6 +490,12 @@ func TestCompilerSingle(t *testing.T) { expression: `'2020-01-01' + interval month(date_sub(FROM_UNIXTIME(1234), interval 1 month))-1 month`, result: `CHAR("2020-12-01")`, }, + { + expression: `case column0 when 1 then column1 else column2 end`, + values: []sqltypes.Value{sqltypes.NewInt64(42), sqltypes.NewVarChar("sole"), sqltypes.NewInt64(0)}, + result: `VARCHAR("0")`, + collation: collations.CollationUtf8mb4ID, + }, } tz, _ := time.LoadLocation("Europe/Madrid") @@ -524,6 +531,9 @@ func TestCompilerSingle(t *testing.T) { if expected.String() != tc.result { t.Fatalf("bad evaluation from eval engine: got %s, want %s", expected.String(), tc.result) } + if tc.collation != collations.Unknown && tc.collation != expected.Collation() { + t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation) + } // re-run the same evaluation multiple times to ensure results are always consistent for i := 0; i < 8; i++ { @@ -535,6 +545,9 @@ func TestCompilerSingle(t *testing.T) { if res.String() != tc.result { t.Errorf("bad evaluation from compiler: got %s, want %s (iteration %d)", res, tc.result, i) } + if tc.collation != collations.Unknown && tc.collation != res.Collation() { + t.Fatalf("bad collation evaluation from compiler: got %d, want %d", res.Collation(), tc.collation) + } } }) } diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index ee09f96cded..e327b9d5651 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -238,7 +238,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) return newEvalFloat(fval), nil default: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -265,7 +265,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -285,7 +285,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I i, err := fastparse.ParseInt64(v.RawStr(), 10) return newEvalInt64(i), err default: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -304,7 +304,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I u, err := fastparse.ParseUint64(v.RawStr(), 10) return newEvalUint64(u), err default: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -315,15 +315,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case sqltypes.IsText(typ) || sqltypes.IsBinary(typ): switch { case v.IsText() || v.IsBinary(): - return newEvalRaw(v.Type(), v.Raw(), defaultCoercionCollation(collation)), nil + return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil case sqltypes.IsText(typ): - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } return evalToVarchar(e, collation, true) default: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -333,7 +333,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case typ == sqltypes.TypeJSON: return json.NewFromSQL(v) case typ == sqltypes.Date: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -344,7 +344,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return d, nil case typ == sqltypes.Datetime || typ == sqltypes.Timestamp: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } @@ -355,7 +355,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return dt, nil case typ == sqltypes.Time: - e, err := valueToEval(v, defaultCoercionCollation(collation)) + e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation)) if err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index 34d1f17d7f8..d44839a6853 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -140,7 +140,7 @@ func (e *evalTemporal) isZero() bool { return e.dt.IsZero() } -func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval { +func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval { var tmp *evalTemporal var ok bool @@ -150,16 +150,16 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio tmp.dt.Date, ok = e.dt.Date.AddInterval(interval) case tt == sqltypes.Time && !interval.Unit().HasDateParts(): tmp = &evalTemporal{t: e.t} - tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid()) + tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, coll != collations.Unknown) case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()): tmp = e.toDateTime(int(e.prec), now) - tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid()) + tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, coll != collations.Unknown) } if !ok { return nil } - if strcoll.Valid() { - return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), strcoll) + if coll != collations.Unknown { + return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), typedCoercionCollation(sqltypes.Char, coll)) } return tmp } diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 4b08ee7683c..6bc49caf660 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -70,7 +70,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { tuple := make([]eval, 0, len(bvar.Values)) for _, value := range bvar.Values { - e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.CollationForType(value.Type, bv.Collation))) + e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation))) if err != nil { return nil, err } @@ -86,7 +86,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { if bv.typed() { typ = bv.Type } - return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.CollationForType(typ, bv.Collation))) + return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation))) } } @@ -110,7 +110,7 @@ func (bv *BindVariable) typeof(env *ExpressionEnv) (ctype, error) { case sqltypes.BitNum: return ctype{Type: sqltypes.VarBinary, Flag: flagBit, Col: collationNumeric}, nil default: - return ctype{Type: tt, Flag: 0, Col: defaultCoercionCollation(collations.CollationForType(tt, bv.Collation))}, nil + return ctype{Type: tt, Flag: 0, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil } } @@ -119,7 +119,7 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { if bvar.typed() { typ.Type = bvar.Type - typ.Col = defaultCoercionCollation(collations.CollationForType(bvar.Type, bvar.Collation)) + typ.Col = typedCoercionCollation(bvar.Type, collations.CollationForType(bvar.Type, bvar.Collation)) } else if c.dynamicTypes != nil { typ = c.dynamicTypes[bvar.dynamicTypeOffset] } else { diff --git a/go/vt/vtgate/evalengine/expr_collate.go b/go/vt/vtgate/evalengine/expr_collate.go index 60791311e28..47e65a0dcc7 100644 --- a/go/vt/vtgate/evalengine/expr_collate.go +++ b/go/vt/vtgate/evalengine/expr_collate.go @@ -18,7 +18,6 @@ package evalengine import ( "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -127,96 +126,6 @@ func (expr *CollateExpr) compile(c *compiler) (ctype, error) { return ct, nil } -func evalCollation(e eval) collations.TypedCollation { - switch e := e.(type) { - case nil: - return collationNull - case evalNumeric, *evalTemporal: - return collationNumeric - case *evalJSON: - return collationJSON - case *evalBytes: - return e.col - default: - return collationBinary - } -} - -func mergeCollations(c1, c2 collations.TypedCollation, t1, t2 sqltypes.Type) (collations.TypedCollation, colldata.Coercion, colldata.Coercion, error) { - if c1.Collation == c2.Collation { - return c1, nil, nil, nil - } - - lt := sqltypes.IsText(t1) || sqltypes.IsBinary(t1) - rt := sqltypes.IsText(t2) || sqltypes.IsBinary(t2) - if !lt || !rt { - if lt { - return c1, nil, nil, nil - } - if rt { - return c2, nil, nil, nil - } - return collationBinary, nil, nil, nil - } - - env := collations.Local() - return colldata.Merge(env, c1, c2, colldata.CoercionOptions{ - ConvertToSuperset: true, - ConvertWithCoercion: true, - }) -} - -func mergeAndCoerceCollations(left, right eval) (eval, eval, collations.TypedCollation, error) { - lt := left.SQLType() - rt := right.SQLType() - - mc, coerceLeft, coerceRight, err := mergeCollations(evalCollation(left), evalCollation(right), lt, rt) - if err != nil { - return nil, nil, collations.TypedCollation{}, err - } - if coerceLeft == nil && coerceRight == nil { - return left, right, mc, nil - } - - left1 := newEvalRaw(lt, left.(*evalBytes).bytes, mc) - right1 := newEvalRaw(rt, right.(*evalBytes).bytes, mc) - - if coerceLeft != nil { - left1.bytes, err = coerceLeft(nil, left1.bytes) - if err != nil { - return nil, nil, collations.TypedCollation{}, err - } - } - if coerceRight != nil { - right1.bytes, err = coerceRight(nil, right1.bytes) - if err != nil { - return nil, nil, collations.TypedCollation{}, err - } - } - return left1, right1, mc, nil -} - -type collationAggregation struct { - cur collations.TypedCollation -} - -func (ca *collationAggregation) add(env *collations.Environment, tc collations.TypedCollation) error { - if ca.cur.Collation == collations.Unknown { - ca.cur = tc - } else { - var err error - ca.cur, _, _, err = colldata.Merge(env, ca.cur, tc, colldata.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true}) - if err != nil { - return err - } - } - return nil -} - -func (ca *collationAggregation) result() collations.TypedCollation { - return ca.cur -} - var _ IR = (*IntroducerExpr)(nil) func (expr *IntroducerExpr) eval(env *ExpressionEnv) (eval, error) { diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index 93670b5ca12..741d04c6a06 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -68,7 +68,7 @@ func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { return ctype{ Type: field.Type, - Col: defaultCoercionCollation(collations.ID(field.Charset)), + Col: typedCoercionCollation(field.Type, collations.ID(field.Charset)), Flag: f, }, nil } diff --git a/go/vt/vtgate/evalengine/fn_base64.go b/go/vt/vtgate/evalengine/fn_base64.go index dfc5e037629..d404d391dd6 100644 --- a/go/vt/vtgate/evalengine/fn_base64.go +++ b/go/vt/vtgate/evalengine/fn_base64.go @@ -82,9 +82,9 @@ func (call *builtinToBase64) eval(env *ExpressionEnv) (eval, error) { encoded := mysqlBase64Encode(b.bytes) if arg.SQLType() == sqltypes.Blob || arg.SQLType() == sqltypes.TypeJSON { - return newEvalRaw(sqltypes.Text, encoded, defaultCoercionCollation(call.collate)), nil + return newEvalRaw(sqltypes.Text, encoded, typedCoercionCollation(sqltypes.Text, call.collate)), nil } - return newEvalText(encoded, defaultCoercionCollation(call.collate)), nil + return newEvalText(encoded, typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinToBase64) compile(c *compiler) (ctype, error) { @@ -106,7 +106,7 @@ func (call *builtinToBase64) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, t, 0, false) } - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(t, c.collation) c.asm.Fn_TO_BASE64(t, col) c.asm.jumpDestination(skip) diff --git a/go/vt/vtgate/evalengine/fn_crypto.go b/go/vt/vtgate/evalengine/fn_crypto.go index a10183dd12f..31783291ce7 100644 --- a/go/vt/vtgate/evalengine/fn_crypto.go +++ b/go/vt/vtgate/evalengine/fn_crypto.go @@ -48,7 +48,7 @@ func (call *builtinMD5) eval(env *ExpressionEnv) (eval, error) { sum := md5.Sum(b.bytes) buf := make([]byte, hex.EncodedLen(len(sum))) hex.Encode(buf, sum[:]) - return newEvalText(buf, defaultCoercionCollation(call.collate)), nil + return newEvalText(buf, typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinMD5) compile(c *compiler) (ctype, error) { @@ -65,7 +65,7 @@ func (call *builtinMD5) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, sqltypes.Binary, 0, false) } - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) c.asm.Fn_MD5(col) c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag}, nil @@ -91,7 +91,7 @@ func (call *builtinSHA1) eval(env *ExpressionEnv) (eval, error) { sum := sha1.Sum(b.bytes) buf := make([]byte, hex.EncodedLen(len(sum))) hex.Encode(buf, sum[:]) - return newEvalText(buf, defaultCoercionCollation(call.collate)), nil + return newEvalText(buf, typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinSHA1) compile(c *compiler) (ctype, error) { @@ -107,7 +107,7 @@ func (call *builtinSHA1) compile(c *compiler) (ctype, error) { default: c.asm.Convert_xb(1, sqltypes.Binary, 0, false) } - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) c.asm.Fn_SHA1(col) c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag}, nil @@ -153,7 +153,7 @@ func (call *builtinSHA2) eval(env *ExpressionEnv) (eval, error) { buf := make([]byte, hex.EncodedLen(len(sum))) hex.Encode(buf, sum[:]) - return newEvalText(buf, defaultCoercionCollation(call.collate)), nil + return newEvalText(buf, typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinSHA2) compile(c *compiler) (ctype, error) { @@ -186,7 +186,7 @@ func (call *builtinSHA2) compile(c *compiler) (ctype, error) { c.asm.Convert_xi(1) } - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) c.asm.Fn_SHA2(col) c.asm.jumpDestination(skip1, skip2) return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag | flagNullable}, nil diff --git a/go/vt/vtgate/evalengine/fn_hex.go b/go/vt/vtgate/evalengine/fn_hex.go index 0641f7dcf90..8552ab888ae 100644 --- a/go/vt/vtgate/evalengine/fn_hex.go +++ b/go/vt/vtgate/evalengine/fn_hex.go @@ -49,9 +49,9 @@ func (call *builtinHex) eval(env *ExpressionEnv) (eval, error) { encoded = hex.EncodeBytes(arg.ToRawBytes()) } if arg.SQLType() == sqltypes.Blob || arg.SQLType() == sqltypes.TypeJSON { - return newEvalRaw(sqltypes.Text, encoded, defaultCoercionCollation(call.collate)), nil + return newEvalRaw(sqltypes.Text, encoded, typedCoercionCollation(sqltypes.Text, call.collate)), nil } - return newEvalText(encoded, defaultCoercionCollation(call.collate)), nil + return newEvalText(encoded, typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinHex) compile(c *compiler) (ctype, error) { @@ -61,11 +61,11 @@ func (call *builtinHex) compile(c *compiler) (ctype, error) { } skip := c.compileNullCheck1(str) - col := defaultCoercionCollation(c.collation) t := sqltypes.VarChar if str.Type == sqltypes.Blob || str.Type == sqltypes.TypeJSON { t = sqltypes.Text } + col := typedCoercionCollation(t, c.collation) switch { case sqltypes.IsNumber(str.Type): diff --git a/go/vt/vtgate/evalengine/fn_misc.go b/go/vt/vtgate/evalengine/fn_misc.go index 56b49fdfd24..2f228ff55fa 100644 --- a/go/vt/vtgate/evalengine/fn_misc.go +++ b/go/vt/vtgate/evalengine/fn_misc.go @@ -141,7 +141,7 @@ func (call *builtinInetNtoa) eval(env *ExpressionEnv) (eval, error) { } b := binary.BigEndian.AppendUint32(nil, uint32(rawIp)) - return newEvalText(hack.StringBytes(netip.AddrFrom4([4]byte(b)).String()), defaultCoercionCollation(call.collate)), nil + return newEvalText(hack.StringBytes(netip.AddrFrom4([4]byte(b)).String()), typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinInetNtoa) compile(c *compiler) (ctype, error) { @@ -153,11 +153,11 @@ func (call *builtinInetNtoa) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) c.compileToUint64(arg, 1) - col := defaultCoercionCollation(call.collate) + col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_INET_NTOA(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: defaultCoercionCollation(call.collate)}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } func (call *builtinInet6Aton) eval(env *ExpressionEnv) (eval, error) { @@ -241,11 +241,12 @@ func (call *builtinInet6Ntoa) eval(env *ExpressionEnv) (eval, error) { return nil, nil } + col := typedCoercionCollation(sqltypes.VarChar, call.collate) if ip, ok := printIPv6AsIPv4(ip); ok { - return newEvalText(hack.StringBytes("::"+ip.String()), defaultCoercionCollation(call.collate)), nil + return newEvalText(hack.StringBytes("::"+ip.String()), col), nil } - return newEvalText(hack.StringBytes(ip.String()), defaultCoercionCollation(call.collate)), nil + return newEvalText(hack.StringBytes(ip.String()), col), nil } func (call *builtinInet6Ntoa) compile(c *compiler) (ctype, error) { @@ -255,16 +256,16 @@ func (call *builtinInet6Ntoa) compile(c *compiler) (ctype, error) { } skip := c.compileNullCheck1(arg) + col := typedCoercionCollation(sqltypes.VarChar, call.collate) switch arg.Type { case sqltypes.VarBinary, sqltypes.Blob, sqltypes.Binary: - col := defaultCoercionCollation(call.collate) c.asm.Fn_INET6_NTOA(col) default: c.asm.SetNull(1) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: defaultCoercionCollation(call.collate)}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } func (call *builtinIsIPV4) eval(env *ExpressionEnv) (eval, error) { @@ -445,7 +446,7 @@ func (call *builtinBinToUUID) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, errIncorrectUUID(raw, "bin_to_uuid") } - return newEvalText(hack.StringBytes(parsed.String()), defaultCoercionCollation(call.collate)), nil + return newEvalText(hack.StringBytes(parsed.String()), typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (call *builtinBinToUUID) compile(c *compiler) (ctype, error) { @@ -461,7 +462,7 @@ func (call *builtinBinToUUID) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false) } - col := defaultCoercionCollation(call.collate) + col := typedCoercionCollation(sqltypes.VarChar, call.collate) ct := ctype{Type: sqltypes.VarChar, Flag: arg.Flag, Col: col} if len(call.Arguments) == 1 { diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 954aceb66c9..7bdd8d8b92e 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -1342,7 +1342,7 @@ func (call *builtinConv) eval(env *ExpressionEnv) (eval, error) { } else { out = strconv.AppendUint(out, u, int(toBase)) } - return newEvalText(upcaseASCII(out), defaultCoercionCollation(call.collate)), nil + return newEvalText(upcaseASCII(out), typedCoercionCollation(sqltypes.VarChar, call.collate)), nil } func (expr *builtinConv) compile(c *compiler) (ctype, error) { @@ -1383,7 +1383,7 @@ func (expr *builtinConv) compile(c *compiler) (ctype, error) { c.asm.Fn_CONV_bu(3, 2) } - col := defaultCoercionCollation(expr.collate) + col := typedCoercionCollation(t, expr.collate) c.asm.Fn_CONV_uc(t, col) c.asm.jumpDestination(skip) diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 2549572605e..8d61905d237 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -587,7 +587,7 @@ func (call *builtinLeftRight) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck2(str, l) - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) switch { case str.isTextual(): col = str.Col @@ -693,7 +693,7 @@ func (call *builtinPad) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck3(str, l, pad) - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) switch { case str.isTextual(): col = str.Col @@ -880,7 +880,7 @@ func (call builtinTrim) compile(c *compiler) (ctype, error) { skip1 := c.compileNullCheck1(str) - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) switch { case str.isTextual(): col = str.Col @@ -989,7 +989,7 @@ func (call *builtinConcat) eval(env *ExpressionEnv) (eval, error) { // If we only had numbers, we instead fall back to the default // collation instead of using the numeric collation. if tc.Coercibility == collations.CoerceNumeric { - tc = defaultCoercionCollation(call.collate) + tc = typedCoercionCollation(tt, call.collate) } var buf []byte @@ -1041,7 +1041,7 @@ func (call *builtinConcat) compile(c *compiler) (ctype, error) { // If we only had numbers, we instead fall back to the default // collation instead of using the numeric collation. if tc.Coercibility == collations.CoerceNumeric { - tc = defaultCoercionCollation(call.collate) + tc = typedCoercionCollation(tt, call.collate) } for i, arg := range args { @@ -1103,7 +1103,7 @@ func (call *builtinConcatWs) eval(env *ExpressionEnv) (eval, error) { // If we only had numbers, we instead fall back to the default // collation instead of using the numeric collation. if tc.Coercibility == collations.CoerceNumeric { - tc = defaultCoercionCollation(call.collate) + tc = typedCoercionCollation(tt, call.collate) } var sep []byte @@ -1173,7 +1173,7 @@ func (call *builtinConcatWs) compile(c *compiler) (ctype, error) { // If we only had numbers, we instead fall back to the default // collation instead of using the numeric collation. if tc.Coercibility == collations.CoerceNumeric { - tc = defaultCoercionCollation(call.collate) + tc = typedCoercionCollation(tt, call.collate) } for i, arg := range args { diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 1796921b6f6..430b975974b 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -275,7 +275,7 @@ func (b *builtinDateFormat) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - return newEvalText(d, defaultCoercionCollation(b.collate)), nil + return newEvalText(d, typedCoercionCollation(sqltypes.VarChar, b.collate)), nil } func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { @@ -305,7 +305,7 @@ func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false) } - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) c.asm.Fn_DATE_FORMAT(col) c.asm.jumpDestination(skip1, skip2) return ctype{Type: sqltypes.VarChar, Col: col, Flag: arg.Flag | flagNullable}, nil @@ -623,7 +623,7 @@ func (b *builtinFromUnixtime) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - return newEvalText(d, defaultCoercionCollation(b.collate)), nil + return newEvalText(d, typedCoercionCollation(sqltypes.VarChar, b.collate)), nil } func (call *builtinFromUnixtime) compile(c *compiler) (ctype, error) { @@ -676,7 +676,7 @@ func (call *builtinFromUnixtime) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false) } - col := defaultCoercionCollation(c.collation) + col := typedCoercionCollation(sqltypes.VarChar, c.collation) c.asm.Fn_DATE_FORMAT(col) c.asm.jumpDestination(skip1, skip2) return ctype{Type: sqltypes.VarChar, Col: col, Flag: arg.Flag | flagNullable}, nil @@ -1142,7 +1142,7 @@ func (b *builtinMonthName) eval(env *ExpressionEnv) (eval, error) { return nil, nil } - return newEvalText(hack.StringBytes(time.Month(d.dt.Date.Month()).String()), defaultCoercionCollation(b.collate)), nil + return newEvalText(hack.StringBytes(time.Month(d.dt.Date.Month()).String()), typedCoercionCollation(sqltypes.VarChar, b.collate)), nil } func (call *builtinMonthName) compile(c *compiler) (ctype, error) { @@ -1158,7 +1158,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { default: c.asm.Convert_xD(1) } - col := defaultCoercionCollation(call.collate) + col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_MONTHNAME(col) c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Col: col, Flag: arg.Flag | flagNullable}, nil @@ -1596,11 +1596,11 @@ func (call *builtinDateMath) eval(env *ExpressionEnv) (eval, error) { } if tmp, ok := date.(*evalTemporal); ok { - return tmp.addInterval(interval, collations.TypedCollation{}, env.now), nil + return tmp.addInterval(interval, collations.Unknown, env.now), nil } if tmp := evalToTemporal(date); tmp != nil { - return tmp.addInterval(interval, defaultCoercionCollation(call.collate), env.now), nil + return tmp.addInterval(interval, call.collate, env.now), nil } return nil, nil @@ -1634,8 +1634,8 @@ func (call *builtinDateMath) compile(c *compiler) (ctype, error) { c.asm.Fn_DATEADD_D(call.unit, call.sub) default: ret.Type = sqltypes.Char - ret.Col = defaultCoercionCollation(c.collation) - c.asm.Fn_DATEADD_s(call.unit, call.sub, ret.Col) + ret.Col = typedCoercionCollation(sqltypes.Char, c.collation) + c.asm.Fn_DATEADD_s(call.unit, call.sub, ret.Col.Collation) } return ret, nil } diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 57393a97bc0..c8f6f7d1337 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -253,7 +253,7 @@ func translateLiteral(lit *sqlparser.Literal, collation collations.ID) (*Literal case sqlparser.DecimalVal: return NewLiteralDecimalFromBytes(lit.Bytes()) case sqlparser.StrVal: - return NewLiteralString(lit.Bytes(), defaultCoercionCollation(collation)), nil + return NewLiteralString(lit.Bytes(), typedCoercionCollation(sqltypes.VarChar, collation)), nil case sqlparser.HexNum: return NewLiteralBinaryFromHexNum(lit.Bytes()) case sqlparser.HexVal: diff --git a/go/vt/vtgate/planbuilder/operators/dml_planning.go b/go/vt/vtgate/planbuilder/operators/dml_planning.go index 9618c34e21e..8f87a71c95f 100644 --- a/go/vt/vtgate/planbuilder/operators/dml_planning.go +++ b/go/vt/vtgate/planbuilder/operators/dml_planning.go @@ -19,6 +19,8 @@ package operators import ( "fmt" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" @@ -56,7 +58,7 @@ func getVindexInformation(id semantics.TableSet, table *vindexes.Table) ( return primaryVindex, vindexesAndPredicates, nil } -func buildChangedVindexesValues(update *sqlparser.Update, table *vindexes.Table, ksidCols []sqlparser.IdentifierCI, assignments []SetExpr) (vv map[string]*engine.VindexValues, ownedVindexQuery string, subQueriesArgOnChangedVindex []string, err error) { +func buildChangedVindexesValues(ctx *plancontext.PlanningContext, update *sqlparser.Update, table *vindexes.Table, ksidCols []sqlparser.IdentifierCI, assignments []SetExpr) (vv map[string]*engine.VindexValues, ownedVindexQuery string, subQueriesArgOnChangedVindex []string, err error) { changedVindexes := make(map[string]*engine.VindexValues) buf, offset := initialQuery(ksidCols, table) for i, vindex := range table.ColumnVindexes { @@ -73,7 +75,10 @@ func buildChangedVindexesValues(update *sqlparser.Update, table *vindexes.Table, return nil, "", nil, vterrors.VT03015(assignment.Name.Name) } found = true - pv, err := evalengine.Translate(assignment.Expr.EvalExpr, nil) + pv, err := evalengine.Translate(assignment.Expr.EvalExpr, &evalengine.Config{ + ResolveType: ctx.SemTable.TypeForExpr, + Collation: ctx.SemTable.Collation, + }) if err != nil { return nil, "", nil, invalidUpdateExpr(assignment.Name.Name.String(), assignment.Expr.EvalExpr) } diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 29000142e57..a48e53c18b1 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -174,7 +174,7 @@ func createInsertOperator(ctx *plancontext.PlanningContext, insStmt *sqlparser.I } // modify column list or values for autoincrement column. - autoIncGen, err := modifyForAutoinc(insStmt, vTbl) + autoIncGen, err := modifyForAutoinc(ctx, insStmt, vTbl) if err != nil { return nil, err } @@ -186,7 +186,7 @@ func createInsertOperator(ctx *plancontext.PlanningContext, insStmt *sqlparser.I insOp.ColVindexes = getColVindexes(insOp) switch rows := insStmt.Rows.(type) { case sqlparser.Values: - route.Source, err = insertRowsPlan(insOp, insStmt, rows) + route.Source, err = insertRowsPlan(ctx, insOp, insStmt, rows) if err != nil { return nil, err } @@ -277,7 +277,7 @@ func columnMismatch(gen *Generate, ins *sqlparser.Insert, sel sqlparser.SelectSt return false } -func insertRowsPlan(insOp *Insert, ins *sqlparser.Insert, rows sqlparser.Values) (*Insert, error) { +func insertRowsPlan(ctx *plancontext.PlanningContext, insOp *Insert, ins *sqlparser.Insert, rows sqlparser.Values) (*Insert, error) { for _, row := range rows { if len(ins.Columns) != len(row) { return nil, vterrors.VT03006() @@ -300,7 +300,10 @@ func insertRowsPlan(insOp *Insert, ins *sqlparser.Insert, rows sqlparser.Values) routeValues[vIdx][colIdx] = make([]evalengine.Expr, len(rows)) colNum, _ := findOrAddColumn(ins, col) for rowNum, row := range rows { - innerpv, err := evalengine.Translate(row[colNum], nil) + innerpv, err := evalengine.Translate(row[colNum], &evalengine.Config{ + ResolveType: ctx.SemTable.TypeForExpr, + Collation: ctx.SemTable.Collation, + }) if err != nil { return nil, err } @@ -401,7 +404,7 @@ func populateInsertColumnlist(ins *sqlparser.Insert, table *vindexes.Table) *sql // modifyForAutoinc modifies the AST and the plan to generate necessary autoinc values. // For row values cases, bind variable names are generated using baseName. -func modifyForAutoinc(ins *sqlparser.Insert, vTable *vindexes.Table) (*Generate, error) { +func modifyForAutoinc(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, vTable *vindexes.Table) (*Generate, error) { if vTable.AutoIncrement == nil { return nil, nil } @@ -425,7 +428,10 @@ func modifyForAutoinc(ins *sqlparser.Insert, vTable *vindexes.Table) (*Generate, row[colNum] = sqlparser.NewArgument(engine.SeqVarName + strconv.Itoa(rowNum)) } var err error - gen.Values, err = evalengine.Translate(autoIncValues, nil) + gen.Values, err = evalengine.Translate(autoIncValues, &evalengine.Config{ + ResolveType: ctx.SemTable.TypeForExpr, + Collation: ctx.SemTable.Collation, + }) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 31be577913d..1c751467890 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -585,7 +585,10 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) { } // for everything else, we'll turn to the evalengine - eexpr, err := evalengine.Translate(rewritten, nil) + eexpr, err := evalengine.Translate(rewritten, &evalengine.Config{ + ResolveType: ctx.SemTable.TypeForExpr, + Collation: ctx.SemTable.Collation, + }) if err != nil { panic(err) } diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index e55789f40da..079813388b3 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -149,6 +149,7 @@ func generateOwnedVindexQuery(tblExpr sqlparser.TableExpr, del *sqlparser.Delete } func getUpdateVindexInformation( + ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, vindexTable *vindexes.Table, tableID semantics.TableSet, @@ -163,7 +164,7 @@ func getUpdateVindexInformation( return nil, nil, "", nil, err } - changedVindexValues, ownedVindexQuery, subQueriesArgOnChangedVindex, err := buildChangedVindexesValues(updStmt, vindexTable, primaryVindex.Columns, assignments) + changedVindexValues, ownedVindexQuery, subQueriesArgOnChangedVindex, err := buildChangedVindexesValues(ctx, updStmt, vindexTable, primaryVindex.Columns, assignments) if err != nil { return nil, nil, "", nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 52dae6bd9b5..743812f9dd7 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -149,7 +149,7 @@ func createUpdateOperator(ctx *plancontext.PlanningContext, updStmt *sqlparser.U } } - vp, cvv, ovq, subQueriesArgOnChangedVindex, err := getUpdateVindexInformation(updStmt, vindexTable, qt.ID, assignments) + vp, cvv, ovq, subQueriesArgOnChangedVindex, err := getUpdateVindexInformation(ctx, updStmt, vindexTable, qt.ID, assignments) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 1c3b76c92b1..f40ea961334 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -22,7 +22,7 @@ { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "1 DESC, (2|5) ASC", + "OrderBy": "1 DESC COLLATE utf8mb4_0900_ai_ci, (2|5) ASC", "ResultColumns": 4, "Inputs": [ { @@ -233,7 +233,7 @@ "Instructions": { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "1 DESC", + "OrderBy": "1 DESC COLLATE utf8mb4_0900_ai_ci", "ResultColumns": 2, "Inputs": [ { @@ -758,7 +758,7 @@ { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "2 DESC", + "OrderBy": "2 DESC COLLATE utf8mb4_0900_ai_ci", "ResultColumns": 8, "Inputs": [ { @@ -1075,7 +1075,7 @@ { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "1 DESC", + "OrderBy": "1 DESC COLLATE utf8mb4_0900_ai_ci", "Inputs": [ { "OperatorType": "Aggregate", diff --git a/go/vt/vtgate/semantics/dependencies.go b/go/vt/vtgate/semantics/dependencies.go index d93d895c8e3..89b6da7045d 100644 --- a/go/vt/vtgate/semantics/dependencies.go +++ b/go/vt/vtgate/semantics/dependencies.go @@ -68,6 +68,7 @@ func createUncertain(direct TableSet, recursive TableSet) *uncertain { dependency: dependency{ direct: direct, recursive: recursive, + typ: evalengine.UnknownType(), }, } } diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index b43ea49c4d1..625077f4da1 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -58,11 +58,10 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { if !ok { return nil } - var inputType sqltypes.Type + inputType := sqltypes.Unknown if arg := node.GetArg(); arg != nil { - t, ok := t.m[arg] - if ok { - inputType = t.Type + if tt, ok := t.m[arg]; ok { + inputType = tt.Type } } type_ := code.Type(inputType)