diff --git a/common/sql/statistics.go b/common/sql/statistics.go index 7b38ea61c..6e0ca05f1 100644 --- a/common/sql/statistics.go +++ b/common/sql/statistics.go @@ -15,8 +15,7 @@ type Statistics struct { RowCount int64 ColumnStatistics []ColumnStatistics - - //Selectivity, for plan statistics + // NOTE: above may be better as []any to work with a generic ColStatsT[T] } func (s *Statistics) String() string { @@ -27,17 +26,18 @@ func (s *Statistics) String() string { } for i, cs := range s.ColumnStatistics { fmt.Fprintf(&st, " Column %d:\n", i) - fmt.Fprintf(&st, " - Min/Max = %v / %v\n", cs.Min, cs.Max) + if _, ok := cs.Min.(string); ok { + fmt.Fprintf(&st, " - Min/Max = %.64s / %.64s\n", cs.Min, cs.Max) + } else { + fmt.Fprintf(&st, " - Min/Max = %v / %v\n", cs.Min, cs.Max) + } fmt.Fprintf(&st, " - NULL count = %v\n", cs.NullCount) + fmt.Fprintf(&st, " - Num MCVs = %v\n", len(cs.MCFreqs)) + fmt.Fprintf(&st, " - Histogram = {%v}\n", cs.Histogram) // it's any, but also a fmt.Stringer } return st.String() } -type ValCount struct { - Val any - Count int -} - // ColumnStatistics contains statistics about a column. type ColumnStatistics struct { NullCount int64 @@ -58,16 +58,133 @@ type ColumnStatistics struct { // MCVs []ValCount // MCVs map[cmp.Ordered] - // MCVals []any - // MCFreqs []int + MCVals []any // []T + MCFreqs []int + // ^ NOTE: MCVals was easier in many ways with just any.([]T), but other + // ways much more inconvenient, so we have it as an []any. May go back. // DistinctCount is harder. For example, unless we sub-sample // (deterministically), tracking distinct values could involve a data // structure with the same number of elements as rows in the table. - DistinctCount int64 + // or sophisticated a algo e.g. https://github.com/axiomhq/hyperloglog + // DistinctCount int64 + // alt, -1 means don't know + // AvgSize can affect cost as it changes the number of "pages" in postgres + // terminology, representing the size of data returned or processed by an + // expression. AvgSize int64 // maybe: length of text, length of array, otherwise not used for scalar? - // without histogram, we can make uniformity assumption to simplify the cost model - //Histogram []HistogramBucket + Histogram any // histo[T] +} + +/* Perhaps I should have started fresh with a fully generic column stats struct... under consideration. + +type ColStatsT[T any] struct { + NullCount int + Min T + MinCount int + Max T + MaxCount int + MCVals []T + MCFreqs []int +} +*/ + +func NewEmptyStatistics(numCols int) *Statistics { + return &Statistics{ + RowCount: 0, + ColumnStatistics: make([]ColumnStatistics, numCols), + } +} + +// ALL of the following types are from the initial query plan draft PR by Yaiba. +// Only TableRef gets much use in the current statistics work. An integration +// branch uses the other field and schema types a bit more, but it's easy to +// change any of this... + +// TableRef is a PostgreSQL-schema-qualified table name. +type TableRef struct { + Namespace string // e.g. schema in Postgres, derived from Kwil dataset schema DBID + Table string +} + +// String returns the fully qualified table name as "namepace.table" if +// Namespace is set, otherwise it just returns the table name. +func (t *TableRef) String() string { + if t.Namespace != "" { + return fmt.Sprintf("%s.%s", t.Namespace, t.Table) + } + return t.Table +} + +type ColumnDef struct { + Relation *TableRef + Name string +} + +func ColumnUnqualified(name string) *ColumnDef { + return &ColumnDef{Name: name} +} + +func Column(table *TableRef, name string) *ColumnDef { + return &ColumnDef{Relation: table, Name: name} +} + +// Field represents a field (column) in a schema. +type Field struct { + Rel *TableRef + + Name string + Type string + Nullable bool + HasIndex bool +} + +func NewField(name string, dataType string, nullable bool) Field { + return Field{Name: name, Type: dataType, Nullable: nullable} +} + +func NewFieldWithRelation(name string, dataType string, nullable bool, relation *TableRef) Field { + return Field{Name: name, Type: dataType, Nullable: nullable, Rel: relation} +} + +func (f *Field) Relation() *TableRef { + return f.Rel +} + +func (f *Field) QualifiedColumn() *ColumnDef { + return Column(f.Rel, f.Name) +} + +// Schema represents a database as a slice of all columns in all relations. See +// also Field. +type Schema struct { + Fields []Field +} + +func NewSchema(fields ...Field) *Schema { + return &Schema{Fields: fields} +} + +func NewSchemaQualified(relation *TableRef, fields ...Field) *Schema { + for i := range fields { + fields[i].Rel = relation + } + return &Schema{Fields: fields} +} + +func (s *Schema) String() string { + var fields []string + for _, f := range s.Fields { + fields = append(fields, fmt.Sprintf("%s/%s", f.Name, f.Type)) + } + return fmt.Sprintf("[%s]", strings.Join(fields, ", ")) +} + +type DataSource interface { + // Schema returns the schema for the underlying data source + Schema() *Schema + // Statistics returns the statistics of the data source. + Statistics() *Statistics } diff --git a/core/types/uuid.go b/core/types/uuid.go index a45c740b4..9a7b1fbc4 100644 --- a/core/types/uuid.go +++ b/core/types/uuid.go @@ -1,6 +1,7 @@ package types import ( + "bytes" "database/sql" "database/sql/driver" "encoding/json" @@ -14,6 +15,13 @@ var namespace = uuid.MustParse("cc1cd90f-b4db-47f4-b6df-4bbe5fca88eb") // UUID is a rfc4122 compliant uuidv5 type UUID [16]byte +// CmpUUID compares two UUIDs, returning 0 if equal, -1 if uv. +// This satisfies the comparison function required by many generic functions in +// the standard library and Kwil. +func CmpUUID(u, v UUID) int { + return bytes.Compare(u[:], v[:]) +} + // NewUUIDV5 generates a uuidv5 from a byte slice. // This is used to deterministically generate uuids. func NewUUIDV5(from []byte) *UUID { diff --git a/internal/sql/pg/db_live_test.go b/internal/sql/pg/db_live_test.go index 20a58f935..92bc2d462 100644 --- a/internal/sql/pg/db_live_test.go +++ b/internal/sql/pg/db_live_test.go @@ -7,6 +7,7 @@ import ( "cmp" "context" "fmt" + "os" "reflect" "slices" "strconv" @@ -28,7 +29,7 @@ import ( func TestMain(m *testing.M) { // UseLogger(log.NewStdOut(log.InfoLevel)) - m.Run() + os.Exit(m.Run()) } const ( @@ -266,13 +267,6 @@ func TestNULL(t *testing.T) { require.Equal(t, bvn.Int64, insB) } -// typeFor returns the reflect.Type that represents the type argument T. TODO: -// Remove this in favor of reflect.TypeFor when Go 1.22 becomes the minimum -// required version since it is not available in Go 1.21. -func typeFor[T any]() reflect.Type { - return reflect.TypeOf((*T)(nil)).Elem() -} - func TestScanVal(t *testing.T) { cols := []ColInfo{ {Pos: 1, Name: "a", DataType: "bigint", Nullable: false}, @@ -974,32 +968,6 @@ func TestTypeRoundtrip(t *testing.T) { } } -// mustDecimal panics if the string cannot be converted to a decimal. -func mustDecimal(s string) *decimal.Decimal { - d, err := decimal.NewFromString(s) - if err != nil { - panic(err) - } - return d -} - -func mustParseUUID(s string) *types.UUID { - u, err := types.ParseUUID(s) - if err != nil { - panic(err) - } - return u -} - -// mustUint256 panics if the string cannot be converted to a Uint256. -func mustUint256(s string) *types.Uint256 { - u, err := types.Uint256FromString(s) - if err != nil { - panic(err) - } - return u -} - func Test_DelayedTx(t *testing.T) { ctx := context.Background() diff --git a/internal/sql/pg/histo.go b/internal/sql/pg/histo.go new file mode 100644 index 000000000..cd741864d --- /dev/null +++ b/internal/sql/pg/histo.go @@ -0,0 +1,323 @@ +package pg + +// This file defines a generic histogram type, and many of the interpolation +// functions necessary to use it and create its bounds. + +import ( + "fmt" + "math" + "math/big" + "slices" + "strings" + + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/decimal" +) + +type histo[T any] struct { + // Normally given N boundaries, there would be N-1 bins/freqs. + // In this histogram, there are two special bins that catch all + // below the lowest bin boundary, and above the highest bin boundary. + // This is done because we need to decide the boundaries before the + // actual min and max are known, *and* we have to limit the number of bins. + // + // bounds: h_0 h_1 h_2 + // | | | + // freqs: f_0 | f_1 | f_2 | f_3 + // + // bins include values on (h_i,h_j] + + bounds []T // len n + freqs []int // len n+1 + + comp func(a, b T) int // -1 for ab + + // For more accurate summation up to a given value that is not also equal to + // one of the bounds, linear interpolation can be used. (TODO) + // interp func(f float64, a, b T) T + // interpF func(v, a, b T) float64 // (b-v)/(b-a) -- [a...v...b] +} + +// ins adds an observed value into the appropriate bin and returns the index of +// the updated bin. +func (h histo[T]) ins(v T) int { + loc, _ := slices.BinarySearchFunc(h.bounds, v, h.comp) + h.freqs[loc]++ + return loc +} + +// rm removes an observed value from the appropriate bin and returns the index +// of the updated bin. +func (h histo[T]) rm(v T) int { + loc, _ := slices.BinarySearchFunc(h.bounds, v, h.comp) + h.freqs[loc]-- + if f := h.freqs[loc]; f < 0 { + panic("accounting error -- negative bin count on rm") + } + return loc +} + +func (h *histo[T]) TotalCount() int { + var total int + for _, f := range h.freqs { + total += f + } + return total +} + +func (h histo[T]) String() string { + totalFreq := h.TotalCount() + return fmt.Sprintf("total = %d, bounds = %v, freqs = %v", + totalFreq, h.bounds, h.freqs) +} + +// ltTotal returns the cumulative frequency for values less than (or equal) to +// the given value. There will be a host of methods along these lines (e.g. +// range, greater, equal) to support selectivity computation. +// +// Presently this is a simple summation, but it should be updated to perform +// interpolation when a value is not also exactly a boundary. +func (h histo[T]) ltTotal(v T) int { //nolint:unused + loc, _ := slices.BinarySearchFunc(h.bounds, v, h.comp) + + var freq int + for i, f := range h.freqs { + + if i == loc { + /*if found { // no interp from next bin, just add this freq and break + freq += f + } else { // the value is somewhere between bins (before bin i) => linearly interpolate + freq += int(float64(f) * h.interp(v, h.bounds[i-1], h.bounds[i])) + }*/ + freq += f + break + } + + freq += f + } + return freq +} + +type SignedInt interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +type UnsignedInt interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 +} + +type Num interface { + SignedInt | UnsignedInt | float32 | float64 +} + +// The following interpolation functions are used to create histogram bounds +// from min and max values. They will also be used for partial summation. + +func interpNumF[T Num](f float64, a, b T) float64 { + return f*float64(b) + (1-f)*float64(a) + // return float64(a) + f*(float64(b)-float64(a)) +} + +func interpNum[T Num](f float64, a, b T) T { + return T(interpNumF(f, a, b)) + // return a + T(f*(float64(b)-float64(a))) +} + +func interpBig(f float64, a, b *big.Int) *big.Int { + if b.Cmp(a) <= 0 { + panic("b must not be less than a") + } + + // return a number on [a,b] computed via interpolation with f on [0,1] + // representing where between the two numbers. + + diff := new(big.Int).Sub(b, a) + frac := new(big.Float).SetPrec(big.MaxPrec).SetInt(diff) + frac = frac.Mul(frac, big.NewFloat(f).SetPrec(big.MaxPrec)) // f*(b-a) + + // a + frac + fracInt, _ := frac.Int(nil) + + return new(big.Int).Add(a, fracInt) +} + +// interpUint256 and interpDec below are both defined in terms of interpBig. + +func interpUint256(f float64, a, b *types.Uint256) *types.Uint256 { + c := interpBig(f, a.ToBig(), b.ToBig()) + d, err := types.Uint256FromBig(c) + if err != nil { // shouldn't even possible if a and b were not NaNs + panic(err.Error()) + } + return d +} + +func interpDec(f float64, a, b *decimal.Decimal) *decimal.Decimal { + c := interpBig(f, a.BigInt(), b.BigInt()) + // d, err := decimal.NewFromBigInt(c, a.Exp()) + // d.SetPrecisionAndScale(a.Precision(), a.Scale()) + d, err := decimal.NewExplicit(c.String(), a.Precision(), a.Scale()) + if err != nil { + panic(err.Error()) + } + // This is messier with Decimal's math methods, and I don't yet know if + // there'd be a benefit: + // bma, err := decimal.Sub(b, a) // etc + return d +} + +func interpBts(f float64, a, b []byte) []byte { + ai := big.NewInt(0).SetBytes(a) + bi := big.NewInt(0).SetBytes(b) + return interpBig(f, ai, bi).Bytes() +} + +func interpUUID(f float64, a, b types.UUID) types.UUID { + return types.UUID(interpBts(f, a[:], b[:])) +} + +// interpBool is largely nonsense and should be unused as there should not ever +// be a boolean histogram, but it is here for completeness. Could just panic... +func interpBool(f float64, a, b bool) bool { + if f < 0.5 { + return a + } + return b +} + +// interpString needs more consideration and testing. It MUST be consistent with +// lexicographic comparison a la strings.Compare (e.g. "x" > "abc"), so we can't +// just interpolate as bytes, which takes numerics semantics. For now we +// right-pad to make them the same length, then interpolate each character. +func interpString(f float64, a, b string) string { + if f < 0 || f > 1 { + panic("f out of range") + } + if a > b { + panic("a > b") + } + // Ensure both strings are the same length by padding with \0 + maxLen := len(a) + if len(b) > maxLen { + maxLen = len(b) + } + + a = padString(a, maxLen) + b = padString(b, maxLen) + + result := make([]byte, maxLen) + // Interpolate each character, wonky but remains legible + for i := range result { + charA, charB := a[i], b[i] + diff := float64(charB) - float64(charA) + result[i] = charA + byte(math.Round(f*diff)) + } + + // Convert the byte slice back to a string and trim any null characters + return strings.TrimRight(string(result), "\x00") +} + +func padString(s string, length int) string { + if len(s) < length { + return s + strings.Repeat("\x00", length-len(s)) + } + return s +} + +// makeBounds computes n+1 evenly spaced histogram bounds given the range [a,b]. +// However, if two bounds would be equal, as is possible when T is an integer +// and n is too small, there will be fewer bounds. +func makeBounds[T any](n int, a, b T, comp func(a, b T) int, interp func(f float64, a, b T) T) []T { + if comp(a, b) == 1 { + panic("no good") + } + bounds := make([]T, 0, n+1) + f := 1 / float64(n) + for i := 0; i <= n; i++ { + next := interp(f*float64(i), a, b) + if i > 0 && comp(next, bounds[len(bounds)-1]) == 0 { + continue // trying to over-subdivide, easy with integers + } + bounds = append(bounds, next) + } + return bounds +} + +func makeHisto[T any](bounds []T, comp func(a, b T) int) histo[T] { + return histo[T]{ + bounds: bounds, + freqs: make([]int, len(bounds)+1), + comp: comp, + } +} + +// interpNumRat provides a possibly more accurate approach, particularly when T +// is floating point. May remove. +func interpNumRat[T Num](f *big.Rat, a, b T) T { + return T(interpNumRatF(f, a, b)) +} + +func interpNumRatF[T Num](f *big.Rat, a, b T) float64 { + ra, _ := new(big.Float).SetPrec(big.MaxPrec).SetFloat64(float64(a)).Rat(nil) + rb, _ := new(big.Float).SetPrec(big.MaxPrec).SetFloat64(float64(b)).Rat(nil) + // ra := new(big.Rat).SetFloat64(float64(a)) + // rb := new(big.Rat).SetFloat64(float64(b)) + + // a + f*(b-a) + rbma := new(big.Rat).Sub(rb, ra) + fab := new(big.Rat).Mul(f, rbma) + res, exact := new(big.Rat).Add(ra, fab).Float64() + + if exact { + return res + } + + // just go with all float + ff, _ := f.Float64() + return interpNumF(ff, a, b) + + /*different groupings that may yield better results: + // f*b + (1-f)*a + + // 1-f + mf := big.NewRat(1, 1) + mf.Sub(mf, f) + // (1-r) * a + mf.Mul(mf, ra) + + // f*b + fb := new(big.Rat).Mul(f, rb) + // f*b + (1-r) * a + res, exact = fb.Add(fb, mf).Float64() + //fmt.Println(exact) + + if exact { return res } + + // f*b + a - f*a + fb = new(big.Rat).Mul(f, rb) + fb.Add(fb, ra) + fa := new(big.Rat).Mul(f, ra) + res, exact = fb.Sub(fb, fa).Float64() + //fmt.Println(exact) + + return res + */ +} + +func makeBoundsNum[T Num](n int, a, b T) []T { + if b <= a { + panic("no good") + } + bounds := make([]T, 0, n+1) + // f := 1 / float64(n) + for i := 0; i <= n; i++ { + fi := big.NewRat(int64(i), int64(n)) // 1/n + next := interpNumRat(fi, a, b) + // next := interpNum(f*float64(i), a, b) + if i > 0 && next == bounds[len(bounds)-1] { + continue // trying to over-subdivide, easy with integers + } + bounds = append(bounds, next) + } + return bounds +} diff --git a/internal/sql/pg/histo_test.go b/internal/sql/pg/histo_test.go new file mode 100644 index 000000000..bb1553fba --- /dev/null +++ b/internal/sql/pg/histo_test.go @@ -0,0 +1,538 @@ +package pg + +import ( + "bytes" + "cmp" + "math" + "math/big" + "math/rand" + "reflect" + "strings" + "testing" +) + +func Test_histo_explicit_int(t *testing.T) { + comp := cmp.Compare[int] + interp := interpNum[int] + bounds := makeBounds(10, int(-100), 800, comp, interp) + hInt := makeHisto(bounds, comp) + for i := 0; i < 1000; i++ { + hInt.ins(rand.Intn(1200) - 200) + } + + t.Log(hInt) +} + +func Test_histo_num(t *testing.T) { + bounds := makeBoundsNum[int](10, -100, 800) + hInt := makeHisto(bounds, cmp.Compare[int]) + for i := 0; i < 1000; i++ { + hInt.ins(rand.Intn(1200) - 200) + } + + t.Log(hInt) +} + +func Test_histo_float(t *testing.T) { + bounds := makeBoundsNum[float64](10, -100, 800) + hInt := makeHisto(bounds, cmp.Compare[float64]) + for i := 0; i < 1000; i++ { + hInt.ins(float64(rand.Intn(1200) - 200)) + } + // Without big.Rat impl: + // bounds = [-100 -10 80 170.00000000000003 260 350 440.00000000000006 530 620 710 800] + + t.Log(hInt) +} + +func Test_interpNumF(t *testing.T) { + tests := []struct { + name string + f float64 + a int + b int + expected float64 + }{ + {"Zero interpolation", 0, 10, 20, 10}, + {"Full interpolation", 1, 10, 20, 20}, + {"Mid interpolation", 0.5, 10, 20, 15}, + {"Quarter interpolation", 0.25, 10, 20, 12.5}, + {"Three-quarter interpolation", 0.75, 10, 20, 17.5}, + {"Negative numbers", 0.5, -10, 10, 0}, + {"Same numbers", 0.5, 5, 5, 5}, + {"Large numbers", 0.5, 1000000, 2000000, 1500000}, + {"Small fractional numbers", 0.1, 1, 2, 1.1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpNumF(tt.f, tt.a, tt.b) + if result != tt.expected { + t.Errorf("interpNumF(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_interpNumF_float64(t *testing.T) { + tests := []struct { + name string + f float64 + a float64 + b float64 + expected float64 + }{ + {"Fractional interpolation", 0.3, 1.5, 3.5, 2.1}, + {"Negative fractional interpolation", 0.7, -2.5, -1.5, -1.8}, + {"Zero to one interpolation", 0.5, 0, 1, 0.5}, + {"Very small numbers", 0.5, 1e-10, 2e-10, 1.5e-10}, + {"Very large numbers", 0.5, 1e10, 2e10, 1.5e10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpNumF(tt.f, tt.a, tt.b) + if !almostEqual(result, tt.expected, 1e-9) { + t.Errorf("interpNumF(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func almostEqual(a, b, tolerance float64) bool { + return math.Abs(a-b) <= tolerance +} + +func Test_interpNum(t *testing.T) { + tests := []struct { + name string + f float64 + a int + b int + expected int + }{ + {"Extreme interpolation near 1", 0.99, 0, 100, 99}, + {"Extreme interpolation near 0", 0.01, 0, 100, 1}, + {"Interpolation with negative and positive", 0.6, -50, 50, 10}, + {"Interpolation with both negative", 0.4, -100, -50, -80}, + {"Interpolation with large numbers", 0.75, 1000000, 2000000, 1750000}, + {"Zero interpolation with large difference", 0, -1000000, 1000000, -1000000}, + {"Full interpolation with large difference", 1, -1000000, 1000000, 1000000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpNum(tt.f, tt.a, tt.b) + if result != tt.expected { + t.Errorf("interpNum(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_interpNum_float64(t *testing.T) { + tests := []struct { + name string + f float64 + a float64 + b float64 + expected float64 + }{ + {"Interpolation with small fractional difference", 0.3, 1.1, 1.2, 1.13}, + {"Interpolation with large fractional difference", 0.7, 0.001, 0.1, 0.0703}, + {"Interpolation with negative fractionals", 0.4, -0.5, -0.1, -0.34}, + {"Extreme values near float64 limits", 0.5, -math.MaxFloat64 / 2, math.MaxFloat64 / 2, 0}, + {"Very small positive numbers", 0.6, 1e-15, 1e-14, 6.4e-15}, + {"Very small negative numbers", 0.8, -1e-14, -1e-15, -2.8e-15}, + {"Numbers close to each other", 0.5, 1.00000001, 1.00000002, 1.000000015}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpNum(tt.f, tt.a, tt.b) + if !almostEqual(result, tt.expected, 1e-9) { + t.Errorf("interpNum(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_makeBounds_int(t *testing.T) { + tests := []struct { + name string + n int + a int + b int + expected []int + }{ + { + name: "Five equal intervals", + n: 5, + a: 0, + b: 100, + expected: []int{0, 20, 40, 60, 80, 100}, + }, + { + name: "two intervals (3 bounds) with negative num", + n: 2, + a: -10, + b: 10, + expected: []int{-10, 0, 10}, + }, + { + name: "single interval", + n: 1, + a: 5, + b: 10, + expected: []int{5, 10}, + }, + { + name: "too small integer range", + n: 10, + a: 0, + b: 4, + expected: []int{0, 1, 2, 3, 4}, // integer forces fewer bounds + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := makeBoundsNum(tt.n, tt.a, tt.b) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("makeBounds(%v, %v, %v, interpNum[int]) = %v, want %v", tt.n, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_makeBounds_float64(t *testing.T) { + tests := []struct { + name string + n int + a float64 + b float64 + expected []float64 + }{ + { + name: "Four intervals with fractional values", + n: 4, + a: 0.5, + b: 2.5, + expected: []float64{0.5, 1.0, 1.5, 2.0, 2.5}, + }, + { + name: "Three intervals with very small numbers", + n: 3, + a: 1e-6, + b: 1e-5, + expected: []float64{1e-6, 4e-6, 7e-6, 1e-5}, + }, + { + name: "four intervals with negative fractional values", + n: 4, + a: -1.5, + b: 1.5, + expected: []float64{-1.5, -0.75, 0, 0.75, 1.5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := makeBoundsNum(tt.n, tt.a, tt.b) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("makeBounds(%v, %v, %v, interpNum[float64]) = %v, want %v", tt.n, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_interpNum_edge_cases(t *testing.T) { + tests := []struct { + name string + f float64 + a int + b int + expected int + }{ + {"Interpolation with f = 0", 0, 10, 20, 10}, + {"Interpolation with f = 1", 1, 10, 20, 20}, + {"Interpolation with f > 1", 1.5, 10, 20, 25}, + {"Interpolation with f < 0", -0.5, 10, 20, 5}, + {"Interpolation with a = b", 0.5, 15, 15, 15}, + {"Interpolation with very large numbers", 0.5, math.MaxInt32 / 2, math.MaxInt32, 1610612735}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpNum(tt.f, tt.a, tt.b) + if result != tt.expected { + t.Errorf("interpNum(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_interpBig(t *testing.T) { + tests := []struct { + name string + f float64 + a *big.Int + b *big.Int + expected *big.Int + }{ + { + name: "Simple interpolation", + f: 0.5, + a: big.NewInt(100), + b: big.NewInt(200), + expected: big.NewInt(150), + }, + { + name: "Zero interpolation", + f: 0, + a: big.NewInt(1000), + b: big.NewInt(2000), + expected: big.NewInt(1000), + }, + { + name: "Full interpolation", + f: 1, + a: big.NewInt(500), + b: big.NewInt(1500), + expected: big.NewInt(1500), + }, + { + name: "Interpolation with negative numbers", + f: 0.25, + a: big.NewInt(-1000), + b: big.NewInt(1000), + expected: big.NewInt(-500), + }, + { + name: "Interpolation with large numbers", + f: 0.5, + a: new(big.Int).Exp(big.NewInt(2), big.NewInt(100), nil), + b: new(big.Int).Exp(big.NewInt(2), big.NewInt(101), nil), + expected: new(big.Int).Add( + new(big.Int).Exp(big.NewInt(2), big.NewInt(100), nil), + new(big.Int).Div(new(big.Int).Exp(big.NewInt(2), big.NewInt(100), nil), big.NewInt(2)), + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpBig(tt.f, tt.a, tt.b) + if result.Cmp(tt.expected) != 0 { + t.Errorf("interpBig(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_interpBig_Panic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + + a := big.NewInt(100) + b := big.NewInt(50) + interpBig(0.5, a, b) +} + +func Test_interpStr(t *testing.T) { + tests := []struct { + name string + f float64 + a string + b string + expected string + }{ + { + name: "simple", + f: 0.5, + a: "abc", + b: "xyz", + expected: "mno", + }, + { + name: "zero f", + f: 0, + a: "hello", + b: "world", + expected: "hello", + }, + { + name: "very simple", + f: 0.5, + a: "a", + b: "c", + expected: "b", + }, + { + name: "full interp", + f: 1, + a: "a", + b: "c", + expected: "c", + }, + { + name: "Interpolation with empty strings", + f: 0.5, + a: "", + b: "test", + expected: ":3::", + }, + { + name: "Interpolation with unicode characters", + f: 0.5, + a: "αβγ", + b: "δεζ", + expected: "γδε", + }, + // { + // name: "Interpolation with f > 1", + // f: 1.5, + // a: "abc", + // b: "xyz", + // expected: "", + // }, + // { + // name: "Interpolation with f < 0", + // f: -0.5, + // a: "abc", + // b: "xyz", + // expected: "UVW", + // }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpString(tt.f, tt.a, tt.b) + if result != tt.expected { + t.Errorf("interpStr(%v, %q, %q) = %q, want %q", tt.f, tt.a, tt.b, result, tt.expected) + } + + // t.Log(result) + + if strings.Compare(tt.a, result) == 1 { + t.Errorf("%v not <= %v", tt.a, result) + } + if strings.Compare(result, tt.b) == 1 { + t.Errorf("%v not >) %v", tt.b, result) + } + }) + } +} + +func Test_interpBts(t *testing.T) { + tests := []struct { + name string + f float64 + a []byte + b []byte + expected []byte + }{ + { + name: "Simple interpolation", + f: 0.5, + a: []byte{0x00}, + b: []byte{0xFF}, + expected: []byte{0x7F}, + }, + { + name: "Zero interpolation", + f: 0, + a: []byte{0x10, 0x20}, + b: []byte{0x30, 0x40}, + expected: []byte{0x10, 0x20}, + }, + { + name: "Full interpolation", + f: 1, + a: []byte{0x00, 0x00}, + b: []byte{0xFF, 0xFF}, + expected: []byte{0xFF, 0xFF}, + }, + { + name: "Interpolation with different byte lengths", + f: 0.25, + a: []byte{0x01}, + b: []byte{0x01, 0x00}, + expected: []byte{0x40}, + }, + { + name: "Interpolation with large numbers", + f: 0.75, + a: []byte{0x00, 0x00, 0x00, 0x00}, + b: []byte{0xFF, 0xFF, 0xFF, 0xFF}, + expected: []byte{0xBF, 0xFF, 0xFF, 0xFF}, + }, + { + name: "Interpolation with empty byte slice", + f: 0.5, + a: []byte{}, + b: []byte{0x2}, + expected: []byte{0x1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpBts(tt.f, tt.a, tt.b) + if !bytes.Equal(result, tt.expected) { + t.Errorf("interpBts(%v, %v, %v) = %v, want %v", tt.f, tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func Test_interpBts_EdgeCases(t *testing.T) { + tests := []struct { + name string + f float64 + a []byte + b []byte + want []byte + }{ + { + name: "Interpolation with f > 1", + f: 1.5, + a: []byte{0x00}, + b: []byte{0xFF}, + want: []byte{1, 126}, + }, + { + name: "Interpolation with f < 0", + f: -0.5, + a: []byte{0x00}, + b: []byte{0xFF}, + want: []byte{127}, + }, + { + name: "Interpolation with very large byte slices", + f: 0.5, + a: make([]byte, 1000), + b: bytes.Repeat([]byte{0xFF}, 1000), + want: append([]byte{127}, bytes.Repeat([]byte{0xFF}, 999)...), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := interpBts(tt.f, tt.a, tt.b) + if len(result) == 0 { + t.Errorf("interpBts(%v, %v, %v) returned empty slice", tt.f, tt.a, tt.b) + } + if bytes.Compare(result, tt.a) < 0 || bytes.Compare(result, tt.b) > 0 { + t.Errorf("interpBts(%v, %v, %v) = %v, which is out of range", tt.f, tt.a, tt.b, result) + } + // t.Log(result) + if !bytes.Equal(result, tt.want) { + t.Errorf("wanted %x, got %x", tt.want, result) + } + }) + } +} diff --git a/internal/sql/pg/stats.go b/internal/sql/pg/stats.go index d0a2f9e96..b079262c6 100644 --- a/internal/sql/pg/stats.go +++ b/internal/sql/pg/stats.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "slices" + "sort" "strings" "github.com/jackc/pgx/v5/pgtype" @@ -16,9 +17,12 @@ import ( "github.com/kwilteam/kwil-db/core/types/decimal" ) -// RowCount gets a precise row count for the named fully qualified table. If the -// Executor satisfies the RowCounter interface, that method will be used -// directly. Otherwise a simple select query is used. +// statsCap is the limit on the number of MCVs and histogram bins when working +// with column statistics. Fixed for now, but we should consider making a stats +// field or settable another way. +const statsCap = 100 + +// RowCount gets a precise row count for the named fully qualified table. func RowCount(ctx context.Context, qualifiedTable string, db sql.Executor) (int64, error) { stmt := fmt.Sprintf(`SELECT count(1) FROM %s`, qualifiedTable) res, err := db.Execute(ctx, stmt) @@ -120,24 +124,75 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db pkCols[i] = row[0].(string) } - numCols := len(colInfo) - colTypes := make([]ColType, numCols) - for i := range colInfo { - colTypes[i] = colInfo[i].Type() + stmt := `SELECT * FROM ` + qualifiedTable + if len(pkCols) > 0 { + stmt += ` ORDER BY ` + strings.Join(pkCols, ",") + } + + colStats, err := colStatsInternal(ctx, nil, stmt, colInfo, db) + if err != nil { + return nil, err + } + + // I sunk a bunch of time into a two-pass experiment to more accurate MCVs + // and better histogram bounds, but this is quite costly. May remove. + + return colStats, nil // single ASC pass + // second DESC pass: + // return colStatsInternal(ctx, colStats, stmt, colInfo, db) +} + +func cmpBool(a, b bool) int { + if b { + if a { // true == true + return 0 + } + return -1 // false < true + } + if a { + return 1 // true > false + } + return 0 // false == false +} + +func cmpDecimal(val, mm *decimal.Decimal) int { + d, err := val.Cmp(mm) + if err != nil { + panic(fmt.Sprintf("%s: (nan decimal?) %v or %v", err, val, mm)) + } + return d +} + +func colStatsInternal(ctx context.Context, firstPass []sql.ColumnStatistics, + stmt string, colInfo []ColInfo, db sql.Executor) ([]sql.ColumnStatistics, error) { + + // The idea I'm testing with a two-pass scan is to collect mcvs from a + // reverse-order scan. This is pointless if the MCVs is not at capacity as + // nothing new can be added. + if firstPass != nil { + stmt += ` DESC` // otherwise ASC is implied with ORDER BY } - colStats := make([]sql.ColumnStatistics, numCols) + getLast := func(i int) *sql.ColumnStatistics { + if len(firstPass) != 0 { + return &firstPass[i] + } + return nil + } + + // Iterate over all rows (select *), scan into NULLable values like + // pgtype.Int8, then make statistics with Go native or Kwil types. - // iterate over all rows (select *) var scans []any - for _, col := range colInfo { + colTypes := make([]ColType, len(colInfo)) + for i, col := range colInfo { + colTypes[i] = colInfo[i].Type() scans = append(scans, col.scanVal()) } - stmt := `SELECT * FROM ` + qualifiedTable - if len(pkCols) > 0 { - stmt += ` ORDER BY ` + strings.Join(pkCols, ",") - } - err = QueryRowFunc(ctx, db, stmt, scans, + + colStats := make([]sql.ColumnStatistics, len(colInfo)) + + err := QueryRowFunc(ctx, db, stmt, scans, func() error { var err error for i, val := range scans { @@ -172,7 +227,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db } } - ins(stat, valInt, cmp.Compare[int64]) + ins(stat, getLast(i), valInt, cmp.Compare[int64], interpNum) case ColTypeText: // use string in stats valStr, null, ok := TextValue(val) // val.(string) @@ -184,7 +239,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db continue } - ins(stat, valStr, strings.Compare) + ins(stat, getLast(i), valStr, strings.Compare, interpString) case ColTypeByteA: // use []byte in stats var valBytea []byte @@ -220,7 +275,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db return fmt.Errorf("not bytea: %T", val) } - ins(stat, valBytea, bytes.Compare) + ins(stat, getLast(i), valBytea, bytes.Compare, interpBts) case ColTypeBool: // use bool in stats var b bool @@ -246,7 +301,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db return fmt.Errorf("invalid bool (%T)", val) } - ins(stat, b, cmpBool) + ins(stat, getLast(i), b, cmpBool, interpBool) // there should never be a boolean histogram! case ColTypeNumeric: // use *decimal.Decimal in stats var dec *decimal.Decimal @@ -298,7 +353,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db dec = &v2 } - ins(stat, dec, cmpDecimal) + ins(stat, getLast(i), dec, cmpDecimal, interpDec) case ColTypeUINT256: v, ok := val.(*types.Uint256) @@ -311,7 +366,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db continue } - ins(stat, v.Clone(), types.CmpUint256) + ins(stat, getLast(i), v.Clone(), types.CmpUint256, interpUint256) case ColTypeFloat: // we don't want, don't have var varFloat float64 @@ -353,7 +408,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db return fmt.Errorf("invalid float (%T)", val) } - ins(stat, varFloat, cmp.Compare[float64]) + ins(stat, getLast(i), varFloat, cmp.Compare[float64], interpNum) case ColTypeUUID: fallthrough // TODO @@ -369,35 +424,324 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db return nil, err } + // If this is a second pass, merge. See other comments about the two-pass + // approach. It is costly, complex, and hard to quantify benefit. Likely to remove! + for i, p := range firstPass { + if len(colStats[i].MCFreqs) == 0 { // nothing new was recorded for this column + colStats[i].MCFreqs = p.MCFreqs + colStats[i].MCVals = p.MCVals + continue + } + + // merge up the last mcvs. This is horribly inefficient for now: + // concat, build slice of sortable struct, sort by frequency, extract + // up to statsCap, re-sort by value. + freqs := append(colStats[i].MCFreqs, p.MCFreqs...) + vals := append(colStats[i].MCVals, p.MCVals...) + + type mcv struct { + freq int + val any + } + mcvs := make([]mcv, len(freqs)) + for i := range freqs { + mcvs[i] = mcv{freqs[i], vals[i]} + } + if len(mcvs) > statsCap { // drop the lowest frequency values + slices.SortFunc(mcvs, func(a, b mcv) int { + return cmp.Compare(b.freq, a.freq) // descending freq + }) + mcvs = mcvs[:statsCap] // mcvs = slices.Delete(mcvs, statsCap, len(mcvs)) + } + + valCompFun := compFun(vals[0]) // based on prototype value + slices.SortFunc(mcvs, func(a, b mcv) int { + return valCompFun(b.val, a.val) // ascending value + }) + + // extract the values and frequencies slices + freqs, vals = nil, nil // clear and preserve type + for i := range mcvs { + freqs = append(freqs, mcvs[i].freq) + vals = append(vals, mcvs[i].val) + if i == statsCap { + break + } + } + + /* ALT in-place with sort.Interface + if len(mcvs) > statsCap { // drop the lowest frequency values + sort.Sort(mcvDescendingFreq{ + vals: vals, + freqs: freqs, + }) + vals = vals[:statsCap] + freqs = freqs[:statsCap] + } + sort.Sort(mcvAscendingValue{ + vals: vals, + freqs: freqs, + comp: compFun(vals[0]), + }) */ + colStats[i].MCFreqs = freqs + colStats[i].MCVals = vals + } + return colStats, nil } -func cmpBool(a, b bool) int { - if b { - if a { // true == true - return 0 +/* xx will remove if we decide a two-pass scan isn't worth it +type mcvDescendingFreq struct { + vals []any + freqs []int +} + +func (m mcvDescendingFreq) Len() int { return len(m.vals) } + +func (m mcvDescendingFreq) Less(i int, j int) bool { + return m.freqs[i] < m.freqs[j] +} + +func (m mcvDescendingFreq) Swap(i int, j int) { + m.vals[i], m.vals[j] = m.vals[j], m.vals[i] + m.freqs[i], m.freqs[j] = m.freqs[j], m.freqs[i] +} + +type mcvAscendingValue struct { + vals []any + freqs []int + comp func(a, b any) int +} + +func (m mcvAscendingValue) Len() int { return len(m.vals) } + +func (m mcvAscendingValue) Less(i int, j int) bool { + return m.comp(m.vals[i], m.vals[j]) == 0 +} + +func (m mcvAscendingValue) Swap(i int, j int) { + m.vals[i], m.vals[j] = m.vals[j], m.vals[i] + m.freqs[i], m.freqs[j] = m.freqs[j], m.freqs[i] +} +*/ + +// the pain of []any vs. []T (in an any) +func wrapCompFun[T any](f func(a, b T) int) func(a, b any) int { + return func(a, b any) int { // must not be nil + return f(a.(T), b.(T)) + } +} + +// vs func compFun[T any]() func(a, b T) int +func compFun(val any) func(a, b any) int { + switch val.(type) { + case []byte: + return wrapCompFun(bytes.Compare) + case int64: + return wrapCompFun(cmp.Compare[int64]) + case float64: + return wrapCompFun(cmp.Compare[float64]) + case string: + return wrapCompFun(strings.Compare) + case bool: + return wrapCompFun(cmpBool) + case *decimal.Decimal: + return wrapCompFun(cmpDecimal) + case *types.Uint256: + return wrapCompFun(types.CmpUint256) + case types.UUID: + return wrapCompFun(types.CmpUint256) + + case decimal.DecimalArray: // TODO + case types.Uint256Array: // TODO + case []string: + case []int64: + } + + panic(fmt.Sprintf("no comp fun for type %T", val)) +} + +// The following functions perform a type switch to correctly handle null values +// and then dispatch to the generic ins/up functions with the appropriate +// comparison function for the underlying type: upColStatsWithInsert, +// upColStatsWithDelete, and upColStatsWithUpdate. + +// upColStatsWithInsert expects a value of the type created with a +// (*datatype).DeserializeChangeset method. +func upColStatsWithInsert(stats *sql.ColumnStatistics, val any) error { + if val == nil { + stats.NullCount++ + return nil + } + // INSERT + switch nt := val.(type) { + case []byte: + return ins(stats, nil, nt, bytes.Compare, interpBts) + case int64: + return ins(stats, nil, nt, cmp.Compare[int64], interpNum[int64]) + case float64: + return ins(stats, nil, nt, cmp.Compare[float64], interpNum[float64]) + case string: + return ins(stats, nil, nt, strings.Compare, interpString) + case bool: + return ins(stats, nil, nt, cmpBool, interpBool) + + case *decimal.Decimal: + if nt.NaN() { + stats.NullCount++ + return nil // ignore, don't put in stats } - return -1 // false < true + nt2 := *nt + return ins(stats, nil, &nt2, cmpDecimal, interpDec) + + case *types.Uint256: + if nt.Null { + stats.NullCount++ + return nil + } + + return ins(stats, nil, nt.Clone(), types.CmpUint256, interpUint256) + + case types.UUID: // TODO + + return ins(stats, nil, nt, types.CmpUUID, interpUUID) + + case decimal.DecimalArray: // TODO + case types.Uint256Array: // TODO + case []string: + case []int64: + + default: + return fmt.Errorf("unrecognized tuple column type %T", val) } - if a { - return 1 // true > false + + fmt.Printf("unhandled %T", val) + + return nil // known type, just no stats handling +} + +func upColStatsWithDelete(stats *sql.ColumnStatistics, old any) error { + // DELETE: + // - unset min max if removing it. reference mincount and maxcount to know + // - update null count + // - adjust mcvs / histogram + + if old == nil { + stats.NullCount-- + return nil } - return 0 // false == false + + switch nt := old.(type) { + case int64: + del(stats, nt, cmp.Compare[int64]) + case string: + del(stats, nt, strings.Compare) + case float64: + del(stats, nt, cmp.Compare[float64]) + case bool: + del(stats, nt, cmpBool) + case []byte: + del(stats, nt, bytes.Compare) + + case *decimal.Decimal: + if nt.NaN() { + stats.NullCount-- + return nil // ignore, don't put in stats + } + nt2 := *nt + + del(stats, &nt2, cmpDecimal) + + case *types.Uint256: + if nt.Null { + stats.NullCount-- + return nil + } + + return del(stats, nt.Clone(), types.CmpUint256) + + case types.UUID: + + return del(stats, nt, types.CmpUUID) + + case decimal.DecimalArray: // TODO + case types.Uint256Array: // TODO + case []string: + case []int64: + + default: + return fmt.Errorf("unrecognized tuple column type %T", old) + } + + return nil } -func cmpDecimal(val, mm *decimal.Decimal) int { - d, err := val.Cmp(mm) +func upColStatsWithUpdate(stats *sql.ColumnStatistics, old, up any) error { //nolint:unused + // update may or may not affect null count + err := upColStatsWithDelete(stats, old) if err != nil { - panic(fmt.Sprintf("%s: (nan decimal?) %v or %v", err, val, mm)) + return err } - return d + return upColStatsWithInsert(stats, up) +} + +// The following are the generic functions that handle the re-typed and non-NULL +// values from the upColStatsWith* functions above. + +// insMCVs attempts to insert a new value into the MCV set. If the set already +// includes the value, its frequency is incremented. If the value is new and the +// set is not yet at capacity, it is inserted at the appropriate location in the +// slices to keep them sorted by value The sorting is needed later when +// computing cumulative frequency with inequality conditions, and it allows +// locating known values in log time. +func insMCVs[T any](vals []any, freqs []int, val T, comp func(a, b T) int) ([]any, bool, []int) { + var spill bool + // sort.Search is much harder to use than slices.BinarySearchFunc but here we are + loc := sort.Search(len(vals), func(i int) bool { + v := vals[i].(T) + return comp(v, val) != -1 // v[i] >= val + }) + found := loc != len(vals) && comp(vals[loc].(T), val) == 0 + if found { + if comp(vals[loc].(T), val) != 0 { + panic("wrong loc") + } + freqs[loc]++ + } else if len(vals) < statsCap { + vals = slices.Insert(vals, loc, any(val)) + freqs = slices.Insert(freqs, loc, 1) + } else { + spill = true + } + return vals, spill, freqs } -func ins[T any](stats *sql.ColumnStatistics, val T, comp func(v, m T) int) error { - if stats.Min == nil { +// this is SOOO much easier with vals as a []T rather than []any. +// alas, that created pains in other places... might switch back. + +/*func insMCVs[T any](vals []T, freqs []int, val T, comp func(a, b T) int) ([]T, []int) { + loc, found := slices.BinarySearchFunc(vals, val, comp) + if found { + freqs[loc]++ + } else if len(vals) < statsCap { + vals = slices.Insert(vals, loc, val) + freqs = slices.Insert(freqs, loc, 1) + } + return vals, freqs +}*/ + +// ins is used to insert a non-NULL value, but it can be used in different contexts: +// (1) performing a full scan, where mcvSpill may be non-nil on a second pass, +// (2) maintaining stats estimate on an insert. del and update only apply in the +// latter context. +func ins[T any](stats, prev *sql.ColumnStatistics, val T, comp func(v, m T) int, + interp func(f float64, a, b T) T) error { + + switch mn := stats.Min.(type) { + case nil: // first observation stats.Min = val stats.MinCount = 1 - } else if mn, ok := stats.Min.(T); ok { + case T: switch comp(val, mn) { case -1: // new MINimum stats.Min = val @@ -405,24 +749,143 @@ func ins[T any](stats *sql.ColumnStatistics, val T, comp func(v, m T) int) error case 0: // another of the same stats.MinCount++ } - } else { + case unknown: // it was deleted, only full (re)scan can figure it out + default: return fmt.Errorf("invalid stats value type %T for tuple of type %T", val, stats.Min) } - if stats.Max == nil { + switch mn := stats.Max.(type) { + case nil: // first observation (also would have set Min above) stats.Max = val stats.MaxCount = 1 - } else if mx, ok := stats.Max.(T); ok { - switch comp(val, mx) { + case T: + switch comp(val, mn) { case 1: // new MAXimum stats.Max = val stats.MaxCount = 1 case 0: // another of the same stats.MaxCount++ } - } else { + case unknown: // it was deleted, only full (re)scan can figure it out + default: + return fmt.Errorf("invalid stats value type %T for tuple of type %T", val, stats.Min) + } + + if stats.MCVals == nil { + stats.MCVals = []any{} // insMCVs: = []T{val} ; stats.MCFreqs = []int{1} + } + + var missed bool + if prev == nil || len(prev.MCVals) < statsCap { // first pass of full scan, pointless second pass, or no-context insert + stats.MCVals, missed, stats.MCFreqs = insMCVs(stats.MCVals, stats.MCFreqs, val, comp) + // vals, freqs, _ := insMCVs(convSlice[T](stats.MCVals), stats.MCFreqs, val, comp) + // stats.MCVals, stats.MCFreqs = convSlice2(vals), freqs + } else { // a second pass in a full scan + // The freqs for MCVals are complete, representing the entire table. + // Fill the spills struct instead, but EXCLUDING vals in MCVals, + // allowing possibly higher counts in later rows. Caller should merge + // back to the MCVals/MCFreqs after the complete second pass. + + // When MCVals was an any (underlying []T) rather than a []any, we were + // able to simply assert to []T, but I switched to []any for other reasons. + // _, found := slices.BinarySearchFunc(stats.MCVals.([]T), val, comp) + + loc := sort.Search(len(prev.MCVals), func(i int) bool { + v := prev.MCVals[i].(T) + return comp(v, val) != -1 // v[i] >= val + }) + found := loc != len(prev.MCVals) && comp(prev.MCVals[loc].(T), val) == 0 + // _, found := slices.BinarySearchFunc(convSlice[T](prev.MCVals), val, comp) + if !found { // not in previous scan + stats.MCVals, missed, stats.MCFreqs = insMCVs(stats.MCVals, stats.MCFreqs, val, comp) + } // else ignore this value, already counted in prev pass + } + + // If the value was not included in the MCVs, it spills into the histogram. + if missed { + // we create the histogram only *after* MCVs have been collected so that + // the bounds can be chosen based on at least some observed values. + if stats.Histogram == nil { + var left, right T + if mn, ok := stats.Min.(T); ok { + left = mn + } else { + left = stats.MCVals[0].(T) + } + if mn, ok := stats.Max.(T); ok { + right = mn + } else { + right = stats.MCVals[len(stats.MCVals)-1].(T) + } + bounds := makeBounds(statsCap, left, right, comp, interp) + stats.Histogram = makeHisto(bounds, comp) + } + + h := stats.Histogram.(histo[T]) + h.ins(val) + } + + return nil +} + +// If we've eliminated all of a value via delete, then the real Min/Max is +// unknown, and it just cannot be used to affect selectivity. Further inserts +// also cannot be compared with an unknown min/max. Rescan is needed to identify +// the actual value. This type signals this case. +type unknown struct{} + +// del is used to delete a non-NULL value. +func del[T any](stats *sql.ColumnStatistics, val T, comp func(v, m T) int) error { + + switch mn := stats.Min.(type) { + case nil: // should not happen + return errors.New("nil Min on del") + case T: + if comp(val, mn) == 0 { + stats.MinCount-- + if stats.MinCount == 0 { + stats.Min = unknown{} + } + } + case unknown: // it was deleted, only full (re)scan can figure it out + default: + return fmt.Errorf("invalid stats value type %T for tuple of type %T", val, stats.Min) + } + + switch mn := stats.Max.(type) { + case nil: // should not happen + return errors.New("nil Max on del") + case T: + if comp(val, mn) == 0 { + stats.MaxCount-- + if stats.MaxCount == 0 { + stats.Max = unknown{} + } + } + case unknown: // it was deleted, only full (re)scan can figure it out + default: return fmt.Errorf("invalid stats value type %T for tuple of type %T", val, stats.Max) } + // Look for it in the MCVs + loc := sort.Search(len(stats.MCVals), func(i int) bool { + v := stats.MCVals[i].(T) + return comp(v, val) != -1 // v[i] >= val + }) + found := loc != len(stats.MCVals) && comp(stats.MCVals[loc].(T), val) == 0 + // loc, found := slices.BinarySearchFunc(convSlice[T](stats.MCVals), val, comp) + if found { + if stats.MCFreqs[loc] == 1 { + stats.MCVals = slices.Delete(stats.MCVals, loc, loc+1) + stats.MCFreqs = slices.Delete(stats.MCFreqs, loc, loc+1) + } else { + stats.MCFreqs[loc]-- + } + } else { + // adjust histogram freq-- + hist := stats.Histogram.(histo[T]) + hist.rm(val) + } + return nil } diff --git a/internal/sql/pg/stats_enc.go b/internal/sql/pg/stats_enc.go new file mode 100644 index 000000000..8b7757365 --- /dev/null +++ b/internal/sql/pg/stats_enc.go @@ -0,0 +1,79 @@ +package pg + +import ( + "encoding/gob" + "errors" + "io" + "reflect" + + "github.com/kwilteam/kwil-db/common/sql" +) + +// EncStats encodes a set of table statistics to an io.Writer. +func EncStats(w io.Writer, statsGroup map[sql.TableRef]*sql.Statistics) error { + enc := gob.NewEncoder(w) + // TableRef1,Statistics1,TableRef2,Statistics2,... + for tblRef, stats := range statsGroup { + err := enc.Encode(tblRef) + if err != nil { + return err + } + err = enc.Encode(stats) + if err != nil { + return err + } + } + return nil +} + +// DecStats decodes a set of table statistics from an io.Reader. Any type used +// must be registered with the gob package. +func DecStats(r io.Reader) (map[sql.TableRef]*sql.Statistics, error) { + dec := gob.NewDecoder(r) + + // Read until EOF + out := map[sql.TableRef]*sql.Statistics{} + for { + var tblRef sql.TableRef + if err := dec.Decode(&tblRef); err != nil { + if errors.Is(err, io.EOF) { + break // return out, nil + } + return nil, err + } + + stats := new(sql.Statistics) + if err := dec.Decode(&stats); err != nil { + return nil, err + } + + if stats.RowCount > 0 { + // gob will leave empty slice as nil, which is wrong for most fields + for i := range stats.ColumnStatistics { + cs := &stats.ColumnStatistics[i] + cs.Min = nilToEmptySlice(cs.Min) + cs.Max = nilToEmptySlice(cs.Max) + for i := range cs.MCVals { + cs.MCVals[i] = nilToEmptySlice(cs.MCVals[i]) + } + } + } + + out[tblRef] = stats + } + + return out, nil +} + +func nilToEmptySlice(s any) any { + rt := reflect.TypeOf(s) + if rt == nil { + return nil + } + if rt.Kind() == reflect.Slice && + reflect.ValueOf(s).IsNil() { + st := reflect.SliceOf(rt.Elem()) + return reflect.MakeSlice(st, 0, 0).Interface() + } + return s +} diff --git a/internal/sql/pg/stats_enc_test.go b/internal/sql/pg/stats_enc_test.go new file mode 100644 index 000000000..7f525eb00 --- /dev/null +++ b/internal/sql/pg/stats_enc_test.go @@ -0,0 +1,143 @@ +package pg + +import ( + "bytes" + "reflect" + "testing" + + "github.com/kwilteam/kwil-db/common/sql" +) + +func TestStatsEncoding(t *testing.T) { + var buf bytes.Buffer + + // tables slice so the test loop is ordered + var tblRefs []sql.TableRef + tblRefs = append(tblRefs, sql.TableRef{ // native Go type, int64 + Namespace: "asdf", + Table: "int64", + }) + tblRefs = append(tblRefs, sql.TableRef{ // Kwil type Decimal + Namespace: "asdf", + Table: "decimal", + }) + tblRefs = append(tblRefs, sql.TableRef{ // Kwil type Uint256 + Namespace: "asdf", + Table: "uint256", + }) + tblRefs = append(tblRefs, sql.TableRef{ // Kwil type UUID + Namespace: "nnnn", + Table: "uuid", + }) + tblRefs = append(tblRefs, sql.TableRef{ // byte slice + Namespace: "asdf", + Table: "[]byte", + }) + + in := map[sql.TableRef]*sql.Statistics{ + tblRefs[0]: { + RowCount: 1234, + ColumnStatistics: []sql.ColumnStatistics{ + { + NullCount: 42, + Min: int64(1), + MinCount: 2, + Max: int64(6), + MaxCount: 1, + MCVals: []any{int64(1), int64(1), int64(4), int64(5), int64(6)}, + MCFreqs: []int{2, 9, 3, 1, 1}, + }, + }, + }, + tblRefs[1]: { + RowCount: 1234, + ColumnStatistics: []sql.ColumnStatistics{ + { + NullCount: 42, + Min: mustDecimal("1"), + MinCount: 2, + Max: mustDecimal("6"), + MaxCount: 1, + MCVals: []any{mustDecimal("1"), mustDecimal("1"), mustDecimal("4"), mustDecimal("5"), mustDecimal("6")}, + MCFreqs: []int{2, 9, 3, 1, 1}, + }, + }, + }, + tblRefs[2]: { + RowCount: 1234, + ColumnStatistics: []sql.ColumnStatistics{ + { + NullCount: 42, + Min: mustUint256("1"), + MinCount: 2, + Max: mustUint256("6"), + MaxCount: 1, + MCVals: []any{mustUint256("1"), mustUint256("1"), mustUint256("4"), mustUint256("5"), mustUint256("6")}, + MCFreqs: []int{2, 9, 3, 1, 1}, + }, + }, + }, + tblRefs[3]: { // uuid + RowCount: 1645, + ColumnStatistics: []sql.ColumnStatistics{ + { + NullCount: 8, + Min: mustParseUUID("0000857c-8671-4f4e-99bd-fcc621f9d3d1"), + MinCount: 6, + Max: mustParseUUID("9000857c-8671-4f4e-99bd-fcc621f9d3d1"), + MaxCount: 789, + MCVals: []any{mustParseUUID("0000857c-8671-4f4e-99bd-fcc621f9d3d1"), + mustParseUUID("1000857c-8671-4f4e-99bd-fcc621f9d3d1"), + mustParseUUID("2000857c-8671-4f4e-99bd-fcc621f9d3d1"), + mustParseUUID("3000857c-8671-4f4e-99bd-fcc621f9d3d1"), + mustParseUUID("9000857c-8671-4f4e-99bd-fcc621f9d3d1"), + }, + MCFreqs: []int{6, 9, 3, 1, 789}, + }, + }, + }, + tblRefs[4]: { // []byte + RowCount: 88, + ColumnStatistics: []sql.ColumnStatistics{ + { + NullCount: 42, + Min: []byte{}, // important distinction with non-nil empty slice + MinCount: 2, + Max: []byte{0xff, 0xff, 0xff}, + MaxCount: 1, + MCVals: []any{[]byte{0}, []byte{1}, []byte{2}, []byte{0xff, 0xff, 0xff}}, + MCFreqs: []int{2, 9, 3, 1}, + }, + }, + }, + } + + err := EncStats(&buf, in) + if err != nil { + t.Fatal(err) + } + + bts := buf.Bytes() // t.Logf("encoding length = %d", len(bts)) + + rd := bytes.NewReader(bts) + out, err := DecStats(rd) + if err != nil { + t.Fatal(err) + } + + if len(out) != len(in) { + t.Fatal("maps length not the same") + } + + for tblRef, stat := range in { + outStat, have := out[tblRef] + if !have { + t.Fatalf("output stats lack table %v", tblRef) + } + // require.Equal(t, stat, outStat) + if !reflect.DeepEqual(stat, outStat) { + t.Fatalf("%v: output stats (%v) != input stats (%v)", + tblRef, stat, outStat) + } + } +} diff --git a/internal/sql/pg/stats_test.go b/internal/sql/pg/stats_test.go index 7dd64f693..da1298464 100644 --- a/internal/sql/pg/stats_test.go +++ b/internal/sql/pg/stats_test.go @@ -5,38 +5,87 @@ package pg import ( "context" "fmt" + "slices" + "strconv" "testing" + "github.com/kwilteam/kwil-db/common/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestTableStats(t *testing.T) { +func mkTestTableDB(t *testing.T) *DB { ctx := context.Background() - db, err := NewDB(ctx, cfg) if err != nil { t.Fatal(err) } - defer db.Close() + t.Cleanup(func() { + defer db.Close() + }) + return db +} - tx, err := db.BeginTx(ctx) +func mkStatsTestTableTx(t *testing.T, db *DB) sql.PreparedTx { + ctx := context.Background() + tx, err := db.BeginPreparedTx(ctx) + // tx, err := db.BeginTx(ctx) if err != nil { t.Fatal(err) } - defer tx.Rollback(ctx) - tbl := "colcheck" + t.Cleanup(func() { + defer tx.Rollback(ctx) + if t.Failed() { + db.AutoCommit(true) + db.Execute(ctx, `drop table if exists `+tbl) + } + }) + _, err = tx.Execute(ctx, `drop table if exists `+tbl) if err != nil { t.Fatal(err) } _, err = tx.Execute(ctx, `create table if not exists `+tbl+ - ` (a int8 primary key, b int4 default 42, c text, d bytea, e numeric(20,5), f int8[], g uint256, h uint256[])`) + ` (a int8 primary key, b int8 default 42, c text, d bytea, e numeric(20,5), f int4[], g uint256, h uint256[])`) if err != nil { t.Fatal(err) } + return tx +} + +func TestStatsUpdates(t *testing.T) { + ctx := context.Background() + db := mkTestTableDB(t) + txOuter := mkStatsTestTableTx(t, db) + + txOuter.Execute(ctx, "--ping") + + tx, err := txOuter.BeginTx(ctx) + require.NoError(t, err) + t.Cleanup(func() { + tx.Rollback(ctx) + }) + + // insert some stuff + tbl := `colcheck` + _, err = tx.Execute(ctx, `INSERT INTO `+tbl+` VALUES(0, 123, 'asdf')`) + require.NoError(t, err) + + err = tx.Commit(ctx) + require.NoError(t, err) + + // commit tx +} + +func TestTableStats(t *testing.T) { + ctx := context.Background() + db := mkTestTableDB(t) + tx := mkStatsTestTableTx(t, db) + + tbl := "colcheck" + cols, err := ColumnInfo(ctx, tx, "", tbl) if err != nil { t.Fatal(err) @@ -44,11 +93,11 @@ func TestTableStats(t *testing.T) { wantCols := []ColInfo{ {Pos: 1, Name: "a", DataType: "bigint", Nullable: false}, - {Pos: 2, Name: "b", DataType: "integer", Nullable: true, defaultVal: "42"}, + {Pos: 2, Name: "b", DataType: "bigint", Nullable: true, defaultVal: "42"}, {Pos: 3, Name: "c", DataType: "text", Nullable: true}, {Pos: 4, Name: "d", DataType: "bytea", Nullable: true}, {Pos: 5, Name: "e", DataType: "numeric", Nullable: true}, - {Pos: 6, Name: "f", DataType: "bigint", Array: true, Nullable: true}, + {Pos: 6, Name: "f", DataType: "integer", Array: true, Nullable: true}, {Pos: 7, Name: "g", DataType: "uint256", Nullable: true}, {Pos: 8, Name: "h", DataType: "uint256", Array: true, Nullable: true}, } @@ -75,8 +124,76 @@ func TestTableStats(t *testing.T) { fmt.Println(stats.ColumnStatistics[4].Max) } +// Test_scanSineBig is similar to Test_updates_demo, but actually uses a DB, +// inserting data into a table and testing the TableStats function. +func Test_scanSineBig(t *testing.T) { + // Build the full set of values + // sine wave with 100 samples per periods, 100 periods + const numUpdates = 40000 + const samplesPerPeriod = 100 + const ampl = 200.0 // larger => more integer discretization + const amplSteps = 10 // "noise" with small ampl variations between periods + const amplInc = 2.0 // each step adds a multiple of this to the amplitude + vals := makeTestVals(numUpdates, samplesPerPeriod, amplSteps, ampl, amplInc) + + ctx := context.Background() + + db := mkTestTableDB(t) + tx := mkStatsTestTableTx(t, db) + tbl := `colcheck` + + for i, val := range vals { + _, err := tx.Execute(ctx, `INSERT INTO `+tbl+` VALUES($1,$2,$3);`, + i, val, strconv.FormatInt(val, 10)) + require.NoError(t, err) + } + + stats, err := TableStats(ctx, "", tbl, tx) + require.NoError(t, err) + + require.True(t, stats.RowCount == numUpdates) + + // check the MCVs for the int8 column + col := stats.ColumnStatistics[1] + + require.Equal(t, len(col.MCFreqs), statsCap) + require.Equal(t, len(col.MCVals), statsCap) + + _, ok := col.MCVals[0].(int64) + require.True(t, ok, "wrong value type") + + valsT := convSliceAsserted[int64](col.MCVals) + require.True(t, slices.IsSorted(valsT)) + + t.Log(valsT) + t.Log(col.MCFreqs) + + var totalFreqMCVs int + for _, f := range col.MCFreqs { + totalFreqMCVs += f + } + fracMCVs := float64(totalFreqMCVs) / numUpdates + t.Log(fracMCVs) + + require.Greater(t, totalFreqMCVs, statsCap) // not just all ones + require.LessOrEqual(t, totalFreqMCVs, numUpdates) + + hist := col.Histogram.(histo[int64]) + t.Log(hist) + var totalFreqHist int + for _, f := range hist.freqs { + totalFreqHist += f + } + fracHists := float64(totalFreqHist) / numUpdates + t.Log(fracHists) + + t.Log(fracMCVs + fracHists) + + t.Log(col.Min.(int64), col.Max.(int64)) +} + /*func TestScanBig(t *testing.T) { -// This test is commented, but helpful for benchmarking performance with a large table. + // This test is commented, but helpful for benchmarking performance with a large table. ctx := context.Background() cfg := *cfg @@ -97,13 +214,13 @@ func TestTableStats(t *testing.T) { defer tx.Rollback(ctx) tbl := `giant` - cols, err := ColumnInfo(ctx, tx, tbl) + cols, err := ColumnInfo(ctx, tx, "", tbl) if err != nil { t.Fatal(err) } t.Logf("%#v", cols) - stats, err := TableStats(ctx, tbl, tx) + stats, err := TableStats(ctx, "", tbl, tx) if err != nil { t.Fatal(err) } diff --git a/internal/sql/pg/stats_update_test.go b/internal/sql/pg/stats_update_test.go new file mode 100644 index 000000000..1e5d42fd5 --- /dev/null +++ b/internal/sql/pg/stats_update_test.go @@ -0,0 +1,121 @@ +package pg + +import ( + "math" + "slices" + "testing" + + "github.com/kwilteam/kwil-db/common/sql" + "github.com/stretchr/testify/require" +) + +func makeTestVals(num, samplesPerPeriod, amplSteps int, ampl, amplInc float64) []int64 { + vals := make([]int64, num) + for i := 0; i < num; i++ { + p, f := math.Modf(float64(i) / float64(samplesPerPeriod)) + f *= 2 * math.Pi + pMod := math.Mod(p, float64(amplSteps)) + p = amplInc * pMod // small periodic variation (0,1,2) in amplitude + vals[i] = int64(math.Round((ampl + p) * math.Sin(f))) + } + return vals +} + +// Test_updates_demo tests manual updates to a ColumnStatistics with the generic +// up* functions, which is how statistics are kept updated in the logical +// replication stream. Test_scanSineBig tests a full scan with TableStats. +func Test_updates_demo(t *testing.T) { + stats := &sql.ColumnStatistics{} + + // Make data to ingest that: + // - uses the capacity of the MCVs + // - has many repeats + // - >90% (but not all) of values in MCVs + // sine wave with 100 samples per periods, 100 periods + const numUpdates = 10000 + const samplesPerPeriod = 100 + const ampl = 600.0 // larger => more integer discretization + const amplSteps = 3 // "noise" with small ampl variations between periods + const amplInc = 2.2 // each step adds a multiple of this to the amplitude + + // Build the full set of values + vals := makeTestVals(numUpdates, samplesPerPeriod, amplSteps, ampl, amplInc) + + // ensure the test data exceeds the MCV cap + fullCounts := make(map[int64]int) + for _, v := range vals { + fullCounts[v]++ + } + require.Greater(t, len(fullCounts), statsCap) + + // insert one at a time + for _, v := range vals { + require.NoError(t, upColStatsWithInsert(stats, v)) + } + + maxVal := slices.Max(vals) + + // min/max must be captured even if not in the MCVs + require.Equal(t, -maxVal, stats.Min) + require.Equal(t, maxVal, stats.Max) + + hist := stats.Histogram.(histo[int64]) // t.Log("histogram:", hist) + histCount0 := hist.TotalCount() + + // insert an outlier + require.NoError(t, upColStatsWithInsert(stats, int64(math.MaxInt64))) + require.Equal(t, int64(math.MaxInt64), stats.Max) // 9223372036854775807 + require.Equal(t, histCount0+1, hist.TotalCount()) // one more in the histogram + + // The MCVals slice should be sorted (ascending). + mcVals := convSliceAsserted[int64](stats.MCVals) + require.Equal(t, len(mcVals), statsCap) + require.True(t, slices.IsSorted(mcVals)) + // outlier must not be in MCVs, which is full + require.False(t, slices.Contains(mcVals, int64(math.MaxInt64))) + + // remove the outlier + require.NoError(t, upColStatsWithDelete(stats, int64(math.MaxInt64))) + require.Equal(t, unknown{}, stats.Max) + require.Equal(t, histCount0, hist.TotalCount()) // back to the original counts + + ltVal := int64(0) // WHERE v < 0, about half the values for a sine wave + loc, found := slices.BinarySearch(mcVals, ltVal) + + require.True(t, found) + require.Equal(t, ltVal, mcVals[loc]) + + // for "WHERE v <= 0", use loc++ (if found) + loc++ + + sumMCFreqs := func() int { + var mcvSum, negativeFreqTotal int + freqs := make([]float64, len(stats.MCFreqs)) + for i, f := range stats.MCFreqs { + freqs[i] = float64(f) / float64(numUpdates) + mcvSum += f + if i < loc { + negativeFreqTotal += f + } + } + + // t.Log("total freq of all MCVs:", float64(mcvSum)/float64(numUpdates)) + // t.Log("sum of freqs where v<=0:", float64(negativeFreqTotal)/float64(numUpdates)) + return mcvSum + } + + mcvSum := sumMCFreqs() + totalCounted := mcvSum + hist.TotalCount() + + // every value must be counted in either the MCVs array or the histogram + require.Equal(t, numUpdates, totalCounted) + + // remove vals one at a time + for _, v := range vals { + require.NoError(t, upColStatsWithDelete(stats, v)) + } + + mcvSum = sumMCFreqs() + require.Equal(t, 0, mcvSum) + require.Equal(t, 0, hist.TotalCount()) +} diff --git a/internal/sql/pg/system.go b/internal/sql/pg/system.go index 91c2f4641..81c90b437 100644 --- a/internal/sql/pg/system.go +++ b/internal/sql/pg/system.go @@ -6,11 +6,14 @@ package pg import ( "cmp" "context" + "encoding/gob" "errors" "fmt" + "reflect" "slices" "strconv" "strings" + "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" @@ -244,6 +247,111 @@ type ColInfo struct { defaultVal any } +func convSliceAsserted[T any](s []any) []T { + out := make([]T, len(s)) + for i := range s { + out[i] = s[i].(T) + } + return out +} + +func convSliceInterfaces[T any](s []T) []any { //nolint + out := make([]any, len(s)) + for i := range s { + out[i] = s[i] + } + return out +} + +// joinSlices is for the case when I use `vals any // []T` rather than `vals []any` +func joinSlices(s1, s2 any) any { //nolint + rt := reflect.TypeOf(s1) + if rt == nil || rt.Kind() != reflect.Slice { + panic("not a slice") + } + if rt.Elem() != reflect.TypeOf(s2).Elem() { + panic("different element types") + } + + rv1 := reflect.ValueOf(s1) + rv2 := reflect.ValueOf(s2) + l1, l2 := rv1.Len(), rv2.Len() + tot := l1 + l2 + + s := mkSlice(rt.Elem(), tot) + for i := 0; i < l1; i++ { + s.Index(i).Set(rv1.Index(i)) + } + for i := 0; i < l2; i++ { + s.Index(i + l1).Set(rv2.Index(i)) + } + return s.Interface() +} + +func mkSlice(rt reflect.Type, l int) reflect.Value { + st := reflect.SliceOf(rt) + return reflect.MakeSlice(st, l, l) +} + +// typeFor returns the reflect.Type that represents the type argument T. TODO: +// Remove this in favor of reflect.TypeFor when Go 1.22 becomes the minimum +// required version since it is not available in Go 1.21. +func typeFor[T any]() reflect.Type { //nolint + return reflect.TypeOf((*T)(nil)).Elem() +} + +func statsVal(ct ColType) any { + // return reflect.New(statsValType(ct)).Interface() + switch ct { + case ColTypeInt: + return int64(0) + case ColTypeText: + return string("") + case ColTypeBool: + return bool(false) + case ColTypeByteA: + return []byte{} + case ColTypeUUID: + return new(types.UUID) + case ColTypeNumeric: + return new(decimal.Decimal) + case ColTypeUINT256: + return new(types.Uint256) + case ColTypeFloat: + return float64(0) + case ColTypeTime: + return time.Time{} + default: + return nil + } +} + +func statsValType(ct ColType) reflect.Type { + return reflect.TypeOf(statsVal(ct)) + /*switch ct { + case ColTypeInt: + return typeFor[int64]() + case ColTypeText: + return typeFor[string]() + case ColTypeBool: + return typeFor[bool]() + case ColTypeByteA: + return typeFor[[]byte]() + case ColTypeUUID: + return typeFor[*types.UUID]() + case ColTypeNumeric: + return typeFor[*decimal.Decimal]() + case ColTypeUINT256: + return typeFor[*types.Uint256]() + case ColTypeFloat: + return typeFor[float64]() + case ColTypeTime: + return typeFor[time.Time]() + default: + return nil + }*/ +} + func scanVal(ct ColType) any { switch ct { case ColTypeInt: @@ -365,6 +473,13 @@ const ( ColTypeUnknown ColType = "unknown" ) +// register the custom types for gob decoding. +func init() { + gob.RegisterName("kwil_"+string(ColTypeUUID), statsVal(ColTypeUUID)) + gob.RegisterName("kwil_"+string(ColTypeNumeric), statsVal(ColTypeNumeric)) + gob.RegisterName("kwil_"+string(ColTypeUINT256), statsVal(ColTypeUINT256)) +} + func arrayType(ct ColType) ColType { switch ct { case ColTypeInt: @@ -567,11 +682,6 @@ func columnInfo(ctx context.Context, conn *pgx.Conn, schema, tbl string) ([]ColI // ColumnInfo attempts to describe the columns of a table in a specified // PostgreSQL schema. The results are **as reported by information_schema.column**. -// -// If the provided sql.Executor is also a ColumnInfoer, its ColumnInfo method -// will be used. This is primarily for testing with a mocked DB transaction. -// Otherwise, the Executor must be one of the transaction types created by this -// package, which provide access to the underlying DB connection. func ColumnInfo(ctx context.Context, tx sql.Executor, schema, tbl string) ([]ColInfo, error) { if ti, ok := tx.(conner); ok { conn := ti.Conn() diff --git a/internal/sql/pg/system_test.go b/internal/sql/pg/system_test.go index 93f6ef389..0ee409de7 100644 --- a/internal/sql/pg/system_test.go +++ b/internal/sql/pg/system_test.go @@ -1,8 +1,14 @@ package pg import ( + "reflect" "testing" + "time" + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/decimal" + + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -344,3 +350,190 @@ func TestSettingValidFnOR(t *testing.T) { }) } } + +func Test_joinSlices(t *testing.T) { + var s0 any = []int{} + // t.Logf("%T: %v", s0, s0) // []int: [] + + s1t := []int{1, 2, 3, 4} + var s1 any = s1t + s2t := []int{7, 8, 9} + var s2 any = s2t + + // var s3 []int // nope, s3 is any([]int) + s3 := joinSlices(s1, s2) + + assert.Equal(t, reflect.TypeOf(s0), reflect.TypeOf(s3)) + t.Log(reflect.TypeOf(s3)) + + s3t, ok := s3.([]int) + if !ok { + t.Fatalf("not a []int: %T", s3) + } + // t.Logf("%T: %v", s3, s3) // []int: [1 2 3 4 7 8 9] + + require.Len(t, s3t, len(s1t)+len(s2t)) + require.EqualValues(t, append(s1t, s2t...), s3t) +} + +func TestStatsVal(t *testing.T) { + tests := []struct { + name string + colType ColType + expected any + }{ + { + name: "ColTypeInt", + colType: ColTypeInt, + expected: int64(0), + }, + { + name: "ColTypeText", + colType: ColTypeText, + expected: "", + }, + { + name: "ColTypeBool", + colType: ColTypeBool, + expected: false, + }, + { + name: "ColTypeByteA", + colType: ColTypeByteA, + expected: []byte{}, + }, + { + name: "ColTypeUUID", + colType: ColTypeUUID, + expected: &types.UUID{}, + }, + { + name: "ColTypeNumeric", + colType: ColTypeNumeric, + expected: &decimal.Decimal{}, + }, + { + name: "ColTypeUINT256", + colType: ColTypeUINT256, + expected: &types.Uint256{}, + }, + { + name: "ColTypeFloat", + colType: ColTypeFloat, + expected: float64(0), + }, + { + name: "ColTypeTime", + colType: ColTypeTime, + expected: time.Time{}, + }, + { + name: "Unknown ColType", + colType: ColType("asdfasdf"), + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := statsVal(tt.colType) + require.IsType(t, tt.expected, result) + if tt.expected != nil { + require.Equal(t, reflect.TypeOf(tt.expected), reflect.TypeOf(result)) + } + }) + } +} +func TestStatsValType(t *testing.T) { + tests := []struct { + name string + colType ColType + expected reflect.Type + }{ + { + name: "ColTypeInt", + colType: ColTypeInt, + expected: reflect.TypeOf(int64(0)), + }, + { + name: "ColTypeText", + colType: ColTypeText, + expected: reflect.TypeOf(""), + }, + { + name: "ColTypeBool", + colType: ColTypeBool, + expected: reflect.TypeOf(false), + }, + { + name: "ColTypeByteA", + colType: ColTypeByteA, + expected: reflect.TypeOf([]byte(nil)), + }, + { + name: "ColTypeUUID", + colType: ColTypeUUID, + expected: reflect.TypeOf(&types.UUID{}), + }, + { + name: "ColTypeNumeric", + colType: ColTypeNumeric, + expected: reflect.TypeOf(&decimal.Decimal{}), + }, + { + name: "ColTypeUINT256", + colType: ColTypeUINT256, + expected: reflect.TypeOf(&types.Uint256{}), + }, + { + name: "ColTypeFloat", + colType: ColTypeFloat, + expected: reflect.TypeOf(float64(0)), + }, + { + name: "ColTypeTime", + colType: ColTypeTime, + expected: reflect.TypeOf(time.Time{}), + }, + { + name: "Unknown ColType", + colType: ColType("unknown"), + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := statsValType(tt.colType) + if tt.expected == nil { + require.Nil(t, result) + } else { + require.Equal(t, tt.expected, result) + } + }) + } +} + +func TestStatsValTypeConsistency(t *testing.T) { + allTypes := []ColType{ + ColTypeInt, + ColTypeText, + ColTypeBool, + ColTypeByteA, + ColTypeUUID, + ColTypeNumeric, + ColTypeUINT256, + ColTypeFloat, + ColTypeTime, + } + + for _, ct := range allTypes { + t.Run(string(ct), func(t *testing.T) { + valType := statsValType(ct) + val := statsVal(ct) + require.NotNil(t, valType) + require.NotNil(t, val) + require.Equal(t, valType, reflect.TypeOf(val)) + }) + } +} diff --git a/internal/sql/pg/types_test.go b/internal/sql/pg/types_test.go index b4c0142d6..fd16407d0 100644 --- a/internal/sql/pg/types_test.go +++ b/internal/sql/pg/types_test.go @@ -5,8 +5,37 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/decimal" ) +// mustDecimal panics if the string cannot be converted to a decimal. +func mustDecimal(s string) *decimal.Decimal { + d, err := decimal.NewFromString(s) + if err != nil { + panic(err) + } + return d +} + +func mustParseUUID(s string) *types.UUID { + u, err := types.ParseUUID(s) + if err != nil { + panic(err) + } + return u +} + +// mustUint256 panics if the string cannot be converted to a Uint256. +func mustUint256(s string) *types.Uint256 { + u, err := types.Uint256FromString(s) + if err != nil { + panic(err) + } + return u +} + func Test_ArrayEncodeDecode(t *testing.T) { arr := []string{"a", "b", "c"} res, err := serializeArray(arr, 4, func(s string) ([]byte, error) {