diff --git a/go.mod b/go.mod index 8503708b3d..0560e16350 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/spf13/viper v1.18.1 github.com/stretchr/testify v1.8.4 github.com/tetratelabs/wazero v1.1.0 + github.com/tidwall/btree v1.7.0 github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9 go.uber.org/mock v0.3.0 golang.org/x/crypto v0.16.0 diff --git a/go.sum b/go.sum index d1d9dd77d1..0b1b43ecf1 100644 --- a/go.sum +++ b/go.sum @@ -686,6 +686,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/timwu20/go-substrate-rpc-client/v4 v4.0.0-20231110032757-3d8e441b7303 h1:FX7wMjDD0sWGWsC9k+stJaYwThbaq6BDT7ArlInU0KI= github.com/timwu20/go-substrate-rpc-client/v4 v4.0.0-20231110032757-3d8e441b7303/go.mod h1:1p5145LS4BacYYKFstnHScydK9MLjZ15l72v8mbngPQ= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= diff --git a/pkg/btree/btree.go b/pkg/btree/btree.go new file mode 100644 index 0000000000..3cf2d85ac5 --- /dev/null +++ b/pkg/btree/btree.go @@ -0,0 +1,187 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package btree + +import ( + "fmt" + "io" + "reflect" + + "github.com/ChainSafe/gossamer/pkg/scale" + + "golang.org/x/exp/constraints" + + "github.com/tidwall/btree" +) + +type Codec interface { + MarshalSCALE() ([]byte, error) + UnmarshalSCALE(reader io.Reader) error +} + +// Tree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items +// stored in the BTree. This is needed during decoding because the Tree item is a generic type, and we need to know it +// at the time of decoding. +type Tree struct { + *btree.BTree + Comparator func(a, b interface{}) bool + ItemType reflect.Type +} + +// MarshalSCALE encodes the Tree using SCALE. +func (bt Tree) MarshalSCALE() ([]byte, error) { + encodedLen, err := scale.Marshal(uint(bt.Len())) + if err != nil { + return nil, fmt.Errorf("failed to encode BTree length: %w", err) + } + + var encodedItems []byte + bt.Ascend(nil, func(item interface{}) bool { + var encodedItem []byte + encodedItem, err = scale.Marshal(item) + if err != nil { + return false + } + + encodedItems = append(encodedItems, encodedItem...) + return true + }) + + return append(encodedLen, encodedItems...), err +} + +// UnmarshalSCALE decodes the Tree using SCALE. +func (bt Tree) UnmarshalSCALE(reader io.Reader) error { + if bt.Comparator == nil { + return fmt.Errorf("comparator not found") + } + + sliceType := reflect.SliceOf(bt.ItemType) + slicePtr := reflect.New(sliceType) + encodedItems, err := io.ReadAll(reader) + if err != nil { + return fmt.Errorf("read BTree items: %w", err) + } + err = scale.Unmarshal(encodedItems, slicePtr.Interface()) + if err != nil { + return fmt.Errorf("decode BTree items: %w", err) + } + + for i := 0; i < slicePtr.Elem().Len(); i++ { + item := slicePtr.Elem().Index(i).Interface() + bt.Set(item) + } + return nil +} + +// Copy returns a copy of the Tree. +func (bt Tree) Copy() *Tree { + return &Tree{ + BTree: bt.BTree.Copy(), + Comparator: bt.Comparator, + ItemType: bt.ItemType, + } +} + +// NewTree creates a new Tree with the given comparator function. +func NewTree[T any](comparator func(a, b any) bool) Tree { + elementType := reflect.TypeOf((*T)(nil)).Elem() + return Tree{ + BTree: btree.New(comparator), + Comparator: comparator, + ItemType: elementType, + } +} + +var _ Codec = (*Tree)(nil) + +// Map is a wrapper around tidwall/btree.Map +type Map[K constraints.Ordered, V any] struct { + *btree.Map[K, V] + Degree int +} + +type mapItem[K constraints.Ordered, V any] struct { + Key K + Value V +} + +// MarshalSCALE encodes the Map using SCALE. +func (btm Map[K, V]) MarshalSCALE() ([]byte, error) { + encodedLen, err := scale.Marshal(uint(btm.Len())) + if err != nil { + return nil, fmt.Errorf("failed to encode Map length: %w", err) + } + + var ( + pivot K + encodedItems []byte + ) + btm.Ascend(pivot, func(key K, value V) bool { + var ( + encodedKey []byte + encodedValue []byte + ) + encodedKey, err = scale.Marshal(key) + if err != nil { + return false + } + + encodedValue, err = scale.Marshal(value) + if err != nil { + return false + } + + encodedItems = append(encodedItems, encodedKey...) + encodedItems = append(encodedItems, encodedValue...) + return true + }) + + return append(encodedLen, encodedItems...), err +} + +// UnmarshalSCALE decodes the Map using SCALE. +func (btm Map[K, V]) UnmarshalSCALE(reader io.Reader) error { + if btm.Degree == 0 { + return fmt.Errorf("nothing to decode into") + } + + if btm.Map == nil { + btm.Map = btree.NewMap[K, V](btm.Degree) + } + + sliceType := reflect.SliceOf(reflect.TypeOf((*mapItem[K, V])(nil)).Elem()) + slicePtr := reflect.New(sliceType) + encodedItems, err := io.ReadAll(reader) + if err != nil { + return fmt.Errorf("read Map items: %w", err) + } + err = scale.Unmarshal(encodedItems, slicePtr.Interface()) + if err != nil { + return fmt.Errorf("decode Map items: %w", err) + } + + for i := 0; i < slicePtr.Elem().Len(); i++ { + item := slicePtr.Elem().Index(i).Interface().(mapItem[K, V]) + btm.Map.Set(item.Key, item.Value) + } + return nil +} + +// Copy returns a copy of the Map. +func (btm Map[K, V]) Copy() Map[K, V] { + return Map[K, V]{ + Map: btm.Map.Copy(), + } +} + +// NewMap creates a new Map with the given degree. +func NewMap[K constraints.Ordered, V any](degree int) Map[K, V] { + return Map[K, V]{ + Map: btree.NewMap[K, V](degree), + Degree: degree, + } +} + +var _ Codec = (*Map[int, string])(nil) diff --git a/pkg/btree/btree_test.go b/pkg/btree/btree_test.go new file mode 100644 index 0000000000..3d73e4187b --- /dev/null +++ b/pkg/btree/btree_test.go @@ -0,0 +1,85 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package btree + +import ( + "testing" + + "github.com/ChainSafe/gossamer/pkg/scale" + + "github.com/stretchr/testify/require" +) + +type dummy struct { + Field1 uint32 + Field2 [32]byte +} + +func TestBTree_Codec(t *testing.T) { + comparator := func(a, b interface{}) bool { + v1 := a.(dummy) + v2 := b.(dummy) + return v1.Field1 < v2.Field1 + } + + // Create a Tree with 3 dummy items + tree := NewTree[dummy](comparator) + tree.Set(dummy{Field1: 1}) + tree.Set(dummy{Field1: 2}) + tree.Set(dummy{Field1: 3}) + encoded, err := scale.Marshal(tree) + require.NoError(t, err) + + //let mut btree = Map::::new(); + //btree.insert(1, Hash::zero()); + //btree.insert(2, Hash::zero()); + //btree.insert(3, Hash::zero()); + //let encoded = btree.encode(); + //println!("encoded: {:?}", encoded); + expectedEncoded := []byte{12, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + } + require.Equal(t, expectedEncoded, encoded) + + expected := NewTree[dummy](comparator) + err = scale.Unmarshal(expectedEncoded, &expected) + require.NoError(t, err) + + // Check that the expected Tree has the same items as the original + require.Equal(t, tree.Len(), expected.Len()) + require.Equal(t, tree.ItemType, expected.ItemType) + require.Equal(t, tree.Min(), expected.Min()) + require.Equal(t, tree.Max(), expected.Max()) + require.Equal(t, tree.Get(dummy{Field1: 1}), expected.Get(dummy{Field1: 1})) + require.Equal(t, tree.Get(dummy{Field1: 2}), expected.Get(dummy{Field1: 2})) + require.Equal(t, tree.Get(dummy{Field1: 3}), expected.Get(dummy{Field1: 3})) +} + +func TestBTreeMap_Codec(t *testing.T) { + btreeMap := NewMap[uint32, dummy](32) + btreeMap.Set(uint32(1), dummy{Field1: 1}) + btreeMap.Set(uint32(2), dummy{Field1: 2}) + btreeMap.Set(uint32(3), dummy{Field1: 3}) + encoded, err := scale.Marshal(btreeMap) + require.NoError(t, err) + + //let mut btree = Map::::new(); + //btree.insert(1, (1, Hash::zero())); + //btree.insert(2, (2, Hash::zero())); + //btree.insert(3, (3, Hash::zero())); + //let encoded = btree.encode(); + //println!("encoded: {:?}", encoded); + expectedEncoded := []byte{12, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + } + require.Equal(t, expectedEncoded, encoded) + expected := NewMap[uint32, dummy](32) + err = scale.Unmarshal(expectedEncoded, &expected) + require.NoError(t, err) + require.Equal(t, btreeMap.Len(), expected.Len()) +} diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index bcd99c8d1c..1c3d965346 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -113,16 +113,8 @@ type decodeState struct { } func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { - unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem() - if dstv.CanAddr() && dstv.Addr().Type().Implements(unmarshalerType) { - methodVal := dstv.Addr().MethodByName("UnmarshalSCALE") - values := methodVal.Call([]reflect.Value{reflect.ValueOf(ds.Reader)}) - if !values[0].IsNil() { - errIn := values[0].Interface() - err := errIn.(error) - return err - } - return + if unmarshaler, ok := dstv.Addr().Interface().(Unmarshaler); ok { + return unmarshaler.UnmarshalSCALE(ds.Reader) } in := dstv.Interface() diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 713f3a7bce..c5c219b521 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -588,9 +588,10 @@ func Test_decodeState_Unmarshaller(t *testing.T) { Middle: uint32(2), Last: 3, } - bytes := MustMarshal(expected) + encoded := MustMarshal(expected) ms := myStruct{} - Unmarshal(bytes, &ms) + err := Unmarshal(encoded, &ms) + assert.NoError(t, err) assert.Equal(t, expected, ms) type myParentStruct struct { @@ -603,9 +604,10 @@ func Test_decodeState_Unmarshaller(t *testing.T) { Middle: expected, Last: 3, } - bytes = MustMarshal(expectedParent) + encoded = MustMarshal(expectedParent) mps := myParentStruct{} - Unmarshal(bytes, &mps) + err = Unmarshal(encoded, &mps) + assert.NoError(t, err) assert.Equal(t, expectedParent, mps) } @@ -615,8 +617,8 @@ func Test_decodeState_Unmarshaller_Error(t *testing.T) { Middle: uint32(2), Last: 3, } - bytes := MustMarshal(expected) + encoded := MustMarshal(expected) mse := myStructError{} - err := Unmarshal(bytes, &mse) - assert.Error(t, err, "eh?") + err := Unmarshal(encoded, &mse) + assert.Error(t, err) }