Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn committed Dec 15, 2023
1 parent 9ff55c5 commit 5615c8a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
44 changes: 22 additions & 22 deletions pkg/btree/btree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ package btree

import (
"fmt"
"github.com/ChainSafe/gossamer/pkg/scale"
"io"
"reflect"

"github.com/ChainSafe/gossamer/pkg/scale"

"golang.org/x/exp/constraints"

"github.com/tidwall/btree"
Expand All @@ -19,17 +20,17 @@ type Codec interface {
UnmarshalSCALE(reader io.Reader) error
}

// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items
// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the
// type of the items stored in the BTree in order to decode them.
type BTree struct {
// Tree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items
// stored in the BTree. This is needed during decoding because the Tree item is a generic type, and we need to know it
// at the time of decoding.
type Tree struct {
*btree.BTree
Comparator func(a, b interface{}) bool
ItemType reflect.Type
}

// MarshalSCALE encodes the BTree using SCALE.
func (bt BTree) MarshalSCALE() ([]byte, error) {
// MarshalSCALE encodes the Tree using SCALE.
func (bt Tree) MarshalSCALE() ([]byte, error) {
encodedLen, err := scale.Marshal(uint(bt.Len()))
if err != nil {
return nil, fmt.Errorf("failed to encode BTree length: %w", err)
Expand All @@ -50,8 +51,8 @@ func (bt BTree) MarshalSCALE() ([]byte, error) {
return append(encodedLen, encodedItems...), err
}

// UnmarshalSCALE decodes the BTree using SCALE.
func (bt BTree) UnmarshalSCALE(reader io.Reader) error {
// UnmarshalSCALE decodes the Tree using SCALE.
func (bt Tree) UnmarshalSCALE(reader io.Reader) error {
if bt.Comparator == nil {
return fmt.Errorf("comparator not found")
}
Expand All @@ -74,26 +75,26 @@ func (bt BTree) UnmarshalSCALE(reader io.Reader) error {
return nil
}

// Copy returns a copy of the BTree.
func (bt BTree) Copy() *BTree {
return &BTree{
// Copy returns a copy of the Tree.
func (bt Tree) Copy() *Tree {
return &Tree{
BTree: bt.BTree.Copy(),
Comparator: bt.Comparator,
ItemType: bt.ItemType,
}
}

// NewBTree creates a new BTree with the given comparator function.
func NewBTree[T any](comparator func(a, b any) bool) BTree {
// NewTree creates a new Tree with the given comparator function.
func NewTree[T any](comparator func(a, b any) bool) Tree {
elementType := reflect.TypeOf((*T)(nil)).Elem()
return BTree{
return Tree{
BTree: btree.New(comparator),
Comparator: comparator,
ItemType: elementType,
}
}

var _ Codec = (*BTree)(nil)
var _ Codec = (*Tree)(nil)

// Map is a wrapper around tidwall/btree.Map
type Map[K constraints.Ordered, V any] struct {
Expand All @@ -110,10 +111,9 @@ type mapItem[K constraints.Ordered, V any] struct {
func (btm Map[K, V]) MarshalSCALE() ([]byte, error) {
encodedLen, err := scale.Marshal(uint(btm.Len()))
if err != nil {
return nil, fmt.Errorf("failed to encode BTree length: %w", err)
return nil, fmt.Errorf("failed to encode Map length: %w", err)
}

// write each item in the tree
var (
pivot K
encodedItems []byte
Expand Down Expand Up @@ -155,11 +155,11 @@ func (btm Map[K, V]) UnmarshalSCALE(reader io.Reader) error {
slicePtr := reflect.New(sliceType)
encodedItems, err := io.ReadAll(reader)
if err != nil {
return fmt.Errorf("read BTree items: %w", err)
return fmt.Errorf("read Map items: %w", err)
}
err = scale.Unmarshal(encodedItems, slicePtr.Interface())
if err != nil {
return fmt.Errorf("decode BTree items: %w", err)
return fmt.Errorf("decode Map items: %w", err)
}

for i := 0; i < slicePtr.Elem().Len(); i++ {
Expand All @@ -176,8 +176,8 @@ func (btm Map[K, V]) Copy() Map[K, V] {
}
}

// NewBTreeMap creates a new Map with the given degree.
func NewBTreeMap[K constraints.Ordered, V any](degree int) Map[K, V] {
// NewMap creates a new Map with the given degree.
func NewMap[K constraints.Ordered, V any](degree int) Map[K, V] {
return Map[K, V]{
Map: btree.NewMap[K, V](degree),
Degree: degree,
Expand Down
15 changes: 8 additions & 7 deletions pkg/btree/btree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
package btree

import (
"github.com/ChainSafe/gossamer/pkg/scale"
"testing"

"github.com/ChainSafe/gossamer/pkg/scale"

"github.com/stretchr/testify/require"
)

Expand All @@ -22,8 +23,8 @@ func TestBTree_Codec(t *testing.T) {
return v1.Field1 < v2.Field1
}

// Create a BTree with 3 dummy items
tree := NewBTree[dummy](comparator)
// Create a Tree with 3 dummy items
tree := NewTree[dummy](comparator)
tree.Set(dummy{Field1: 1})
tree.Set(dummy{Field1: 2})
tree.Set(dummy{Field1: 3})
Expand All @@ -43,11 +44,11 @@ func TestBTree_Codec(t *testing.T) {
}
require.Equal(t, expectedEncoded, encoded)

expected := NewBTree[dummy](comparator)
expected := NewTree[dummy](comparator)
err = scale.Unmarshal(expectedEncoded, &expected)
require.NoError(t, err)

// Check that the expected BTree has the same items as the original
// Check that the expected Tree has the same items as the original
require.Equal(t, tree.Len(), expected.Len())
require.Equal(t, tree.ItemType, expected.ItemType)
require.Equal(t, tree.Min(), expected.Min())
Expand All @@ -58,7 +59,7 @@ func TestBTree_Codec(t *testing.T) {
}

func TestBTreeMap_Codec(t *testing.T) {
btreeMap := NewBTreeMap[uint32, dummy](32)
btreeMap := NewMap[uint32, dummy](32)
btreeMap.Set(uint32(1), dummy{Field1: 1})
btreeMap.Set(uint32(2), dummy{Field1: 2})
btreeMap.Set(uint32(3), dummy{Field1: 3})
Expand All @@ -77,7 +78,7 @@ func TestBTreeMap_Codec(t *testing.T) {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}
require.Equal(t, expectedEncoded, encoded)
expected := NewBTreeMap[uint32, dummy](32)
expected := NewMap[uint32, dummy](32)
err = scale.Unmarshal(expectedEncoded, &expected)
require.NoError(t, err)
require.Equal(t, btreeMap.Len(), expected.Len())
Expand Down

0 comments on commit 5615c8a

Please sign in to comment.