From a15ef42929c5d44bd78c960619f672bd50976198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vicent=20Mart=C3=AD?= <42793+vmg@users.noreply.github.com> Date: Tue, 7 Nov 2023 11:41:12 +0100 Subject: [PATCH] Tiny Weights (#14402) Signed-off-by: Vicent Marti --- go/hack/runtime.go | 7 +- go/mysql/collations/colldata/8bit.go | 18 ++ go/mysql/collations/colldata/collation.go | 8 + go/mysql/collations/colldata/uca.go | 23 ++ go/mysql/collations/colldata/uca_test.go | 56 ++++ go/mysql/fastparse/fastparse.go | 2 +- go/sqltypes/bind_variables.go | 4 +- go/sqltypes/parse_rows.go | 10 +- go/sqltypes/result.go | 10 +- go/sqltypes/testing.go | 2 +- go/sqltypes/value.go | 119 +++++--- go/vt/vtgate/engine/aggregations.go | 20 +- go/vt/vtgate/engine/cached_size.go | 12 +- go/vt/vtgate/engine/comparer.go | 78 ----- go/vt/vtgate/engine/comparer_test.go | 114 ------- go/vt/vtgate/engine/distinct.go | 4 +- go/vt/vtgate/engine/fake_vcursor_test.go | 12 +- go/vt/vtgate/engine/memory_sort.go | 108 +------ go/vt/vtgate/engine/memory_sort_test.go | 22 +- go/vt/vtgate/engine/merge_sort.go | 131 ++------ go/vt/vtgate/engine/merge_sort_test.go | 16 +- go/vt/vtgate/engine/ordered_aggregate.go | 11 +- go/vt/vtgate/engine/route.go | 73 +---- go/vt/vtgate/engine/route_test.go | 22 +- go/vt/vtgate/evalengine/api_compare.go | 288 +++++++++++++++++- go/vt/vtgate/evalengine/api_compare_test.go | 70 +++++ go/vt/vtgate/evalengine/weights.go | 130 ++++++++ go/vt/vtgate/evalengine/weights_test.go | 74 +++++ go/vt/vtgate/executor_framework_test.go | 38 ++- go/vt/vtgate/executor_select_test.go | 51 ++-- .../planbuilder/operator_transformers.go | 13 +- go/vt/vttablet/tabletmanager/vdiff/utils.go | 4 +- go/vt/wrangler/vdiff.go | 4 +- 33 files changed, 939 insertions(+), 615 deletions(-) delete mode 100644 go/vt/vtgate/engine/comparer.go delete mode 100644 go/vt/vtgate/engine/comparer_test.go diff --git a/go/hack/runtime.go b/go/hack/runtime.go index 724a6c34f8d..5f6b946e33d 100644 --- a/go/hack/runtime.go +++ b/go/hack/runtime.go @@ -52,8 +52,11 @@ func RuntimeAllocSize(size int64) int64 { return int64(roundupsize(uintptr(size))) } -//go:linkname ParseFloatPrefix strconv.parseFloatPrefix -func ParseFloatPrefix(s string, bitSize int) (float64, int, error) +//go:linkname Atof64 strconv.atof64 +func Atof64(s string) (float64, int, error) + +//go:linkname Atof32 strconv.atof32 +func Atof32(s string) (float32, int, error) //go:linkname FastRand runtime.fastrand func FastRand() uint32 diff --git a/go/mysql/collations/colldata/8bit.go b/go/mysql/collations/colldata/8bit.go index 2355888bbab..67ae8541d56 100644 --- a/go/mysql/collations/colldata/8bit.go +++ b/go/mysql/collations/colldata/8bit.go @@ -17,6 +17,8 @@ limitations under the License. package colldata import ( + "encoding/binary" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/vt/vthash" @@ -168,6 +170,16 @@ func (c *Collation_8bit_simple_ci) Collate(left, right []byte, rightIsPrefix boo return len(left) - len(right) } +func (c *Collation_8bit_simple_ci) TinyWeightString(src []byte) uint32 { + var w32 [4]byte + sortOrder := c.sort + sortLen := min(4, len(src)) + for i := 0; i < sortLen; i++ { + w32[i] = sortOrder[src[i]] + } + return binary.BigEndian.Uint32(w32[:4]) +} + func (c *Collation_8bit_simple_ci) WeightString(dst, src []byte, numCodepoints int) []byte { padToMax := false sortOrder := c.sort @@ -272,6 +284,12 @@ func (c *Collation_binary) Collate(left, right []byte, isPrefix bool) int { return collationBinary(left, right, isPrefix) } +func (c *Collation_binary) TinyWeightString(src []byte) uint32 { + var w32 [4]byte + copy(w32[:4], src) + return binary.BigEndian.Uint32(w32[:4]) +} + func (c *Collation_binary) WeightString(dst, src []byte, numCodepoints int) []byte { padToMax := false copyCodepoints := len(src) diff --git a/go/mysql/collations/colldata/collation.go b/go/mysql/collations/colldata/collation.go index ec66fc09b58..7697c08cbed 100644 --- a/go/mysql/collations/colldata/collation.go +++ b/go/mysql/collations/colldata/collation.go @@ -155,6 +155,14 @@ type CaseAwareCollation interface { ToLower(dst []byte, src []byte) []byte } +// TinyWeightCollation implements the TinyWeightString API for collations. +type TinyWeightCollation interface { + Collation + // TinyWeightString returns a 32-bit weight string for a source string based on this collation. + // This is usually the 4-byte prefix of the full weight string, calculated more efficiently. + TinyWeightString(src []byte) uint32 +} + func Lookup(id collations.ID) Collation { if int(id) >= len(collationsById) { return nil diff --git a/go/mysql/collations/colldata/uca.go b/go/mysql/collations/colldata/uca.go index 4b7272bfbc3..a59dbf024bd 100644 --- a/go/mysql/collations/colldata/uca.go +++ b/go/mysql/collations/colldata/uca.go @@ -18,6 +18,7 @@ package colldata import ( "bytes" + "encoding/binary" "math/bits" "vitess.io/vitess/go/mysql/collations" @@ -119,6 +120,28 @@ nextLevel: return int(l) - int(r) } +func (c *Collation_utf8mb4_uca_0900) TinyWeightString(src []byte) uint32 { + it := c.uca.Iterator(src) + defer it.Done() + + if fast, ok := it.(*uca.FastIterator900); ok { + var chunk [16]byte + fast.NextWeightBlock64(chunk[:16]) + return binary.BigEndian.Uint32(chunk[:4]) + } + + var w32 uint32 + w, ok := it.Next() + if ok { + w32 = uint32(w) << 16 + w, ok = it.Next() + if ok { + w32 |= uint32(w) + } + } + return w32 +} + func (c *Collation_utf8mb4_uca_0900) WeightString(dst, src []byte, numCodepoints int) []byte { it := c.uca.Iterator(src) defer it.Done() diff --git a/go/mysql/collations/colldata/uca_test.go b/go/mysql/collations/colldata/uca_test.go index 70c9312636e..e00fb5fd6d1 100644 --- a/go/mysql/collations/colldata/uca_test.go +++ b/go/mysql/collations/colldata/uca_test.go @@ -805,6 +805,62 @@ func TestCompareWithWeightString(t *testing.T) { } } +func TestTinyWeightStrings(t *testing.T) { + var Collations = []Collation{ + testcollation(t, "utf8mb4_0900_as_cs"), + testcollation(t, "utf8mb4_0900_as_ci"), + testcollation(t, "utf8mb4_0900_ai_ci"), + } + + var Strings = []string{ + "a", "A", "aa", "AA", "aaa", "AAA", "aaaa", "AAAA", + "b", "B", "BB", "BB", "bbb", "BBB", "bbbb", "BBBB", + "Abc", "aBC", + "ǍḄÇ", "ÁḆĈ", + "\uA73A", "\uA738", + "\uAC00", "\u326E", + ExampleString, + ExampleStringLong, + JapaneseString, + WhitespaceString, + HungarianString, + JapaneseString2, + ChineseString, + ChineseString2, + SpanishString, + EnglishString, + } + + for _, coll := range Collations { + tw := coll.(TinyWeightCollation) + + for _, a := range Strings { + aw := tw.TinyWeightString([]byte(a)) + + for _, b := range Strings { + bw := tw.TinyWeightString([]byte(b)) + cmp := tw.Collate([]byte(a), []byte(b), false) + + switch { + case cmp == 0: + if aw != bw { + t.Errorf("[%s] %q vs %q: should be equal, got %08x / %08x", coll.Name(), a, b, aw, bw) + } + case cmp < 0: + if aw > bw { + t.Errorf("[%s] %q vs %q: should be <=, got %08x / %08x", coll.Name(), a, b, aw, bw) + } + case cmp > 0: + if aw < bw { + t.Errorf("[%s] %q vs %q: should be >= got %08x / %08x", coll.Name(), a, b, aw, bw) + } + } + } + } + } + +} + func TestFastIterators(t *testing.T) { allASCIICharacters := make([]byte, 128) for n := range allASCIICharacters { diff --git a/go/mysql/fastparse/fastparse.go b/go/mysql/fastparse/fastparse.go index 33aa16105c2..f9aca692abd 100644 --- a/go/mysql/fastparse/fastparse.go +++ b/go/mysql/fastparse/fastparse.go @@ -234,7 +234,7 @@ func ParseFloat64(s string) (float64, error) { // We only care to parse as many of the initial float characters of the // string as possible. This functionality is implemented in the `strconv` package // of the standard library, but not exposed, so we hook into it. - val, l, err := hack.ParseFloatPrefix(s[ws:], 64) + val, l, err := hack.Atof64(s[ws:]) for l < len(s[ws:]) { if !isSpace(s[ws+uint(l)]) { break diff --git a/go/sqltypes/bind_variables.go b/go/sqltypes/bind_variables.go index 041730ec517..18beda37702 100644 --- a/go/sqltypes/bind_variables.go +++ b/go/sqltypes/bind_variables.go @@ -49,7 +49,7 @@ func TupleToProto(v []Value) *querypb.Value { // ValueToProto converts Value to a *querypb.Value. func ValueToProto(v Value) *querypb.Value { - return &querypb.Value{Type: v.typ, Value: v.val} + return &querypb.Value{Type: v.Type(), Value: v.val} } // ProtoToValue converts a *querypb.Value to a Value. @@ -143,7 +143,7 @@ func BytesBindVariable(v []byte) *querypb.BindVariable { // ValueBindVariable converts a Value to a bind var. func ValueBindVariable(v Value) *querypb.BindVariable { - return &querypb.BindVariable{Type: v.typ, Value: v.val} + return &querypb.BindVariable{Type: v.Type(), Value: v.val} } // BuildBindVariable builds a *querypb.BindVariable from a valid input type. diff --git a/go/sqltypes/parse_rows.go b/go/sqltypes/parse_rows.go index 2654141ed3b..5e1db627c8b 100644 --- a/go/sqltypes/parse_rows.go +++ b/go/sqltypes/parse_rows.go @@ -19,7 +19,7 @@ package sqltypes import ( "fmt" "io" - "reflect" + "slices" "strconv" "strings" "text/scanner" @@ -127,6 +127,12 @@ func (e *RowMismatchError) Error() string { return fmt.Sprintf("results differ: %v\n\twant: %v\n\tgot: %v", e.err, e.want, e.got) } +func RowEqual(want, got Row) bool { + return slices.EqualFunc(want, got, func(a, b Value) bool { + return a.Equal(b) + }) +} + func RowsEquals(want, got []Row) error { if len(want) != len(got) { return &RowMismatchError{ @@ -143,7 +149,7 @@ func RowsEquals(want, got []Row) error { if matched[i] { continue } - if reflect.DeepEqual(aa, bb) { + if RowEqual(aa, bb) { matched[i] = true ok = true break diff --git a/go/sqltypes/result.go b/go/sqltypes/result.go index 7c04e1d89fa..389b7fff620 100644 --- a/go/sqltypes/result.go +++ b/go/sqltypes/result.go @@ -19,7 +19,7 @@ package sqltypes import ( "crypto/sha256" "fmt" - "reflect" + "slices" "google.golang.org/protobuf/proto" @@ -69,8 +69,8 @@ func (result *Result) Repair(fields []*querypb.Field) { // Usage of j is intentional. for j, f := range fields { for _, r := range result.Rows { - if r[j].typ != Null { - r[j].typ = f.Type + if r[j].Type() != Null { + r[j].typ = uint16(f.Type) } } } @@ -198,7 +198,9 @@ func (result *Result) Equal(other *Result) bool { return FieldsEqual(result.Fields, other.Fields) && result.RowsAffected == other.RowsAffected && result.InsertID == other.InsertID && - reflect.DeepEqual(result.Rows, other.Rows) + slices.EqualFunc(result.Rows, other.Rows, func(a, b Row) bool { + return RowEqual(a, b) + }) } // ResultsEqual compares two arrays of Result. diff --git a/go/sqltypes/testing.go b/go/sqltypes/testing.go index 9042acf6680..3894635eae0 100644 --- a/go/sqltypes/testing.go +++ b/go/sqltypes/testing.go @@ -147,7 +147,7 @@ func TestValue(typ querypb.Type, val string) Value { // This function should only be used for testing. func TestTuple(vals ...Value) Value { return Value{ - typ: Tuple, + typ: uint16(Tuple), val: encodeTuple(vals), } } diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 331c494710e..45415814700 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -18,6 +18,7 @@ limitations under the License. package sqltypes import ( + "bytes" "encoding/base64" "encoding/hex" "encoding/json" @@ -52,6 +53,11 @@ var ( ErrIncompatibleTypeCast = errors.New("Cannot convert value to desired type") ) +const ( + // flagTinyWeight marks this Value as having a Tiny Weight String + flagTinyWeight = 0x1 +) + type ( // BinWriter interface is used for encoding values. // Types like bytes.Buffer conform to this interface. @@ -65,7 +71,15 @@ type ( // an integral type, the bytes are always stored as a canonical // representation that matches how MySQL returns such values. Value struct { - typ querypb.Type + // typ is the value's sqltypes.Type (this always fits in 16 bits) + typ uint16 + // flags are the flags set for this Value; right now this field is only + // used to track whether tinyweight is set + flags uint16 + // tinyweight is a weight string prefix for this Value. + // See: evalengine.TinyWeighter + tinyweight uint32 + // val is the raw byte representation of this Value val []byte } @@ -114,7 +128,7 @@ func MakeTrusted(typ querypb.Type, val []byte) Value { if typ == Null { return NULL } - return Value{typ: typ, val: val} + return Value{typ: uint16(typ), val: val} } // NewHexNum builds an Hex Value. @@ -271,7 +285,7 @@ func InterfaceToValue(goval any) (Value, error) { // Type returns the type of Value. func (v Value) Type() querypb.Type { - return v.typ + return querypb.Type(v.typ) } // Raw returns the internal representation of the value. For newer types, @@ -292,7 +306,7 @@ func (v Value) RawStr() string { // match MySQL's representation for hex encoded binary data or newer types. // If the value is not convertible like in the case of Expression, it returns an error. func (v Value) ToBytes() ([]byte, error) { - switch v.typ { + switch v.Type() { case Expression: return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "expression cannot be converted to bytes") case HexVal: // TODO: all the decode below have problem when decoding odd number of bytes. This needs to be fixed. @@ -345,7 +359,7 @@ func (v Value) ToInt() (int, error) { // ToFloat64 returns the value as MySQL would return it as a float64. func (v Value) ToFloat64() (float64, error) { - if !IsNumber(v.typ) { + if !IsNumber(v.Type()) { return 0, ErrIncompatibleTypeCast } @@ -403,7 +417,7 @@ func (v Value) ToBool() (bool, error) { // ToString returns the value as MySQL would return it as string. // If the value is not convertible like in the case of Expression, it returns nil. func (v Value) ToString() string { - if v.typ == Expression { + if v.Type() == Expression { return "" } return hack.String(v.val) @@ -411,23 +425,23 @@ func (v Value) ToString() string { // String returns a printable version of the value. func (v Value) String() string { - if v.typ == Null { + if v.Type() == Null { return "NULL" } - if v.IsQuoted() || v.typ == Bit { - return fmt.Sprintf("%v(%q)", v.typ, v.val) + if v.IsQuoted() || v.Type() == Bit { + return fmt.Sprintf("%v(%q)", Type(v.typ), v.val) } - return fmt.Sprintf("%v(%s)", v.typ, v.val) + return fmt.Sprintf("%v(%s)", Type(v.typ), v.val) } // EncodeSQL encodes the value into an SQL statement. Can be binary. func (v Value) EncodeSQL(b BinWriter) { switch { - case v.typ == Null: + case v.Type() == Null: b.Write(NullBytes) case v.IsQuoted(): encodeBytesSQL(v.val, b) - case v.typ == Bit: + case v.Type() == Bit: encodeBytesSQLBits(v.val, b) default: b.Write(v.val) @@ -438,13 +452,13 @@ func (v Value) EncodeSQL(b BinWriter) { // as its writer, so it can be inlined for performance. func (v Value) EncodeSQLStringBuilder(b *strings.Builder) { switch { - case v.typ == Null: + case v.Type() == Null: b.Write(NullBytes) case v.IsQuoted(): encodeBytesSQLStringBuilder(v.val, b) - case v.typ == Bit: + case v.Type() == Bit: encodeBytesSQLBits(v.val, b) - case v.typ == Tuple: + case v.Type() == Tuple: b.WriteByte('(') var i int _ = v.ForEachValue(func(bv Value) { @@ -464,11 +478,11 @@ func (v Value) EncodeSQLStringBuilder(b *strings.Builder) { // as its writer, so it can be inlined for performance. func (v Value) EncodeSQLBytes2(b *bytes2.Buffer) { switch { - case v.typ == Null: + case v.Type() == Null: b.Write(NullBytes) case v.IsQuoted(): encodeBytesSQLBytes2(v.val, b) - case v.typ == Bit: + case v.Type() == Bit: encodeBytesSQLBits(v.val, b) default: b.Write(v.val) @@ -478,9 +492,9 @@ func (v Value) EncodeSQLBytes2(b *bytes2.Buffer) { // EncodeASCII encodes the value using 7-bit clean ascii bytes. func (v Value) EncodeASCII(b BinWriter) { switch { - case v.typ == Null: + case v.Type() == Null: b.Write(NullBytes) - case v.IsQuoted() || v.typ == Bit: + case v.IsQuoted() || v.Type() == Bit: encodeBytesASCII(v.val, b) default: b.Write(v.val) @@ -489,75 +503,75 @@ func (v Value) EncodeASCII(b BinWriter) { // IsNull returns true if Value is null. func (v Value) IsNull() bool { - return v.typ == Null + return v.Type() == Null } // IsIntegral returns true if Value is an integral. func (v Value) IsIntegral() bool { - return IsIntegral(v.typ) + return IsIntegral(v.Type()) } // IsSigned returns true if Value is a signed integral. func (v Value) IsSigned() bool { - return IsSigned(v.typ) + return IsSigned(v.Type()) } // IsUnsigned returns true if Value is an unsigned integral. func (v Value) IsUnsigned() bool { - return IsUnsigned(v.typ) + return IsUnsigned(v.Type()) } // IsFloat returns true if Value is a float. func (v Value) IsFloat() bool { - return IsFloat(v.typ) + return IsFloat(v.Type()) } // IsQuoted returns true if Value must be SQL-quoted. func (v Value) IsQuoted() bool { - return IsQuoted(v.typ) + return IsQuoted(v.Type()) } // IsText returns true if Value is a collatable text. func (v Value) IsText() bool { - return IsText(v.typ) + return IsText(v.Type()) } // IsBinary returns true if Value is binary. func (v Value) IsBinary() bool { - return IsBinary(v.typ) + return IsBinary(v.Type()) } // IsDateTime returns true if Value is datetime. func (v Value) IsDateTime() bool { - return v.typ == querypb.Type_DATETIME + return v.Type() == querypb.Type_DATETIME } // IsTimestamp returns true if Value is date. func (v Value) IsTimestamp() bool { - return v.typ == querypb.Type_TIMESTAMP + return v.Type() == querypb.Type_TIMESTAMP } // IsDate returns true if Value is date. func (v Value) IsDate() bool { - return v.typ == querypb.Type_DATE + return v.Type() == querypb.Type_DATE } // IsTime returns true if Value is time. func (v Value) IsTime() bool { - return v.typ == querypb.Type_TIME + return v.Type() == querypb.Type_TIME } // IsDecimal returns true if Value is a decimal. func (v Value) IsDecimal() bool { - return IsDecimal(v.typ) + return IsDecimal(v.Type()) } // IsComparable returns true if the Value is null safe comparable without collation information. func (v *Value) IsComparable() bool { - if v.typ == Null || IsNumber(v.typ) || IsBinary(v.typ) { + if v.Type() == Null || IsNumber(v.Type()) || IsBinary(v.Type()) { return true } - switch v.typ { + switch v.Type() { case Timestamp, Date, Time, Datetime, Enum, Set, TypeJSON, Bit: return true } @@ -568,9 +582,9 @@ func (v *Value) IsComparable() bool { // It's not a complete implementation. func (v Value) MarshalJSON() ([]byte, error) { switch { - case v.IsQuoted() || v.typ == Bit: + case v.IsQuoted() || v.Type() == Bit: return json.Marshal(v.ToString()) - case v.typ == Null: + case v.Type() == Null: return NullBytes, nil } return v.val, nil @@ -670,7 +684,7 @@ func encodeTuple(tuple []Value) []byte { } func (v *Value) ForEachValue(each func(bv Value)) error { - if v.typ != Tuple { + if v.Type() != Tuple { panic("Value.ForEachValue on non-tuple") } @@ -690,13 +704,42 @@ func (v *Value) ForEachValue(each func(bv Value)) error { } buf = buf[varlen:] - each(Value{val: buf[:sz], typ: Type(ty)}) + each(Value{val: buf[:sz], typ: uint16(ty)}) buf = buf[sz:] } return nil } +// Equal compares this Value to other. It ignores any flags set. +func (v Value) Equal(other Value) bool { + return v.typ == other.typ && bytes.Equal(v.val, other.val) +} + +// SetTinyWeight sets this Value's tiny weight string +func (v *Value) SetTinyWeight(w uint32) { + v.tinyweight = w + v.flags |= flagTinyWeight +} + +// TinyWeightCmp performs a fast comparison of this Value with other if both have a Tiny Weight String set. +// For any 2 instances of Value: if both instances have a Tiny Weight string, +// and the weight strings are **different**, the two values will sort accordingly to the 32-bit +// numerical sort of their tiny weight strings. Otherwise, the relative sorting of the two values +// will not be known, and they will require a full sort using e.g. evalengine.NullsafeCompare +// See: evalengine.TinyWeighter +func (v Value) TinyWeightCmp(other Value) int { + // both values need a tinyweight; otherwise the comparison is invalid + if v.flags&other.flags&flagTinyWeight == 0 { + return 0 + } + return int(int64(v.tinyweight) - int64(other.tinyweight)) +} + +func (v Value) TinyWeight() uint32 { + return v.tinyweight +} + func encodeBytesSQL(val []byte, b BinWriter) { buf := &bytes2.Buffer{} encodeBytesSQLBytes2(val, buf) diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index 33b80faab55..dd7a259d1b6 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -108,16 +108,20 @@ type aggregatorDistinct struct { func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { if a.column >= 0 { - if !a.last.IsNull() { - cmp, err := evalengine.NullsafeCompare(a.last, row[a.column], a.coll) - if err != nil { - return true, err - } - if cmp == 0 { - return true, nil + last := a.last + next := row[a.column] + if !last.IsNull() { + if last.TinyWeightCmp(next) == 0 { + cmp, err := evalengine.NullsafeCompare(last, next, a.coll) + if err != nil { + return true, err + } + if cmp == 0 { + return true, nil + } } } - a.last = row[a.column] + a.last = next } return false, nil } diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 5e035a2e479..b70f83b192d 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -579,9 +579,9 @@ func (cached *MemorySort) CachedSize(alloc bool) int64 { if cc, ok := cached.UpperLimit.(cachedObject); ok { size += cc.CachedSize(true) } - // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams + // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(39)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(27)) } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Input.(cachedObject); ok { @@ -606,9 +606,9 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { } } } - // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams + // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(39)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(27)) } return size } @@ -799,9 +799,9 @@ func (cached *Route) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.TableName))) // field FieldQuery string size += hack.RuntimeAllocSize(int64(len(cached.FieldQuery))) - // field OrderBy []vitess.io/vitess/go/vt/vtgate/engine.OrderByParams + // field OrderBy vitess.io/vitess/go/vt/vtgate/evalengine.Comparison { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(39)) + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(27)) } // field RoutingParameters *vitess.io/vitess/go/vt/vtgate/engine.RoutingParameters size += cached.RoutingParameters.CachedSize(true) diff --git a/go/vt/vtgate/engine/comparer.go b/go/vt/vtgate/engine/comparer.go deleted file mode 100644 index 591b1cf2be0..00000000000 --- a/go/vt/vtgate/engine/comparer.go +++ /dev/null @@ -1,78 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package engine - -import ( - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/vtgate/evalengine" -) - -// comparer is the struct that has the logic for comparing two rows in the result set -type comparer struct { - orderBy, weightString, starColFixedIndex int - collationID collations.ID - desc bool -} - -// compare compares two rows given the comparer and returns which one should be earlier in the result set -// -1 if the first row should be earlier -// 1 is the second row should be earlier -// 0 if both the rows have equal ordering -func (c *comparer) compare(r1, r2 []sqltypes.Value) (int, error) { - var colIndex int - if c.starColFixedIndex > c.orderBy && c.starColFixedIndex < len(r1) { - colIndex = c.starColFixedIndex - } else { - colIndex = c.orderBy - } - cmp, err := evalengine.NullsafeCompare(r1[colIndex], r2[colIndex], c.collationID) - if err != nil { - _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) - _, isCollationErr := err.(evalengine.UnsupportedCollationError) - if !isComparisonErr && !isCollationErr || c.weightString == -1 { - return 0, err - } - // in case of a comparison or collation error switch to using the weight string column for ordering - c.orderBy = c.weightString - c.weightString = -1 - cmp, err = evalengine.NullsafeCompare(r1[c.orderBy], r2[c.orderBy], c.collationID) - if err != nil { - return 0, err - } - } - // change the result if descending ordering is required - if c.desc { - cmp = -cmp - } - return cmp, nil -} - -// extractSlices extracts the three fields of OrderByParams into a slice of comparers -func extractSlices(input []OrderByParams) []*comparer { - var result []*comparer - for _, order := range input { - result = append(result, &comparer{ - orderBy: order.Col, - weightString: order.WeightStringCol, - desc: order.Desc, - starColFixedIndex: order.StarColFixedIndex, - collationID: order.Type.Coll, - }) - } - return result -} diff --git a/go/vt/vtgate/engine/comparer_test.go b/go/vt/vtgate/engine/comparer_test.go deleted file mode 100644 index c1be2c25e82..00000000000 --- a/go/vt/vtgate/engine/comparer_test.go +++ /dev/null @@ -1,114 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package engine - -import ( - "strconv" - "testing" - - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/sqltypes" -) - -func TestComparer(t *testing.T) { - tests := []struct { - comparer comparer - row1 []sqltypes.Value - row2 []sqltypes.Value - output int - }{ - { - comparer: comparer{ - orderBy: 0, - weightString: -1, - desc: true, - }, - row1: []sqltypes.Value{ - sqltypes.NewInt64(23), - }, - row2: []sqltypes.Value{ - sqltypes.NewInt64(34), - }, - output: 1, - }, { - comparer: comparer{ - orderBy: 0, - weightString: -1, - desc: false, - }, - row1: []sqltypes.Value{ - sqltypes.NewInt64(23), - }, - row2: []sqltypes.Value{ - sqltypes.NewInt64(23), - }, - output: 0, - }, { - comparer: comparer{ - orderBy: 0, - weightString: -1, - desc: false, - }, - row1: []sqltypes.Value{ - sqltypes.NewInt64(23), - }, - row2: []sqltypes.Value{ - sqltypes.NewInt64(12), - }, - output: 1, - }, { - comparer: comparer{ - orderBy: 1, - weightString: 0, - desc: false, - }, - row1: []sqltypes.Value{ - sqltypes.NewInt64(23), - sqltypes.NewVarChar("b"), - }, - row2: []sqltypes.Value{ - sqltypes.NewInt64(34), - sqltypes.NewVarChar("a"), - }, - output: -1, - }, { - comparer: comparer{ - orderBy: 1, - weightString: 0, - desc: true, - }, - row1: []sqltypes.Value{ - sqltypes.NewInt64(23), - sqltypes.NewVarChar("A"), - }, - row2: []sqltypes.Value{ - sqltypes.NewInt64(23), - sqltypes.NewVarChar("a"), - }, - output: 0, - }, - } - - for i, test := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - got, err := test.comparer.compare(test.row1, test.row2) - require.NoError(t, err) - require.Equal(t, test.output, got) - }) - } -} diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 477e0803c1b..cd6b93a9f32 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -139,8 +139,8 @@ func (pt *probeTable) equal(a, b sqltypes.Row) (bool, error) { for i, checkCol := range pt.checkCols { cmp, err := evalengine.NullsafeCompare(a[i], b[i], checkCol.Type.Coll) if err != nil { - _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) - if !isComparisonErr || checkCol.WsCol == nil { + _, isCollErr := err.(evalengine.UnsupportedCollationError) + if !isCollErr || checkCol.WsCol == nil { return false, err } checkCol = checkCol.SwitchToWeightString() diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index c2418c73560..6c99af33313 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -808,8 +808,16 @@ func (t *noopVCursor) GetLogs() ([]ExecuteEntry, error) { func expectResult(t *testing.T, msg string, result, want *sqltypes.Result) { t.Helper() - if !reflect.DeepEqual(result, want) { - t.Errorf("%s:\n%v\nwant:\n%v", msg, result, want) + fieldsResult := fmt.Sprintf("%v", result.Fields) + fieldsWant := fmt.Sprintf("%v", want.Fields) + if fieldsResult != fieldsWant { + t.Errorf("%s (mismatch in Fields):\n%s\nwant:\n%s", msg, fieldsResult, fieldsWant) + } + + rowsResult := fmt.Sprintf("%v", result.Rows) + rowsWant := fmt.Sprintf("%v", want.Rows) + if rowsResult != rowsWant { + t.Errorf("%s (mismatch in Rows):\n%s\nwant:\n%s", msg, rowsResult, rowsWant) } } diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index b1770225211..b896b303923 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -17,19 +17,16 @@ limitations under the License. package engine import ( - "container/heap" "context" "fmt" "math" "reflect" - "sort" "strconv" "strings" - "vitess.io/vitess/go/vt/vtgate/evalengine" - "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) var _ Primitive = (*MemorySort)(nil) @@ -37,7 +34,7 @@ var _ Primitive = (*MemorySort)(nil) // MemorySort is a primitive that performs in-memory sorting. type MemorySort struct { UpperLimit evalengine.Expr - OrderBy []OrderByParams + OrderBy evalengine.Comparison Input Primitive // TruncateColumnCount specifies the number of columns to return @@ -77,15 +74,10 @@ func (ms *MemorySort) TryExecute(ctx context.Context, vcursor VCursor, bindVars if err != nil { return nil, err } - sh := &sortHeap{ - rows: result.Rows, - comparers: extractSlices(ms.OrderBy), - } - sort.Sort(sh) - if sh.err != nil { - return nil, sh.err + + if err = ms.OrderBy.SortResult(result); err != nil { + return nil, err } - result.Rows = sh.rows if len(result.Rows) > count { result.Rows = result.Rows[:count] } @@ -93,7 +85,9 @@ func (ms *MemorySort) TryExecute(ctx context.Context, vcursor VCursor, bindVars } // TryStreamExecute satisfies the Primitive interface. -func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { +func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) (err error) { + defer evalengine.PanicHandler(&err) + count, err := ms.fetchCount(ctx, vcursor, bindVars) if err != nil { return err @@ -103,11 +97,9 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin return callback(qr.Truncate(ms.TruncateColumnCount)) } - // You have to reverse the ordering because the highest values - // must be dropped once the upper limit is reached. - sh := &sortHeap{ - comparers: extractSlices(ms.OrderBy), - reverse: true, + sorter := &evalengine.Sorter{ + Compare: ms.OrderBy, + Limit: count, } err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { @@ -116,14 +108,9 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin } } for _, row := range qr.Rows { - heap.Push(sh, row) - // Remove the highest element from the heap if the size is more than the count - // This optimization means that the maximum size of the heap is going to be (count + 1) - for len(sh.rows) > count { - _ = heap.Pop(sh) - } + sorter.Push(row) } - if vcursor.ExceedsMaxMemoryRows(len(sh.rows)) { + if vcursor.ExceedsMaxMemoryRows(sorter.Len()) { return fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows()) } return nil @@ -131,17 +118,7 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin if err != nil { return err } - if sh.err != nil { - return sh.err - } - // Set ordering to normal for the final ordering. - sh.reverse = false - sort.Sort(sh) - if sh.err != nil { - // Unreachable. - return sh.err - } - return cb(&sqltypes.Result{Rows: sh.rows}) + return cb(&sqltypes.Result{Rows: sorter.Sorted()}) } // GetFields satisfies the Primitive interface. @@ -194,7 +171,8 @@ func (ms *MemorySort) description() PrimitiveDescription { } func orderByParamsToString(i any) string { - return i.(OrderByParams).String() + obp := i.(evalengine.OrderByParams) + return obp.String() } // GenericJoin will iterate over arrays, slices or maps, and executes the f function to get a @@ -216,57 +194,3 @@ func GenericJoin(input any, f func(any) string) string { } return strings.Join(keys, ", ") } - -// sortHeap is sorted based on the orderBy params. -// Implementation is similar to scatterHeap -type sortHeap struct { - rows [][]sqltypes.Value - comparers []*comparer - reverse bool - err error -} - -// Len satisfies sort.Interface and heap.Interface. -func (sh *sortHeap) Len() int { - return len(sh.rows) -} - -// Less satisfies sort.Interface and heap.Interface. -func (sh *sortHeap) Less(i, j int) bool { - for _, c := range sh.comparers { - if sh.err != nil { - return true - } - cmp, err := c.compare(sh.rows[i], sh.rows[j]) - if err != nil { - sh.err = err - return true - } - if cmp == 0 { - continue - } - if sh.reverse { - cmp = -cmp - } - return cmp < 0 - } - return true -} - -// Swap satisfies sort.Interface and heap.Interface. -func (sh *sortHeap) Swap(i, j int) { - sh.rows[i], sh.rows[j] = sh.rows[j], sh.rows[i] -} - -// Push satisfies heap.Interface. -func (sh *sortHeap) Push(x any) { - sh.rows = append(sh.rows, x.([]sqltypes.Value)) -} - -// Pop satisfies heap.Interface. -func (sh *sortHeap) Pop() any { - n := len(sh.rows) - x := sh.rows[n-1] - sh.rows = sh.rows[:n-1] - return x -} diff --git a/go/vt/vtgate/engine/memory_sort_test.go b/go/vt/vtgate/engine/memory_sort_test.go index 2c73d49e74b..bc9369c57af 100644 --- a/go/vt/vtgate/engine/memory_sort_test.go +++ b/go/vt/vtgate/engine/memory_sort_test.go @@ -54,7 +54,7 @@ func TestMemorySortExecute(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 1, }}, @@ -107,7 +107,7 @@ func TestMemorySortStreamExecuteWeightString(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: 0, Col: 1, }}, @@ -173,7 +173,7 @@ func TestMemorySortExecuteWeightString(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: 1, Col: 0, }}, @@ -227,7 +227,7 @@ func TestMemorySortStreamExecuteCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ Col: 0, Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }}, @@ -315,7 +315,7 @@ func TestMemorySortExecuteCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ Col: 0, Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }}, @@ -368,7 +368,7 @@ func TestMemorySortStreamExecute(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 1, }}, @@ -445,7 +445,7 @@ func TestMemorySortExecuteTruncate(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 1, }}, @@ -484,7 +484,7 @@ func TestMemorySortStreamExecuteTruncate(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 1, }}, @@ -527,7 +527,7 @@ func TestMemorySortMultiColumn(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ Col: 1, WeightStringCol: -1, }, { @@ -600,7 +600,7 @@ func TestMemorySortMaxMemoryRows(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 1, }}, @@ -636,7 +636,7 @@ func TestMemorySortExecuteNoVarChar(t *testing.T) { } ms := &MemorySort{ - OrderBy: []OrderByParams{{ + OrderBy: []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, }}, diff --git a/go/vt/vtgate/engine/merge_sort.go b/go/vt/vtgate/engine/merge_sort.go index 6c694ae9e37..3c26a383594 100644 --- a/go/vt/vtgate/engine/merge_sort.go +++ b/go/vt/vtgate/engine/merge_sort.go @@ -17,11 +17,11 @@ limitations under the License. package engine import ( - "container/heap" "context" "io" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -49,7 +49,7 @@ var _ Primitive = (*MergeSort)(nil) // so that vdiff can use it. In that situation, only StreamExecute is used. type MergeSort struct { Primitives []StreamExecutor - OrderBy []OrderByParams + OrderBy evalengine.Comparison ScatterErrorsAsWarnings bool noInputs noTxNeeded @@ -75,7 +75,9 @@ func (ms *MergeSort) GetFields(ctx context.Context, vcursor VCursor, bindVars ma } // TryStreamExecute performs a streaming exec. -func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { +func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) (err error) { + defer evalengine.PanicHandler(&err) + var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) defer cancel() @@ -90,22 +92,22 @@ func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bind } } + merge := &evalengine.Merger{ + Compare: ms.OrderBy, + } + if wantfields { - err := ms.getStreamingFields(handles, callback) + fields, err := ms.getStreamingFields(handles) if err != nil { return err } - } - - comparers := extractSlices(ms.OrderBy) - sh := &scatterHeap{ - rows: make([]streamRow, 0, len(handles)), - comparers: comparers, + if err := callback(&sqltypes.Result{Fields: fields}); err != nil { + return err + } } var errs []error - // Prime the heap. One element must be pulled from - // each stream. + // Prime the heap. One element must be pulled from each stream. for i, handle := range handles { select { case row, ok := <-handle.row: @@ -121,49 +123,38 @@ func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bind // If so, don't add anything to the heap. continue } - sh.rows = append(sh.rows, streamRow{row: row, id: i}) + merge.Push(row, i) case <-ctx.Done(): return ctx.Err() } } - heap.Init(sh) - if sh.err != nil { - return sh.err - } + merge.Init() // Iterate one row at a time: // Pop a row from the heap and send it out. // Then pull the next row from the stream the popped // row came from and push it into the heap. - for len(sh.rows) != 0 { - sr := heap.Pop(sh).(streamRow) - if sh.err != nil { - // Unreachable: This should never fail. - return sh.err - } - if err := callback(&sqltypes.Result{Rows: [][]sqltypes.Value{sr.row}}); err != nil { + for merge.Len() != 0 { + row, stream := merge.Pop() + if err := callback(&sqltypes.Result{Rows: [][]sqltypes.Value{row}}); err != nil { return err } select { - case row, ok := <-handles[sr.id].row: + case row, ok := <-handles[stream].row: if !ok { - if handles[sr.id].err != nil { - return handles[sr.id].err + if handles[stream].err != nil { + return handles[stream].err } continue } - sr.row = row - heap.Push(sh, sr) - if sh.err != nil { - return sh.err - } + merge.Push(row, stream) case <-ctx.Done(): return ctx.Err() } } - err := vterrors.Aggregate(errs) + err = vterrors.Aggregate(errs) if err != nil && ms.ScatterErrorsAsWarnings && len(errs) < len(handles) { // we got errors, but not all shards failed, so we can hide the error and just warn instead partialSuccessScatterQueries.Add(1) @@ -174,7 +165,7 @@ func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bind return err } -func (ms *MergeSort) getStreamingFields(handles []*streamHandle, callback func(*sqltypes.Result) error) error { +func (ms *MergeSort) getStreamingFields(handles []*streamHandle) ([]*querypb.Field, error) { var fields []*querypb.Field if ms.ScatterErrorsAsWarnings { @@ -193,20 +184,16 @@ func (ms *MergeSort) getStreamingFields(handles []*streamHandle, callback func(* if fields == nil { // something went wrong. need to figure out where the error can be if !ms.ScatterErrorsAsWarnings { - return handles[0].err + return nil, handles[0].err } var errs []error for _, handle := range handles { errs = append(errs, handle.err) } - return vterrors.Aggregate(errs) - } - - if err := callback(&sqltypes.Result{Fields: fields}); err != nil { - return err + return nil, vterrors.Aggregate(errs) } - return nil + return fields, nil } func (ms *MergeSort) description() PrimitiveDescription { @@ -266,65 +253,3 @@ func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bi return handle } - -// A streamRow represents a row identified by the stream -// it came from. It is used as an element in scatterHeap. -type streamRow struct { - row []sqltypes.Value - id int -} - -// scatterHeap is the heap that is used for merge-sorting. -// You can push streamRow elements into it. Popping an -// element will return the one with the lowest value -// as defined by the orderBy criteria. If a comparison -// yielded an error, err is set. This must be checked -// after every heap operation. -type scatterHeap struct { - rows []streamRow - err error - comparers []*comparer -} - -// Len satisfies sort.Interface and heap.Interface. -func (sh *scatterHeap) Len() int { - return len(sh.rows) -} - -// Less satisfies sort.Interface and heap.Interface. -func (sh *scatterHeap) Less(i, j int) bool { - for _, c := range sh.comparers { - if sh.err != nil { - return true - } - // First try to compare the columns that we want to order - cmp, err := c.compare(sh.rows[i].row, sh.rows[j].row) - if err != nil { - sh.err = err - return true - } - if cmp == 0 { - continue - } - return cmp < 0 - } - return true -} - -// Swap satisfies sort.Interface and heap.Interface. -func (sh *scatterHeap) Swap(i, j int) { - sh.rows[i], sh.rows[j] = sh.rows[j], sh.rows[i] -} - -// Push satisfies heap.Interface. -func (sh *scatterHeap) Push(x any) { - sh.rows = append(sh.rows, x.(streamRow)) -} - -// Pop satisfies heap.Interface. -func (sh *scatterHeap) Pop() any { - n := len(sh.rows) - x := sh.rows[n-1] - sh.rows = sh.rows[:n-1] - return x -} diff --git a/go/vt/vtgate/engine/merge_sort_test.go b/go/vt/vtgate/engine/merge_sort_test.go index be370c0e86b..803c70ca463 100644 --- a/go/vt/vtgate/engine/merge_sort_test.go +++ b/go/vt/vtgate/engine/merge_sort_test.go @@ -60,7 +60,7 @@ func TestMergeSortNormal(t *testing.T) { "8|h", ), }} - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, }} @@ -118,7 +118,7 @@ func TestMergeSortWeightString(t *testing.T) { "8|h", ), }} - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ WeightStringCol: 0, Col: 1, }} @@ -180,7 +180,7 @@ func TestMergeSortCollation(t *testing.T) { }} collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ Col: 0, Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }} @@ -240,7 +240,7 @@ func TestMergeSortDescending(t *testing.T) { "4|d", ), }} - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, Desc: true, @@ -291,7 +291,7 @@ func TestMergeSortEmptyResults(t *testing.T) { }, { results: sqltypes.MakeTestStreamingResults(idColFields), }} - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, }} @@ -319,7 +319,7 @@ func TestMergeSortEmptyResults(t *testing.T) { // TestMergeSortResultFailures tests failures at various // stages of result return. func TestMergeSortResultFailures(t *testing.T) { - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, }} @@ -365,7 +365,7 @@ func TestMergeSortDataFailures(t *testing.T) { "2.1|b", ), }} - orderBy := []OrderByParams{{ + orderBy := []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, }} @@ -391,7 +391,7 @@ func TestMergeSortDataFailures(t *testing.T) { require.EqualError(t, err, want) } -func testMergeSort(shardResults []*shardResult, orderBy []OrderByParams, callback func(qr *sqltypes.Result) error) error { +func testMergeSort(shardResults []*shardResult, orderBy []evalengine.OrderByParams, callback func(qr *sqltypes.Result) error) error { prims := make([]StreamExecutor, 0, len(shardResults)) for _, sr := range shardResults { prims = append(prims, sr) diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 61a3140aa27..07ea06fa5fd 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -342,11 +342,16 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n } for _, gb := range oa.GroupByKeys { - cmp, err := evalengine.NullsafeCompare(currentKey[gb.KeyCol], nextRow[gb.KeyCol], gb.Type.Coll) + v1 := currentKey[gb.KeyCol] + v2 := nextRow[gb.KeyCol] + if v1.TinyWeightCmp(v2) != 0 { + return nextRow, true, nil + } + + cmp, err := evalengine.NullsafeCompare(v1, v2, gb.Type.Coll) if err != nil { - _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) _, isCollationErr := err.(evalengine.UnsupportedCollationError) - if !isComparisonErr && !isCollationErr || gb.WeightStringCol == -1 { + if !isCollationErr || gb.WeightStringCol == -1 { return nil, false, err } gb.KeyCol = gb.WeightStringCol diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 8dc7572e296..30713f45f91 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -20,13 +20,10 @@ import ( "context" "fmt" "math/rand" - "slices" "sort" - "strconv" "strings" "time" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" @@ -71,7 +68,7 @@ type Route struct { // OrderBy specifies the key order for merge sorting. This will be // set only for scatter queries that need the results to be // merge-sorted. - OrderBy []OrderByParams + OrderBy evalengine.Comparison // TruncateColumnCount specifies the number of columns to return // in the final result. Rest of the columns are truncated @@ -112,41 +109,6 @@ func NewRoute(opcode Opcode, keyspace *vindexes.Keyspace, query, fieldQuery stri } } -// OrderByParams specifies the parameters for ordering. -// This is used for merge-sorting scatter queries. -type OrderByParams struct { - Col int - // WeightStringCol is the weight_string column that will be used for sorting. - // It is set to -1 if such a column is not added to the query - WeightStringCol int - Desc bool - StarColFixedIndex int - - // Type for knowing if the collation is relevant - Type evalengine.Type -} - -// String returns a string. Used for plan descriptions -func (obp OrderByParams) String() string { - val := strconv.Itoa(obp.Col) - if obp.StarColFixedIndex > obp.Col { - val = strconv.Itoa(obp.StarColFixedIndex) - } - if obp.WeightStringCol != -1 && obp.WeightStringCol != obp.Col { - val = fmt.Sprintf("(%s|%d)", val, obp.WeightStringCol) - } - if obp.Desc { - val += " DESC" - } else { - val += " ASC" - } - - if sqltypes.IsText(obp.Type.Type) && obp.Type.Coll != collations.Unknown { - val += " COLLATE " + collations.Local().LookupName(obp.Type.Coll) - } - return val -} - var ( partialSuccessScatterQueries = stats.NewCounter("PartialSuccessScatterQueries", "Count of partially successful scatter queries") ) @@ -422,37 +384,15 @@ func (route *Route) GetFields(ctx context.Context, vcursor VCursor, bindVars map } func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) { - var err error // Since Result is immutable, we make a copy. // The copy can be shallow because we won't be changing // the contents of any row. out := in.ShallowCopy() - comparers := extractSlices(route.OrderBy) - - slices.SortFunc(out.Rows, func(a, b sqltypes.Row) int { - var cmp int - if err != nil { - return -1 - } - // If there are any errors below, the function sets - // the external err and returns true. Once err is set, - // all subsequent calls return true. This will make - // Slice think that all elements are in the correct - // order and return more quickly. - for _, c := range comparers { - cmp, err = c.compare(a, b) - if err != nil { - return -1 - } - if cmp != 0 { - return cmp - } - } - return 0 - }) - - return out.Truncate(route.TruncateColumnCount), err + if err := route.OrderBy.SortResult(out); err != nil { + return nil, err + } + return out.Truncate(route.TruncateColumnCount), nil } func (route *Route) description() PrimitiveDescription { @@ -590,7 +530,8 @@ func getQueries(query string, bvs []map[string]*querypb.BindVariable) []*querypb } func orderByToString(in any) string { - return in.(OrderByParams).String() + obp := in.(evalengine.OrderByParams) + return obp.String() } func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, queries []*querypb.BoundQuery) { diff --git a/go/vt/vtgate/engine/route_test.go b/go/vt/vtgate/engine/route_test.go index 58e6fb4a9f1..274ac58c7d4 100644 --- a/go/vt/vtgate/engine/route_test.go +++ b/go/vt/vtgate/engine/route_test.go @@ -296,7 +296,7 @@ func TestSelectNone(t *testing.T) { result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) require.Empty(t, vc.log) - expectResult(t, "sel.StreamExecute", result, nil) + require.Nil(t, result) vc.Rewind() @@ -452,7 +452,7 @@ func TestSelectEqualNoRoute(t *testing.T) { `Execute select from, toc from lkp where from in ::from from: type:TUPLE values:{type:INT64 value:"1"} false`, `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationNone()`, }) - expectResult(t, "sel.StreamExecute", result, nil) + require.Nil(t, result) // test with special no-routes handling sel.NoRoutesSpecialHandling = true @@ -888,7 +888,7 @@ func TestRouteSort(t *testing.T) { "dummy_select", "dummy_select_field", ) - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 0, WeightStringCol: -1, }} @@ -970,7 +970,7 @@ func TestRouteSortWeightStrings(t *testing.T) { "dummy_select", "dummy_select_field", ) - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 1, WeightStringCol: 0, }} @@ -1036,7 +1036,7 @@ func TestRouteSortWeightStrings(t *testing.T) { }) t.Run("Error when no weight string set", func(t *testing.T) { - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 1, WeightStringCol: -1, }} @@ -1075,7 +1075,7 @@ func TestRouteSortCollation(t *testing.T) { collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci") - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 0, Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID}, }} @@ -1141,7 +1141,7 @@ func TestRouteSortCollation(t *testing.T) { }) t.Run("Error when Unknown Collation", func(t *testing.T) { - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 0, Type: evalengine.UnknownType(), }} @@ -1167,7 +1167,7 @@ func TestRouteSortCollation(t *testing.T) { }) t.Run("Error when Unsupported Collation", func(t *testing.T) { - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 0, Type: evalengine.Type{Coll: 1111}, }} @@ -1203,7 +1203,7 @@ func TestRouteSortTruncate(t *testing.T) { "dummy_select", "dummy_select_field", ) - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 0, }} sel.TruncateColumnCount = 1 @@ -1294,7 +1294,7 @@ func TestRouteStreamSortTruncate(t *testing.T) { "dummy_select", "dummy_select_field", ) - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ Col: 0, }} sel.TruncateColumnCount = 1 @@ -1437,7 +1437,7 @@ func TestExecFail(t *testing.T) { vc.Rewind() vc.resultErr = sqlerror.NewSQLError(sqlerror.ERQueryInterrupted, "", "query timeout -20") // test when there is order by column - sel.OrderBy = []OrderByParams{{ + sel.OrderBy = []evalengine.OrderByParams{{ WeightStringCol: -1, Col: 0, }} diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index d05e86a12bb..1f30a17b9d5 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -19,25 +19,17 @@ package evalengine import ( "bytes" "fmt" + "slices" + "strconv" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) -// UnsupportedComparisonError represents the error where the comparison between the two types is unsupported on vitess -type UnsupportedComparisonError struct { - Type1 sqltypes.Type - Type2 sqltypes.Type -} - -// Error function implements the error interface -func (err UnsupportedComparisonError) Error() string { - return fmt.Sprintf("types are not comparable: %v vs %v", err.Type1, err.Type2) -} - // UnsupportedCollationError represents the error where the comparison using provided collation is unsupported on vitess type UnsupportedCollationError struct { ID collations.ID @@ -171,3 +163,277 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err } return compare(v1, v2, collationID) } + +// OrderByParams specifies the parameters for ordering. +// This is used for merge-sorting scatter queries. +type ( + OrderByParams struct { + Col int + // WeightStringCol is the weight_string column that will be used for sorting. + // It is set to -1 if such a column is not added to the query + WeightStringCol int + Desc bool + + // Type for knowing if the collation is relevant + Type Type + } + + Comparison []OrderByParams + + tinyWeighter struct { + col int + apply func(v *sqltypes.Value) + } +) + +// String returns a string. Used for plan descriptions +func (obp *OrderByParams) String() string { + val := strconv.Itoa(obp.Col) + if obp.WeightStringCol != -1 && obp.WeightStringCol != obp.Col { + val = fmt.Sprintf("(%s|%d)", val, obp.WeightStringCol) + } + if obp.Desc { + val += " DESC" + } else { + val += " ASC" + } + + if sqltypes.IsText(obp.Type.Type) && obp.Type.Coll != collations.Unknown { + val += " COLLATE " + collations.Local().LookupName(obp.Type.Coll) + } + return val +} + +func (obp *OrderByParams) Compare(r1, r2 []sqltypes.Value) int { + v1 := r1[obp.Col] + v2 := r2[obp.Col] + cmp := v1.TinyWeightCmp(v2) + + if cmp == 0 { + var err error + cmp, err = NullsafeCompare(v1, v2, obp.Type.Coll) + if err != nil { + _, isCollationErr := err.(UnsupportedCollationError) + if !isCollationErr || obp.WeightStringCol == -1 { + panic(err) + } + // in case of a comparison or collation error switch to using the weight string column for ordering + obp.Col = obp.WeightStringCol + obp.WeightStringCol = -1 + cmp, err = NullsafeCompare(r1[obp.Col], r2[obp.Col], obp.Type.Coll) + if err != nil { + panic(err) + } + } + } + // change the result if descending ordering is required + if obp.Desc { + cmp = -cmp + } + return cmp +} + +func (cmp Comparison) tinyWeighters(fields []*querypb.Field) []tinyWeighter { + weights := make([]tinyWeighter, 0, len(cmp)) + for _, c := range cmp { + if apply := TinyWeighter(fields[c.Col], c.Type.Coll); apply != nil { + weights = append(weights, tinyWeighter{c.Col, apply}) + } + } + return weights +} + +func (cmp Comparison) ApplyTinyWeights(out *sqltypes.Result) { + weights := cmp.tinyWeighters(out.Fields) + if len(weights) == 0 { + return + } + + for _, row := range out.Rows { + for _, w := range weights { + w.apply(&row[w.col]) + } + } +} + +func (cmp Comparison) Compare(a, b sqltypes.Row) int { + for _, c := range cmp { + if cmp := c.Compare(a, b); cmp != 0 { + return cmp + } + } + return 0 +} + +func (cmp Comparison) Less(a, b sqltypes.Row) bool { + for _, c := range cmp { + if cmp := c.Compare(a, b); cmp != 0 { + return cmp < 0 + } + } + return false +} + +func (cmp Comparison) More(a, b sqltypes.Row) bool { + for _, c := range cmp { + if cmp := c.Compare(a, b); cmp != 0 { + return cmp > 0 + } + } + return false +} + +func PanicHandler(err *error) { + if r := recover(); r != nil { + badness, ok := r.(error) + if !ok { + panic(r) + } + + *err = badness + } +} + +func (cmp Comparison) SortResult(out *sqltypes.Result) (err error) { + defer PanicHandler(&err) + cmp.ApplyTinyWeights(out) + cmp.Sort(out.Rows) + return +} + +func (cmp Comparison) Sort(out []sqltypes.Row) { + slices.SortFunc(out, func(a, b sqltypes.Row) int { + return cmp.Compare(a, b) + }) +} + +type Sorter struct { + Compare Comparison + Limit int + + rows []sqltypes.Row + heap bool +} + +func (s *Sorter) Len() int { + return len(s.rows) +} + +func (s *Sorter) Push(row sqltypes.Row) { + if len(s.rows) < s.Limit { + s.rows = append(s.rows, row) + return + } + if !s.heap { + heapify(s.rows, s.Compare.More) + s.heap = true + } + if s.Compare.Compare(s.rows[0], row) < 0 { + return + } + s.rows[0] = row + fix(s.rows, 0, s.Compare.More) +} + +func (s *Sorter) Sorted() []sqltypes.Row { + if !s.heap { + s.Compare.Sort(s.rows) + return s.rows + } + + h := s.rows + end := len(h) + for end > 1 { + end = end - 1 + h[end], h[0] = h[0], h[end] + down(h[:end], 0, s.Compare.More) + } + return h +} + +type mergeRow struct { + row sqltypes.Row + source int +} + +type Merger struct { + Compare Comparison + + rows []mergeRow + less func(a, b mergeRow) bool +} + +func (m *Merger) Len() int { + return len(m.rows) +} + +func (m *Merger) Init() { + m.less = func(a, b mergeRow) bool { + return m.Compare.Less(a.row, b.row) + } + heapify(m.rows, m.less) +} + +func (m *Merger) Push(row sqltypes.Row, source int) { + m.rows = append(m.rows, mergeRow{row, source}) + if m.less != nil { + up(m.rows, len(m.rows)-1, m.less) + } +} + +func (m *Merger) Pop() (sqltypes.Row, int) { + x := m.rows[0] + m.rows[0] = m.rows[len(m.rows)-1] + m.rows = m.rows[:len(m.rows)-1] + down(m.rows, 0, m.less) + return x.row, x.source +} + +func heapify[T any](h []T, less func(a, b T) bool) { + n := len(h) + for i := n/2 - 1; i >= 0; i-- { + down(h, i, less) + } +} + +func fix[T any](h []T, i int, less func(a, b T) bool) { + if !down(h, i, less) { + up(h, i, less) + } +} + +func down[T any](h []T, i0 int, less func(a, b T) bool) bool { + i := i0 + for { + left, right := 2*i+1, 2*i+2 + if left >= len(h) || left < 0 { // `left < 0` in case of overflow + break + } + + // find the smallest child + j := left + if right < len(h) && less(h[right], h[left]) { + j = right + } + + if !less(h[j], h[i]) { + break + } + + h[i], h[j] = h[j], h[i] + i = j + } + return i > i0 +} + +func up[T any](h []T, i int, less func(a, b T) bool) { + for { + parent := (i - 1) / 2 + if i == 0 || !less(h[i], h[parent]) { + break + } + + h[i], h[parent] = h[parent], h[i] + i = parent + } +} diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index b44d735de22..3f97d9d18e9 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -19,10 +19,14 @@ package evalengine import ( "context" "fmt" + "math" + "math/rand" + "slices" "strings" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" @@ -1325,3 +1329,69 @@ func BenchmarkNullSafeComparison(b *testing.B) { } }) } + +func TestCompareSorter(t *testing.T) { + var cases = []struct { + Count int + Limit int + Random sqltypes.RandomGenerator + Cmp Comparison + }{ + { + Count: 100, + Limit: 10, + Random: sqltypes.RandomGenerators[sqltypes.Int64], + Cmp: Comparison{{Col: 0, Desc: false, Type: Type{Type: sqltypes.Int64}}}, + }, + { + Count: 100, + Limit: 10, + Random: sqltypes.RandomGenerators[sqltypes.Int64], + Cmp: Comparison{{Col: 0, Desc: true, Type: Type{Type: sqltypes.Int64}}}, + }, + { + Count: 100, + Limit: math.MaxInt, + Random: sqltypes.RandomGenerators[sqltypes.Int64], + Cmp: Comparison{{Col: 0, Desc: false, Type: Type{Type: sqltypes.Int64}}}, + }, + { + Count: 100, + Limit: math.MaxInt, + Random: sqltypes.RandomGenerators[sqltypes.Int64], + Cmp: Comparison{{Col: 0, Desc: true, Type: Type{Type: sqltypes.Int64}}}, + }, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s-%d-%d", tc.Cmp[0].Type.Type, tc.Count, tc.Limit), func(t *testing.T) { + unsorted := make([]sqltypes.Row, 0, tc.Count) + for i := 0; i < tc.Count; i++ { + unsorted = append(unsorted, []sqltypes.Value{tc.Random()}) + } + rand.Shuffle(len(unsorted), func(i, j int) { + unsorted[i], unsorted[j] = unsorted[j], unsorted[i] + }) + + want := slices.Clone(unsorted) + tc.Cmp.Sort(want) + if len(want) > tc.Limit { + want = want[:tc.Limit] + } + + sorter := &Sorter{Compare: tc.Cmp, Limit: tc.Limit} + for _, v := range unsorted { + sorter.Push(v) + } + + sorted := sorter.Sorted() + assert.Equal(t, len(want), len(sorted)) + for i := 0; i < len(want); i++ { + if !sqltypes.RowEqual(want[i], sorted[i]) { + t.Fatalf("row %d is not sorted.\nwant: %v\ngot: %v", i, want, sorted) + } + } + }) + } + +} diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go index 08ec844f357..fa7fa7e11a6 100644 --- a/go/vt/vtgate/evalengine/weights.go +++ b/go/vt/vtgate/evalengine/weights.go @@ -20,12 +20,14 @@ import ( "encoding/binary" "math" + "vitess.io/vitess/go/hack" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/mysql/json" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) @@ -176,3 +178,131 @@ func evalWeightString(dst []byte, e eval, length, precision int) ([]byte, bool, return dst, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", e.SQLType()) } + +// TinyWeighter returns a callback to apply a Tiny Weight string to a sqltypes.Value. +// A tiny weight string is a compressed 4-byte representation of the value's full weight string that +// sorts identically to its full weight. Obviously, the tiny weight string can collide because +// it's represented in fewer bytes than the full one. +// Hence, for any 2 instances of sqltypes.Value: if both instances have a Tiny Weight string, +// and the weight strings are **different**, the two values will sort accordingly to the 32-bit +// numerical sort of their tiny weight strings. Otherwise, the relative sorting of the two values +// will not be known, and they will require a full sort using e.g. NullsafeCompare. +func TinyWeighter(f *querypb.Field, collation collations.ID) func(v *sqltypes.Value) { + switch { + case sqltypes.IsNull(f.Type): + return nil + + case sqltypes.IsSigned(f.Type): + return func(v *sqltypes.Value) { + i, err := v.ToInt64() + if err != nil { + return + } + // The full weight string for an integer is just its MSB bit-inverted 64 bit representation. + // However, we only have 4 bytes to work with here, so in order to minimize the amount + // of collisions for the tiny weight string, instead of grabbing the top 32 bits of the + // 64 bit representation, we're going to cast to float32. Floats are sortable once bit-inverted, + // and although they cannot represent the full 64-bit range (duh!), that's perfectly fine + // because close-by numbers will collide into the same tiny weight, allowing us to fall back + // to a full comparison. + raw := math.Float32bits(float32(i)) + if i < 0 { + raw = ^raw + } else { + raw = raw ^ (1 << 31) + } + v.SetTinyWeight(raw) + } + + case sqltypes.IsUnsigned(f.Type): + return func(v *sqltypes.Value) { + u, err := v.ToUint64() + if err != nil { + return + } + // See comment for the IsSigned block. No bit-inversion is required here as all floats will be positive. + v.SetTinyWeight(math.Float32bits(float32(u))) + } + + case sqltypes.IsFloat(f.Type): + return func(v *sqltypes.Value) { + fl, err := v.ToFloat64() + if err != nil { + return + } + // Similarly as the IsSigned block, we could take the top 32 bits of the float64 bit representation, + // but by down-sampling to a float32 we reduce the amount of collisions. + raw := math.Float32bits(float32(fl)) + if math.Signbit(fl) { + raw = ^raw + } else { + raw = raw ^ (1 << 31) + } + v.SetTinyWeight(raw) + } + + case sqltypes.IsBinary(f.Type): + return func(v *sqltypes.Value) { + if v.IsNull() { + return + } + + var w32 [4]byte + copy(w32[:4], v.Raw()) + v.SetTinyWeight(binary.BigEndian.Uint32(w32[:4])) + } + + case sqltypes.IsText(f.Type): + if coll := colldata.Lookup(collation); coll != nil { + if twcoll, ok := coll.(colldata.TinyWeightCollation); ok { + return func(v *sqltypes.Value) { + if v.IsNull() { + return + } + v.SetTinyWeight(twcoll.TinyWeightString(v.Raw())) + } + } + } + return nil + + case sqltypes.IsDecimal(f.Type): + return func(v *sqltypes.Value) { + if v.IsNull() { + return + } + // To generate a 32-bit weight string of the decimal, we'll just attempt a fast 32bit atof parse + // of its contents. This can definitely fail for many corner cases, but that's OK: we'll just fall + // back to a full decimal comparison in those cases. + fl, _, err := hack.Atof32(v.RawStr()) + if err != nil { + return + } + raw := math.Float32bits(fl) + if raw&(1<<31) != 0 { + raw = ^raw + } else { + raw = raw ^ (1 << 31) + } + v.SetTinyWeight(raw) + } + + case f.Type == sqltypes.TypeJSON: + return func(v *sqltypes.Value) { + if v.IsNull() { + return + } + j, err := json.NewFromSQL(*v) + if err != nil { + return + } + var w32 [4]byte + // TODO: this can be done more efficiently without having to calculate the full weight string and + // extracting its prefix. + copy(w32[:4], j.WeightString(nil)) + v.SetTinyWeight(binary.BigEndian.Uint32(w32[:4])) + } + + default: + return nil + } +} diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index 50a1d91f20c..0dee4c72d03 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -25,8 +25,82 @@ import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) +func TestTinyWeightStrings(t *testing.T) { + const Length = 10000 + + var cases = []struct { + typ sqltypes.Type + gen func() sqltypes.Value + col collations.ID + len int + prec int + }{ + {typ: sqltypes.Int32, gen: sqltypes.RandomGenerators[sqltypes.Int32], col: collations.CollationBinaryID}, + {typ: sqltypes.Int64, gen: sqltypes.RandomGenerators[sqltypes.Int64], col: collations.CollationBinaryID}, + {typ: sqltypes.Uint32, gen: sqltypes.RandomGenerators[sqltypes.Uint32], col: collations.CollationBinaryID}, + {typ: sqltypes.Uint64, gen: sqltypes.RandomGenerators[sqltypes.Uint64], col: collations.CollationBinaryID}, + {typ: sqltypes.Float64, gen: sqltypes.RandomGenerators[sqltypes.Float64], col: collations.CollationBinaryID}, + {typ: sqltypes.VarChar, gen: sqltypes.RandomGenerators[sqltypes.VarChar], col: collations.CollationUtf8mb4ID}, + {typ: sqltypes.VarBinary, gen: sqltypes.RandomGenerators[sqltypes.VarBinary], col: collations.CollationBinaryID}, + {typ: sqltypes.Decimal, gen: sqltypes.RandomGenerators[sqltypes.Decimal], col: collations.CollationBinaryID, len: 20, prec: 10}, + {typ: sqltypes.TypeJSON, gen: sqltypes.RandomGenerators[sqltypes.TypeJSON], col: collations.CollationBinaryID}, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%v", tc.typ), func(t *testing.T) { + field := &querypb.Field{ + Type: tc.typ, + ColumnLength: uint32(tc.len), + Charset: uint32(tc.col), + Decimals: uint32(tc.prec), + } + weight := TinyWeighter(field, tc.col) + if weight == nil { + t.Fatalf("could not generate Tiny Weight function") + } + + items := make([]sqltypes.Value, 0, Length) + for i := 0; i < Length; i++ { + v := tc.gen() + weight(&v) + items = append(items, v) + } + + var fastComparisons int + var fullComparisons int + slices.SortFunc(items, func(a, b sqltypes.Value) int { + if cmp := a.TinyWeightCmp(b); cmp != 0 { + fastComparisons++ + return cmp + } + + cmp, err := NullsafeCompare(a, b, tc.col) + require.NoError(t, err) + + fullComparisons++ + return cmp + }) + + for i := 0; i < Length-1; i++ { + a := items[i] + b := items[i+1] + + cmp, err := NullsafeCompare(a, b, tc.col) + require.NoError(t, err) + + if cmp > 0 { + t.Fatalf("expected %v [pos=%d] to come after %v [pos=%d]\n%v | %032b\n%v | %032b", a, i, b, i+1, a, a.TinyWeight(), b, b.TinyWeight()) + } + } + + t.Logf("%d fast comparisons, %d full comparisons (%.02f%% were fast)", fastComparisons, fullComparisons, 100.0*float64(fastComparisons)/float64(fastComparisons+fullComparisons)) + }) + } +} + func TestWeightStrings(t *testing.T) { const Length = 1000 diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index ceda947e8bb..8baffdfde09 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -131,7 +131,7 @@ func init() { vindexes.Register("keyrange_lookuper_unique", newKeyRangeLookuperUnique) } -func createExecutorEnv(t testing.TB) (executor *Executor, sbc1, sbc2, sbclookup *sandboxconn.SandboxConn, ctx context.Context) { +func createExecutorEnvCallback(t testing.TB, eachShard func(shard, ks string, tabletType topodatapb.TabletType, conn *sandboxconn.SandboxConn)) (executor *Executor, ctx context.Context) { var cancel context.CancelFunc ctx, cancel = context.WithCancel(context.Background()) cell := "aa" @@ -166,19 +166,15 @@ func createExecutorEnv(t testing.TB) (executor *Executor, sbc1, sbc2, sbclookup } resolver := newTestResolver(ctx, hc, serv, cell) - sbc1 = hc.AddTestTablet(cell, "-20", 1, "TestExecutor", "-20", topodatapb.TabletType_PRIMARY, true, 1, nil) - sbc2 = hc.AddTestTablet(cell, "40-60", 1, "TestExecutor", "40-60", topodatapb.TabletType_PRIMARY, true, 1, nil) - // Create these connections so scatter queries don't fail. - _ = hc.AddTestTablet(cell, "20-40", 1, "TestExecutor", "20-40", topodatapb.TabletType_PRIMARY, true, 1, nil) - _ = hc.AddTestTablet(cell, "60-60", 1, "TestExecutor", "60-80", topodatapb.TabletType_PRIMARY, true, 1, nil) - _ = hc.AddTestTablet(cell, "80-a0", 1, "TestExecutor", "80-a0", topodatapb.TabletType_PRIMARY, true, 1, nil) - _ = hc.AddTestTablet(cell, "a0-c0", 1, "TestExecutor", "a0-c0", topodatapb.TabletType_PRIMARY, true, 1, nil) - _ = hc.AddTestTablet(cell, "c0-e0", 1, "TestExecutor", "c0-e0", topodatapb.TabletType_PRIMARY, true, 1, nil) - _ = hc.AddTestTablet(cell, "e0-", 1, "TestExecutor", "e0-", topodatapb.TabletType_PRIMARY, true, 1, nil) - // Below is needed so that SendAnyWherePlan doesn't fail + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} - sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) - _ = hc.AddTestTablet(cell, "2", 3, KsTestUnsharded, "0", topodatapb.TabletType_REPLICA, true, 1, nil) + for _, shard := range shards { + conn := hc.AddTestTablet(cell, shard, 1, KsTestSharded, shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + eachShard(shard, KsTestSharded, topodatapb.TabletType_PRIMARY, conn) + } + + eachShard("0", KsTestUnsharded, topodatapb.TabletType_PRIMARY, hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)) + eachShard("0", KsTestUnsharded, topodatapb.TabletType_REPLICA, hc.AddTestTablet(cell, "2", 3, KsTestUnsharded, "0", topodatapb.TabletType_REPLICA, true, 1, nil)) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) @@ -199,7 +195,21 @@ func createExecutorEnv(t testing.TB) (executor *Executor, sbc1, sbc2, sbclookup cancel() }) - return executor, sbc1, sbc2, sbclookup, ctx + return executor, ctx +} + +func createExecutorEnv(t testing.TB) (executor *Executor, sbc1, sbc2, sbclookup *sandboxconn.SandboxConn, ctx context.Context) { + executor, ctx = createExecutorEnvCallback(t, func(shard, ks string, tabletType topodatapb.TabletType, conn *sandboxconn.SandboxConn) { + switch { + case ks == KsTestSharded && shard == "-20": + sbc1 = conn + case ks == KsTestSharded && shard == "40-60": + sbc2 = conn + case ks == KsTestUnsharded && tabletType == topodatapb.TabletType_PRIMARY: + sbclookup = conn + } + }) + return } func createCustomExecutor(t testing.TB, vschema string) (executor *Executor, sbc1, sbc2, sbclookup *sandboxconn.SandboxConn, ctx context.Context) { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 62dec384e33..f3544c1362e 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -663,17 +663,8 @@ func TestStreamBuffering(t *testing.T) { } func TestStreamLimitOffset(t *testing.T) { - executor, sbc1, sbc2, _, _ := createExecutorEnv(t) - - // This test is similar to TestStreamUnsharded except that it returns a Result > 10 bytes, - // such that the splitting of the Result into multiple Result responses gets tested. - sbc1.SetResults([]*sqltypes.Result{{ - Fields: []*querypb.Field{ - {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, - {Name: "textcol", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, - {Name: "weight_string(id)", Type: sqltypes.VarBinary, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_BINARY_FLAG)}, - }, - Rows: [][]sqltypes.Value{{ + returnRows := map[string][]sqltypes.Row{ + "-20": [][]sqltypes.Value{{ sqltypes.NewInt32(1), sqltypes.NewVarChar("1234"), sqltypes.NULL, @@ -682,20 +673,30 @@ func TestStreamLimitOffset(t *testing.T) { sqltypes.NewVarChar("4567"), sqltypes.NULL, }}, - }}) - - sbc2.SetResults([]*sqltypes.Result{{ - Fields: []*querypb.Field{ - {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, - {Name: "textcol", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, - {Name: "weight_string(id)", Type: sqltypes.VarBinary, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_BINARY_FLAG)}, - }, - Rows: [][]sqltypes.Value{{ + "40-60": [][]sqltypes.Value{{ sqltypes.NewInt32(2), sqltypes.NewVarChar("2345"), sqltypes.NULL, }}, - }}) + "80-a0": [][]sqltypes.Value{{ + sqltypes.NewInt32(3), + sqltypes.NewVarChar("3456"), + sqltypes.NULL, + }}, + } + + executor, _ := createExecutorEnvCallback(t, func(shard, ks string, tabletType topodatapb.TabletType, conn *sandboxconn.SandboxConn) { + if ks == KsTestSharded { + conn.SetResults([]*sqltypes.Result{{ + Fields: []*querypb.Field{ + {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "textcol", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "weight_string(id)", Type: sqltypes.VarBinary, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_BINARY_FLAG)}, + }, + Rows: returnRows[shard], + }}) + } + }) results := make(chan *sqltypes.Result, 10) session := &vtgatepb.Session{ @@ -722,11 +723,11 @@ func TestStreamLimitOffset(t *testing.T) { }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt32(1), - sqltypes.NewVarChar("1234"), + sqltypes.NewInt32(3), + sqltypes.NewVarChar("3456"), }, { - sqltypes.NewInt32(1), - sqltypes.NewVarChar("foo"), + sqltypes.NewInt32(4), + sqltypes.NewVarChar("4567"), }}, } var gotResults []*sqltypes.Result diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index f51cbee047b..a04c4b00c2c 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -269,12 +269,11 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin for idx, order := range ordering.Order { typ, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr) - ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, engine.OrderByParams{ - Col: ordering.Offset[idx], - WeightStringCol: ordering.WOffset[idx], - Desc: order.Inner.Direction == sqlparser.DescOrder, - StarColFixedIndex: ordering.Offset[idx], - Type: typ, + ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, evalengine.OrderByParams{ + Col: ordering.Offset[idx], + WeightStringCol: ordering.WOffset[idx], + Desc: order.Inner.Direction == sqlparser.DescOrder, + Type: typ, }) } @@ -500,7 +499,7 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route eroute, err := routeToEngineRoute(ctx, op, hints) for _, order := range op.Ordering { typ, _ := ctx.SemTable.TypeForExpr(order.AST) - eroute.OrderBy = append(eroute.OrderBy, engine.OrderByParams{ + eroute.OrderBy = append(eroute.OrderBy, evalengine.OrderByParams{ Col: order.Offset, WeightStringCol: order.WOffset, Desc: order.Direction == sqlparser.DescOrder, diff --git a/go/vt/vttablet/tabletmanager/vdiff/utils.go b/go/vt/vttablet/tabletmanager/vdiff/utils.go index d756e6f6984..5904fd41795 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/utils.go +++ b/go/vt/vttablet/tabletmanager/vdiff/utils.go @@ -38,7 +38,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare for _, participant := range participants { prims = append(prims, participant) } - ob := make([]engine.OrderByParams, len(comparePKs)) + ob := make([]evalengine.OrderByParams, len(comparePKs)) for i, cpk := range comparePKs { weightStringCol := -1 // if the collation is nil or unknown, use binary collation to compare as bytes @@ -49,7 +49,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare if cpk.collation != collations.Unknown { t.Coll = cpk.collation } - ob[i] = engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: t} + ob[i] = evalengine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: t} } return &engine.MergeSort{ Primitives: prims, diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index d09232e8997..2d6e49b73d7 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -780,7 +780,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare for _, participant := range participants { prims = append(prims, participant) } - ob := make([]engine.OrderByParams, 0, len(comparePKs)) + ob := make([]evalengine.OrderByParams, 0, len(comparePKs)) for _, cpk := range comparePKs { weightStringCol := -1 // if the collation is nil or unknown, use binary collation to compare as bytes @@ -788,7 +788,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare if cpk.collation != collations.Unknown { t.Coll = cpk.collation } - ob = append(ob, engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: t}) + ob = append(ob, evalengine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, Type: t}) } return &engine.MergeSort{ Primitives: prims,