Skip to content

Commit

Permalink
evalengine: fix numeric coercibility (vitessio#14473)
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
Signed-off-by: Andres Taylor <[email protected]>
Co-authored-by: Andres Taylor <[email protected]>
  • Loading branch information
vmg and systay authored Nov 6, 2023
1 parent a7f0ead commit a7a44c6
Show file tree
Hide file tree
Showing 28 changed files with 229 additions and 181 deletions.
2 changes: 1 addition & 1 deletion go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func IsNull(t querypb.Type) bool {
// switch statements for those who want to cover types
// by their category.
const (
Unknown = -1
Unknown = querypb.Type(-1)
Null = querypb.Type_NULL_TYPE
Int8 = querypb.Type_INT8
Uint8 = querypb.Type_UINT8
Expand Down
1 change: 1 addition & 0 deletions go/test/endtoend/vtgate/queries/random/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,5 @@ func TestBuggyQueries(t *testing.T) {
mcmp.Exec("select count(tbl1.dname) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.dname > tbl1.loc where tbl1.loc <=> tbl1.dname group by tbl1.dname order by tbl1.dname asc")
mcmp.Exec("select count(*) from (select count(*) from dept as tbl0) as tbl0")
mcmp.Exec("select count(*), count(*) from (select count(*) from dept as tbl0) as tbl0, dept as tbl1")
mcmp.Exec(`select distinct case max(tbl0.ename) when min(tbl0.job) then 'sole' else count(case when false then -27 when 'gazelle' then tbl0.deptno end) end as caggr0 from emp as tbl0`)
}
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,19 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
case AggregateUnassigned:
return sqltypes.Null
case AggregateGroupConcat:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
if sqltypes.IsBinary(typ) {
return sqltypes.Blob
}
return sqltypes.Text
case AggregateMax, AggregateMin, AggregateAnyValue:
return typ
case AggregateSumDistinct, AggregateSum:
if typ == sqltypes.Unknown {
return sqltypes.Unknown
}
if sqltypes.IsIntegral(typ) || sqltypes.IsDecimal(typ) {
return sqltypes.Decimal
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/api_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func NewColumn(offset int, typ Type, original sqlparser.Expr) *Column {
return &Column{
Offset: offset,
Type: typ.Type,
Collation: defaultCoercionCollation(typ.Coll),
Collation: typedCoercionCollation(typ.Type, typ.Coll),
Original: original,
dynamicTypeOffset: -1,
}
Expand Down
115 changes: 109 additions & 6 deletions go/vt/vtgate/evalengine/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,115 @@ limitations under the License.

package evalengine

import "vitess.io/vitess/go/mysql/collations"
import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/collations/colldata"
"vitess.io/vitess/go/sqltypes"
)

func defaultCoercionCollation(id collations.ID) collations.TypedCollation {
return collations.TypedCollation{
Collation: id,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireUnicode,
func typedCoercionCollation(typ sqltypes.Type, id collations.ID) collations.TypedCollation {
switch {
case sqltypes.IsNull(typ):
return collationNull
case sqltypes.IsNumber(typ) || sqltypes.IsDateOrTime(typ):
return collationNumeric
case typ == sqltypes.TypeJSON:
return collationJSON
default:
return collations.TypedCollation{
Collation: id,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireUnicode,
}
}
}

func evalCollation(e eval) collations.TypedCollation {
switch e := e.(type) {
case nil:
return collationNull
case evalNumeric, *evalTemporal:
return collationNumeric
case *evalJSON:
return collationJSON
case *evalBytes:
return e.col
default:
return collationBinary
}
}

func mergeCollations(c1, c2 collations.TypedCollation, t1, t2 sqltypes.Type) (collations.TypedCollation, colldata.Coercion, colldata.Coercion, error) {
if c1.Collation == c2.Collation {
return c1, nil, nil, nil
}

lt := sqltypes.IsText(t1) || sqltypes.IsBinary(t1)
rt := sqltypes.IsText(t2) || sqltypes.IsBinary(t2)
if !lt || !rt {
if lt {
return c1, nil, nil, nil
}
if rt {
return c2, nil, nil, nil
}
return collationBinary, nil, nil, nil
}

env := collations.Local()
return colldata.Merge(env, c1, c2, colldata.CoercionOptions{
ConvertToSuperset: true,
ConvertWithCoercion: true,
})
}

func mergeAndCoerceCollations(left, right eval) (eval, eval, collations.TypedCollation, error) {
lt := left.SQLType()
rt := right.SQLType()

mc, coerceLeft, coerceRight, err := mergeCollations(evalCollation(left), evalCollation(right), lt, rt)
if err != nil {
return nil, nil, collations.TypedCollation{}, err
}
if coerceLeft == nil && coerceRight == nil {
return left, right, mc, nil
}

left1 := newEvalRaw(lt, left.(*evalBytes).bytes, mc)
right1 := newEvalRaw(rt, right.(*evalBytes).bytes, mc)

if coerceLeft != nil {
left1.bytes, err = coerceLeft(nil, left1.bytes)
if err != nil {
return nil, nil, collations.TypedCollation{}, err
}
}
if coerceRight != nil {
right1.bytes, err = coerceRight(nil, right1.bytes)
if err != nil {
return nil, nil, collations.TypedCollation{}, err
}
}
return left1, right1, mc, nil
}

type collationAggregation struct {
cur collations.TypedCollation
}

func (ca *collationAggregation) add(env *collations.Environment, tc collations.TypedCollation) error {
if ca.cur.Collation == collations.Unknown {
ca.cur = tc
} else {
var err error
ca.cur, _, _, err = colldata.Merge(env, ca.cur, tc, colldata.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true})
if err != nil {
return err
}
}
return nil
}

func (ca *collationAggregation) result() collations.TypedCollation {
return ca.cur
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4165,13 +4165,13 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
}

tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.Unknown, env.now)
env.vm.sp--
return 1
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
}

func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.TypedCollation) {
func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.ID) {
asm.adjustStack(-1)
asm.emit(func(env *ExpressionEnv) int {
var interval *datetime.Interval
Expand Down
13 changes: 13 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func TestCompilerSingle(t *testing.T) {
expression string
values []sqltypes.Value
result string
collation collations.ID
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -489,6 +490,12 @@ func TestCompilerSingle(t *testing.T) {
expression: `'2020-01-01' + interval month(date_sub(FROM_UNIXTIME(1234), interval 1 month))-1 month`,
result: `CHAR("2020-12-01")`,
},
{
expression: `case column0 when 1 then column1 else column2 end`,
values: []sqltypes.Value{sqltypes.NewInt64(42), sqltypes.NewVarChar("sole"), sqltypes.NewInt64(0)},
result: `VARCHAR("0")`,
collation: collations.CollationUtf8mb4ID,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -524,6 +531,9 @@ func TestCompilerSingle(t *testing.T) {
if expected.String() != tc.result {
t.Fatalf("bad evaluation from eval engine: got %s, want %s", expected.String(), tc.result)
}
if tc.collation != collations.Unknown && tc.collation != expected.Collation() {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
Expand All @@ -535,6 +545,9 @@ func TestCompilerSingle(t *testing.T) {
if res.String() != tc.result {
t.Errorf("bad evaluation from compiler: got %s, want %s (iteration %d)", res, tc.result, i)
}
if tc.collation != collations.Unknown && tc.collation != res.Collation() {
t.Fatalf("bad collation evaluation from compiler: got %d, want %d", res.Collation(), tc.collation)
}
}
})
}
Expand Down
20 changes: 10 additions & 10 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
fval, _ := fastparse.ParseFloat64(v.RawStr())
return newEvalFloat(fval), nil
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -265,7 +265,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
fval, _ := fastparse.ParseFloat64(v.RawStr())
dec = decimal.NewFromFloat(fval)
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -285,7 +285,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
i, err := fastparse.ParseInt64(v.RawStr(), 10)
return newEvalInt64(i), err
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -304,7 +304,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
u, err := fastparse.ParseUint64(v.RawStr(), 10)
return newEvalUint64(u), err
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -315,15 +315,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
case sqltypes.IsText(typ) || sqltypes.IsBinary(typ):
switch {
case v.IsText() || v.IsBinary():
return newEvalRaw(v.Type(), v.Raw(), defaultCoercionCollation(collation)), nil
return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil
case sqltypes.IsText(typ):
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
return evalToVarchar(e, collation, true)
default:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -333,7 +333,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
case typ == sqltypes.TypeJSON:
return json.NewFromSQL(v)
case typ == sqltypes.Date:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -344,7 +344,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
}
return d, nil
case typ == sqltypes.Datetime || typ == sqltypes.Timestamp:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand All @@ -355,7 +355,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
}
return dt, nil
case typ == sqltypes.Time:
e, err := valueToEval(v, defaultCoercionCollation(collation))
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (e *evalTemporal) isZero() bool {
return e.dt.IsZero()
}

func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval {
func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval {
var tmp *evalTemporal
var ok bool

Expand All @@ -150,16 +150,16 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio
tmp.dt.Date, ok = e.dt.Date.AddInterval(interval)
case tt == sqltypes.Time && !interval.Unit().HasDateParts():
tmp = &evalTemporal{t: e.t}
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, coll != collations.Unknown)
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
tmp = e.toDateTime(int(e.prec), now)
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, coll != collations.Unknown)
}
if !ok {
return nil
}
if strcoll.Valid() {
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), strcoll)
if coll != collations.Unknown {
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), typedCoercionCollation(sqltypes.Char, coll))
}
return tmp
}
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {

tuple := make([]eval, 0, len(bvar.Values))
for _, value := range bvar.Values {
e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.CollationForType(value.Type, bv.Collation)))
e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation)))
if err != nil {
return nil, err
}
Expand All @@ -86,7 +86,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {
if bv.typed() {
typ = bv.Type
}
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.CollationForType(typ, bv.Collation)))
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)))
}
}

Expand All @@ -110,7 +110,7 @@ func (bv *BindVariable) typeof(env *ExpressionEnv) (ctype, error) {
case sqltypes.BitNum:
return ctype{Type: sqltypes.VarBinary, Flag: flagBit, Col: collationNumeric}, nil
default:
return ctype{Type: tt, Flag: 0, Col: defaultCoercionCollation(collations.CollationForType(tt, bv.Collation))}, nil
return ctype{Type: tt, Flag: 0, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
}
}

Expand All @@ -119,7 +119,7 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) {

if bvar.typed() {
typ.Type = bvar.Type
typ.Col = defaultCoercionCollation(collations.CollationForType(bvar.Type, bvar.Collation))
typ.Col = typedCoercionCollation(bvar.Type, collations.CollationForType(bvar.Type, bvar.Collation))
} else if c.dynamicTypes != nil {
typ = c.dynamicTypes[bvar.dynamicTypeOffset]
} else {
Expand Down
Loading

0 comments on commit a7a44c6

Please sign in to comment.