Skip to content

Commit

Permalink
types: uint256 and uuid marshaling fixes (#925)
Browse files Browse the repository at this point in the history
  • Loading branch information
jchappelow authored Aug 21, 2024
1 parent e7ac91f commit 4ece583
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 13 deletions.
63 changes: 57 additions & 6 deletions core/types/uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package types
import (
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"fmt"
"math/big"

Expand All @@ -14,6 +16,9 @@ import (
// extra methods for usage in Postgres.
type Uint256 struct {
base uint256.Int // not exporting massive method set, which also has params and returns of holiman types
// Null indicates if this is a NULL value in a SQL table. This approach is
// typical in most sql.Valuers, which precludes using a nil pointer to
// indicate a NULL value.
Null bool
}

Expand All @@ -22,7 +27,9 @@ func Uint256FromInt(i uint64) *Uint256 {
return &Uint256{base: *uint256.NewInt(i)}
}

// Uint256FromString creates a new Uint256 from a string.
// Uint256FromString creates a new Uint256 from a string. A Uint256 representing
// a NULL value should be created with a literal (&Uint256{ Null: true }) or via
// of the unmarshal / scan methods.
func Uint256FromString(s string) (*Uint256, error) {
i, err := uint256.FromDecimal(s)
if err != nil {
Expand All @@ -33,11 +40,17 @@ func Uint256FromString(s string) (*Uint256, error) {

// Uint256FromBig creates a new Uint256 from a big.Int.
func Uint256FromBig(i *big.Int) (*Uint256, error) {
if i == nil {
return &Uint256{Null: true}, nil
}
return Uint256FromString(i.String())
}

// Uint256FromBytes creates a new Uint256 from a byte slice.
func Uint256FromBytes(b []byte) (*Uint256, error) {
if b == nil {
return &Uint256{Null: true}, nil
} // zero length non-null is for the actual value 0
bigInt := new(big.Int).SetBytes(b)
return Uint256FromBig(bigInt)
}
Expand All @@ -54,10 +67,6 @@ func (u Uint256) ToBig() *big.Int {
return u.base.ToBig()
}

func (u Uint256) MarshalJSON() ([]byte, error) {
return []byte(u.base.String()), nil // ? json ?
}

func (u *Uint256) Clone() *Uint256 {
v := *u
return &v
Expand All @@ -71,8 +80,29 @@ func CmpUint256(u, v *Uint256) int {
return u.Cmp(v)
}

var _ json.Marshaler = Uint256{}
var _ json.Marshaler = (*Uint256)(nil)

func (u Uint256) MarshalJSON() ([]byte, error) {
if u.Null {
return []byte("null"), nil
}
return []byte(`"` + u.base.String() + `"`), nil
}

var _ json.Unmarshaler = (*Uint256)(nil)

func (u *Uint256) UnmarshalJSON(b []byte) error {
u2, err := Uint256FromString(string(b))
var str string
if err := json.Unmarshal(b, &str); err != nil {
return err
}
if str == "" { // JSON data was null or ""
u.Null = true
u.base.Clear()
return nil
}
u2, err := Uint256FromString(str)
if err != nil {
return err
}
Expand All @@ -81,6 +111,27 @@ func (u *Uint256) UnmarshalJSON(b []byte) error {
return nil
}

var _ encoding.BinaryMarshaler = Uint256{}
var _ encoding.BinaryMarshaler = (*Uint256)(nil)

func (u Uint256) MarshalBinary() ([]byte, error) {
if u.Null {
return nil, nil
}
return u.base.Bytes(), nil
}

var _ encoding.BinaryUnmarshaler = (*Uint256)(nil)

func (u *Uint256) UnmarshalBinary(data []byte) error {
if data == nil {
*u = Uint256{Null: true}
return nil
} // len(data) == 0 is the actual value 0
u.base.SetBytes(data) // u.base, _ = uint256.FromBig(new(big.Int).SetBytes(buf))
return nil
}

// Value implements the driver.Valuer interface.
func (u Uint256) Value() (driver.Value, error) {
if u.Null {
Expand Down
97 changes: 97 additions & 0 deletions core/types/uint256_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package types

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// func intHex(v int64) []byte

const hugeIntStrP1 = "18446744073709551616" // 1 + math.MaxUint64
const hugeIntStrX10 = "184467440737095516150" // 10 * math.MaxUint64

func TestUint256BinaryMarshaling(t *testing.T) {
tests := []struct {
name string
val string
expected []byte
}{
{
name: "small int",
val: "123",
expected: []byte{0x7b},
},
{
name: "zero",
val: "0",
expected: []byte{}, // optimized to empty slice
},
{
name: "null",
val: "", // special case
expected: nil,
},
{
name: "just bigger than uint64",
val: hugeIntStrP1,
expected: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
name: "much than uint64",
val: hugeIntStrX10,
expected: []byte{0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xf6},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u *Uint256
if tt.val == "" {
u = &Uint256{Null: true}
} else {
var err error
u, err = Uint256FromString(tt.val)
require.NoError(t, err)
}

marshaled, err := u.MarshalBinary()
require.NoError(t, err)
assert.Equal(t, tt.expected, marshaled)

var unmarshaled Uint256
err = unmarshaled.UnmarshalBinary(marshaled)
require.NoError(t, err)

assert.Equal(t, u.String(), unmarshaled.String())
})
}
}

func TestUint256JSONRoundTrip(t *testing.T) {
for _, str := range []string{"12345", "0", "", hugeIntStrX10} {
var original *Uint256
if str == "" {
original = &Uint256{Null: true}
} else {
var err error
original, err = Uint256FromString(str)
require.NoError(t, err)
}

marshaled, err := original.MarshalJSON()
require.NoError(t, err)
if len(str) > 0 {
require.Equal(t, `"`+str+`"`, string(marshaled))
} else {
require.Equal(t, string(marshaled), "null")
}

var unmarshaled Uint256
err = unmarshaled.UnmarshalJSON(marshaled)
require.NoError(t, err)

assert.Equal(t, original.String(), unmarshaled.String())
}
}
9 changes: 7 additions & 2 deletions core/types/uuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,16 @@ func (u *UUID) Bytes() []byte {
return u[:]
}

// Over json, we want to send uuids as strings
var _ json.Marshaler = UUID{}
var _ json.Marshaler = (*UUID)(nil)

// MarshalJSON implements json.Marshaler.
func (u UUID) MarshalJSON() ([]byte, error) {
return json.Marshal(u.String())
return []byte(`"` + u.String() + `"`), nil
}

var _ json.Unmarshaler = (*UUID)(nil)

func (u *UUID) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
Expand Down
13 changes: 8 additions & 5 deletions core/types/uuid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"

"github.com/kwilteam/kwil-db/core/types"
)

Expand All @@ -17,7 +19,7 @@ func Test_UUID(t *testing.T) {
}
}

func Test_UUIDJSON(t *testing.T) {
func Test_UUIDJSONRoundTrip(t *testing.T) {
seed := []byte("test")

uuid := types.NewUUIDV5(seed)
Expand All @@ -27,12 +29,13 @@ func Test_UUIDJSON(t *testing.T) {
t.Fatal(err)
}

t.Log(uuid)
assert.Equal(t, `"24aa70cf-0e18-57c9-b449-da8c9db37821"`, string(b))

var uuid3 types.UUID
err = json.Unmarshal(b, &uuid3)
var uuidBack types.UUID
err = json.Unmarshal(b, &uuidBack)
if err != nil {
t.Fatal(err)
}
t.Log(uuid3) // 00000000-0000-0000-0000-000000000000

assert.Equal(t, *uuid, uuidBack)
}

0 comments on commit 4ece583

Please sign in to comment.