From 9a145aaf8eb71c0446457ff5de830e56e420057b Mon Sep 17 00:00:00 2001 From: Brennan Lamey Date: Fri, 17 Jan 2025 12:19:36 -0600 Subject: [PATCH] core: added scanning for results added tests for float added scan to method renamed decimal functions to be more clear fix lint --- core/types/data_types.go | 31 +- core/types/{decimal => }/decimal.go | 67 +-- core/types/{decimal => }/decimal_test.go | 67 ++- core/types/encoded_value_test.go | 13 +- core/types/payloads.go | 15 +- core/types/results.go | 247 +++++++++- core/types/results_test.go | 506 ++++++++++++++++++++ core/types/uuid.go | 9 + node/engine/functions.go | 4 +- node/engine/interpreter/interpreter_test.go | 15 +- node/engine/interpreter/sql_test.go | 5 +- node/engine/interpreter/values.go | 83 ++-- node/engine/interpreter/values_test.go | 33 +- node/engine/parse/antlr.go | 5 +- node/engine/parse/ast.go | 3 +- node/engine/pg_generate/generate.go | 3 +- node/pg/db_live_test.go | 36 +- node/pg/stats.go | 10 +- node/pg/system.go | 6 +- node/pg/types.go | 31 +- node/services/jsonrpc/openrpc/reflect.go | 3 +- node/utils/conv/conv.go | 23 +- node/utils/conv/conv_test.go | 9 +- 23 files changed, 984 insertions(+), 240 deletions(-) rename core/types/{decimal => }/decimal.go (87%) rename core/types/{decimal => }/decimal_test.go (90%) diff --git a/core/types/data_types.go b/core/types/data_types.go index f6f6ac7b0..936d0fb54 100644 --- a/core/types/data_types.go +++ b/core/types/data_types.go @@ -7,8 +7,6 @@ import ( "regexp" "strconv" "strings" - - "github.com/kwilteam/kwil-db/core/types/decimal" ) // DataType is a data type. @@ -144,7 +142,7 @@ func (c *DataType) PGScalar() (string, error) { scalar = "UINT256" case NumericStr: if !c.HasMetadata() { - return "", errors.New("decimal type requires metadata") + return "", errors.New("numeric type requires metadata") } else { scalar = fmt.Sprintf("NUMERIC(%d,%d)", c.Metadata[0], c.Metadata[1]) } @@ -174,7 +172,7 @@ func (c *DataType) Clean() error { if !c.HasMetadata() { return fmt.Errorf("type %s requires metadata", c.Name) } - err := decimal.CheckPrecisionAndScale(c.Metadata[0], c.Metadata[1]) + err := CheckDecimalPrecisionAndScale(c.Metadata[0], c.Metadata[1]) if err != nil { return err } @@ -263,14 +261,14 @@ var ( Name: uuidStr, } UUIDArrayType = ArrayType(UUIDType) - // DecimalType contains 1,0 metadata. + // NumericType contains 1,0 metadata. // For type detection, users should prefer compare a datatype - // name with the DecimalStr constant. - DecimalType = &DataType{ + // name with the NumericStr constant. + NumericType = &DataType{ Name: NumericStr, Metadata: [2]uint16{0, 0}, // unspecified precision and scale } - DecimalArrayType = ArrayType(DecimalType) + NumericArrayType = ArrayType(NumericType) Uint256Type = &DataType{ Name: uint256Str, // TODO: delete } @@ -307,9 +305,9 @@ const ( nullStr = "null" ) -// NewDecimalType creates a new fixed point decimal type. -func NewDecimalType(precision, scale uint16) (*DataType, error) { - err := decimal.CheckPrecisionAndScale(precision, scale) +// NewNumericType creates a new fixed point numeric type. +func NewNumericType(precision, scale uint16) (*DataType, error) { + err := CheckDecimalPrecisionAndScale(precision, scale) if err != nil { return nil, err } @@ -351,14 +349,14 @@ func ParseDataType(s string) (*DataType, error) { var metadata [2]uint16 if rawMetadata != "" { metadata = [2]uint16{} - // only decimal types can have metadata + // only numeric types can have metadata if baseName != NumericStr { - return nil, fmt.Errorf("metadata is only allowed for decimal type") + return nil, fmt.Errorf("metadata is only allowed for numeric type") } parts := strings.Split(rawMetadata, ",") - // can be either DECIMAL(10,5) or just DECIMAL - if len(parts) != 2 && len(parts) != 0 { + // must be either NUMERIC(10,5) + if len(parts) != 2 { return nil, fmt.Errorf("invalid metadata format: %s", rawMetadata) } for i, part := range parts { @@ -366,6 +364,9 @@ func ParseDataType(s string) (*DataType, error) { if err != nil { return nil, fmt.Errorf("invalid metadata value: %s", part) } + if num > int(maxPrecision) { + return nil, fmt.Errorf("precision must be less than %d", maxPrecision) + } metadata[i] = uint16(num) } } diff --git a/core/types/decimal/decimal.go b/core/types/decimal.go similarity index 87% rename from core/types/decimal/decimal.go rename to core/types/decimal.go index e7c490e3a..0b9cf7189 100644 --- a/core/types/decimal/decimal.go +++ b/core/types/decimal.go @@ -2,7 +2,7 @@ // It is mostly a wrapper around github.com/cockroachdb/apd/v3, with some // functionality that makes it easier to use in the context of Kwil. It enforces // certain semantics of Postgres's decimal, such as precision and scale. -package decimal +package types import ( "database/sql" @@ -44,9 +44,9 @@ type Decimal struct { precision uint16 } -// NewExplicit creates a new Decimal from a string, with an explicit precision and scale. +// NewDecimalExplicit creates a new Decimal from a string, with an explicit precision and scale. // The precision must be between 1 and 1000, and the scale must be between 0 and precision. -func NewExplicit(s string, precision, scale uint16) (*Decimal, error) { +func NewDecimalExplicit(s string, precision, scale uint16) (*Decimal, error) { dec := &Decimal{} if err := dec.SetPrecisionAndScale(precision, scale); err != nil { @@ -60,16 +60,25 @@ func NewExplicit(s string, precision, scale uint16) (*Decimal, error) { return dec, nil } -// NewFromString creates a new Decimal from a string. It automatically infers the precision and scale. -func NewFromString(s string) (*Decimal, error) { +// ParseDecimal creates a new Decimal from a string. It automatically infers the precision and scale. +func ParseDecimal(s string) (*Decimal, error) { inferredPrecision, inferredScale := inferPrecisionAndScale(s) - return NewExplicit(s, inferredPrecision, inferredScale) + return NewDecimalExplicit(s, inferredPrecision, inferredScale) } -// NewFromBigInt creates a new Decimal from a big.Int and an exponent. +// MustParseDecimal is like ParseDecimal but panics if the string cannot be parsed. +func MustParseDecimal(s string) *Decimal { + dec, err := ParseDecimal(s) + if err != nil { + panic(err) + } + return dec +} + +// NewDecimalFromBigInt creates a new Decimal from a big.Int and an exponent. // The negative of the exponent is the scale of the decimal. -func NewFromBigInt(i *big.Int, exp int32) (*Decimal, error) { +func NewDecimalFromBigInt(i *big.Int, exp int32) (*Decimal, error) { if exp > 0 { i2 := big.NewInt(10) i2.Exp(i2, big.NewInt(int64(exp)), nil) @@ -98,15 +107,15 @@ func NewFromBigInt(i *big.Int, exp int32) (*Decimal, error) { dec.precision = dec.scale } - if err := CheckPrecisionAndScale(dec.precision, dec.scale); err != nil { + if err := CheckDecimalPrecisionAndScale(dec.precision, dec.scale); err != nil { return nil, err } return dec, nil } -// NewNaN creates a new NaN Decimal. -func NewNaN() *Decimal { +// NewNaNDecimal creates a new NaN Decimal. +func NewNaNDecimal() *Decimal { return &Decimal{ dec: apd.Decimal{ Form: apd.NaN, @@ -193,7 +202,7 @@ func (d *Decimal) setScale(scale uint16) error { // SetPrecisionAndScale sets the precision and scale of the decimal. // The precision must be between 1 and 1000, and the scale must be between 0 and precision. func (d *Decimal) SetPrecisionAndScale(precision, scale uint16) error { - if err := CheckPrecisionAndScale(precision, scale); err != nil { + if err := CheckDecimalPrecisionAndScale(precision, scale); err != nil { return err } @@ -340,7 +349,7 @@ func (d *Decimal) Scan(src interface{}) error { } // set scale and prec from the string - d2, err := NewFromString(s) + d2, err := ParseDecimal(s) if err != nil { return err } @@ -468,44 +477,44 @@ func (d *Decimal) enforceScale() error { return err } -// Add adds two decimals together. +// DecimalAdd adds two decimals together. // It will return a decimal with maximum precision and scale. -func Add(x, y *Decimal) (*Decimal, error) { +func DecimalAdd(x, y *Decimal) (*Decimal, error) { return mathOp(x, y, context.Add) } -// Sub subtracts y from x. +// DecimalSub subtracts y from x. // It will return a decimal with maximum precision and scale. -func Sub(x, y *Decimal) (*Decimal, error) { +func DecimalSub(x, y *Decimal) (*Decimal, error) { return mathOp(x, y, context.Sub) } -// Mul multiplies two decimals together. +// DecimalMul multiplies two decimals together. // It will return a decimal with maximum precision and scale. -func Mul(x, y *Decimal) (*Decimal, error) { +func DecimalMul(x, y *Decimal) (*Decimal, error) { return mathOp(x, y, context.Mul) } -// Div divides x by y. +// DecimalDiv divides x by y. // It will return a decimal with maximum precision and scale. -func Div(x, y *Decimal) (*Decimal, error) { +func DecimalDiv(x, y *Decimal) (*Decimal, error) { return mathOp(x, y, context.Quo) } -// Mod returns the remainder of x divided by y. +// DecimalMod returns the remainder of x divided by y. // It will return a decimal with maximum precision and scale. -func Mod(x, y *Decimal) (*Decimal, error) { +func DecimalMod(x, y *Decimal) (*Decimal, error) { return mathOp(x, y, context.Rem) } -// Pow raises x to the power of y. -func Pow(x, y *Decimal) (*Decimal, error) { +// DecimalPow raises x to the power of y. +func DecimalPow(x, y *Decimal) (*Decimal, error) { return mathOp(x, y, context.Pow) } -// Cmp compares two decimals. +// DecimalCmp compares two decimals. // It returns -1 if x < y, 0 if x == y, and 1 if x > y. -func Cmp(x, y *Decimal) (int64, error) { +func DecimalCmp(x, y *Decimal) (int64, error) { z := apd.New(0, 0) _, err := context.Cmp(z, &x.dec, &y.dec) if err != nil { @@ -516,8 +525,8 @@ func Cmp(x, y *Decimal) (int64, error) { // return x.dec.Cmp(&y.dec) } -// CheckPrecisionAndScale checks if the precision and scale are valid. -func CheckPrecisionAndScale(precision, scale uint16) error { +// CheckDecimalPrecisionAndScale checks if the precision and scale are valid. +func CheckDecimalPrecisionAndScale(precision, scale uint16) error { if precision < 1 { return fmt.Errorf("precision must be at least 1: %d", precision) } diff --git a/core/types/decimal/decimal_test.go b/core/types/decimal_test.go similarity index 90% rename from core/types/decimal/decimal_test.go rename to core/types/decimal_test.go index b5f25e194..41a08a2a6 100644 --- a/core/types/decimal/decimal_test.go +++ b/core/types/decimal_test.go @@ -1,4 +1,4 @@ -package decimal_test +package types_test import ( "errors" @@ -6,10 +6,9 @@ import ( "math/big" "testing" + "github.com/kwilteam/kwil-db/core/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/kwilteam/kwil-db/core/types/decimal" ) func Test_NewParsedDecimal(t *testing.T) { @@ -114,7 +113,7 @@ func Test_NewParsedDecimal(t *testing.T) { // test cases for decimal creation for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d, err := decimal.NewExplicit(tt.decimal, tt.prec, tt.scale) + d, err := types.NewDecimalExplicit(tt.decimal, tt.prec, tt.scale) if tt.err { require.Errorf(t, err, "result: %v", d) return @@ -179,7 +178,7 @@ func Test_DecimalParsing(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d, err := decimal.NewFromString(tt.in) + d, err := types.ParseDecimal(tt.in) if tt.err { require.Error(t, err) return @@ -198,10 +197,10 @@ func Test_MulDecimal(t *testing.T) { a := "123.456" b := "2.000" - decA, err := decimal.NewFromString(a) + decA, err := types.ParseDecimal(a) require.NoError(t, err) - decB, err := decimal.NewFromString(b) + decB, err := types.ParseDecimal(b) require.NoError(t, err) decMul, err := decA.Mul(decA, decB) @@ -210,23 +209,23 @@ func Test_MulDecimal(t *testing.T) { assert.Equal(t, "246.912", decMul.String()) // overflow - decA, err = decimal.NewFromString("123.456") + decA, err = types.ParseDecimal("123.456") require.NoError(t, err) - decB, err = decimal.NewFromString("10.000") + decB, err = types.ParseDecimal("10.000") require.NoError(t, err) _, err = decA.Mul(decA, decB) require.Error(t, err) // handle the overflow error - decA, err = decimal.NewFromString("123.456") + decA, err = types.ParseDecimal("123.456") require.NoError(t, err) - decB, err = decimal.NewFromString("10.000") + decB, err = types.ParseDecimal("10.000") require.NoError(t, err) - res := decimal.Decimal{} + res := types.Decimal{} err = res.SetPrecisionAndScale(6, 2) require.NoError(t, err) @@ -303,8 +302,8 @@ func Test_DecimalMath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var a *decimal.Decimal - var b *decimal.Decimal + var a *types.Decimal + var b *types.Decimal // greatestScale is the greatest scale of the two decimals var greatestScale uint16 @@ -312,10 +311,10 @@ func Test_DecimalMath(t *testing.T) { // since their pointers get shared between tests. reset := func() { var err error - a, err = decimal.NewFromString(tt.a) + a, err = types.ParseDecimal(tt.a) require.NoError(t, err) - b, err = decimal.NewFromString(tt.b) + b, err = types.ParseDecimal(tt.b) require.NoError(t, err) if a.Scale() > b.Scale() { @@ -354,7 +353,7 @@ func Test_DecimalMath(t *testing.T) { // reset() - pow, err := decimal.Pow(a, b) + pow, err := types.DecimalPow(a, b) switch v := tt.pow.(type) { case string: @@ -371,8 +370,8 @@ func Test_DecimalMath(t *testing.T) { // eq checks that a decimal is equal to a string. // It will round the decimal to the given scale. -func eq(t *testing.T, dec *decimal.Decimal, want string, round uint16) { - dec2, err := decimal.NewFromString(want) +func eq(t *testing.T, dec *types.Decimal, want string, round uint16) { + dec2, err := types.ParseDecimal(want) require.NoError(t, err) old := dec.String() @@ -396,7 +395,7 @@ func eq(t *testing.T, dec *decimal.Decimal, want string, round uint16) { } func Test_AdjustPrecAndScale(t *testing.T) { - a, err := decimal.NewFromString("111.111") + a, err := types.ParseDecimal("111.111") require.NoError(t, err) err = a.SetPrecisionAndScale(9, 6) @@ -416,13 +415,13 @@ func Test_AdjustPrecAndScale(t *testing.T) { } func Test_AdjustScaleMath(t *testing.T) { - a, err := decimal.NewFromString("111.111") + a, err := types.ParseDecimal("111.111") require.NoError(t, err) err = a.SetPrecisionAndScale(6, 3) require.NoError(t, err) - b, err := decimal.NewFromString("222.22") + b, err := types.ParseDecimal("222.22") require.NoError(t, err) _, err = a.Add(a, b) @@ -436,7 +435,7 @@ func Test_AdjustScaleMath(t *testing.T) { require.Equal(t, "333.33", a.String()) - c, err := decimal.NewFromString("30.22") + c, err := types.ParseDecimal("30.22") require.NoError(t, err) _, err = a.Sub(a, c) @@ -446,7 +445,7 @@ func Test_AdjustScaleMath(t *testing.T) { } func Test_RemoveScale(t *testing.T) { - a, err := decimal.NewFromString("111.111") + a, err := types.ParseDecimal("111.111") require.NoError(t, err) err = a.SetPrecisionAndScale(6, 2) @@ -503,10 +502,10 @@ func Test_DecimalCmp(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a, err := decimal.NewFromString(tt.a) + a, err := types.ParseDecimal(tt.a) require.NoError(t, err) - b, err := decimal.NewFromString(tt.b) + b, err := types.ParseDecimal(tt.b) require.NoError(t, err) cmp, err := a.Cmp(b) @@ -576,7 +575,7 @@ func Test_BigAndExp(t *testing.T) { bigInt, ok := new(big.Int).SetString(tt.big, 10) require.True(t, ok) - d, err := decimal.NewFromBigInt(bigInt, tt.exp) + d, err := types.NewDecimalFromBigInt(bigInt, tt.exp) if tt.wantErr { require.Error(t, err) return @@ -629,14 +628,14 @@ func TestDecimalBinaryMarshaling(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d, err := decimal.NewExplicit(tt.decimal, tt.prec, tt.scale) + d, err := types.NewDecimalExplicit(tt.decimal, tt.prec, tt.scale) require.NoError(t, err) marshaled, err := d.MarshalBinary() require.NoError(t, err) assert.Equal(t, tt.expected, marshaled) - var unmarshaled decimal.Decimal + var unmarshaled types.Decimal err = unmarshaled.UnmarshalBinary(marshaled) require.NoError(t, err) @@ -672,7 +671,7 @@ func TestDecimalBinaryUnmarshalingErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var d decimal.Decimal + var d types.Decimal err := d.UnmarshalBinary(tt.input) require.Error(t, err) assert.Contains(t, err.Error(), tt.expectedErr) @@ -681,13 +680,13 @@ func TestDecimalBinaryUnmarshalingErrors(t *testing.T) { } func TestDecimalBinaryRoundTrip(t *testing.T) { - original, err := decimal.NewFromString("12345.6789") + original, err := types.ParseDecimal("12345.6789") require.NoError(t, err) marshaled, err := original.MarshalBinary() require.NoError(t, err) - var unmarshaled decimal.Decimal + var unmarshaled types.Decimal err = unmarshaled.UnmarshalBinary(marshaled) require.NoError(t, err) @@ -698,14 +697,14 @@ func TestDecimalBinaryRoundTrip(t *testing.T) { func TestDecimalJSONRoundTrip(t *testing.T) { str := "12345.6789" - original, err := decimal.NewFromString(str) + original, err := types.ParseDecimal(str) require.NoError(t, err) marshaled, err := original.MarshalJSON() require.NoError(t, err) require.Equal(t, `"`+str+`"`, string(marshaled)) - var unmarshaled decimal.Decimal + var unmarshaled types.Decimal err = unmarshaled.UnmarshalJSON(marshaled) require.NoError(t, err) diff --git a/core/types/encoded_value_test.go b/core/types/encoded_value_test.go index 3b430c69a..2d83b77d4 100644 --- a/core/types/encoded_value_test.go +++ b/core/types/encoded_value_test.go @@ -5,7 +5,6 @@ import ( "math/big" "testing" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -45,13 +44,13 @@ func TestEncodedValue_EdgeCases(t *testing.T) { // THIS IS INCORRECT WITH scientific notation e.g 1e-28 t.Run("encode decimal with max precision", func(t *testing.T) { - d, err := decimal.NewFromBigInt(new(big.Int).SetInt64(1), -6) + d, err := NewDecimalFromBigInt(new(big.Int).SetInt64(1), -6) require.NoError(t, err) ev, err := EncodeValue(d) require.NoError(t, err) decoded, err := ev.Decode() require.NoError(t, err) - assert.Equal(t, d.String(), decoded.(*decimal.Decimal).String()) + assert.Equal(t, d.String(), decoded.(*Decimal).String()) }) t.Run("encode mixed array types should fail", func(t *testing.T) { @@ -103,9 +102,9 @@ func TestEncodedValue_EdgeCases(t *testing.T) { }) t.Run("encode/decode decimal array", func(t *testing.T) { - d1, _ := decimal.NewFromString("100") - d2, _ := decimal.NewFromString("200") - arr := decimal.DecimalArray{d1, d2} + d1, _ := ParseDecimal("100") + d2, _ := ParseDecimal("200") + arr := DecimalArray{d1, d2} ev, err := EncodeValue(arr) require.NoError(t, err) @@ -113,7 +112,7 @@ func TestEncodedValue_EdgeCases(t *testing.T) { decoded, err := ev.Decode() require.NoError(t, err) - decodedArr, ok := decoded.(decimal.DecimalArray) + decodedArr, ok := decoded.(DecimalArray) require.True(t, ok) assert.Equal(t, arr[0].String(), decodedArr[0].String()) assert.Equal(t, arr[1].String(), decodedArr[1].String()) diff --git a/core/types/payloads.go b/core/types/payloads.go index 691614d49..ea6f89d28 100644 --- a/core/types/payloads.go +++ b/core/types/payloads.go @@ -20,7 +20,6 @@ import ( "strconv" "github.com/kwilteam/kwil-db/core/crypto" - "github.com/kwilteam/kwil-db/core/types/decimal" ) // PayloadType is the type of payload @@ -631,7 +630,7 @@ func (e *EncodedValue) Decode() (any, error) { case Uint256Type.Name: return Uint256FromBytes(data) case NumericStr: - return decimal.NewFromString(string(data)) + return ParseDecimal(string(data)) default: return nil, fmt.Errorf("cannot decode type %s", typeName) } @@ -711,14 +710,14 @@ func (e *EncodedValue) Decode() (any, error) { } arrAny = arr case NumericStr: - arr := make(decimal.DecimalArray, 0, len(e.Data)) + arr := make(DecimalArray, 0, len(e.Data)) for _, elem := range e.Data { dec, err := decodeScalar(elem, e.Type.Name, true) if err != nil { return nil, err } - arr = append(arr, dec.(*decimal.Decimal)) + arr = append(arr, dec.(*Decimal)) } arrAny = arr default: @@ -780,15 +779,15 @@ func EncodeValue(v any) (*EncodedValue, error) { case nil: // since we quick return for nil, we can only reach this point if the type is nil // and we are in an array return nil, nil, fmt.Errorf("cannot encode nil in type array") - case *decimal.Decimal: - decTyp, err := NewDecimalType(t.Precision(), t.Scale()) + case *Decimal: + decTyp, err := NewNumericType(t.Precision(), t.Scale()) if err != nil { return nil, nil, err } return []byte(t.String()), decTyp, nil - case decimal.Decimal: - decTyp, err := NewDecimalType(t.Precision(), t.Scale()) + case Decimal: + decTyp, err := NewNumericType(t.Precision(), t.Scale()) if err != nil { return nil, nil, err } diff --git a/core/types/results.go b/core/types/results.go index a385c0f20..ecffe4731 100644 --- a/core/types/results.go +++ b/core/types/results.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "math" + "reflect" + "strconv" ) type TxCode uint16 @@ -148,6 +150,12 @@ func (e *Event) UnmarshalBinary(data []byte) error { return nil } +// CallResult is the result of a procedure call. +type CallResult struct { + QueryResult *QueryResult `json:"query_result"` + Logs []string `json:"logs"` +} + // QueryResult is the result of a SQL query or action. type QueryResult struct { ColumnNames []string `json:"column_names"` @@ -168,8 +176,239 @@ func (qr *QueryResult) ExportToStringMap() []map[string]string { return res } -// CallResult is the result of a procedure call. -type CallResult struct { - QueryResult *QueryResult `json:"query_result"` - Logs []string `json:"logs"` +// Scan scans a value from the query result. +// It accepts a slice of pointers to values, and a function that will be called +// for each row in the result set. +// The passed values can be of type *string, *int64, *int, *bool, *[]byte, *UUID, *Decimal, +// *[]string, *[]int64, *[]int, *[]bool, *[]*int64, *[]*int, *[]*bool, *[]*UUID, *[]*Decimal, +// *[]UUID, *[]Decimal, *[][]byte, or *[]*[]byte. +func (q *QueryResult) Scan(vals []any, fn func() error) error { + + for _, row := range q.Values { + if err := ScanTo(row, vals); err != nil { + return err + } + if err := fn(); err != nil { + return err + } + } + + return nil +} + +// ScanTo scans the src values into the dst values. +func ScanTo(src []any, dst []any) error { + if len(src) != len(dst) { + return fmt.Errorf("expected %d columns, got %d", len(dst), len(src)) + } + + for j, col := range src { + // if the column val is nil, we skip it. + // If it is an array, we need to + typeOf := reflect.TypeOf(col) + if col == nil { + continue + } else if typeOf.Kind() == reflect.Slice && typeOf.Elem().Kind() != reflect.Uint8 { + if err := convertArray(col, dst[j]); err != nil { + return err + } + continue + } else if typeOf.Kind() == reflect.Slice && typeOf.Elem().Kind() == reflect.Uint8 { + if err := convertScalar(col, dst[j]); err != nil { + return err + } + continue + } else if typeOf.Kind() == reflect.Map { + return fmt.Errorf("cannot scan value into map type: %T", dst[j]) + } else { + if err := convertScalar(col, dst[j]); err != nil { + return err + } + } + } + + return nil +} + +func convertArray(src any, dst any) error { + arr, ok := src.([]any) + if !ok { + return fmt.Errorf("unexpected JSON array type: %T", src) + } + + switch v := dst.(type) { + case *[]string: + return convArr(arr, v) + case *[]*string: + return convPtrArr(arr, v) + case *[]int64: + return convArr(arr, v) + case *[]int: + return convArr(arr, v) + case *[]bool: + return convArr(arr, v) + case *[]*int64: + return convPtrArr(arr, v) + case *[]*int: + return convPtrArr(arr, v) + case *[]*bool: + return convPtrArr(arr, v) + case *[]*UUID: + return convPtrArr(arr, v) + case *[]UUID: + return convArr(arr, v) + case *[]*Decimal: + return convPtrArr(arr, v) + case *[]Decimal: + return convArr(arr, v) + case *[][]byte: + return convArr(arr, v) + case *[]*[]byte: + return convPtrArr(arr, v) + default: + return fmt.Errorf("unexpected scan type: %T", dst) + } +} + +func convArr[T any](src []any, dst *[]T) error { + dst2 := make([]T, len(src)) // we dont set the new slice to dst until we know we can convert all values + for i, val := range src { + if err := convertScalar(val, &dst2[i]); err != nil { + return err + } + } + *dst = dst2 + return nil +} + +func convPtrArr[T any](src []any, dst *[]*T) error { + dst2 := make([]*T, len(src)) // we dont set the new slice to dst until we know we can convert all values + for i, val := range src { + if val == nil { + continue + } + + s := new(T) + + err := convertScalar(val, s) + if err != nil { + return err + } + + dst2[i] = s + } + *dst = dst2 + return nil +} + +// convertScalar converts a scalar value to the specified type. +// It converts the source value to a string, then parses it into the specified type. +func convertScalar(src any, dst any) error { + var null bool + if src == nil { + null = true + } + str, err := stringify(src) + if err != nil { + return err + } + switch v := dst.(type) { + case *string: + if null { + return nil + } + *v = str + return nil + case *int64: + if null { + return nil + } + i, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return err + } + *v = i + return nil + case *int: + if null { + return nil + } + i, err := strconv.Atoi(str) + if err != nil { + return err + } + *v = i + return nil + case *bool: + if null { + return nil + } + b, err := strconv.ParseBool(str) + if err != nil { + return err + } + *v = b + return nil + case *[]byte: + if null { + return nil + } + *v = []byte(str) + return nil + case *UUID: + if null { + return nil + } + + if len([]byte(str)) == 16 { + *v = UUID([]byte(str)) + return nil + } + + u, err := ParseUUID(str) + if err != nil { + return err + } + *v = *u + return nil + case *Decimal: + if null { + return nil + } + + dec, err := ParseDecimal(str) + if err != nil { + return err + } + *v = *dec + return nil + default: + return fmt.Errorf("unexpected scan type: %T", dst) + } +} + +// stringify converts a value as a string. +// It only expects values returned from JSON marshalling. +// It does NOT expect slices/arrays (except for []byte) or maps +func stringify(v any) (str string, err error) { + switch val := v.(type) { + case string: + return val, nil + case []byte: + return string(val), nil + case int64: + return strconv.FormatInt(val, 10), nil + case int: + return strconv.Itoa(val), nil + case float64: + return strconv.FormatFloat(val, 'f', -1, 64), nil + case bool: + return strconv.FormatBool(val), nil + case nil: + return "", nil + case float32: + return strconv.FormatFloat(float64(val), 'f', -1, 32), nil + default: + return "", fmt.Errorf("unexpected type: %T", v) + } } diff --git a/core/types/results_test.go b/core/types/results_test.go index 929cb05a7..b34d29b52 100644 --- a/core/types/results_test.go +++ b/core/types/results_test.go @@ -2,7 +2,12 @@ package types import ( "encoding/binary" + "errors" + "reflect" + "strconv" "testing" + + "github.com/stretchr/testify/assert" ) func TestTxResultMarshalUnmarshal(t *testing.T) { @@ -127,3 +132,504 @@ func TestTxResultMarshalUnmarshal(t *testing.T) { } }) } + +// errTestAny is a special error type used within tests if we want +// to signal that we just want any error, and dont care about the +// specific error type. +var errTestAny = errors.New("any test error") + +func TestQueryResultScanScalars(t *testing.T) { + type testcase struct { + name string + rawval any // the value received from json unmarshalling + // all of the "exp" (expect) values are the expected results + // of scanning the rawval into the corresponding type. + // They should be one of 3 values: the core type, nil, or error + expString any + expInt64 any + expInt any + expBool any + expBytes any + expDec any + expUUID any + } + + tests := []testcase{ + { + name: "string", + rawval: "hello", + expString: "hello", + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: strconv.ErrSyntax, + expBytes: []byte("hello"), + expDec: strconv.ErrSyntax, + expUUID: errTestAny, + }, + { + name: "int64", + rawval: int64(123), + expString: "123", + expInt64: int64(123), + expInt: int(123), + expBool: strconv.ErrSyntax, + expBytes: []byte("123"), + expDec: *MustParseDecimal("123"), + expUUID: errTestAny, + }, + { + name: "int", + rawval: int(123), + expString: "123", + expInt64: int64(123), + expInt: int(123), + expBool: strconv.ErrSyntax, + expBytes: []byte("123"), + expDec: *MustParseDecimal("123"), + expUUID: errTestAny, + }, + { + name: "int string", + // this is a string that looks like an int + rawval: "123", + expString: "123", + expInt64: int64(123), + expInt: int(123), + expBool: strconv.ErrSyntax, + expBytes: []byte("123"), + expDec: *MustParseDecimal("123"), + expUUID: errTestAny, + }, + { + name: "bool", + rawval: true, + expString: "true", + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: true, + expBytes: []byte("true"), + expDec: strconv.ErrSyntax, + expUUID: errTestAny, + }, + { + name: "bytes", + rawval: []byte("hello"), + expString: "hello", + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: strconv.ErrSyntax, + expBytes: []byte("hello"), + expDec: strconv.ErrSyntax, + expUUID: errTestAny, + }, + { + name: "bytes (16 bytes)", + rawval: MustParseUUID("12345678-1234-1234-1234-123456789abc").Bytes(), + expString: string(MustParseUUID("12345678-1234-1234-1234-123456789abc").Bytes()), + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: strconv.ErrSyntax, + expBytes: MustParseUUID("12345678-1234-1234-1234-123456789abc").Bytes(), + expDec: errTestAny, + expUUID: *MustParseUUID("12345678-1234-1234-1234-123456789abc"), + }, + { + name: "decimal", + rawval: "123.456", + expString: "123.456", + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: strconv.ErrSyntax, + expBytes: []byte("123.456"), + expDec: *MustParseDecimal("123.456"), + expUUID: errTestAny, + }, + { + name: "uuid", + rawval: "12345678-1234-1234-1234-123456789abc", + expString: "12345678-1234-1234-1234-123456789abc", + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: strconv.ErrSyntax, + expBytes: []byte("12345678-1234-1234-1234-123456789abc"), + expDec: errTestAny, + expUUID: *MustParseUUID("12345678-1234-1234-1234-123456789abc"), + }, + { + name: "nil", + // this is a nil value + rawval: nil, + expString: nil, + expInt64: nil, + expInt: nil, + expBool: nil, + expBytes: nil, + expDec: nil, + expUUID: nil, + }, + { + name: "float", + rawval: float64(123.456), + expString: "123.456", + expInt64: strconv.ErrSyntax, + expInt: strconv.ErrSyntax, + expBool: strconv.ErrSyntax, + expBytes: []byte("123.456"), + expDec: *MustParseDecimal("123.456"), + expUUID: errTestAny, + }, + { + name: "round float", + rawval: float32(123), + expString: "123", + expInt64: int64(123), + expInt: int(123), + expBool: strconv.ErrSyntax, + expBytes: []byte("123"), + expDec: *MustParseDecimal("123"), + expUUID: errTestAny, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + qr := &QueryResult{ + Values: [][]any{{tt.rawval}}, + } + checkType[string](t, qr, tt.expString) + checkType[int64](t, qr, tt.expInt64) + checkType[int](t, qr, tt.expInt) + checkType[bool](t, qr, tt.expBool) + checkType[[]byte](t, qr, tt.expBytes) + checkType[Decimal](t, qr, tt.expDec) + checkType[UUID](t, qr, tt.expUUID) + }) + } +} + +func checkType[T any](t *testing.T, q *QueryResult, want any) { + var name string + _, wantErr := want.(error) + if want != nil && !wantErr { + typeof := reflect.TypeOf(want) + isPtr := false + if typeof.Kind() == reflect.Ptr { + isPtr = true + typeof = typeof.Elem() + } + name = typeof.String() + if isPtr { + name = "*" + name + } + } else if wantErr { + name = "error" + } else { + name = "nil" + } + t.Logf("testing type %T, expecting %s", *new(T), name) + + v := new(T) + err := q.Scan([]any{v}, func() error { + return nil + }) + + switch want := want.(type) { + case nil: + assert.NoError(t, err) + + newNil := new(T) + assert.EqualValues(t, newNil, v) + case error: + if want == errTestAny { + assert.Error(t, err) + } else { + assert.ErrorIs(t, err, want) + } + + newNil := new(T) + assert.EqualValues(t, newNil, v) + case T: + assert.NoError(t, err) + assert.EqualValues(t, want, *v) + default: + t.Fatalf("unexpected want type %T", want) + } +} + +func TestQueryResultScanArrays(t *testing.T) { + type testcase struct { + name string + rawval any // the value received from json unmarshalling + // all of the "exp" (expect) values are the expected results + // of scanning the rawval into the corresponding type. + // They should be one of 3 values: the core type, nil, or error. + expStringArr any + expStringArrPtr any + expInt64Arr any + expInt64ArrPtr any + expIntArr any + expIntArrPtr any + expBoolArr any + expBoolArrPtr any + expBytesArr any + expBytesArrPtr any + expDecArr any + expDecArrPtr any + expUUIDArr any + expUUIDArrPtr any + } + + tests := []testcase{ + { + name: "string", + rawval: []any{"hello", "world", nil}, + expStringArr: []string{"hello", "world", ""}, + expStringArrPtr: ptrArr[string]("hello", "world", nil), + expInt64Arr: errTestAny, + expInt64ArrPtr: errTestAny, + expIntArr: errTestAny, + expIntArrPtr: errTestAny, + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("hello"), []byte("world"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("hello"), []byte("world"), nil), + expDecArr: errTestAny, + expDecArrPtr: errTestAny, + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + { + name: "int64", + rawval: []any{int64(123), int64(456), nil}, + expStringArr: []string{"123", "456", ""}, + expStringArrPtr: ptrArr[string]("123", "456", nil), + expInt64Arr: []int64{int64(123), int64(456), 0}, + expInt64ArrPtr: ptrArr[int64](int64(123), int64(456), nil), + expIntArr: []int{123, 456, 0}, + expIntArrPtr: ptrArr[int](123, 456, nil), + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("123"), []byte("456"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("123"), []byte("456"), nil), + expDecArr: []Decimal{*MustParseDecimal("123"), *MustParseDecimal("456"), {}}, + expDecArrPtr: ptrArr[Decimal](*MustParseDecimal("123"), *MustParseDecimal("456"), nil), + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + { + name: "int", + rawval: []any{int(123), int(456), nil}, + expStringArr: []string{"123", "456", ""}, + expStringArrPtr: ptrArr[string]("123", "456", nil), + expInt64Arr: []int64{int64(123), int64(456), 0}, + expInt64ArrPtr: ptrArr[int64](int64(123), int64(456), nil), + expIntArr: []int{123, 456, 0}, + expIntArrPtr: ptrArr[int](123, 456, nil), + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("123"), []byte("456"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("123"), []byte("456"), nil), + expDecArr: []Decimal{*MustParseDecimal("123"), *MustParseDecimal("456"), {}}, + expDecArrPtr: ptrArr[Decimal](*MustParseDecimal("123"), *MustParseDecimal("456"), nil), + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + { + name: "bool", + rawval: []any{true, false, nil}, + expStringArr: []string{"true", "false", ""}, + expStringArrPtr: ptrArr[string]("true", "false", nil), + expInt64Arr: errTestAny, + expInt64ArrPtr: errTestAny, + expIntArr: errTestAny, + expIntArrPtr: errTestAny, + expBoolArr: []bool{true, false, false}, + expBoolArrPtr: ptrArr[bool](true, false, nil), + expBytesArr: [][]byte{[]byte("true"), []byte("false"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("true"), []byte("false"), nil), + expDecArr: errTestAny, + expDecArrPtr: errTestAny, + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + { + name: "bytes", + rawval: []any{[]byte("hello"), []byte("world"), nil}, + expStringArr: []string{"hello", "world", ""}, + expStringArrPtr: ptrArr[string]("hello", "world", nil), + expInt64Arr: errTestAny, + expInt64ArrPtr: errTestAny, + expIntArr: errTestAny, + expIntArrPtr: errTestAny, + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("hello"), []byte("world"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("hello"), []byte("world"), nil), + expDecArr: errTestAny, + expDecArrPtr: errTestAny, + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + { + name: "decimal", + rawval: []any{"123.456", "789.012", nil}, + expStringArr: []string{"123.456", "789.012", ""}, + expStringArrPtr: ptrArr[string]("123.456", "789.012", nil), + expInt64Arr: errTestAny, + expInt64ArrPtr: errTestAny, + expIntArr: errTestAny, + expIntArrPtr: errTestAny, + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("123.456"), []byte("789.012"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("123.456"), []byte("789.012"), nil), + expDecArr: []Decimal{*MustParseDecimal("123.456"), *MustParseDecimal("789.012"), {}}, + expDecArrPtr: ptrArr[Decimal](*MustParseDecimal("123.456"), *MustParseDecimal("789.012"), nil), + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + { + name: "uuid", + rawval: []any{"12345678-1234-1234-1234-123456789abc", "12345678-1234-1234-1234-123456789def", nil}, + expStringArr: []string{"12345678-1234-1234-1234-123456789abc", "12345678-1234-1234-1234-123456789def", ""}, + expStringArrPtr: ptrArr[string]("12345678-1234-1234-1234-123456789abc", "12345678-1234-1234-1234-123456789def", nil), + expInt64Arr: errTestAny, + expInt64ArrPtr: errTestAny, + expIntArr: errTestAny, + expIntArrPtr: errTestAny, + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("12345678-1234-1234-1234-123456789abc"), []byte("12345678-1234-1234-1234-123456789def"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("12345678-1234-1234-1234-123456789abc"), []byte("12345678-1234-1234-1234-123456789def"), nil), + expDecArr: errTestAny, + expDecArrPtr: errTestAny, + expUUIDArr: []UUID{*MustParseUUID("12345678-1234-1234-1234-123456789abc"), *MustParseUUID("12345678-1234-1234-1234-123456789def"), {}}, + expUUIDArrPtr: ptrArr[UUID](*MustParseUUID("12345678-1234-1234-1234-123456789abc"), *MustParseUUID("12345678-1234-1234-1234-123456789def"), nil), + }, + { + name: "all nil values", + rawval: []any{nil, nil, nil}, + expStringArr: []string{"", "", ""}, + expStringArrPtr: ptrArr[string](nil, nil, nil), + expInt64Arr: []int64{0, 0, 0}, + expInt64ArrPtr: ptrArr[int64](nil, nil, nil), + expIntArr: []int{0, 0, 0}, + expIntArrPtr: ptrArr[int](nil, nil, nil), + expBoolArr: []bool{false, false, false}, + expBoolArrPtr: ptrArr[bool](nil, nil, nil), + expBytesArr: [][]byte{nil, nil, nil}, + expBytesArrPtr: ptrArr[[]byte](nil, nil, nil), + expDecArr: []Decimal{{}, {}, {}}, + expDecArrPtr: ptrArr[Decimal](nil, nil, nil), + expUUIDArr: []UUID{{}, {}, {}}, + expUUIDArrPtr: ptrArr[UUID](nil, nil, nil), + }, + { + name: "nil", + rawval: nil, + expStringArr: nil, + expStringArrPtr: nil, + expInt64Arr: nil, + expInt64ArrPtr: nil, + expIntArr: nil, + expIntArrPtr: nil, + expBoolArr: nil, + expBoolArrPtr: nil, + expBytesArr: nil, + expBytesArrPtr: nil, + expDecArr: nil, + expDecArrPtr: nil, + expUUIDArr: nil, + expUUIDArrPtr: nil, + }, + { + name: "float", + rawval: []any{float64(123.456), float64(789), nil}, + expStringArr: []string{"123.456", "789", ""}, + expStringArrPtr: ptrArr[string]("123.456", "789", nil), + expInt64Arr: errTestAny, + expInt64ArrPtr: errTestAny, + expIntArr: errTestAny, + expIntArrPtr: errTestAny, + expBoolArr: errTestAny, + expBoolArrPtr: errTestAny, + expBytesArr: [][]byte{[]byte("123.456"), []byte("789"), nil}, + expBytesArrPtr: ptrArr[[]byte]([]byte("123.456"), []byte("789"), nil), + expDecArr: []Decimal{*MustParseDecimal("123.456"), *MustParseDecimal("789"), {}}, + expDecArrPtr: ptrArr[Decimal](*MustParseDecimal("123.456"), *MustParseDecimal("789"), nil), + expUUIDArr: errTestAny, + expUUIDArrPtr: errTestAny, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + qr := &QueryResult{ + Values: [][]any{{tt.rawval}}, + } + checkType[[]string](t, qr, tt.expStringArr) + checkType[[]*string](t, qr, tt.expStringArrPtr) + checkType[[]int64](t, qr, tt.expInt64Arr) + checkType[[]*int64](t, qr, tt.expInt64ArrPtr) + checkType[[]int](t, qr, tt.expIntArr) + checkType[[]*int](t, qr, tt.expIntArrPtr) + checkType[[]bool](t, qr, tt.expBoolArr) + checkType[[]*bool](t, qr, tt.expBoolArrPtr) + checkType[[][]byte](t, qr, tt.expBytesArr) + checkType[[]*[]byte](t, qr, tt.expBytesArrPtr) + checkType[[]Decimal](t, qr, tt.expDecArr) + checkType[[]*Decimal](t, qr, tt.expDecArrPtr) + checkType[[]UUID](t, qr, tt.expUUIDArr) + checkType[[]*UUID](t, qr, tt.expUUIDArrPtr) + }) + } +} + +// Im checking here that users are capable of detecting zero length +// arrays vs null arrays +func TestScanArrayNullability(t *testing.T) { + v := new([]string) + qr := &QueryResult{ + Values: [][]any{{[]any{}}}, + } + err := qr.Scan([]any{v}, func() error { + return nil + }) + assert.NoError(t, err) + + assert.True(t, *v != nil) + assert.Len(t, *v, 0) + + v = new([]string) + v2 := new([]string) + qr = &QueryResult{ + Values: [][]any{{[]any{"a"}, nil}}, + } + err = qr.Scan([]any{v, v2}, func() error { + return nil + }) + assert.NoError(t, err) + + assert.True(t, *v != nil) + assert.Len(t, *v, 1) + + assert.False(t, *v2 != nil) +} + +func ptrArr[T any](v ...any) []*T { + out := make([]*T, len(v)) + for i, b := range v { + if b == nil { + out[i] = nil + continue + } + + convV, ok := b.(T) + if !ok { + panic("invalid type") + } + + out[i] = &convV + } + return out +} diff --git a/core/types/uuid.go b/core/types/uuid.go index f1997d0df..c73d983d8 100644 --- a/core/types/uuid.go +++ b/core/types/uuid.go @@ -58,6 +58,15 @@ func ParseUUID(s string) (*UUID, error) { return &u2, nil } +// MustParseUUID parses a uuid from a string and panics on error +func MustParseUUID(s string) *UUID { + u, err := ParseUUID(s) + if err != nil { + panic(err) + } + return u +} + // String returns the string representation of the uuid func (u UUID) String() string { return uuid.UUID(u).String() diff --git a/node/engine/functions.go b/node/engine/functions.go index 13c470baf..a60d4a511 100644 --- a/node/engine/functions.go +++ b/node/engine/functions.go @@ -937,12 +937,12 @@ var ( func init() { var err error - decimal1000, err = types.NewDecimalType(1000, 0) + decimal1000, err = types.NewNumericType(1000, 0) if err != nil { panic(fmt.Sprintf("failed to create decimal type: 1000, 0: %v", err)) } - decimal16_6, err = types.NewDecimalType(16, 6) + decimal16_6, err = types.NewNumericType(16, 6) if err != nil { panic(fmt.Sprintf("failed to create decimal type: 16, 6: %v", err)) } diff --git a/node/engine/interpreter/interpreter_test.go b/node/engine/interpreter/interpreter_test.go index ff2e75182..a58082ab8 100644 --- a/node/engine/interpreter/interpreter_test.go +++ b/node/engine/interpreter/interpreter_test.go @@ -13,7 +13,6 @@ import ( "github.com/kwilteam/kwil-db/common" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/extensions/precompiles" "github.com/kwilteam/kwil-db/node/engine" "github.com/kwilteam/kwil-db/node/engine/interpreter" @@ -699,8 +698,8 @@ func Test_SQL(t *testing.T) { for j, val := range row { // if it is a numeric, we should do a special comparison if test.results[i][j] != nil { - if decVal, ok := test.results[i][j].(*decimal.Decimal); ok { - cmp, err := decVal.Cmp(val.(*decimal.Decimal)) + if decVal, ok := test.results[i][j].(*types.Decimal); ok { + cmp, err := decVal.Cmp(val.(*types.Decimal)) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -780,7 +779,7 @@ func Test_Roundtrip(t *testing.T) { { name: "decimal_array", datatype: "DECIMAL(70,5)[]", - value: []*decimal.Decimal{mustExplicitDecimal("100.101", 70, 5), mustExplicitDecimal("200.202", 70, 5)}, + value: []*types.Decimal{mustExplicitDecimal("100.101", 70, 5), mustExplicitDecimal("200.202", 70, 5)}, }, { name: "uuid_array", @@ -1761,7 +1760,7 @@ func (t *testPrecompile) makeGetMethod(datatype *types.DataType) precompiles.Met } func mustDecType(precision, scale uint16) *types.DataType { - t, err := types.NewDecimalType(precision, scale) + t, err := types.NewNumericType(precision, scale) if err != nil { panic(err) } @@ -2019,7 +2018,7 @@ func Test_Extensions(t *testing.T) { {"bool_array", []bool{true}}, {"bytea_array", [][]byte{{1, 2, 3}}}, {"uuid_array", []*types.UUID{mustUUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")}}, - {"numeric_array", []*decimal.Decimal{mustExplicitDecimal("1.23", 10, 2)}}, + {"numeric_array", []*types.Decimal{mustExplicitDecimal("1.23", 10, 2)}}, } { err = adminCall("test_ext", "get_"+get.key, []any{get.key}, exact(get.value)) require.NoErrorf(t, err, "key: %s", get.key) @@ -2480,8 +2479,8 @@ func eq(a, b any) error { return nil } -func mustExplicitDecimal(s string, prec, scale uint16) *decimal.Decimal { - d, err := decimal.NewExplicit(s, prec, scale) +func mustExplicitDecimal(s string, prec, scale uint16) *types.Decimal { + d, err := types.NewDecimalExplicit(s, prec, scale) if err != nil { panic(err) } diff --git a/node/engine/interpreter/sql_test.go b/node/engine/interpreter/sql_test.go index abc8c71de..276d91cff 100644 --- a/node/engine/interpreter/sql_test.go +++ b/node/engine/interpreter/sql_test.go @@ -10,7 +10,6 @@ import ( "github.com/kwilteam/kwil-db/common" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/extensions/precompiles" "github.com/kwilteam/kwil-db/node/engine" "github.com/kwilteam/kwil-db/node/pg" @@ -268,7 +267,7 @@ func Test_built_in_sql(t *testing.T) { "strarr": mustNewVal([]string{"a", "b", "c"}), "intarr": mustNewVal([]int{1, 2, 3}), "boolarr": mustNewVal([]bool{true, false, true}), - "decarr": mustNewVal([]*decimal.Decimal{mustDec("1.23"), mustDec("4.56")}), + "decarr": mustNewVal([]*types.Decimal{mustDec("1.23"), mustDec("4.56")}), "uuidarr": mustNewVal([]*types.UUID{mustUUID("c7b6a54c-392c-48f9-803d-31cb97e76052"), mustUUID("c7b6a54c-392c-48f9-803d-31cb97e76053")}), "blobarr": mustNewVal([][]byte{{1, 2, 3}, {4, 5, 6}}), } @@ -1068,7 +1067,7 @@ func Test_Metadata(t *testing.T) { } func mustDecType(prec, scale uint16) *types.DataType { - dt, err := types.NewDecimalType(prec, scale) + dt, err := types.NewNumericType(prec, scale) if err != nil { panic(err) } diff --git a/node/engine/interpreter/values.go b/node/engine/interpreter/values.go index c11774508..cf43ff0c4 100644 --- a/node/engine/interpreter/values.go +++ b/node/engine/interpreter/values.go @@ -10,7 +10,6 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/engine" ) @@ -115,13 +114,13 @@ func init() { }, }, valueMapping{ - KwilType: types.DecimalType, + KwilType: types.NumericType, ZeroValue: func(t *types.DataType) (value, error) { if !t.HasMetadata() { return nil, fmt.Errorf("cannot create zero value of decimal type with zero precision and scale") } - dec, err := decimal.NewFromString("0") + dec, err := types.ParseDecimal("0") if err != nil { return nil, err } @@ -197,7 +196,7 @@ func init() { }, }, valueMapping{ - KwilType: types.DecimalArrayType, + KwilType: types.NumericArrayType, ZeroValue: func(t *types.DataType) (value, error) { if !t.HasMetadata() { return nil, fmt.Errorf("cannot create zero value of decimal type with zero precision and scale") @@ -220,7 +219,7 @@ func init() { prec := t.Metadata[0] scale := t.Metadata[1] - arr := newNullDecArr(types.DecimalArrayType) + arr := newNullDecArr(types.NumericArrayType) arr.metadata = &precAndScale{prec, scale} return arr, nil }, @@ -344,9 +343,9 @@ func newValue(v any) (value, error) { return makeUUID(v), nil case types.UUID: return makeUUID(&v), nil - case *decimal.Decimal: + case *types.Decimal: return makeDecimal(v), nil - case decimal.Decimal: + case types.Decimal: return makeDecimal(&v), nil case []int64: pgInts := make([]pgtype.Int8, len(v)) @@ -464,7 +463,7 @@ func newValue(v any) (value, error) { return &blobArrayValue{ singleDimArray: newValidArr(pgBlobs), }, nil - case []*decimal.Decimal: + case []*types.Decimal: pgDecs := make([]pgtype.Numeric, len(v)) for i, val := range v { pgDecs[i] = pgTypeFromDec(val) @@ -695,7 +694,7 @@ func (i *int8Value) Cast(t *types.DataType) (value, error) { return nil, castErr(errors.New("cannot cast int to decimal array")) } - dec, err := decimal.NewFromString(fmt.Sprint(i.Int64)) + dec, err := types.ParseDecimal(fmt.Sprint(i.Int64)) if err != nil { return nil, castErr(err) } @@ -818,7 +817,7 @@ func (s *textValue) Cast(t *types.DataType) (value, error) { return nil, castErr(errors.New("cannot cast text to decimal array")) } - dec, err := decimal.NewFromString(s.String) + dec, err := types.ParseDecimal(s.String) if err != nil { return nil, castErr(err) } @@ -1168,7 +1167,7 @@ func (u *uuidValue) Cast(t *types.DataType) (value, error) { } } -func pgTypeFromDec(d *decimal.Decimal) pgtype.Numeric { +func pgTypeFromDec(d *types.Decimal) pgtype.Numeric { if d == nil { return pgtype.Numeric{ Valid: false, @@ -1196,16 +1195,16 @@ func pgTypeFromDec(d *decimal.Decimal) pgtype.Numeric { } } -func decFromPgType(n pgtype.Numeric, meta *precAndScale) (*decimal.Decimal, error) { +func decFromPgType(n pgtype.Numeric, meta *precAndScale) (*types.Decimal, error) { if n.NaN { - return decimal.NewNaN(), nil + return types.NewNaNDecimal(), nil } if !n.Valid { // we should never get here, but just in case return nil, fmt.Errorf("internal bug: null decimal") } - dec, err := decimal.NewFromBigInt(n.Int, n.Exp) + dec, err := types.NewDecimalFromBigInt(n.Int, n.Exp) if err != nil { return nil, err } @@ -1220,7 +1219,7 @@ func decFromPgType(n pgtype.Numeric, meta *precAndScale) (*decimal.Decimal, erro return dec, nil } -func makeDecimal(d *decimal.Decimal) *decimalValue { +func makeDecimal(d *types.Decimal) *decimalValue { if d == nil { return &decimalValue{ Numeric: pgtype.Numeric{ @@ -1248,7 +1247,7 @@ func (d *decimalValue) Null() bool { return !d.Valid } -func (d *decimalValue) dec() (*decimal.Decimal, error) { +func (d *decimalValue) dec() (*types.Decimal, error) { if d.NaN { return nil, fmt.Errorf("NaN") } @@ -1257,7 +1256,7 @@ func (d *decimalValue) dec() (*decimal.Decimal, error) { return nil, fmt.Errorf("internal bug: null decimal") } - d2, err := decimal.NewFromBigInt(d.Int, d.Exp) + d2, err := types.NewDecimalFromBigInt(d.Int, d.Exp) if err != nil { return nil, err } @@ -1327,20 +1326,20 @@ func (d *decimalValue) Arithmetic(v scalarValue, op engine.ArithmeticOp) (scalar return nil, err } - var d2 *decimal.Decimal + var d2 *types.Decimal switch op { case engine.ADD: - d2, err = decimal.Add(dec1, dec2) + d2, err = types.DecimalAdd(dec1, dec2) case engine.SUB: - d2, err = decimal.Sub(dec1, dec2) + d2, err = types.DecimalSub(dec1, dec2) case engine.MUL: - d2, err = decimal.Mul(dec1, dec2) + d2, err = types.DecimalMul(dec1, dec2) case engine.DIV: - d2, err = decimal.Div(dec1, dec2) + d2, err = types.DecimalDiv(dec1, dec2) case engine.EXP: - d2, err = decimal.Pow(dec1, dec2) + d2, err = types.DecimalPow(dec1, dec2) case engine.MOD: - d2, err = decimal.Mod(dec1, dec2) + d2, err = types.DecimalMod(dec1, dec2) default: return nil, fmt.Errorf("%w: unexpected operator id %d for decimal", engine.ErrArithmetic, op) } @@ -1383,10 +1382,10 @@ func (d *decimalValue) Unary(op engine.UnaryOp) (scalarValue, error) { func (d *decimalValue) Type() *types.DataType { if d.metadata == nil { - return types.DecimalType + return types.NumericType } - t := types.DecimalType.Copy() + t := types.NumericType.Copy() t.Metadata = *d.metadata return t } @@ -1650,8 +1649,8 @@ func (a *int8ArrayValue) Cast(t *types.DataType) (value, error) { return nil, castErr(errors.New("cannot cast int array to decimal")) } - return castArrWithPtr(a, func(i int64) (*decimal.Decimal, error) { - return decimal.NewExplicit(strconv.FormatInt(i, 10), t.Metadata[0], t.Metadata[1]) + return castArrWithPtr(a, func(i int64) (*types.Decimal, error) { + return types.NewDecimalExplicit(strconv.FormatInt(i, 10), t.Metadata[0], t.Metadata[1]) }, newDecArrFn(t)) } @@ -1740,8 +1739,8 @@ func (a *textArrayValue) Cast(t *types.DataType) (value, error) { return nil, castErr(errors.New("cannot cast text array to decimal")) } - return castArrWithPtr(a, func(s string) (*decimal.Decimal, error) { - return decimal.NewExplicit(s, t.Metadata[0], t.Metadata[1]) + return castArrWithPtr(a, func(s string) (*types.Decimal, error) { + return types.NewDecimalExplicit(s, t.Metadata[0], t.Metadata[1]) }, newDecArrFn(t)) } @@ -1919,13 +1918,13 @@ func newNullDecArr(t *types.DataType) *decimalArrayValue { // newDecArrFn returns a function that creates a new DecimalArrayValue. // It is used for type casting. -func newDecArrFn(t *types.DataType) func(d []*decimal.Decimal) *decimalArrayValue { - return func(d []*decimal.Decimal) *decimalArrayValue { +func newDecArrFn(t *types.DataType) func(d []*types.Decimal) *decimalArrayValue { + return func(d []*types.Decimal) *decimalArrayValue { return newDecimalArrayValue(d, t) } } -func newDecimalArrayValue(d []*decimal.Decimal, t *types.DataType) *decimalArrayValue { +func newDecimalArrayValue(d []*types.Decimal, t *types.DataType) *decimalArrayValue { vals := make([]pgtype.Numeric, len(d)) for i, v := range d { var newDec pgtype.Numeric @@ -2057,10 +2056,10 @@ func (a *decimalArrayValue) Set(i int32, v scalarValue) error { func (a *decimalArrayValue) Type() *types.DataType { if a.metadata == nil { - return types.DecimalArrayType + return types.NumericArrayType } - t := types.DecimalArrayType.Copy() + t := types.NumericArrayType.Copy() t.Metadata = *a.metadata return t } @@ -2070,7 +2069,7 @@ func (a *decimalArrayValue) RawValue() any { return nil } - res := make([]*decimal.Decimal, len(a.Elements)) + res := make([]*types.Decimal, len(a.Elements)) for i, v := range a.Elements { if v.Valid { dec, err := decFromPgType(v, a.metadata) @@ -2096,7 +2095,7 @@ func (a *decimalArrayValue) Cast(t *types.DataType) (value, error) { } // otherwise, we need to alter the precision and scale - res := make([]*decimal.Decimal, a.Len()) + res := make([]*types.Decimal, a.Len()) for i := int32(1); i <= a.Len(); i++ { v, err := a.Get(i) if err != nil { @@ -2110,7 +2109,7 @@ func (a *decimalArrayValue) Cast(t *types.DataType) (value, error) { // we need to make a copy of the decimal because SetPrecisionAndScale // will modify the decimal in place. - dec2, err := decimal.NewExplicit(dec.String(), dec.Precision(), dec.Scale()) + dec2, err := types.NewDecimalExplicit(dec.String(), dec.Precision(), dec.Scale()) if err != nil { return nil, err } @@ -2128,10 +2127,10 @@ func (a *decimalArrayValue) Cast(t *types.DataType) (value, error) { switch *t { case *types.TextArrayType: - return castArr(a, func(d *decimal.Decimal) (string, error) { return d.String(), nil }, newTextArrayValue) + return castArr(a, func(d *types.Decimal) (string, error) { return d.String(), nil }, newTextArrayValue) case *types.IntArrayType: - return castArr(a, func(d *decimal.Decimal) (int64, error) { return d.Int64() }, newIntArr) - case *types.DecimalArrayType: + return castArr(a, func(d *types.Decimal) (int64, error) { return d.Int64() }, newIntArr) + case *types.NumericArrayType: return a, nil default: return nil, castErr(fmt.Errorf("cannot cast decimal array to %s", t)) @@ -2571,7 +2570,7 @@ func parseValue(s string, t *types.DataType) (value, error) { } if t.Name == types.NumericStr { - dec, err := decimal.NewExplicit(s, t.Metadata[0], t.Metadata[1]) + dec, err := types.NewDecimalExplicit(s, t.Metadata[0], t.Metadata[1]) if err != nil { return nil, err } diff --git a/node/engine/interpreter/values_test.go b/node/engine/interpreter/values_test.go index 7872921dd..2b51bec89 100644 --- a/node/engine/interpreter/values_test.go +++ b/node/engine/interpreter/values_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/engine" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -153,8 +152,8 @@ func Test_Arithmetic(t *testing.T) { // It handles the semantics of comparing decimal values. func eq(t *testing.T, a, b any) { // if the values are decimals, we need to compare them manually - if aDec, ok := a.(*decimal.Decimal); ok { - bDec, ok := b.(*decimal.Decimal) + if aDec, ok := a.(*types.Decimal); ok { + bDec, ok := b.(*types.Decimal) require.True(t, ok) rec, err := aDec.Cmp(bDec) @@ -163,8 +162,8 @@ func eq(t *testing.T, a, b any) { return } - if aDec, ok := a.([]*decimal.Decimal); ok { - bDec, ok := b.([]*decimal.Decimal) + if aDec, ok := a.([]*types.Decimal); ok { + bDec, ok := b.([]*types.Decimal) require.True(t, ok) require.Len(t, aDec, len(bDec)) @@ -297,8 +296,8 @@ func Test_Comparison(t *testing.T) { }, { name: "decimal-array", - a: []*decimal.Decimal{mustDec("1.00"), mustDec("2.00"), mustDec("3.00")}, - b: []*decimal.Decimal{mustDec("1.00"), mustDec("2.00"), mustDec("3.00")}, + a: []*types.Decimal{mustDec("1.00"), mustDec("2.00"), mustDec("3.00")}, + b: []*types.Decimal{mustDec("1.00"), mustDec("2.00"), mustDec("3.00")}, eq: true, gt: engine.ErrComparison, lt: engine.ErrComparison, @@ -426,9 +425,9 @@ func Test_Cast(t *testing.T) { blobArr any } - mDec := func(dec string) *decimal.Decimal { + mDec := func(dec string) *types.Decimal { // all decimals will be precision 10, scale 5 - d, err := decimal.NewFromString(dec) + d, err := types.ParseDecimal(dec) require.NoError(t, err) err = d.SetPrecisionAndScale(10, 5) @@ -436,8 +435,8 @@ func Test_Cast(t *testing.T) { return d } - mDecArr := func(decimals ...string) []*decimal.Decimal { - var res []*decimal.Decimal + mDecArr := func(decimals ...string) []*types.Decimal { + var res []*types.Decimal for _, dec := range decimals { res = append(res, mDec(dec)) } @@ -587,7 +586,7 @@ func Test_Cast(t *testing.T) { eq(t, want, res.RawValue()) } - decimalType, err := types.NewDecimalType(10, 5) + decimalType, err := types.NewNumericType(10, 5) require.NoError(t, err) decArrType := decimalType.Copy() @@ -800,7 +799,7 @@ func Test_Array(t *testing.T) { // this test tests setting null values to an array of different types func Test_ArrayNull(t *testing.T) { - decType, err := types.NewDecimalType(10, 5) + decType, err := types.NewNumericType(10, 5) require.NoError(t, err) decType.IsArray = true for _, dt := range []*types.DataType{ @@ -829,16 +828,16 @@ func ptrArr[T any](arr []T) []*T { return res } -func mustDec(dec string) *decimal.Decimal { - d, err := decimal.NewFromString(dec) +func mustDec(dec string) *types.Decimal { + d, err := types.ParseDecimal(dec) if err != nil { panic(err) } return d } -func mustExplicitDecimal(dec string, precision, scale uint16) *decimal.Decimal { - d, err := decimal.NewExplicit(dec, precision, scale) +func mustExplicitDecimal(dec string, precision, scale uint16) *types.Decimal { + d, err := types.NewDecimalExplicit(dec, precision, scale) if err != nil { panic(err) } diff --git a/node/engine/parse/antlr.go b/node/engine/parse/antlr.go index aed6235fc..9850e8b10 100644 --- a/node/engine/parse/antlr.go +++ b/node/engine/parse/antlr.go @@ -10,7 +10,6 @@ import ( antlr "github.com/antlr4-go/antlr/v4" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/core/types/validation" "github.com/kwilteam/kwil-db/node/engine/parse/gen" ) @@ -287,13 +286,13 @@ func (s *schemaVisitor) VisitDecimal_literal(ctx *gen.Decimal_literalContext) an // our decimal library can parse the decimal, so we simply pass it there txt := ctx.GetText() - dec, err := decimal.NewFromString(txt) + dec, err := types.ParseDecimal(txt) if err != nil { s.errs.RuleErr(ctx, err, "invalid decimal literal: %s", txt) return unknownExpression(ctx) } - typ, err := types.NewDecimalType(dec.Precision(), dec.Scale()) + typ, err := types.NewNumericType(dec.Precision(), dec.Scale()) if err != nil { s.errs.RuleErr(ctx, err, "invalid decimal literal: %s", txt) return unknownExpression(ctx) diff --git a/node/engine/parse/ast.go b/node/engine/parse/ast.go index e2687f43e..5fa3e6cb7 100644 --- a/node/engine/parse/ast.go +++ b/node/engine/parse/ast.go @@ -7,7 +7,6 @@ import ( antlr "github.com/antlr4-go/antlr/v4" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" ) // this file contains the ASTs for SQL, DDL, and actions. @@ -97,7 +96,7 @@ func literalToString(value any) (string, error) { str.WriteString(fmt.Sprint(v)) case *types.Uint256: str.WriteString(v.String()) - case *decimal.Decimal: + case *types.Decimal: str.WriteString(v.String()) case bool: // for bool type if v { diff --git a/node/engine/pg_generate/generate.go b/node/engine/pg_generate/generate.go index bc95a54bc..e627f6969 100644 --- a/node/engine/pg_generate/generate.go +++ b/node/engine/pg_generate/generate.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/engine" "github.com/kwilteam/kwil-db/node/engine/parse" ) @@ -1130,7 +1129,7 @@ func formatPGLiteral(value any) (string, error) { str.WriteString(v.String()) case *types.Uint256: str.WriteString(v.String()) - case *decimal.Decimal: + case *types.Decimal: str.WriteString(v.String()) case bool: // for bool type if v { diff --git a/node/pg/db_live_test.go b/node/pg/db_live_test.go index feb54cd08..88e685cf7 100644 --- a/node/pg/db_live_test.go +++ b/node/pg/db_live_test.go @@ -17,12 +17,10 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/types/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -142,7 +140,7 @@ func TestQueryRowFunc(t *testing.T) { reflect.TypeFor[*pgtype.Int8](), reflect.TypeFor[*pgtype.Text](), reflect.TypeFor[*[]uint8](), - reflect.TypeFor[*decimal.Decimal](), + reflect.TypeFor[*types.Decimal](), reflect.TypeFor[*pgtype.Array[pgtype.Int8]](), reflect.TypeFor[*types.Uint256](), reflect.TypeFor[*types.Uint256Array](), @@ -162,7 +160,7 @@ func TestQueryRowFunc(t *testing.T) { // Then use QueryRowFunc with the scan vals. - wantDec, err := decimal.NewFromString("12.500") // numeric(x,3)! + wantDec, err := types.ParseDecimal("12.500") // numeric(x,3)! require.NoError(t, err) if wantDec.Scale() != 3 { t.Fatalf("scale of decimal does not match column def: %v", wantDec) @@ -291,7 +289,7 @@ func TestScanVal(t *testing.T) { var ba []byte var i8 pgtype.Int8 var txt pgtype.Text - var num decimal.Decimal // pgtype.Numeric + var num types.Decimal // pgtype.Numeric var u256 types.Uint256 // want pointers to these slices for array types @@ -302,7 +300,7 @@ func TestScanVal(t *testing.T) { var ia pgtype.Array[pgtype.Int8] var ta pgtype.Array[pgtype.Text] var baa pgtype.Array[[]byte] - var na decimal.DecimalArray // pgtype.Array[pgtype.Numeric] + var na types.DecimalArray // pgtype.Array[pgtype.Numeric] var u256a types.Uint256Array wantScans := []any{&i8, &i8, &txt, &ba, &num, &u256, @@ -351,11 +349,11 @@ func TestQueryRowFuncAny(t *testing.T) { reflect.TypeFor[int64](), reflect.TypeFor[string](), reflect.TypeFor[[]byte](), - reflect.TypeFor[*decimal.Decimal](), + reflect.TypeFor[*types.Decimal](), reflect.TypeFor[[]int64](), } - mustDec := func(s string) *decimal.Decimal { - d, err := decimal.NewFromString(s) + mustDec := func(s string) *types.Decimal { + d, err := types.ParseDecimal(s) require.NoError(t, err) return d } @@ -864,8 +862,8 @@ func TestTypeRoundtrip(t *testing.T) { }, { typ: "decimal(6,4)[]", - val: decimal.DecimalArray{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, - want: decimal.DecimalArray{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, + val: types.DecimalArray{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, + want: types.DecimalArray{mustDecimal("12.4223"), mustDecimal("22.4425"), mustDecimal("23.7423")}, }, { typ: "uint256[]", @@ -967,8 +965,8 @@ func TestTypeRoundtrip(t *testing.T) { } // mustDecimal panics if the string cannot be converted to a decimal. -func mustDecimal(s string) *decimal.Decimal { - d, err := decimal.NewFromString(s) +func mustDecimal(s string) *types.Decimal { + d, err := types.ParseDecimal(s) if err != nil { panic(err) } @@ -1053,12 +1051,12 @@ func Test_Changesets(t *testing.T) { val2: []byte("world"), arrayVal2: [][]byte{[]byte("d"), []byte("e"), []byte("f")}, }, - &changesetTestcase[*decimal.Decimal, decimal.DecimalArray]{ + &changesetTestcase[*types.Decimal, types.DecimalArray]{ datatype: "decimal(6,3)", val: mustDecimal("123.456"), - arrayVal: decimal.DecimalArray{mustDecimal("123.456"), mustDecimal("123.456"), mustDecimal("123.456")}, + arrayVal: types.DecimalArray{mustDecimal("123.456"), mustDecimal("123.456"), mustDecimal("123.456")}, val2: mustDecimal("123.457"), - arrayVal2: decimal.DecimalArray{mustDecimal("123.457"), mustDecimal("123.457"), mustDecimal("123.457")}, + arrayVal2: types.DecimalArray{mustDecimal("123.457"), mustDecimal("123.457"), mustDecimal("123.457")}, }, &changesetTestcase[*types.UUID, types.UUIDArray]{ datatype: "uuid", @@ -1631,7 +1629,7 @@ func Test_ParseUnixTimestamp(t *testing.T) { require.Len(t, res.Rows, 1) require.Len(t, res.Rows[0], 1) - expected, err := decimal.NewFromString("1718114052.123456") + expected, err := types.ParseDecimal("1718114052.123456") require.NoError(t, err) require.EqualValues(t, expected, res.Rows[0][0]) diff --git a/node/pg/stats.go b/node/pg/stats.go index 7d051c793..2d56eec59 100644 --- a/node/pg/stats.go +++ b/node/pg/stats.go @@ -10,9 +10,7 @@ import ( "strings" "github.com/jackc/pgx/v5/pgtype" - "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/types/sql" ) @@ -249,7 +247,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db ins(stat, b, cmpBool) case ColTypeNumeric: // use *decimal.Decimal in stats - var dec *decimal.Decimal + var dec *types.Decimal switch v := val.(type) { case *pgtype.Numeric: if !v.Valid { @@ -279,7 +277,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db continue } - case *decimal.Decimal: + case *types.Decimal: if v.NaN() { // we're pretending this is NULL by our sql.Scanner's convetion stat.NullCount++ continue @@ -289,7 +287,7 @@ func colStats(ctx context.Context, qualifiedTable string, colInfo []ColInfo, db v = &v2 } dec = v - case decimal.Decimal: + case types.Decimal: if v.NaN() { // we're pretending this is NULL by our sql.Scanner's convetion stat.NullCount++ continue @@ -385,7 +383,7 @@ func cmpBool(a, b bool) int { return 0 // false == false } -func cmpDecimal(val, mm *decimal.Decimal) int { +func cmpDecimal(val, mm *types.Decimal) int { d, err := val.Cmp(mm) if err != nil { panic(fmt.Sprintf("%s: (nan decimal?) %v or %v", err, val, mm)) diff --git a/node/pg/system.go b/node/pg/system.go index c454456de..06cd942af 100644 --- a/node/pg/system.go +++ b/node/pg/system.go @@ -14,9 +14,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/types/sql" ) @@ -261,7 +259,7 @@ func scanVal(ct ColType) any { // pgtype.Numeric or decimal.Decimal would work. pgtype.Numeric is way // easier to work with and instantiate, but using our types here helps // test their scanners/valuers. - return new(decimal.Decimal) + return new(types.Decimal) case ColTypeUINT256: return new(types.Uint256) case ColTypeFloat: @@ -289,7 +287,7 @@ func scanArrayVal(ct ColType) any { case ColTypeNumeric: // pgArray is also simpler and more efficient, but as long as we // explicitly define array types, we should test them. - return new(decimal.DecimalArray) + return new(types.DecimalArray) case ColTypeUINT256: return new(types.Uint256Array) case ColTypeFloat: diff --git a/node/pg/types.go b/node/pg/types.go index 55de6e0d4..327b24037 100644 --- a/node/pg/types.go +++ b/node/pg/types.go @@ -12,7 +12,6 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/types/sql" ) @@ -125,7 +124,7 @@ type datatype struct { var ErrNaN = errors.New("NaN") -func pgNumericToDecimal(num pgtype.Numeric) (*decimal.Decimal, error) { +func pgNumericToDecimal(num pgtype.Numeric) (*types.Decimal, error) { if num.NaN { // TODO: create a decimal.Decimal that supports NaN return nil, ErrNaN } @@ -147,7 +146,7 @@ func pgNumericToDecimal(num pgtype.Numeric) (*decimal.Decimal, error) { // Really this could be uint256, which is same underlying type (a domain) as // Numeric. If the caller needs to know, that has to happen differently. - return decimal.NewFromBigInt(i, e) + return types.NewDecimalFromBigInt(i, e) } var ( @@ -456,15 +455,15 @@ var ( } decimalType = &datatype{ - KwilType: types.DecimalType, - Matches: []reflect.Type{reflect.TypeOf(decimal.Decimal{}), reflect.TypeOf(&decimal.Decimal{})}, + KwilType: types.NumericType, + Matches: []reflect.Type{reflect.TypeOf(types.Decimal{}), reflect.TypeOf(&types.Decimal{})}, OID: func(*pgtype.Map) uint32 { return pgtype.NumericOID }, EncodeInferred: func(v any) (any, error) { - var dec *decimal.Decimal + var dec *types.Decimal switch v := v.(type) { - case decimal.Decimal: + case types.Decimal: dec = &v - case *decimal.Decimal: + case *types.Decimal: dec = v default: return nil, fmt.Errorf("unexpected type encoding decimal %T", v) @@ -486,7 +485,7 @@ var ( }, SerializeChangeset: func(value string) ([]byte, error) { // parse to ensure it is a valid decimal, then re-encode it to ensure it is in the correct format. - dec, err := decimal.NewFromString(value) + dec, err := types.ParseDecimal(value) if err != nil { return nil, err } @@ -494,16 +493,16 @@ var ( return []byte(dec.String()), nil }, DeserializeChangeset: func(b []byte) (any, error) { - return decimal.NewFromString(string(b)) + return types.ParseDecimal(string(b)) }, } decimalArrayType = &datatype{ - KwilType: types.DecimalArrayType, - Matches: []reflect.Type{reflect.TypeOf(decimal.DecimalArray{})}, + KwilType: types.NumericArrayType, + Matches: []reflect.Type{reflect.TypeOf(types.DecimalArray{})}, OID: func(*pgtype.Map) uint32 { return pgtype.NumericArrayOID }, EncodeInferred: func(v any) (any, error) { - val, ok := v.(decimal.DecimalArray) + val, ok := v.(types.DecimalArray) if !ok { return nil, fmt.Errorf("expected DecimalArray, got %T", v) } @@ -525,19 +524,19 @@ var ( return nil, fmt.Errorf("expected []any, got %T", a) } - vals := make(decimal.DecimalArray, len(arr)) + vals := make(types.DecimalArray, len(arr)) for i, v := range arr { val, err := decimalType.Decode(v) if err != nil { return nil, err } - vals[i] = val.(*decimal.Decimal) + vals[i] = val.(*types.Decimal) } return vals, nil }, SerializeChangeset: arrayFromChildFunc(2, decimalType.SerializeChangeset), - DeserializeChangeset: deserializeArrayFn[*decimal.Decimal](2, decimalType.DeserializeChangeset), + DeserializeChangeset: deserializeArrayFn[*types.Decimal](2, decimalType.DeserializeChangeset), } uint256Type = &datatype{ diff --git a/node/services/jsonrpc/openrpc/reflect.go b/node/services/jsonrpc/openrpc/reflect.go index 2c946a223..d24d99698 100644 --- a/node/services/jsonrpc/openrpc/reflect.go +++ b/node/services/jsonrpc/openrpc/reflect.go @@ -10,7 +10,6 @@ import ( "unicode/utf8" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" ) type MethodDefinition struct { @@ -128,7 +127,7 @@ func typeToSchemaType(t reflect.Type) string { return "string" case reflect.TypeFor[types.Uint256](): // MarshalJSON also makes JSON string return "string" - case reflect.TypeFor[decimal.Decimal](): // MarshalJSON also makes JSON string + case reflect.TypeFor[types.Decimal](): // MarshalJSON also makes JSON string return "string" case reflect.TypeFor[[]byte](): // A regular []byte field is a base64 string. diff --git a/node/utils/conv/conv.go b/node/utils/conv/conv.go index 4ff54d052..89dad175e 100644 --- a/node/utils/conv/conv.go +++ b/node/utils/conv/conv.go @@ -8,7 +8,6 @@ import ( "unicode/utf8" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" ) func String(a any) (string, error) { @@ -222,9 +221,9 @@ func Uint256(a any) (*types.Uint256, error) { return types.Uint256FromBig(b) case string: return types.Uint256FromString(a) - case *decimal.Decimal: + case *types.Decimal: return types.Uint256FromString(a.String()) - case decimal.Decimal: + case types.Decimal: return types.Uint256FromString(a.String()) case int, int8, int16, int32, int64: return types.Uint256FromString(fmt.Sprint(a)) @@ -243,23 +242,23 @@ func Uint256(a any) (*types.Uint256, error) { } // Decimal converts a value to a Decimal. -func Decimal(a any) (*decimal.Decimal, error) { +func Decimal(a any) (*types.Decimal, error) { switch a := a.(type) { - case *decimal.Decimal: + case *types.Decimal: return a, nil case string: - return decimal.NewFromString(a) + return types.ParseDecimal(a) case *types.Uint256: - return decimal.NewFromString(a.String()) + return types.ParseDecimal(a.String()) case types.Uint256: - return decimal.NewFromString(a.String()) + return types.ParseDecimal(a.String()) case int, int8, int16, int32, int64: - return decimal.NewFromString(fmt.Sprint(a)) + return types.ParseDecimal(fmt.Sprint(a)) case fmt.Stringer: - return decimal.NewFromString(a.String()) + return types.ParseDecimal(a.String()) case nil: // return decimal.NewFromBigInt(big.NewInt(0), 0) - return decimal.NewFromString("0") + return types.ParseDecimal("0") } str, err := String(a) @@ -267,5 +266,5 @@ func Decimal(a any) (*decimal.Decimal, error) { return nil, err } - return decimal.NewFromString(str) + return types.ParseDecimal(str) } diff --git a/node/utils/conv/conv_test.go b/node/utils/conv/conv_test.go index 316f205ad..6a6d1c212 100644 --- a/node/utils/conv/conv_test.go +++ b/node/utils/conv/conv_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/types/decimal" "github.com/kwilteam/kwil-db/node/utils/conv" "github.com/stretchr/testify/require" ) @@ -327,7 +326,7 @@ func Test_Decimal(t *testing.T) { tests := []struct { name string arg any - want *decimal.Decimal + want *types.Decimal wantErr bool }{ { @@ -372,8 +371,8 @@ func Test_Decimal(t *testing.T) { } } -func mustDecimal(s string) *decimal.Decimal { - d, err := decimal.NewFromString(s) +func mustDecimal(s string) *types.Decimal { + d, err := types.ParseDecimal(s) if err != nil { panic(err) } @@ -848,7 +847,7 @@ func TestDecimalAdditionalCases(t *testing.T) { tests := []struct { name string arg any - want *decimal.Decimal + want *types.Decimal wantErr bool }{ {