From 85a82e9e5738cddc020b840ca5beaf1521bf008e Mon Sep 17 00:00:00 2001 From: Jonathan Chappelow Date: Tue, 30 Jul 2024 11:25:22 -0500 Subject: [PATCH] pg: colinfo, scan vals, and QueryScan funcs This adds the ability to scan query results into provided variables instead of relying on pgx Row.Values() to choose the type. This provides some foundational components for table statistics collection. The sql.QueryScanner interface is the advanced version of Execute that uses caller-provided scan values and a function to run for each scanned row: // QueryScanner represents a type that provides the ability to execute an SQL // statement, where for each row: // // 1. result values are scanned into the variables in the scans slice // 2. the provided function is then called // // The function would typically capture the variables in the scans slice, // allowing it to operator on the values. For instance, append the values to // slices allocated by the caller, or perform reduction operations like // sum/mean/min/etc. // // NOTE: This method may end up being included in the Tx interface alongside // Executor since all of the concrete transaction implementations provided by // this package implement this method. type QueryScanner interface { QueryScanFn(ctx context.Context, stmt string, scans []any, fn func() error, args ...any) error } Each transaction type in the pg package satisfies the sql.QueryScanner interface. The pg.QueryRowFunc function executes an SQL statement, handling the rows and returned values as described by the sql.QueryScanner interface. The pg.QueryRowFuncAny is similar to pg.QueryRowFunc, except that no scan values slice is provided. The provided function is called for each row of the result. The caller does not determine the types of the Go variables in the values slice. In this way it behaves similar to Execute, but providing "for each row" semantics so that every row does not need to be loaded into memory. Table statistics collection: beginning with a simplified sql.Statistics struct based on the types proposed in the initial unmerged query cost branch, the pg package provides the following new methods aimed at the (relatively expensive) collection of ground truth table statistics: - RowCount provides an exact row count - colStats computes column-wise statistics - TableStats uses the above functions to build a *sql.Statistics for a table. These methods are will not be used routinely. We will have incremental updates, but there are cases where a full scan may be needed to obtain the ground truth statistics. pg: decimal and uint256 use pgNumericToDecimal helper Use the pgNumericToDecimal helper to reuse the logic to convert from pgtypes.Numeric to either our decimal.Decimal or types.Uint256 in the recent pgtype decoding added to the query helper for interpreting the values returned by row.Values() in pgx.CollectRows. types,decimal: sql scan/value for uint256 and decimal and arrays nulls with uint256 and decimal deps: update pgx module from 5.5.5 to 5.6.0 --- common/sql/sql.go | 19 ++ common/sql/statistics.go | 73 ++++++ core/types/decimal/decimal.go | 65 +++-- core/types/uint256.go | 77 ++++-- go.mod | 2 +- go.sum | 4 +- internal/sql/pg/conn.go | 7 + internal/sql/pg/db_live_test.go | 431 +++++++++++++++++++++++++++++++- internal/sql/pg/query.go | 93 ++++++- internal/sql/pg/repl_test.go | 45 ++-- internal/sql/pg/stats.go | 428 +++++++++++++++++++++++++++++++ internal/sql/pg/stats_test.go | 157 ++++++++++++ internal/sql/pg/system.go | 382 ++++++++++++++++++++++++++++ internal/sql/pg/tx.go | 27 ++ internal/sql/pg/types.go | 121 ++++----- test/go.mod | 2 +- test/go.sum | 4 +- 17 files changed, 1792 insertions(+), 145 deletions(-) create mode 100644 common/sql/statistics.go create mode 100644 internal/sql/pg/stats.go create mode 100644 internal/sql/pg/stats_test.go diff --git a/common/sql/sql.go b/common/sql/sql.go index 21b9a0d62..5dd00df02 100644 --- a/common/sql/sql.go +++ b/common/sql/sql.go @@ -35,6 +35,25 @@ type Executor interface { Execute(ctx context.Context, stmt string, args ...any) (*ResultSet, error) } +// QueryScanner represents a type that provides the ability to execute an SQL +// statement, where for each row: +// +// 1. result values are scanned into the variables in the scans slice +// 2. the provided function is then called +// +// The function would typically capture the variables in the scans slice, +// allowing it to operator on the values. For instance, append the values to +// slices allocated by the caller, or perform reduction operations like +// sum/mean/min/etc. +// +// NOTE: This method may end up being included in the Tx interface alongside +// Executor since all of the concrete transaction implementations provided by +// this package implement this method. +type QueryScanner interface { + QueryScanFn(ctx context.Context, stmt string, + scans []any, fn func() error, args ...any) error +} + // TxMaker is an interface that creates a new transaction. In the context of the // recursive Tx interface, is creates a nested transaction. type TxMaker interface { diff --git a/common/sql/statistics.go b/common/sql/statistics.go new file mode 100644 index 000000000..7b38ea61c --- /dev/null +++ b/common/sql/statistics.go @@ -0,0 +1,73 @@ +package sql + +// NOTE: this file is TRANSITIONAL! These types are lifted from the +// unmerged internal/engine/costs/datatypes package. + +import ( + "fmt" + "strings" +) + +// Statistics contains statistics about a table or a Plan. A Statistics can be +// derived directly from the underlying table, or derived from the statistics of +// its children. +type Statistics struct { + RowCount int64 + + ColumnStatistics []ColumnStatistics + + //Selectivity, for plan statistics +} + +func (s *Statistics) String() string { + var st strings.Builder + fmt.Fprintf(&st, "RowCount: %d", s.RowCount) + if len(s.ColumnStatistics) > 0 { + fmt.Fprintln(&st, "") + } + for i, cs := range s.ColumnStatistics { + fmt.Fprintf(&st, " Column %d:\n", i) + fmt.Fprintf(&st, " - Min/Max = %v / %v\n", cs.Min, cs.Max) + fmt.Fprintf(&st, " - NULL count = %v\n", cs.NullCount) + } + return st.String() +} + +type ValCount struct { + Val any + Count int +} + +// ColumnStatistics contains statistics about a column. +type ColumnStatistics struct { + NullCount int64 + + Min any + MinCount int + + Max any + MaxCount int + + // MCVs are the most common values. It should be sorted by the value. It + // should also be limited capacity, which means scan order has to be + // deterministic since we have to throw out same-frequency observations. + // (crap) Solution: multi-pass scan, merge lists, continue until no higher + // freq values observed? OR when capacity reached, use a histogram? Do not + // throw away MCVs, just start putting additional observations in to the + // histogram instead. + // MCVs []ValCount + // MCVs map[cmp.Ordered] + + // MCVals []any + // MCFreqs []int + + // DistinctCount is harder. For example, unless we sub-sample + // (deterministically), tracking distinct values could involve a data + // structure with the same number of elements as rows in the table. + DistinctCount int64 + + AvgSize int64 // maybe: length of text, length of array, otherwise not used for scalar? + + // without histogram, we can make uniformity assumption to simplify the cost model + //Histogram []HistogramBucket +} diff --git a/core/types/decimal/decimal.go b/core/types/decimal/decimal.go index a1388b1cd..ebab23ebe 100644 --- a/core/types/decimal/decimal.go +++ b/core/types/decimal/decimal.go @@ -119,7 +119,7 @@ func inferPrecisionAndScale(s string) (precision, scale uint16) { s = strings.TrimLeft(s, "-+") parts := strings.Split(s, ".") - // remove 0s from the left part, siince 001.23 is the same as 1.23 + // remove 0s from the left part, since 001.23 is the same as 1.23 parts[0] = strings.TrimLeft(parts[0], "0") intPart := uint16(len(parts[0])) @@ -279,9 +279,26 @@ func (d *Decimal) Sign() int { return d.dec.Sign() } +func (d Decimal) NaN() bool { + switch d.dec.Form { + case apd.NaN, apd.NaNSignaling: + return true + } + return false +} + +func (d Decimal) Inf() bool { + return d.dec.Form == apd.Infinite +} + // Value implements the database/sql/driver.Valuer interface. It converts d to a // string. func (d Decimal) Value() (driver.Value, error) { + // NOTE: we're currently (ab)using the NaN case to handle scanning of NULL + // values. Match that here. We may want something different though. + if d.dec.Form == apd.NaN { + return nil, nil + } return d.dec.Value() } @@ -289,7 +306,30 @@ var _ driver.Valuer = &Decimal{} // Scan implements the database/sql.Scanner interface. func (d *Decimal) Scan(src interface{}) error { - return d.dec.Scan(src) + if src == nil { + *d = Decimal{ + dec: apd.Decimal{Form: apd.NaN}, + } + return nil + } + + s, ok := src.(string) + if !ok { + var dec apd.Decimal + err := dec.Scan(src) + if err != nil { + return err + } + s = dec.String() + } + + // set scale and prec from the string + d2, err := NewFromString(s) + if err != nil { + return err + } + *d = *d2 + return nil } var _ sql.Scanner = &Decimal{} @@ -425,6 +465,7 @@ func Cmp(x, y *Decimal) (int64, error) { } return z.Int64() + // return x.dec.Cmp(&y.dec) } // CheckPrecisionAndScale checks if the precision and scale are valid. @@ -459,23 +500,3 @@ func (da DecimalArray) Value() (driver.Value, error) { } var _ driver.Valuer = (*DecimalArray)(nil) - -// Scan implements the sql.Scanner interface. -func (da *DecimalArray) Scan(src interface{}) error { - switch s := src.(type) { - case []string: - *da = make(DecimalArray, len(s)) - for i, str := range s { - d, err := NewFromString(str) - if err != nil { - return err - } - - (*da)[i] = d - } - - return nil - } - - return fmt.Errorf("cannot convert %T to DecimalArray", src) -} diff --git a/core/types/uint256.go b/core/types/uint256.go index 6a63d7d1c..6291a4057 100644 --- a/core/types/uint256.go +++ b/core/types/uint256.go @@ -1,6 +1,7 @@ package types import ( + "database/sql" "database/sql/driver" "fmt" "math/big" @@ -12,12 +13,13 @@ import ( // It is mostly a wrapper around github.com/holiman/uint256.Int, but includes // extra methods for usage in Postgres. type Uint256 struct { - uint256.Int + base uint256.Int // not exporting massive method set, which also has params and returns of holiman types + Null bool } // Uint256FromInt creates a new Uint256 from an int. func Uint256FromInt(i uint64) *Uint256 { - return &Uint256{Int: *uint256.NewInt(i)} + return &Uint256{base: *uint256.NewInt(i)} } // Uint256FromString creates a new Uint256 from a string. @@ -26,7 +28,7 @@ func Uint256FromString(s string) (*Uint256, error) { if err != nil { return nil, err } - return &Uint256{Int: *i}, nil + return &Uint256{base: *i}, nil } // Uint256FromBig creates a new Uint256 from a big.Int. @@ -40,8 +42,33 @@ func Uint256FromBytes(b []byte) (*Uint256, error) { return Uint256FromBig(bigInt) } +func (u Uint256) String() string { + return u.base.String() +} + +func (u Uint256) Bytes() []byte { + return u.base.Bytes() +} + +func (u Uint256) ToBig() *big.Int { + return u.base.ToBig() +} + func (u Uint256) MarshalJSON() ([]byte, error) { - return []byte(u.String()), nil + return []byte(u.base.String()), nil // ? json ? +} + +func (u *Uint256) Clone() *Uint256 { + v := *u + return &v +} + +func (u *Uint256) Cmp(v *Uint256) int { + return u.base.Cmp(&v.base) +} + +func CmpUint256(u, v *Uint256) int { + return u.Cmp(v) } func (u *Uint256) UnmarshalJSON(b []byte) error { @@ -50,16 +77,20 @@ func (u *Uint256) UnmarshalJSON(b []byte) error { return err } - u.Int = u2.Int + u.base = u2.base return nil } // Value implements the driver.Valuer interface. func (u Uint256) Value() (driver.Value, error) { + if u.Null { + return nil, nil + } return u.String(), nil } var _ driver.Valuer = Uint256{} +var _ driver.Valuer = (*Uint256)(nil) // Scan implements the sql.Scanner interface. func (u *Uint256) Scan(src interface{}) error { @@ -70,21 +101,28 @@ func (u *Uint256) Scan(src interface{}) error { return err } - u.Int = u2.Int + u.base = u2.base + u.Null = false + return nil + + case nil: + u.Null = true + u.base.Clear() return nil } return fmt.Errorf("cannot convert %T to Uint256", src) } -var _ driver.Valuer = (*Uint256)(nil) -var _ driver.Valuer = (*Uint256)(nil) +var _ sql.Scanner = (*Uint256)(nil) // Uint256Array is an array of Uint256s. type Uint256Array []*Uint256 // Value implements the driver.Valuer interface. func (ua Uint256Array) Value() (driver.Value, error) { + // Even when implementing pgtype.ArrayGetter we still need this, so that the + // pgx driver can use it's wrapSliceEncodePlan. strs := make([]string, len(ua)) for i, u := range ua { strs[i] = u.String() @@ -95,21 +133,8 @@ func (ua Uint256Array) Value() (driver.Value, error) { var _ driver.Valuer = (*Uint256Array)(nil) -// Scan implements the sql.Scanner interface. -func (ua *Uint256Array) Scan(src interface{}) error { - switch s := src.(type) { - case []string: - *ua = make(Uint256Array, len(s)) - for i, str := range s { - u, err := Uint256FromString(str) - if err != nil { - return err - } - - (*ua)[i] = u - } - return nil - } - - return fmt.Errorf("cannot convert %T to Uint256Array", src) -} +// Uint256Array is a slice of Scanners. pgx at least is smart enough to make +// this work automatically! +// Another approach is to implement pgx.ArraySetter and pgx.ArrayGetter like +// similar in effect to: +// type Uint256Array pgtype.FlatArray[*Uint256] diff --git a/go.mod b/go.mod index 6ac34a571..9a06e8958 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 - github.com/jackc/pgx/v5 v5.5.5 + github.com/jackc/pgx/v5 v5.6.0 github.com/jpillora/backoff v1.0.0 github.com/kwilteam/kwil-db/core v0.2.0 github.com/kwilteam/kwil-db/parse v0.2.0-beta.1 diff --git a/go.sum b/go.sum index d1fd7c1d0..bbe338159 100644 --- a/go.sum +++ b/go.sum @@ -245,8 +245,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= diff --git a/internal/sql/pg/conn.go b/internal/sql/pg/conn.go index 366ded8b4..86a9fdd63 100644 --- a/internal/sql/pg/conn.go +++ b/internal/sql/pg/conn.go @@ -292,6 +292,13 @@ func (p *Pool) Close() error { // BeginTx starts a read-write transaction. It is an error to call this twice // without first closing the initial transaction. func (p *Pool) BeginTx(ctx context.Context) (sql.Tx, error) { + return p.begin(ctx) +} + +// begin is the unexported version of BeginTx that returns a concrete type +// instead of an interface, which is required of the exported method to satisfy +// the sql.TxMaker interface. +func (p *Pool) begin(ctx context.Context) (*nestedTx, error) { tx, err := p.writer.BeginTx(ctx, pgx.TxOptions{ AccessMode: pgx.ReadWrite, IsoLevel: pgx.ReadCommitted, diff --git a/internal/sql/pg/db_live_test.go b/internal/sql/pg/db_live_test.go index ffc848e0b..20a58f935 100644 --- a/internal/sql/pg/db_live_test.go +++ b/internal/sql/pg/db_live_test.go @@ -4,8 +4,11 @@ package pg import ( "bytes" + "cmp" "context" "fmt" + "reflect" + "slices" "strconv" "strings" "sync" @@ -14,10 +17,13 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/core/types/decimal" - "github.com/stretchr/testify/require" - // "github.com/kwilteam/kwil-db/internal/conv" ) func TestMain(m *testing.M) { @@ -47,6 +53,411 @@ var ( } ) +func TestColumnInfo(t *testing.T) { + ctx := context.Background() + + db, err := NewPool(ctx, &cfg.PoolConfig) + require.NoError(t, err) + defer db.Close() + + tx, err := db.BeginTx(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + // Make a temporary table to describe with ColumnInfo. + + tbl := "colcheck" + _, err = tx.Execute(ctx, `drop table if exists `+tbl) + require.NoError(t, err) + _, err = tx.Execute(ctx, `create table if not exists `+tbl+ + ` (a int8 not null, b int4 default 42, c text, + d bytea, e numeric(20,5), f int8[], g uint256)`) + require.NoError(t, err) + + cols, err := ColumnInfo(ctx, tx, "", tbl) + if err != nil { + t.Fatal(err) + } + + wantCols := []ColInfo{ + {Pos: 1, Name: "a", DataType: "bigint", Nullable: false}, + {Pos: 2, Name: "b", DataType: "integer", Nullable: true, defaultVal: "42"}, + {Pos: 3, Name: "c", DataType: "text", Nullable: true}, + {Pos: 4, Name: "d", DataType: "bytea", Nullable: true}, + {Pos: 5, Name: "e", DataType: "numeric", Nullable: true}, + {Pos: 6, Name: "f", DataType: "bigint", Array: true, Nullable: true}, + {Pos: 7, Name: "g", DataType: "uint256", Nullable: true}, + } + + assert.Equal(t, wantCols, cols) // t.Logf("%#v", cols) +} + +func TestQueryRowFunc(t *testing.T) { + ctx := context.Background() + + db, err := NewPool(ctx, &cfg.PoolConfig) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.BeginTx(ctx) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(ctx) + + tbl := "colcheck" + _, err = tx.Execute(ctx, `drop table if exists `+tbl) + if err != nil { + t.Fatal(err) + } + _, err = tx.Execute(ctx, `create table if not exists `+tbl+ + ` (a int8 not null, b int4 default 42, c text, + d bytea, e numeric(20,3), f int8[], g uint256, h uint256[])`) + if err != nil { + t.Fatal(err) + } + + cols, err := ColumnInfo(ctx, tx, "", tbl) + if err != nil { + t.Fatal(err) + } + + // 10 * math.MaxUint64 + hugeIntStr := "184467440737095516150" + hugeInt, err := types.Uint256FromString(hugeIntStr) + require.NoError(t, err) + + stmt := fmt.Sprintf(`insert into %[1]s values (5, null, 'a', '\xabab', 12.5, `+ + `'{2,3,4}', %[2]s::uint256, '{%[2]s,4,3}'::uint256[])`, tbl, hugeIntStr) + _, err = tx.Execute(ctx, stmt) + if err != nil { + t.Fatal(err) + } + + // First get the scan values with (*ColInfo).scanVal. + + wantRTs := []reflect.Type{ + typeFor[*pgtype.Int8](), + typeFor[*pgtype.Int8](), + typeFor[*pgtype.Text](), + typeFor[*[]uint8](), + typeFor[*decimal.Decimal](), + typeFor[*pgtype.Array[pgtype.Int8]](), + typeFor[*types.Uint256](), + typeFor[*types.Uint256Array](), + } + + var scans []any + for i, col := range cols { + sv := col.scanVal() + // t.Logf("scanval: %v (%T)", sv, sv) + scans = append(scans, sv) + + gotRT := reflect.TypeOf(sv) + if wantRTs[i] != gotRT { + t.Errorf("wrong type %v, wanted %v", gotRT, wantRTs[i]) + } + } + + // Then use QueryRowFunc with the scan vals. + + wantDec, err := decimal.NewFromString("12.500") // numeric(x,3)! + require.NoError(t, err) + if wantDec.Scale() != 3 { + t.Fatalf("scale of decimal does not match column def: %v", wantDec) + } + + wantScans := []any{ + &pgtype.Int8{Int64: 5, Valid: true}, + &pgtype.Int8{Int64: 0, Valid: false}, + &pgtype.Text{String: "a", Valid: true}, + &[]uint8{0xab, 0xab}, + wantDec, // this seems way easier as long as we're internal: &pgtype.Numeric{Int: big.NewInt(1200000), Exp: -5, NaN: false, InfinityModifier: 0, Valid: true}, + &pgtype.Array[pgtype.Int8]{ + Elements: []pgtype.Int8{{Int64: 2, Valid: true}, {Int64: 3, Valid: true}, {Int64: 4, Valid: true}}, + Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, + Valid: true, + }, + hugeInt, + &types.Uint256Array{hugeInt, types.Uint256FromInt(4), types.Uint256FromInt(3)}, + } + + err = QueryRowFunc(ctx, tx, `SELECT * FROM `+tbl, scans, + func() error { + for i, val := range scans { + // t.Logf("%#v (%T)", val, val) + assert.Equal(t, wantScans[i], val) + } + return nil + }, + ) + if err != nil { + t.Fatal(err) + } +} + +func TestNULL(t *testing.T) { + ctx := context.Background() + + db, err := NewPool(ctx, &cfg.PoolConfig) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.begin(ctx) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(ctx) + + tbl := "colcheck" + _, err = tx.Execute(ctx, `drop table if exists `+tbl) + if err != nil { + t.Fatal(err) + } + _, err = tx.Execute(ctx, `create table if not exists `+tbl+` (a int8, b int4)`) + if err != nil { + t.Fatal(err) + } + + insB := int64(6) + _, err = tx.Execute(ctx, fmt.Sprintf(`insert into `+tbl+` values (null, %d)`, insB)) + if err != nil { + t.Fatal(err) + } + + sql := `select a, b from ` + tbl + + res, err := tx.Execute(ctx, sql) + require.NoError(t, err) + + // no type for NULL values, just a nil interface{} + a := res.Rows[0][0] + t.Logf("%v (%T)", a, a) // () + require.Equal(t, reflect.TypeOf(a), reflect.Type(nil)) + + // only non-NULL values get a type + b := res.Rows[0][1] + t.Logf("%v (%T)", b, b) // 6 (int64) + require.Equal(t, reflect.TypeOf(b), typeFor[int64]()) + + // Now with scan vals + + // Cannot select a NULL value with pointers to vanilla types + var av, bv int64 + scans := []any{&av, &bv} + err = tx.QueryScanFn(ctx, sql, scans, func() error { return nil }) + // require.Error(t, err) + require.ErrorContains(t, err, "cannot scan NULL into *int64") + + // Can Scan NULL values with pgtype.Int8 with a Valid bool field. + var avn, bvn pgtype.Int8 + scans = []any{&avn, &bvn} + err = tx.QueryScanFn(ctx, sql, scans, func() error { return nil }) + require.NoError(t, err) + + require.False(t, avn.Valid) // Valid=false for NULL + require.True(t, bvn.Valid) + + require.Equal(t, avn.Int64, int64(0)) + require.Equal(t, bvn.Int64, insB) +} + +// typeFor returns the reflect.Type that represents the type argument T. TODO: +// Remove this in favor of reflect.TypeFor when Go 1.22 becomes the minimum +// required version since it is not available in Go 1.21. +func typeFor[T any]() reflect.Type { + return reflect.TypeOf((*T)(nil)).Elem() +} + +func TestScanVal(t *testing.T) { + cols := []ColInfo{ + {Pos: 1, Name: "a", DataType: "bigint", Nullable: false}, + {Pos: 2, Name: "b", DataType: "integer", Nullable: true, defaultVal: "42"}, + {Pos: 3, Name: "c", DataType: "text", Nullable: true}, + {Pos: 4, Name: "d", DataType: "bytea", Nullable: true}, + {Pos: 5, Name: "e", DataType: "numeric", Nullable: true}, + {Pos: 6, Name: "f", DataType: "uint256", Nullable: true}, + + {Pos: 7, Name: "aa", DataType: "bigint", Array: true, Nullable: false}, + {Pos: 8, Name: "ba", DataType: "integer", Array: true, Nullable: true}, + {Pos: 9, Name: "ca", DataType: "text", Array: true, Nullable: true}, + {Pos: 10, Name: "da", DataType: "bytea", Array: true, Nullable: true}, + {Pos: 11, Name: "ea", DataType: "numeric", Array: true, Nullable: true}, + {Pos: 12, Name: "fa", DataType: "uint256", Array: true, Nullable: true}, + } + var scans []any + for _, col := range cols { + scans = append(scans, col.scanVal()) + } + // for _, val := range scans { t.Logf("%#v (%T)", val, val) } + + // want pointers to these base types + var ba []byte + var i8 pgtype.Int8 + var txt pgtype.Text + var num decimal.Decimal // pgtype.Numeric + var u256 types.Uint256 + + // want pointers to these slices for array types + // var ia []pgtype.Int8 + // var ta []pgtype.Text + // var baa [][]byte + // var na []pgtype.Numeric + var ia pgtype.Array[pgtype.Int8] + var ta pgtype.Array[pgtype.Text] + var baa pgtype.Array[[]byte] + var na decimal.DecimalArray // pgtype.Array[pgtype.Numeric] + var u256a types.Uint256Array + + wantScans := []any{&i8, &i8, &txt, &ba, &num, &u256, + &ia, &ia, &ta, &baa, &na, &u256a} + + assert.Equal(t, wantScans, scans) +} + +func TestQueryRowFuncAny(t *testing.T) { + ctx := context.Background() + + db, err := NewDB(ctx, cfg) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.BeginTx(ctx) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(ctx) + + tbl := "colcheck" + _, err = tx.Execute(ctx, `drop table if exists `+tbl) + if err != nil { + t.Fatal(err) + } + + _, err = tx.Execute(ctx, `create table if not exists `+tbl+ + ` (a int8 not null, b int4, c text, d bytea, e numeric(20,5), f int8[])`) + if err != nil { + t.Fatal(err) + } + numCols := 6 + + _, err = tx.Execute(ctx, `insert into `+tbl+ + ` values (5, null, 'a', '\xabab', 12, '{2,3,4}'), `+ + ` (9, 2, 'b', '\xee', 0.9876, '{99}')`) + if err != nil { + t.Fatal(err) + } + + wantTypes := []reflect.Type{ // same for each row scanned, when non-null + typeFor[int64](), + typeFor[int64](), + typeFor[string](), + typeFor[[]byte](), + typeFor[*decimal.Decimal](), + typeFor[[]int64](), + } + mustDec := func(s string) *decimal.Decimal { + d, err := decimal.NewFromString(s) + require.NoError(t, err) + return d + } + wantVals := [][]any{ + {int64(5), nil, "a", []byte{0xab, 0xab}, mustDec("12.00000"), []int64{2, 3, 4}}, + {int64(9), int64(2), "b", []byte{0xee}, mustDec("0.98760"), []int64{99}}, + } + + var rowNum int + err = QueryRowFuncAny(ctx, tx, `SELECT * FROM `+tbl, + func(vals []any) error { + require.Len(t, vals, numCols) + t.Logf("%#v", vals) // e.g. []interface {}{1, "a", "bigint", "YES", interface {}(nil)} + for i, v := range vals { + if v != nil { + require.Equal(t, wantTypes[i], reflect.TypeOf(v), + "it was %T not %v", v, wantTypes[i].String()) + } + require.Equal(t, wantVals[rowNum][i], v) + // t.Logf("%d: %v (%T)", i, v, v) + } + rowNum++ + return nil + }, + ) + if err != nil { + t.Fatal(err) + } + + var colInfo []ColInfo + + // To test QueryRowFuncAny, get some column info. + stmt := `SELECT ordinal_position, column_name, is_nullable + FROM information_schema.columns + WHERE table_name = '` + tbl + `' ORDER BY ordinal_position ASC` + numCols = 3 //based on stmt + + // NOTE: + // - OID 19 pertains to information_schema.sql_identifier, which scans as text + // - OID 1043 pertains to varchar, which can scan as text + wantTypes = []reflect.Type{ // same for each row scanned + typeFor[int64](), // ordinal_position + typeFor[string](), // column_name + typeFor[string](), // is_nullable has boolean semantics but values of "YES"/"NO" + } + wantVals = [][]any{ + {int64(1), "a", "NO"}, + {int64(2), "b", "YES"}, + {int64(3), "c", "YES"}, + {int64(4), "d", "YES"}, + {int64(5), "e", "YES"}, + {int64(6), "f", "YES"}, + } + + rowNum = 0 + err = QueryRowFuncAny(ctx, tx, stmt, func(vals []any) error { + require.Len(t, vals, numCols) + // t.Logf("%#v", vals) // e.g. []interface {}{1, "a", "bigint", "YES", interface {}(nil)} + for i, v := range vals { + require.Equal(t, reflect.TypeOf(v), wantTypes[i]) + require.Equal(t, v, wantVals[rowNum][i]) + // t.Logf("%d: %v (%T)", i, v, v) + } + rowNum++ + return nil + }) + if err != nil { + t.Fatal(err) + } + + // Now the QueryScanFn method and QueryScanner interface with scan vars. + scanner := tx.(sql.QueryScanner) + var pos int + var colName, isNullable string + scans := []any{&pos, &colName, &isNullable} + err = scanner.QueryScanFn(ctx, stmt, scans, func() error { + colInfo = append(colInfo, ColInfo{ + Pos: pos, + Name: colName, + Nullable: strings.EqualFold(isNullable, "yes"), + }) + return nil + }) + if err != nil { + t.Fatal(err) + } + + slices.SortFunc(colInfo, func(a, b ColInfo) int { + return cmp.Compare(a.Pos, b.Pos) + }) + + // now actually check the expected values! +} + // TestRollbackPreparedTxns tests the rollbackPreparedTxns in the following // cases: // @@ -543,6 +954,22 @@ func TestTypeRoundtrip(t *testing.T) { require.Len(t, res.Rows[0], 1) require.EqualValues(t, want, res.Rows[0][0]) + + // verify NULL value handling + _, err = tx.Execute(ctx, "DELETE FROM test", QueryModeExec) + require.NoError(t, err) + + _, err = tx.Execute(ctx, "INSERT INTO test (val) VALUES (NULL)") + require.NoError(t, err) + + res, err = tx.Execute(ctx, "SELECT val FROM test", QueryModeExec) + require.NoError(t, err) + + require.Len(t, res.Columns, 1) + require.Len(t, res.Rows, 1) + require.Len(t, res.Rows[0], 1) + + require.EqualValues(t, nil, res.Rows[0][0]) }) } } diff --git a/internal/sql/pg/query.go b/internal/sql/pg/query.go index cd2ecf573..d9d7a9505 100644 --- a/internal/sql/pg/query.go +++ b/internal/sql/pg/query.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/kwilteam/kwil-db/common/sql" @@ -182,13 +183,12 @@ func query(ctx context.Context, oidToDataType map[uint32]*datatype, cq connQuery resSet := &sql.ResultSet{} var oids []uint32 for _, colInfo := range rows.FieldDescriptions() { - // fmt.Println(colInfo.DataTypeOID, colInfo.DataTypeSize) - // NOTE: if the column Name is "?column?", then colInfo.TableOID is // probably zero, meaning not a column of a table, e.g. the result of an // aggregate function, or just returning the a bound argument directly. // AND no AS was used. resSet.Columns = append(resSet.Columns, colInfo.Name) + // NOTE: for a domain (alias) this will be the OID of the underlying type oids = append(oids, colInfo.DataTypeOID) } @@ -235,3 +235,92 @@ func queryTx(ctx context.Context, oidToDataType map[uint32]*datatype, dbTx txBeg return resSet, err } + +func queryRowFunc(ctx context.Context, conn *pgx.Conn, sql string, + scans []any, fn func() error, args ...any) error { + rows, _ := conn.Query(ctx, sql, args...) + _, err := pgx.ForEachRow(rows, scans, fn) + return err +} + +// QueryRowFunc will attempt to execute an SQL statement, handling the rows and +// returned values as described by the sql.QueryScanner interface. If the +// provided Executor is also a sql.QueryScanner, that method will be used, +// otherwise, it will attempt to use the underlying DB connection. The latter is +// supported for all concrete transaction types in this package as well as +// instances of the pgx.Tx interface. +func QueryRowFunc(ctx context.Context, tx sql.Executor, stmt string, + scans []any, fn func() error, args ...any) error { + switch ti := tx.(type) { + case sql.QueryScanner: + return ti.QueryScanFn(ctx, stmt, scans, fn, args...) + case conner: + conn := ti.Conn() + return queryRowFunc(ctx, conn, stmt, scans, fn, args...) + } + return errors.New("cannot query with scan values") +} + +// QueryRowFuncAny is similar to QueryRowFunc, except that no scan values slice +// is provided. The provided function is called for each row of the result. The +// caller does not determine the types of the Go variables in the values slice. +// In this way it behaves similar to Execute, but providing "for each row" +// functionality so that every row does not need to be loaded into memory. See +// also QueryRowFunc, which allows the caller to provide a scan values slice. +func QueryRowFuncAny(ctx context.Context, tx sql.Executor, stmt string, + fn func([]any) error, args ...any) error { + conner, ok := tx.(conner) + if !ok { + return errors.New("no conn access") + } + conn := conner.Conn() + return queryRowFuncAny(ctx, conn, stmt, fn, args...) +} + +func queryRowFuncAny(ctx context.Context, conn *pgx.Conn, stmt string, + fn func(vals []any) error, args ...any) error { + oidTypes := oidTypesMap(conn.TypeMap()) + + rows, _ := conn.Query(ctx, stmt, args...) + fields := rows.FieldDescriptions() + var oids []uint32 + for _, f := range fields { + // NOTE: for a domain (constrained alias) this will be the OID of the underlying type + oids = append(oids, f.DataTypeOID) + } + defer rows.Close() + + for rows.Next() { + pgxVals, err := rows.Values() + if err != nil { + return err + } + + // Decode the values into Kwil or native types. + decVals := make([]any, len(pgxVals)) + for i, pgVal := range pgxVals { + decVal, err := decodeFromPGVal(pgVal, oids[i], oidTypes) + if err != nil { + if !errors.Is(err, ErrUnsupportedOID) { + return err + } + + switch pgVal.(type) { // let native (sql/driver.Value) types pass + case int64, float64, bool, []byte, string, time.Time, nil: + default: // reject anything else unrecognized + return err + } + decVal = pgVal // use as-is + } + + decVals[i] = decVal + } + + err = fn(decVals) + if err != nil { + return err + } + } + + return rows.Err() +} diff --git a/internal/sql/pg/repl_test.go b/internal/sql/pg/repl_test.go index d5650f79e..be3d49d4d 100644 --- a/internal/sql/pg/repl_test.go +++ b/internal/sql/pg/repl_test.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "encoding/hex" - "errors" "fmt" "sync" "testing" @@ -80,34 +79,26 @@ func Test_repl(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for { - select { - case cid := <-commitChan: - _, commitHash, err := decodeCommitPayload(cid) - if err != nil { - t.Errorf("invalid commit payload encoding: %v", err) - return - } - // t.Logf("Commit HASH: %x\n", commitHash) - if !bytes.Equal(commitHash, wantCommitHash) { - t.Errorf("commit hash mismatch, got %x, wanted %x", commitHash, wantCommitHash) - } - quit() - case err := <-errChan: - if errors.Is(err, context.Canceled) { - return - } - if errors.Is(err, context.DeadlineExceeded) { - t.Error("timeout") - return - } - if err != nil { - t.Error(err) - quit() - } + defer quit() + + for cid := range commitChan { + _, commitHash, err := decodeCommitPayload(cid) + if err != nil { + t.Errorf("invalid commit payload encoding: %v", err) return } + // t.Logf("Commit HASH: %x\n", commitHash) + if !bytes.Equal(commitHash, wantCommitHash) { + t.Errorf("commit hash mismatch, got %x, wanted %x", commitHash, wantCommitHash) + } + + return // receive only once in this test } + + // commitChan was closed before receive (not expected in this test) + t.Error(<-errChan) + + return }() tx, err := connQ.Begin(ctx) @@ -131,6 +122,6 @@ func Test_repl(t *testing.T) { t.Fatal(err) } - wg.Wait() + wg.Wait() // to receive the commit id or an error connQ.Close(ctx) } diff --git a/internal/sql/pg/stats.go b/internal/sql/pg/stats.go new file mode 100644 index 000000000..d0a2f9e96 --- /dev/null +++ b/internal/sql/pg/stats.go @@ -0,0 +1,428 @@ +package pg + +import ( + "bytes" + "cmp" + "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/kwilteam/kwil-db/common/sql" + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/decimal" +) + +// RowCount gets a precise row count for the named fully qualified table. If the +// Executor satisfies the RowCounter interface, that method will be used +// directly. Otherwise a simple select query is used. +func RowCount(ctx context.Context, qualifiedTable string, db sql.Executor) (int64, error) { + stmt := fmt.Sprintf(`SELECT count(1) FROM %s`, qualifiedTable) + res, err := db.Execute(ctx, stmt) + if err != nil { + return 0, fmt.Errorf("unable to count rows: %w", err) + } + if len(res.Rows) != 1 || len(res.Rows[0]) != 1 { + return 0, errors.New("exactly one value not returned by row count query") + } + count, ok := sql.Int64(res.Rows[0][0]) + if !ok { + return 0, fmt.Errorf("no row count for %s", qualifiedTable) + } + return count, nil +} + +// TableStatser is an interface that the implementation of a sql.Executor may +// implement. +type TableStatser interface { + TableStats(ctx context.Context, schema, table string) (*sql.Statistics, error) +} + +// TableStats collects deterministic statistics for a table. If schema is empty, +// the "public" schema is assumed. This method is used to obtain the ground +// truth statistics for a table; incremental statistics updates should be +// preferred when possible. If the sql.Executor implementation is a +// TableStatser, it's method is used directly. This is primarily to allow a stub +// DB for testing. +func TableStats(ctx context.Context, schema, table string, db sql.Executor) (*sql.Statistics, error) { + if ts, ok := db.(TableStatser); ok { + return ts.TableStats(ctx, schema, table) + } + + if schema == "" { + schema = "public" + } + qualifiedTable := schema + "." + table + + count, err := RowCount(ctx, qualifiedTable, db) + if err != nil { + return nil, err + } + // TODO: We needs a schema-table stats database so we don't ever have to do + // a full table scan for column stats. + + colInfo, err := ColumnInfo(ctx, db, schema, table) + if err != nil { + return nil, err + } + + // Column statistics + colStats, err := colStats(ctx, qualifiedTable, colInfo, db) + if err != nil { + return nil, err + } + + return &sql.Statistics{ + RowCount: count, + ColumnStatistics: colStats, + }, nil +} + +// rough outline for postgresql extension w/ a full stats function: +// +// - function: collect_stats(tablename) +// - iterate over each row, perform computations defined in the extension code +// - SPI_connect() -> SPI_cursor_open(... query ...) -> SPI_cursor_fetch -> +// SPI_processed -> SPI_tuptable -> SPI_getbinval + +// colStats collects column-wise statistics for the specified table, using the +// provided column definitions to instantiate scan values used by the full scan +// that iterates over all rows of the table. +func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db sql.Executor) ([]sql.ColumnStatistics, error) { + // rowCount is unused now, and can seemingly be computed via the scan + // itself, but I intend to use it in for more complex statistics building algos. + + // https://wiki.postgresql.org/wiki/Retrieve_primary_key_columns + getIndBase := `SELECT a.attname::text, i.indexrelid::int8 + FROM pg_index i + JOIN pg_attribute a ON a.attnum = ANY(i.indkey) AND a.attrelid = i.indrelid + WHERE i.indrelid = '` + qualifiedTable + `'::regclass` + // use primary key columns first + getPK := getIndBase + ` AND i.indisprimary;` + // then unique+not-null index cols? + // getUniqueInds := getIndBase + ` AND i.indisunique;` + res, err := db.Execute(ctx, getPK) + if err != nil { + return nil, err + } + // IMPORTANT NOTE: if the iteration over all rows of the table involves *no* + // ORDER BY clause, the scan order is not guaranteed. This should be an + // error for tables where stats must be deterministic. + // + // if len(res.Rows) == 0 { + // return nil, errors.New("no suitable orderby column") + // } + pkCols := make([]string, len(res.Rows)) + for i, row := range res.Rows { + pkCols[i] = row[0].(string) + } + + numCols := len(colInfo) + colTypes := make([]ColType, numCols) + for i := range colInfo { + colTypes[i] = colInfo[i].Type() + } + + colStats := make([]sql.ColumnStatistics, numCols) + + // iterate over all rows (select *) + var scans []any + for _, col := range colInfo { + scans = append(scans, col.scanVal()) + } + stmt := `SELECT * FROM ` + qualifiedTable + if len(pkCols) > 0 { + stmt += ` ORDER BY ` + strings.Join(pkCols, ",") + } + err = QueryRowFunc(ctx, db, stmt, scans, + func() error { + var err error + for i, val := range scans { + stat := &colStats[i] + if val == nil { // with QueryRowFuncAny and vals []any, or with QueryRowFunc where scans are native type pointers + stat.NullCount++ + continue + } + + // TODO: do something with array types (num elements stats????) + + switch colTypes[i] { + case ColTypeInt: // use int64 in stats + var valInt int64 + switch it := val.(type) { + case interface{ Int64Value() (pgtype.Int8, error) }: // several of the pgtypes int types + i8, err := it.Int64Value() + if err != nil { + return fmt.Errorf("bad int64: %T", val) + } + if !i8.Valid { + stat.NullCount++ + continue + } + valInt = i8.Int64 + + default: + var ok bool + valInt, ok = sql.Int64(val) + if !ok { + return fmt.Errorf("not int: %T", val) + } + } + + ins(stat, valInt, cmp.Compare[int64]) + + case ColTypeText: // use string in stats + valStr, null, ok := TextValue(val) // val.(string) + if !ok { + return fmt.Errorf("not string: %T", val) + } + if null { + stat.NullCount++ + continue + } + + ins(stat, valStr, strings.Compare) + + case ColTypeByteA: // use []byte in stats + var valBytea []byte + switch vt := val.(type) { + // Presently we're just using []byte, not pgtype.Array, but + // might need to for NULL... + + // case *pgtype.Array[byte]: + // if !vt.Valid { + // stat.NullCount++ + // continue + // } + // valBytea = vt.Elements + // case pgtype.Array[byte]: + // if !vt.Valid { + // stat.NullCount++ + // continue + // } + // valBytea = vt.Elements + case *[]byte: + if vt == nil || *vt == nil { + stat.NullCount++ + continue + } + valBytea = slices.Clone(*vt) + case []byte: + if vt == nil { + stat.NullCount++ + continue + } + valBytea = slices.Clone(vt) + default: + return fmt.Errorf("not bytea: %T", val) + } + + ins(stat, valBytea, bytes.Compare) + + case ColTypeBool: // use bool in stats + var b bool + switch v := val.(type) { + case *pgtype.Bool: + if !v.Valid { + stat.NullCount++ + continue + } + b = v.Bool + case pgtype.Bool: + if !v.Valid { + stat.NullCount++ + continue + } + b = v.Bool + case *bool: + b = *v + case bool: + b = v + + default: + return fmt.Errorf("invalid bool (%T)", val) + } + + ins(stat, b, cmpBool) + + case ColTypeNumeric: // use *decimal.Decimal in stats + var dec *decimal.Decimal + switch v := val.(type) { + case *pgtype.Numeric: + if !v.Valid { + stat.NullCount++ + continue + } + if v.NaN { + continue + } + + dec, err = pgNumericToDecimal(*v) + if err != nil { + continue + } + + case pgtype.Numeric: + if !v.Valid { + stat.NullCount++ + continue + } + if v.NaN { + continue + } + + dec, err = pgNumericToDecimal(v) + if err != nil { + continue + } + + case *decimal.Decimal: + if v.NaN() { // we're pretending this is NULL by our sql.Scanner's convetion + stat.NullCount++ + continue + } + if v != nil { + v2 := *v // clone! + v = &v2 + } + dec = v + case decimal.Decimal: + if v.NaN() { // we're pretending this is NULL by our sql.Scanner's convetion + stat.NullCount++ + continue + } + v2 := v + dec = &v2 + } + + ins(stat, dec, cmpDecimal) + + case ColTypeUINT256: + v, ok := val.(*types.Uint256) + if !ok { + return fmt.Errorf("not a *types.Uint256: %T", val) + } + + if v.Null { + stat.NullCount++ + continue + } + + ins(stat, v.Clone(), types.CmpUint256) + + case ColTypeFloat: // we don't want, don't have + var varFloat float64 + switch v := val.(type) { + case *pgtype.Float8: + if !v.Valid { + stat.NullCount++ + continue + } + varFloat = v.Float64 + case *pgtype.Float4: + if !v.Valid { + stat.NullCount++ + continue + } + varFloat = float64(v.Float32) + case pgtype.Float8: + if !v.Valid { + stat.NullCount++ + continue + } + varFloat = v.Float64 + case pgtype.Float4: + if !v.Valid { + stat.NullCount++ + continue + } + varFloat = float64(v.Float32) + case float32: + varFloat = float64(v) + case float64: + varFloat = v + case *float32: + varFloat = float64(*v) + case *float64: + varFloat = *v + + default: + return fmt.Errorf("invalid float (%T)", val) + } + + ins(stat, varFloat, cmp.Compare[float64]) + + case ColTypeUUID: + fallthrough // TODO + default: // arrays and such + // fmt.Println("unknown", colTypes[i]) + } + } + + return nil + }, + ) + if err != nil { + return nil, err + } + + return colStats, nil +} + +func cmpBool(a, b bool) int { + if b { + if a { // true == true + return 0 + } + return -1 // false < true + } + if a { + return 1 // true > false + } + return 0 // false == false +} + +func cmpDecimal(val, mm *decimal.Decimal) int { + d, err := val.Cmp(mm) + if err != nil { + panic(fmt.Sprintf("%s: (nan decimal?) %v or %v", err, val, mm)) + } + return d +} + +func ins[T any](stats *sql.ColumnStatistics, val T, comp func(v, m T) int) error { + if stats.Min == nil { + stats.Min = val + stats.MinCount = 1 + } else if mn, ok := stats.Min.(T); ok { + switch comp(val, mn) { + case -1: // new MINimum + stats.Min = val + stats.MinCount = 1 + case 0: // another of the same + stats.MinCount++ + } + } else { + return fmt.Errorf("invalid stats value type %T for tuple of type %T", val, stats.Min) + } + + if stats.Max == nil { + stats.Max = val + stats.MaxCount = 1 + } else if mx, ok := stats.Max.(T); ok { + switch comp(val, mx) { + case 1: // new MAXimum + stats.Max = val + stats.MaxCount = 1 + case 0: // another of the same + stats.MaxCount++ + } + } else { + return fmt.Errorf("invalid stats value type %T for tuple of type %T", val, stats.Max) + } + + return nil +} diff --git a/internal/sql/pg/stats_test.go b/internal/sql/pg/stats_test.go new file mode 100644 index 000000000..7dd64f693 --- /dev/null +++ b/internal/sql/pg/stats_test.go @@ -0,0 +1,157 @@ +//go:build pglive + +package pg + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTableStats(t *testing.T) { + ctx := context.Background() + + db, err := NewDB(ctx, cfg) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.BeginTx(ctx) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(ctx) + + tbl := "colcheck" + _, err = tx.Execute(ctx, `drop table if exists `+tbl) + if err != nil { + t.Fatal(err) + } + _, err = tx.Execute(ctx, `create table if not exists `+tbl+ + ` (a int8 primary key, b int4 default 42, c text, d bytea, e numeric(20,5), f int8[], g uint256, h uint256[])`) + if err != nil { + t.Fatal(err) + } + + cols, err := ColumnInfo(ctx, tx, "", tbl) + if err != nil { + t.Fatal(err) + } + + wantCols := []ColInfo{ + {Pos: 1, Name: "a", DataType: "bigint", Nullable: false}, + {Pos: 2, Name: "b", DataType: "integer", Nullable: true, defaultVal: "42"}, + {Pos: 3, Name: "c", DataType: "text", Nullable: true}, + {Pos: 4, Name: "d", DataType: "bytea", Nullable: true}, + {Pos: 5, Name: "e", DataType: "numeric", Nullable: true}, + {Pos: 6, Name: "f", DataType: "bigint", Array: true, Nullable: true}, + {Pos: 7, Name: "g", DataType: "uint256", Nullable: true}, + {Pos: 8, Name: "h", DataType: "uint256", Array: true, Nullable: true}, + } + + assert.Equal(t, wantCols, cols) + // t.Logf("%#v", cols) + + _, err = tx.Execute(ctx, `insert into `+tbl+` values `+ + `(5, null, '', '\xabab', 12.6, '{99}', 30, '{}'), `+ + `(-1, 0, 'B', '\x01', -7, '{1, 2}', 20, '{184467440737095516150}'), `+ + `(3, 1, null, '\x', 8.1, NULL, NULL, NULL), `+ + `(0, 0, 'Q', NULL, NULL, NULL, NULL, NULL), `+ + `(7, -4, 'c', '\x0001', 0.3333, '{2,3,4}', 40, '{5,4,3}')`) + if err != nil { + t.Fatal(err) + } + + stats, err := TableStats(ctx, "", tbl, tx) + require.NoError(t, err) + + t.Log(stats) + + fmt.Println(stats.ColumnStatistics[4].Min) + fmt.Println(stats.ColumnStatistics[4].Max) +} + +/*func TestScanBig(t *testing.T) { +// This test is commented, but helpful for benchmarking performance with a large table. + ctx := context.Background() + + cfg := *cfg + cfg.User = "kwild" + cfg.Pass = "kwild" + cfg.DBName = "kwil_test_db" + + db, err := NewPool(ctx, &cfg.PoolConfig) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.BeginTx(ctx) + if err != nil { + t.Fatal(err) + } + defer tx.Rollback(ctx) + + tbl := `giant` + cols, err := ColumnInfo(ctx, tx, tbl) + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", cols) + + stats, err := TableStats(ctx, tbl, tx) + if err != nil { + t.Fatal(err) + } + + t.Log(stats) +}*/ + +func TestCmpBool(t *testing.T) { + tests := []struct { + name string + a bool + b bool + expected int + }{ + {"true_true", true, true, 0}, + {"false_false", false, false, 0}, + {"true_false", true, false, 1}, + {"false_true", false, true, -1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cmpBool(tt.a, tt.b) + assert.Equal(t, tt.expected, result, "cmpBool(%v, %v) = %v; want %v", tt.a, tt.b, result, tt.expected) + }) + } +} + +func TestCmpBoolSymmetry(t *testing.T) { + booleans := []bool{true, false} + + for _, a := range booleans { + for _, b := range booleans { + t.Run(fmt.Sprintf("a=%v,b=%v", a, b), func(t *testing.T) { + result1 := cmpBool(a, b) + result2 := cmpBool(b, a) + assert.Equal(t, -result2, result1, "cmpBool(%v, %v) and cmpBool(%v, %v) are not symmetric", a, b, b, a) + }) + } + } +} + +func TestCmpBoolTransitivity(t *testing.T) { + a, b, c := false, true, true + + ab := cmpBool(a, b) + bc := cmpBool(b, c) + ac := cmpBool(a, c) + + assert.True(t, (ab < 0 && bc <= 0) == (ac < 0), "cmpBool lacks transitivity") +} diff --git a/internal/sql/pg/system.go b/internal/sql/pg/system.go index 280dfc98f..91c2f4641 100644 --- a/internal/sql/pg/system.go +++ b/internal/sql/pg/system.go @@ -4,14 +4,19 @@ package pg // and system settings of a postgres instance to be used by kwild. import ( + "cmp" "context" "errors" "fmt" + "slices" "strconv" "strings" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/kwilteam/kwil-db/common/sql" + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/types/decimal" ) const ( @@ -198,6 +203,383 @@ var settingValidations = map[string]settingValidFn{ "server_encoding": wantStringFn("UTF8"), } +// TextValue recognizes types used to return SQL TEXT values from SQL queries. +// Depending on the type and value, the value may represent a NULL as indicated +// by the null return. +func TextValue(val any) (txt string, null bool, ok bool) { + switch str := val.(type) { + case string: + return str, false, true + case *string: + if str == nil { // NULL + return "", true, true + } + return *str, false, true + case pgtype.Text: // pgtype.Text uses the Valid field for NULL values + return str.String, !str.Valid, true + case *pgtype.Text: + return str.String, !str.Valid, true + } + return "", false, false +} + +// ColInfo is used when ingesting column descriptions from PostgreSQL, such as +// from information_schema.column. Use the Type method to return a canonical +// ColType. +type ColInfo struct { + Pos int + Name string + // DataType is the string reported by information_schema.column for the + // column. Use the Type() method to return the ColType. + DataType string + Array bool + Nullable bool + + // The default value is not a Kwil type, so not exported. We could remove + // this, but it is helpful for debugging in this package. A bool like + // HasDefault could be good for export. Getting the actual OID for decoding + // into a Kwil type involves messier queries that join on several pg_* + // tables, so unless this would be particularly helpful for consumers we + // won't go that route yet. + defaultVal any +} + +func scanVal(ct ColType) any { + switch ct { + case ColTypeInt: + return new(pgtype.Int8) + case ColTypeText: + return new(pgtype.Text) + case ColTypeBool: + return new(pgtype.Bool) + case ColTypeByteA: + return new([]byte) // this is nil-able + case ColTypeUUID: + return new(pgtype.UUID) + case ColTypeNumeric: + // pgtype.Numeric or decimal.Decimal would work. pgtype.Numeric is way + // easier to work with and instantiate, but using our types here helps + // test their scanners/valuers. + return new(decimal.Decimal) + case ColTypeUINT256: + return new(types.Uint256) + case ColTypeFloat: + return new(pgtype.Float8) + case ColTypeTime: + return new(pgtype.Timestamp) + default: + var v any + return &v + } +} + +func scanArrayVal(ct ColType) any { + switch ct { + case ColTypeInt: + return pgArray[pgtype.Int8]() + case ColTypeText: + return pgArray[pgtype.Text]() + case ColTypeBool: + return pgArray[pgtype.Bool]() + case ColTypeByteA: // [][]byte + return pgArray[[]byte]() + case ColTypeUUID: + return pgArray[pgtype.UUID]() + case ColTypeNumeric: + // pgArray is also simpler and more efficient, but as long as we + // explicitly define array types, we should test them. + return new(decimal.DecimalArray) + case ColTypeUINT256: + return new(types.Uint256Array) + case ColTypeFloat: + return pgArray[pgtype.Float8]() + case ColTypeTime: + return pgArray[pgtype.Timestamp]() + default: + return new([]any) + } +} + +func (ci *ColInfo) baseScanVal() any { + return scanVal(ci.baseType()) +} + +func pgArray[T any]() *pgtype.Array[T] { + return &pgtype.Array[T]{} +} + +// ScanVal returns an instance of a suitable type into which a result value may +// be scanned (in the sql.Scanner sense). If left to the DB driver, it may not +// be the most suitable type. This method uses the ColType associations defined +// in this package. +// +// Note that this is obviously only applicable to result values from column +// expressions rather than other expressions like arithmetic or aggregates. When +// using QueryRowFunc in such cases, the appropriate type would be determined +// based on prior knowledge of the statement. +func (ci *ColInfo) scanVal() any { + val := ci.baseScanVal() // pointer to instance of the type + if ci.Array { // return pointer to slice of the type + return scanArrayVal(ci.baseType()) + + // A pgtype.Array is the best option overall, particularly for handling + // NULL entries, but it is possible to instantiate native slices of the + // base type's scan valueWS: + + // rt := reflect.TypeOf(val).Elem() + // st := reflect.SliceOf(rt) + // return reflect.New(st).Interface() + + // sl := reflect.MakeSlice(st, 0, 0) + // return sl.Interface() + } + return val +} + +// ColType is the type used to enumerate various known column types (and arrays +// of those types). These are used to describe tables characterized by the +// ColumnInfo function, and to support its ScanVal method. +type ColType string + +const ( + ColTypeInt ColType = "int" + ColTypeText ColType = "text" + ColTypeBool ColType = "bool" + ColTypeByteA ColType = "bytea" + ColTypeUUID ColType = "uuid" + ColTypeNumeric ColType = "numeric" + ColTypeUINT256 ColType = "uint256" + ColTypeFloat ColType = "float" + ColTypeTime ColType = "timestamp" + + ColTypeIntArray ColType = "int[]" + ColTypeTextArray ColType = "text[]" + ColTypeBoolArray ColType = "bool[]" + ColTypeByteAArray ColType = "bytea[]" + ColTypeUUIDArray ColType = "uuid[]" + ColTypeNumericArray ColType = "numeric[]" + ColTypeUINT256Array ColType = "uint256[]" + ColTypeFloatArray ColType = "float[]" + ColTypeTimeArray ColType = "timestamp[]" + + ColTypeUnknown ColType = "unknown" +) + +func arrayType(ct ColType) ColType { + switch ct { + case ColTypeInt: + return ColTypeIntArray + case ColTypeText: + return ColTypeTextArray + case ColTypeBool: + return ColTypeBoolArray + case ColTypeByteA: + return ColTypeByteAArray + case ColTypeUUID: + return ColTypeUUIDArray + case ColTypeNumeric: + return ColTypeNumericArray + case ColTypeUINT256: + return ColTypeUINT256Array + case ColTypeFloat: + return ColTypeFloatArray + case ColTypeTime: + return ColTypeTimeArray + default: + return ColTypeUnknown + } +} + +// Type returns the canonical ColType based on the DataType, which is the +// type string reported by PostgreSQL from information_schema.columns. +func (ci *ColInfo) Type() ColType { + baseType := ci.baseType() + if ci.Array { + return arrayType(baseType) + } + return baseType +} + +func (ci *ColInfo) baseType() ColType { + // TODO: merge into since switch or map when this has settled. + if ci.IsInt() { + return ColTypeInt + } + if ci.IsText() { + return ColTypeText + } + if ci.IsBool() { + return ColTypeBool + } + if ci.IsByteA() { + return ColTypeByteA + } + if ci.IsNumeric() { + return ColTypeNumeric + } + if ci.IsUINT256() { + return ColTypeUINT256 + } + if ci.IsFloat() { + return ColTypeFloat + } + if ci.IsUUID() { + return ColTypeUUID + } + if ci.IsTime() { + return ColTypeTime + } + return ColTypeUnknown +} + +// The following methods recognize the DataType values as reported by "regtype" +// values in the information_schema.columns PostgreSQL system table. Use the +// Type method to obtain the canonical ColType. + +func (ci *ColInfo) IsInt() bool { + switch strings.ToLower(ci.DataType) { + case "bigint", "integer", "smallint", "int", "int2", "int4", "int8": + return true + } + return false +} + +func (ci *ColInfo) IsText() bool { + switch strings.ToLower(ci.DataType) { + case "text", "varchar": + return true + } + return false +} + +func (ci *ColInfo) IsFloat() bool { + switch strings.ToLower(ci.DataType) { + case "double precision", "single precision", "float32", "float64": + return true + } + return false +} + +func (ci *ColInfo) IsBool() bool { + switch strings.ToLower(ci.DataType) { + case "boolean", "bool": + return true + } + return false +} + +func (ci *ColInfo) IsUUID() bool { + switch strings.ToLower(ci.DataType) { + case "uuid": + return true + } + return false + +} + +func (ci *ColInfo) IsByteA() bool { + switch strings.ToLower(ci.DataType) { + case "bytea": + return true + } + return false +} + +func (ci *ColInfo) IsUINT256() bool { + switch strings.ToLower(ci.DataType) { + case "uint256": + return true + } + return false +} + +func (ci *ColInfo) IsNumeric() bool { + dt := strings.ToLower(ci.DataType) + return strings.HasPrefix(dt, "numeric") // all numeric, including plain or with prec/scale +} + +func (ci *ColInfo) IsTime() bool { + dt := strings.ToLower(ci.DataType) + return strings.HasPrefix(dt, "timestamp") // includes timestamptz +} + +func columnInfo(ctx context.Context, conn *pgx.Conn, schema, tbl string) ([]ColInfo, error) { + var colInfo []ColInfo + + if schema == "" { + schema = "public" // otherwise we can get multiple rows + } + + dbName := conn.Config().Database + + // get column data types + sql := `SELECT ordinal_position, column_name, + data_type, udt_name::regtype, domain_name::regtype, + is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '` + tbl + `' AND table_schema = '` + schema + `' + AND table_catalog = '` + dbName + `'` + + var worked bool + + var pos int + var domainName pgtype.Text // null in Valid bool + var colName, dataType, typeOrArray, isNullable string + var colDefault any + scans := []any{&pos, &colName, &typeOrArray, &dataType, &domainName, &isNullable, &colDefault} + err := queryRowFunc(ctx, conn, sql, scans, func() error { + isArray := strings.EqualFold(typeOrArray, "ARRAY") + if domainName.Valid && domainName.String != "" { + dataType = domainName.String + } + var wasArr bool + dataType, wasArr = strings.CutSuffix(dataType, "[]") + if isArray && !wasArr { + return errors.New("inconsistent array typing") + } + + colInfo = append(colInfo, ColInfo{ + Pos: pos, + Name: colName, + DataType: dataType, + Array: isArray, + Nullable: strings.EqualFold(isNullable, "YES"), + defaultVal: colDefault, + }) + + worked = true + return nil + }) + if err != nil { + return nil, err + } + + if !worked { + return nil, fmt.Errorf("no results for table %s.%s", schema, tbl) + } + + slices.SortFunc(colInfo, func(a, b ColInfo) int { + return cmp.Compare(a.Pos, b.Pos) + }) + + return colInfo, nil +} + +// ColumnInfo attempts to describe the columns of a table in a specified +// PostgreSQL schema. The results are **as reported by information_schema.column**. +// +// If the provided sql.Executor is also a ColumnInfoer, its ColumnInfo method +// will be used. This is primarily for testing with a mocked DB transaction. +// Otherwise, the Executor must be one of the transaction types created by this +// package, which provide access to the underlying DB connection. +func ColumnInfo(ctx context.Context, tx sql.Executor, schema, tbl string) ([]ColInfo, error) { + if ti, ok := tx.(conner); ok { + conn := ti.Conn() + return columnInfo(ctx, conn, schema, tbl) + } + return nil, errors.New("cannot get column info") +} + func verifySettings(ctx context.Context, conn *pgx.Conn) error { checkSettings := make([]string, 0, len(settingValidations)) for name := range settingValidations { diff --git a/internal/sql/pg/tx.go b/internal/sql/pg/tx.go index 15df8f456..cb0c62f74 100644 --- a/internal/sql/pg/tx.go +++ b/internal/sql/pg/tx.go @@ -25,6 +25,7 @@ type nestedTx struct { } var _ common.Tx = (*nestedTx)(nil) +var _ common.QueryScanner = (*nestedTx)(nil) // BeginTx creates a new transaction with the same access mode as the parent. // Internally this is savepoint, which allows rollback to the innermost @@ -53,6 +54,14 @@ func (tx *nestedTx) Execute(ctx context.Context, stmt string, args ...any) (*com return query(ctx, tx.oidTypes, tx.Tx, stmt, args...) } +// QueryScanFn satisfies sql.QueryScanner. +func (tx *nestedTx) QueryScanFn(ctx context.Context, sql string, + scans []any, fn func() error, args ...any) error { + + conn := tx.Conn() + return queryRowFunc(ctx, conn, sql, scans, fn, args...) +} + // AccessMode returns the access mode of the transaction. func (tx *nestedTx) AccessMode() common.AccessMode { return tx.accessMode @@ -68,6 +77,16 @@ type dbTx struct { accessMode common.AccessMode } +// conner is a db or tx type that provides access to the underlying *pgx.Conn. +// All of the transaction types in this package should be conners. This is a +// subset of the pg.Tx interface. +type conner interface{ Conn() *pgx.Conn } + +var _ conner = (pgx.Tx)(nil) + +var _ conner = (*dbTx)(nil) +var _ conner = (*nestedTx)(nil) + // Precommit creates a prepared transaction for a two-phase commit. An ID // derived from the updates is return. This must be called before Commit. Either // Commit or Rollback must follow. It takes a writer to write the full changeset to. @@ -111,6 +130,8 @@ type readTx struct { subscribers *syncmap.Map[int64, chan<- string] } +var _ conner = (*readTx)(nil) + // Commit is a no-op for read-only transactions. // It will unconditionally return the connection to the pool. func (tx *readTx) Commit(ctx context.Context) error { @@ -163,6 +184,12 @@ func (d *delayedReadTx) Execute(ctx context.Context, stmt string, args ...any) ( return d.tx.Execute(ctx, stmt, args...) } +var _ conner = (*nestedTx)(nil) + +func (d *delayedReadTx) Conn() *pgx.Conn { + return d.tx.Conn() +} + func (d *delayedReadTx) Commit(ctx context.Context) error { if d.tx == nil { return nil diff --git a/internal/sql/pg/types.go b/internal/sql/pg/types.go index 40aea1b73..434941b0b 100644 --- a/internal/sql/pg/types.go +++ b/internal/sql/pg/types.go @@ -3,6 +3,7 @@ package pg import ( "encoding/binary" "encoding/hex" + "errors" "fmt" "math/big" "reflect" @@ -10,6 +11,7 @@ import ( "strings" "github.com/jackc/pgx/v5/pgtype" + "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/core/types/decimal" ) @@ -31,6 +33,8 @@ var ( kwilTypeToDataType = map[types.DataType]*datatype{} ) +var ErrUnsupportedOID = errors.New("unsupported OID") + // registerOIDs registers all of the data types that we support in Postgres. func registerDatatype(scalar *datatype, array *datatype) { for _, match := range scalar.Matches { @@ -119,11 +123,39 @@ type datatype struct { DeserializeChangeset func([]byte) (any, error) } +var ErrNaN = errors.New("NaN") + +func pgNumericToDecimal(num pgtype.Numeric) (*decimal.Decimal, error) { + if num.NaN { // TODO: create a decimal.Decimal that supports NaN + return nil, ErrNaN + } + if !num.Valid { + return nil, errors.New("invalid or null") // TODO: create a decimal.Decimal that supports NULL + } + + i, e := num.Int, num.Exp + + // Kwil's decimal semantics do not allow negative scale (only shift decimal + // left), so if the exponent is positive we need to apply it to the integer. + if e > 0 { + // i * 10^e + z := new(big.Int) + z.Exp(big.NewInt(10), big.NewInt(int64(e)), nil) + z.Mul(z, i) + i, e = z, 0 + } + + // Really this could be uint256, which is same underlying type (a domain) as + // Numeric. If the caller needs to know, that has to happen differently. + return decimal.NewFromBigInt(i, e) +} + var ( textType = &datatype{ KwilType: types.TextType, Matches: []reflect.Type{reflect.TypeOf("")}, OID: func(*pgtype.Map) uint32 { return pgtype.TextOID }, + ExtraOIDs: []uint32{pgtype.VarcharOID}, EncodeInferred: defaultEncodeDecode, Decode: defaultEncodeDecode, SerializeChangeset: func(value string) ([]byte, error) { @@ -198,28 +230,11 @@ var ( ExtraOIDs: []uint32{pgtype.Int2OID, pgtype.Int4OID}, EncodeInferred: defaultEncodeDecode, Decode: func(a any) (any, error) { - switch v := a.(type) { - case int: - return int64(v), nil - case int8: - return int64(v), nil - case int16: - return int64(v), nil - case int32: - return int64(v), nil - case int64: - return v, nil - case uint: - return int64(v), nil - case uint16: - return int64(v), nil - case uint32: - return int64(v), nil - case uint64: - return int64(v), nil - default: + v, ok := sql.Int64(a) + if !ok { return nil, fmt.Errorf("unexpected type %T", a) } + return v, nil }, SerializeChangeset: func(value string) ([]byte, error) { intVal, err := strconv.ParseInt(value, 10, 64) @@ -467,25 +482,7 @@ var ( return nil, fmt.Errorf("expected pgtype.Numeric, got %T", a) } - if pgType.NaN { - return "NaN", nil - } - - // if we give postgres a number such as 5000, it will return it as 5 with exponent 3. - // Since kwil's decimal semantics do not allow negative scale, we need to multiply - // the number by 10^exp to get the correct value. - if pgType.Exp > 0 { - z := new(big.Int) - z.Exp(big.NewInt(10), big.NewInt(int64(pgType.Exp)), nil) - z.Mul(z, pgType.Int) - pgType.Int = z - pgType.Exp = 0 - } - - // there is a bit of an edge case here, where uint256 can be returned. - // since most results simply get returned to the user via JSON, it doesn't - // matter too much right now, so we'll leave it as-is. - return decimal.NewFromBigInt(pgType.Int, pgType.Exp) + return pgNumericToDecimal(pgType) }, SerializeChangeset: func(value string) ([]byte, error) { // parse to ensure it is a valid decimal, then re-encode it to ensure it is in the correct format. @@ -581,17 +578,18 @@ var ( return nil, fmt.Errorf("expected pgtype.Numeric, got %T", a) } - // if the number ends in 0s, it will have an exponent, so we need to multiply - // the number by 10^exp to get the correct value. - if pgType.Exp > 0 { - z := new(big.Int) - z.Exp(big.NewInt(10), big.NewInt(int64(pgType.Exp)), nil) - z.Mul(z, pgType.Int) - pgType.Int = z - pgType.Exp = 0 + if pgType.Exp == 0 { + return types.Uint256FromBig(pgType.Int) } - return types.Uint256FromBig(pgType.Int) + dec, err := pgNumericToDecimal(pgType) + if err != nil { + return nil, err + } + if dec.Exp() != 0 { + return nil, errors.New("fractional numeric") + } + return types.Uint256FromBig(dec.BigInt()) }, SerializeChangeset: func(value string) ([]byte, error) { // parse to ensure it is a valid uint256, then re-encode it to ensure it is in the correct format. @@ -771,7 +769,20 @@ func encodeToPGType(oids *pgtype.Map, values ...any) ([]any, []uint32, error) { // a nil value with the void OID. var voidOID = uint32(2278) -// decodeFromPGType decodes several pgx types to their corresponding Go types. +func decodeFromPGVal(val any, oid uint32, oidToDataType map[uint32]*datatype) (any, error) { + if val == nil { + return nil, nil + } + + dt, ok := oidToDataType[oid] + if !ok { + return nil, fmt.Errorf("%w: %d", ErrUnsupportedOID, oid) + } + + return dt.Decode(val) +} + +// decodeFromPG decodes several pgx types to their corresponding Go types. // It is capable of detecting special Kwil types and decoding them to their // corresponding Go types. func decodeFromPG(vals []any, oids []uint32, oidToDataType map[uint32]*datatype) ([]any, error) { @@ -781,17 +792,7 @@ func decodeFromPG(vals []any, oids []uint32, oidToDataType map[uint32]*datatype) continue } - if vals[i] == nil { - results = append(results, nil) - continue - } - - dt, ok := oidToDataType[oid] - if !ok { - return nil, fmt.Errorf("unsupported oid %d", oid) - } - - decoded, err := dt.Decode(vals[i]) + decoded, err := decodeFromPGVal(vals[i], oid, oidToDataType) if err != nil { return nil, err } diff --git a/test/go.mod b/test/go.mod index f9b8389d3..a748679ea 100644 --- a/test/go.mod +++ b/test/go.mod @@ -148,7 +148,7 @@ require ( github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jmhodges/levigo v1.0.0 // indirect github.com/jonboulle/clockwork v0.4.0 // indirect diff --git a/test/go.sum b/test/go.sum index b30906fb8..b72c61a71 100644 --- a/test/go.sum +++ b/test/go.sum @@ -438,8 +438,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=