diff --git a/pkg/trie/inmemory/db_getter_mocks_test.go b/pkg/trie/inmemory/db_getter_mocks_test.go index 20b4c07cbf..33d27e49b6 100644 --- a/pkg/trie/inmemory/db_getter_mocks_test.go +++ b/pkg/trie/inmemory/db_getter_mocks_test.go @@ -3,10 +3,10 @@ // // Generated by this command: // -// mockgen -destination=db_getter_mocks_test.go -package=trie github.com/ChainSafe/gossamer/pkg/trie/db DBGetter +// mockgen -destination=db_getter_mocks_test.go -package=inmemory github.com/ChainSafe/gossamer/pkg/trie/db DBGetter // -// Package trie is a generated GoMock package. +// Package inmemory is a generated GoMock package. package inmemory import ( diff --git a/pkg/trie/inmemory/in_memory.go b/pkg/trie/inmemory/in_memory.go index 52d8830930..823c875081 100644 --- a/pkg/trie/inmemory/in_memory.go +++ b/pkg/trie/inmemory/in_memory.go @@ -100,7 +100,7 @@ func (t *InMemoryTrie) Snapshot() (newTrie *InMemoryTrie) { } } -// handleTrackedDeltas sets the pending deleted node hashes in +// HandleTrackedDeltas sets the pending deleted node hashes in // the trie deltas tracker if and only if success is true. func (t *InMemoryTrie) HandleTrackedDeltas(success bool, pendingDeltas tracking.Getter) { if !success || t.generation == 0 { diff --git a/pkg/trie/inmemory/mocks_generate_test.go b/pkg/trie/inmemory/mocks_generate_test.go new file mode 100644 index 0000000000..177c011537 --- /dev/null +++ b/pkg/trie/inmemory/mocks_generate_test.go @@ -0,0 +1,6 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package inmemory + +//go:generate mockgen -destination=db_getter_mocks_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/pkg/trie/db DBGetter diff --git a/pkg/trie/trie.go b/pkg/trie/trie.go index b4bb005b7e..373b75df5e 100644 --- a/pkg/trie/trie.go +++ b/pkg/trie/trie.go @@ -13,17 +13,23 @@ import ( // EmptyHash is the empty trie hash. var EmptyHash = common.MustBlake2bHash([]byte{0}) -type ChildTrieSupport interface { +type ChildTriesRead interface { GetChild(keyToChild []byte) (Trie, error) GetFromChild(keyToChild, key []byte) ([]byte, error) GetChildTries() map[common.Hash]Trie +} + +type ChildTriesWrite interface { PutIntoChild(keyToChild, key, value []byte) error DeleteChild(keyToChild []byte) (err error) ClearFromChild(keyToChild, key []byte) error } -type KVStore interface { +type KVStoreRead interface { Get(key []byte) []byte +} + +type KVStoreWrite interface { Put(key, value []byte) error Delete(key []byte) error } @@ -33,8 +39,11 @@ type TrieIterator interface { NextKey(key []byte) []byte } -type PrefixTrie interface { +type PrefixTrieRead interface { GetKeysWithPrefix(prefix []byte) (keysLE [][]byte) +} + +type PrefixTrieWrite interface { ClearPrefix(prefix []byte) (err error) ClearPrefixLimit(prefix []byte, limit uint32) ( deleted uint32, allDeleted bool, err error) @@ -54,13 +63,21 @@ type Hashable interface { Hash() (common.Hash, error) } -type Trie interface { - PrefixTrie - KVStore +type TrieRead interface { + fmt.Stringer + + KVStoreRead Hashable - ChildTrieSupport + ChildTriesRead + PrefixTrieRead TrieIterator - TrieDeltas +} + +type Trie interface { + TrieRead + ChildTriesWrite + PrefixTrieWrite + KVStoreWrite Versioned - fmt.Stringer + TrieDeltas } diff --git a/pkg/trie/triedb/child_tries.go b/pkg/trie/triedb/child_tries.go new file mode 100644 index 0000000000..aafe7bc72d --- /dev/null +++ b/pkg/trie/triedb/child_tries.go @@ -0,0 +1,21 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/trie" +) + +func (t *TrieDB) GetChild(keyToChild []byte) (trie.Trie, error) { + panic("not implemented yet") +} + +func (t *TrieDB) GetFromChild(keyToChild, key []byte) ([]byte, error) { + panic("not implemented yet") +} + +func (t *TrieDB) GetChildTries() map[common.Hash]trie.Trie { + panic("not implemented yet") +} diff --git a/pkg/trie/triedb/codec/decode.go b/pkg/trie/triedb/codec/decode.go new file mode 100644 index 0000000000..04e495a627 --- /dev/null +++ b/pkg/trie/triedb/codec/decode.go @@ -0,0 +1,166 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +var ( + ErrDecodeHashedStorageValue = errors.New("cannot decode hashed storage value") + ErrDecodeHashedValueTooShort = errors.New("hashed storage value too short") + ErrReadChildrenBitmap = errors.New("cannot read children bitmap") + // ErrDecodeChildHash is defined since no sentinel error is defined + // in the scale package. + ErrDecodeChildHash = errors.New("cannot decode child hash") + // ErrDecodeStorageValue is defined since no sentinel error is defined + // in the scale package. + ErrDecodeStorageValue = errors.New("cannot decode storage value") +) + +const hashLength = common.HashLength + +// Decode decodes a node from a reader. +// The encoding format is documented in the README.md +// of this package, and specified in the Polkadot spec at +// https://spec.polkadot.network/chap-state#defn-node-header +func Decode(reader io.Reader) (n Node, err error) { + variant, partialKeyLength, err := decodeHeader(reader) + if err != nil { + return nil, fmt.Errorf("decoding header: %w", err) + } + + if variant == emptyVariant { + return Empty{}, nil + } + + partialKey, err := decodeKey(reader, partialKeyLength) + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + switch variant { + case leafVariant, leafWithHashedValueVariant: + n, err = decodeLeaf(reader, variant, partialKey) + if err != nil { + return nil, fmt.Errorf("cannot decode leaf: %w", err) + } + return n, nil + case branchVariant, branchWithValueVariant, branchWithHashedValueVariant: + n, err = decodeBranch(reader, variant, partialKey) + if err != nil { + return nil, fmt.Errorf("cannot decode branch: %w", err) + } + return n, nil + default: + // this is a programming error, an unknown node variant should be caught by decodeHeader. + panic(fmt.Sprintf("not implemented for node variant %08b", variant)) + } +} + +// decodeBranch reads from a reader and decodes to a node branch. +// Note that we are not decoding the children nodes. +func decodeBranch(reader io.Reader, variant variant, partialKey []byte) ( + node Branch, err error) { + node = Branch{ + PartialKey: partialKey, + } + + var childrenBitmap uint16 + err = binary.Read(reader, binary.LittleEndian, &childrenBitmap) + if err != nil { + return Branch{}, fmt.Errorf("%w: %s", ErrReadChildrenBitmap, err) + } + + sd := scale.NewDecoder(reader) + + switch variant { + case branchWithValueVariant: + valueBytes := make([]byte, 0) + err := sd.Decode(&valueBytes) + if err != nil { + return Branch{}, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + } + + node.Value = NewInlineValue(valueBytes) + case branchWithHashedValueVariant: + hashedValue, err := decodeHashedValue(reader) + if err != nil { + return Branch{}, err + } + node.Value = NewHashedValue(hashedValue) + default: + // Do nothing, branch without value + } + + for i := 0; i < ChildrenCapacity; i++ { + // Skip this index if we don't have a child here + if (childrenBitmap>>i)&1 != 1 { + continue + } + + var hash []byte + err := sd.Decode(&hash) + if err != nil { + return Branch{}, fmt.Errorf("%w: at index %d: %s", + ErrDecodeChildHash, i, err) + } + + if len(hash) < hashLength { + node.Children[i] = NewInlineNode(hash) + } else { + node.Children[i] = NewHashedNode(hash) + } + } + + return node, nil +} + +// decodeLeaf reads from a reader and decodes to a leaf node. +func decodeLeaf(reader io.Reader, variant variant, partialKey []byte) (node Leaf, err error) { + node = Leaf{ + PartialKey: partialKey, + } + + sd := scale.NewDecoder(reader) + + if variant == leafWithHashedValueVariant { + hashedValue, err := decodeHashedValue(reader) + if err != nil { + return Leaf{}, err + } + + node.Value = NewHashedValue(hashedValue) + return node, nil + } + + valueBytes := make([]byte, 0) + err = sd.Decode(&valueBytes) + if err != nil { + return Leaf{}, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + } + + node.Value = NewInlineValue(valueBytes) + + return node, nil +} + +func decodeHashedValue(reader io.Reader) ([]byte, error) { + buffer := make([]byte, hashLength) + n, err := reader.Read(buffer) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + } + if n < hashLength { + return nil, fmt.Errorf("%w: expected %d, got: %d", ErrDecodeHashedValueTooShort, hashLength, n) + } + + return buffer, nil +} diff --git a/pkg/trie/triedb/codec/decode_test.go b/pkg/trie/triedb/codec/decode_test.go new file mode 100644 index 0000000000..12c868bafb --- /dev/null +++ b/pkg/trie/triedb/codec/decode_test.go @@ -0,0 +1,396 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "bytes" + "io" + "testing" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + return scaleEncodeByteSlice(t, b) +} + +func scaleEncodeByteSlice(t *testing.T, b []byte) (encoded []byte) { + encoded, err := scale.Marshal(b) + require.NoError(t, err) + return encoded +} + +func Test_Decode(t *testing.T) { + t.Parallel() + + hashedValue, err := common.Blake2bHash([]byte("test")) + assert.NoError(t, err) + + testCases := map[string]struct { + reader io.Reader + n Node + errWrapped error + errMessage string + }{ + "no_data": { + reader: bytes.NewReader(nil), + errWrapped: io.EOF, + errMessage: "decoding header: reading header byte: EOF", + }, + "unknown_node_variant": { + reader: bytes.NewReader([]byte{0b0000_1000}), + errWrapped: ErrVariantUnknown, + errMessage: "decoding header: decoding header byte: node variant is unknown: for header byte 00001000", + }, + "empty_node": { + reader: bytes.NewReader([]byte{emptyVariant.bits}), + n: Empty{}, + }, + "leaf_decoding_error": { + reader: bytes.NewReader([]byte{ + leafVariant.bits | 1, // key length 1 + // missing key data byte + }), + errWrapped: io.EOF, + errMessage: "cannot decode key: " + + "reading from reader: EOF", + }, + "leaf_success": { + reader: bytes.NewReader(bytes.Join([][]byte{ + {leafVariant.bits | 1}, // partial key length 1 + {9}, // key data + scaleEncodeBytes(t, 1, 2, 3), + }, nil)), + n: Leaf{ + PartialKey: []byte{9}, + Value: NewInlineValue([]byte{1, 2, 3}), + }, + }, + "branch_decoding_error": { + reader: bytes.NewReader([]byte{ + branchVariant.bits | 1, // key length 1 + // missing key data byte + }), + errWrapped: io.EOF, + errMessage: "cannot decode key: " + + "reading from reader: EOF", + }, + "branch_success": { + reader: bytes.NewReader(bytes.Join([][]byte{ + {branchVariant.bits | 1}, // partial key length 1 + {9}, // key data + {0b0000_0000, 0b0000_0000}, // no children bitmap + }, nil)), + n: Branch{ + PartialKey: []byte{9}, + }, + }, + "leaf_with_hashed_value_success": { + reader: bytes.NewReader(bytes.Join([][]byte{ + {leafWithHashedValueVariant.bits | 1}, // partial key length 1 + {9}, // key data + hashedValue.ToBytes(), + }, nil)), + n: Leaf{ + PartialKey: []byte{9}, + Value: NewHashedValue(hashedValue.ToBytes()), + }, + }, + "leaf_with_hashed_value_fail_too_short": { + reader: bytes.NewReader(bytes.Join([][]byte{ + {leafWithHashedValueVariant.bits | 1}, // partial key length 1 + {9}, // key data + {0b0000_0000}, // less than 32bytes + }, nil)), + errWrapped: ErrDecodeHashedValueTooShort, + errMessage: "cannot decode leaf: hashed storage value too short: expected 32, got: 1", + }, + "branch_with_hashed_value_success": { + reader: bytes.NewReader(bytes.Join([][]byte{ + {branchWithHashedValueVariant.bits | 1}, // partial key length 1 + {9}, // key data + {0b0000_0000, 0b0000_0000}, // no children bitmap + hashedValue.ToBytes(), + }, nil)), + n: Branch{ + PartialKey: []byte{9}, + Value: NewHashedValue(hashedValue.ToBytes()), + }, + }, + "branch_with_hashed_value_fail_too_short": { + reader: bytes.NewReader(bytes.Join([][]byte{ + {branchWithHashedValueVariant.bits | 1}, // partial key length 1 + {9}, // key data + {0b0000_0000, 0b0000_0000}, // no children bitmap + {0b0000_0000}, + }, nil)), + errWrapped: ErrDecodeHashedValueTooShort, + errMessage: "cannot decode branch: hashed storage value too short: expected 32, got: 1", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + n, err := Decode(testCase.reader) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.n, n) + }) + } +} + +func Test_decodeBranch(t *testing.T) { + t.Parallel() + + const childHashLength = 32 + childHash := make([]byte, childHashLength) + for i := range childHash { + childHash[i] = byte(i) + } + scaleEncodedChildHash := scaleEncodeByteSlice(t, childHash) + + testCases := map[string]struct { + reader io.Reader + nodeVariant variant + partialKey []byte + branch Branch + errWrapped error + errMessage string + }{ + "children_bitmap_read_error": { + reader: bytes.NewBuffer([]byte{ + // missing children bitmap 2 bytes + }), + nodeVariant: branchVariant, + errWrapped: ErrReadChildrenBitmap, + errMessage: "cannot read children bitmap: EOF", + }, + "children_decoding_error": { + reader: bytes.NewBuffer([]byte{ + 0b0000_0000, 0b0000_0100, // children bitmap + // missing children scale encoded data + }), + nodeVariant: branchVariant, + partialKey: []byte{1}, + errWrapped: ErrDecodeChildHash, + errMessage: "cannot decode child hash: at index 10: decoding uint: reading byte: EOF", + }, + "success_for_branch_variant": { + reader: bytes.NewBuffer( + bytes.Join([][]byte{ + {0b0000_0000, 0b0000_0100}, // children bitmap + scaleEncodedChildHash, + }, nil), + ), + nodeVariant: branchVariant, + partialKey: []byte{1}, + branch: Branch{ + PartialKey: []byte{1}, + Children: [ChildrenCapacity]MerkleValue{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + HashedNode{ + Data: childHash, + }, + }, + }, + }, + "value_decoding_error_for_branch_with_value_variant": { + reader: bytes.NewBuffer( + bytes.Join([][]byte{ + {0b0000_0000, 0b0000_0100}, // children bitmap + // missing encoded branch storage value + }, nil), + ), + nodeVariant: branchWithValueVariant, + partialKey: []byte{1}, + errWrapped: ErrDecodeStorageValue, + errMessage: "cannot decode storage value: decoding uint: reading byte: EOF", + }, + "success_for_branch_with_value": { + reader: bytes.NewBuffer(bytes.Join([][]byte{ + {0b0000_0000, 0b0000_0100}, // children bitmap + scaleEncodeBytes(t, 7, 8, 9), // branch storage value + scaleEncodedChildHash, + }, nil)), + nodeVariant: branchWithValueVariant, + partialKey: []byte{1}, + branch: Branch{ + PartialKey: []byte{1}, + Value: NewInlineValue([]byte{7, 8, 9}), + Children: [ChildrenCapacity]MerkleValue{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + HashedNode{ + Data: childHash, + }, + }, + }, + }, + "branch_with_inlined_node_decoding_error": { + reader: bytes.NewBuffer(bytes.Join([][]byte{ + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncodeBytes(t, 1), // branch storage value + {0}, // garbage inlined node + }, nil)), + nodeVariant: branchWithValueVariant, + partialKey: []byte{1}, + branch: Branch{ + PartialKey: []byte{1}, + Value: NewInlineValue([]byte{1}), + Children: [ChildrenCapacity]MerkleValue{ + InlineNode{ + Data: []byte{}, + }, + }, + }, + }, + "branch_with_inlined_branch_and_leaf": { + reader: bytes.NewBuffer(bytes.Join([][]byte{ + {0b0000_0011, 0b0000_0000}, // children bitmap + // top level inlined leaf less than 32 bytes + scaleEncodeByteSlice(t, bytes.Join([][]byte{ + {leafVariant.bits | 1}, // partial key length of 1 + {2}, // key data + scaleEncodeBytes(t, 2), // storage value data + }, nil)), + // top level inlined branch less than 32 bytes + scaleEncodeByteSlice(t, bytes.Join([][]byte{ + {branchWithValueVariant.bits | 1}, // partial key length of 1 + {3}, // key data + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncodeBytes(t, 3), // branch storage value + // bottom level leaf + scaleEncodeByteSlice(t, bytes.Join([][]byte{ + {leafVariant.bits | 1}, // partial key length of 1 + {4}, // key data + scaleEncodeBytes(t, 4), // storage value data + }, nil)), + }, nil)), + }, nil)), + nodeVariant: branchVariant, + partialKey: []byte{1}, + branch: Branch{ + PartialKey: []byte{1}, + Children: [ChildrenCapacity]MerkleValue{ + InlineNode{ + Data: bytes.Join([][]byte{ + {leafVariant.bits | 1}, // partial key length of 1 + {2}, // key data + scaleEncodeBytes(t, 2), // storage value data + }, nil), + }, + InlineNode{ + Data: bytes.Join([][]byte{ + {branchWithValueVariant.bits | 1}, // partial key length of 1 + {3}, // key data + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncodeBytes(t, 3), // branch storage value + // bottom level leaf + scaleEncodeByteSlice(t, bytes.Join([][]byte{ + {leafVariant.bits | 1}, // partial key length of 1 + {4}, // key data + scaleEncodeBytes(t, 4), // storage value data + }, nil)), + }, nil), + }, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + branch, err := decodeBranch(testCase.reader, + testCase.nodeVariant, testCase.partialKey) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.branch, branch) + }) + } +} + +func Test_decodeLeaf(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + variant variant + partialKey []byte + leaf Leaf + errWrapped error + errMessage string + }{ + "value_decoding_error": { + reader: bytes.NewBuffer(bytes.Join([][]byte{ + {255, 255}, // bad storage value data + }, nil)), + variant: leafVariant, + partialKey: []byte{9}, + errWrapped: ErrDecodeStorageValue, + errMessage: "cannot decode storage value: decoding uint: unknown prefix for compact uint: 255", + }, + "missing_storage_value_data": { + reader: bytes.NewBuffer([]byte{ + // missing storage value data + }), + variant: leafVariant, + partialKey: []byte{9}, + errWrapped: ErrDecodeStorageValue, + errMessage: "cannot decode storage value: decoding uint: reading byte: EOF", + }, + "empty_storage_value_data": { + reader: bytes.NewBuffer(bytes.Join([][]byte{ + scaleEncodeByteSlice(t, []byte{}), // results to []byte{0} + }, nil)), + variant: leafVariant, + partialKey: []byte{9}, + leaf: Leaf{ + PartialKey: []byte{9}, + Value: NewInlineValue([]byte{}), + }, + }, + "success": { + reader: bytes.NewBuffer(bytes.Join([][]byte{ + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // storage value data + }, nil)), + variant: leafVariant, + partialKey: []byte{9}, + leaf: Leaf{ + PartialKey: []byte{9}, + Value: NewInlineValue([]byte{1, 2, 3, 4, 5}), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + leaf, err := decodeLeaf(testCase.reader, testCase.variant, testCase.partialKey) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.leaf, leaf) + }) + } +} diff --git a/pkg/trie/triedb/codec/header.go b/pkg/trie/triedb/codec/header.go new file mode 100644 index 0000000000..ae73010740 --- /dev/null +++ b/pkg/trie/triedb/codec/header.go @@ -0,0 +1,110 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "errors" + "fmt" + "io" +) + +var ( + ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than 2^16") + ErrVariantUnknown = errors.New("node variant is unknown") +) + +func decodeHeader(reader io.Reader) (nodeVariant variant, + partialKeyLength uint16, err error) { + buffer := make([]byte, 1) + _, err = reader.Read(buffer) + if err != nil { + return nodeVariant, 0, fmt.Errorf("reading header byte: %w", err) + } + + nodeVariant, partialKeyLengthHeader, err := decodeHeaderByte(buffer[0]) + if err != nil { + return variant{}, 0, fmt.Errorf("decoding header byte: %w", err) + } + + partialKeyLengthHeaderMask := nodeVariant.partialKeyLengthHeaderMask() + if partialKeyLengthHeaderMask == emptyVariant.bits { + // empty node or compact encoding which have no + // partial key. The partial key length mask is + // 0b0000_0000 since the variant mask is + // 0b1111_1111. + return nodeVariant, 0, nil + } + + partialKeyLength = uint16(partialKeyLengthHeader) + if partialKeyLengthHeader < partialKeyLengthHeaderMask { + // partial key length is contained in the first byte. + return nodeVariant, partialKeyLength, nil + } + + // the partial key length header byte is equal to its maximum + // possible value; this means the partial key length is greater + // than this (0 to 2^6 - 1 = 63) maximum value, and we need to + // accumulate the next bytes from the reader to get the full + // partial key length. + // Specification: https://spec.polkadot.network/#defn-node-header + var previousKeyLength uint16 // used to track an eventual overflow + for { + _, err = reader.Read(buffer) + if err != nil { + return variant{}, 0, fmt.Errorf("reading key length: %w", err) + } + + previousKeyLength = partialKeyLength + partialKeyLength += uint16(buffer[0]) + + if partialKeyLength < previousKeyLength { + // the partial key can have a length up to 65535 which is the + // maximum uint16 value; therefore if we overflowed, we went over + // this maximum. + overflowed := maxPartialKeyLength - previousKeyLength + partialKeyLength + return variant{}, 0, fmt.Errorf("%w: overflowed by %d", ErrPartialKeyTooBig, overflowed) + } + + if buffer[0] < 255 { + // the end of the partial key length has been reached. + return nodeVariant, partialKeyLength, nil + } + } +} + +// variantsOrderedByBitMask is an array of all variants sorted +// in ascending order by the number of LHS set bits each variant mask has. +// See https://spec.polkadot.network/#defn-node-header +// WARNING: DO NOT MUTATE. +// This array is defined at global scope for performance +// reasons only, instead of having it locally defined in +// the decodeHeaderByte function below. +// For 7 variants, the performance is improved by ~20%. +var variantsOrderedByBitMask = [...]variant{ + leafVariant, // mask 1100_0000 + branchVariant, // mask 1100_0000 + branchWithValueVariant, // mask 1100_0000 + leafWithHashedValueVariant, // mask 1110_0000 + branchWithHashedValueVariant, // mask 1111_0000 + emptyVariant, // mask 1111_1111 + compactEncodingVariant, // mask 1111_1111 +} + +func decodeHeaderByte(header byte) (nodeVariant variant, + partialKeyLengthHeader byte, err error) { + var partialKeyLengthHeaderMask byte + for i := len(variantsOrderedByBitMask) - 1; i >= 0; i-- { + nodeVariant = variantsOrderedByBitMask[i] + variantBits := header & nodeVariant.mask + if variantBits != nodeVariant.bits { + continue + } + + partialKeyLengthHeaderMask = nodeVariant.partialKeyLengthHeaderMask() + partialKeyLengthHeader = header & partialKeyLengthHeaderMask + return nodeVariant, partialKeyLengthHeader, nil + } + + return invalidVariant, 0, fmt.Errorf("%w: for header byte %08b", ErrVariantUnknown, header) +} diff --git a/pkg/trie/triedb/codec/header_test.go b/pkg/trie/triedb/codec/header_test.go new file mode 100644 index 0000000000..f749fcaef9 --- /dev/null +++ b/pkg/trie/triedb/codec/header_test.go @@ -0,0 +1,206 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "errors" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +var errTest = errors.New("test error") + +func Test_decodeHeader(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reads []readCall + nodeVariant variant + partialKeyLength uint16 + errWrapped error + errMessage string + }{ + "first_byte_read_error": { + reads: []readCall{ + {buffArgCap: 1, err: errTest}, + }, + errWrapped: errTest, + errMessage: "reading header byte: test error", + }, + "header_byte_decoding_error": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{0b0000_1000}}, + }, + errWrapped: ErrVariantUnknown, + errMessage: "decoding header byte: node variant is unknown: for header byte 00001000", + }, + "partial_key_length_contained_in_first_byte": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1110}}, + }, + nodeVariant: leafVariant, + partialKeyLength: uint16(0b0011_1110), + }, + "long_partial_key_length_and_second_byte_read_error": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1111}}, + {buffArgCap: 1, err: errTest}, + }, + errWrapped: errTest, + errMessage: "reading key length: test error", + }, + "partial_key_length_spread_on_multiple_bytes": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1111}}, + {buffArgCap: 1, read: []byte{0b1111_1111}}, + {buffArgCap: 1, read: []byte{0b1111_0000}}, + }, + nodeVariant: leafVariant, + partialKeyLength: uint16(0b0011_1111 + 0b1111_1111 + 0b1111_0000), + }, + "partial_key_length_too_long": { + reads: repeatReadCall(readCall{ + buffArgCap: 1, + read: []byte{0b1111_1111}, + }, 258), + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than 2^16: overflowed by 254", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + reader := NewMockReader(ctrl) + var previousCall *gomock.Call + for _, readCall := range testCase.reads { + readCall := readCall // required variable pinning + byteSliceCapMatcher := newByteSliceCapMatcher(readCall.buffArgCap) + call := reader.EXPECT().Read(byteSliceCapMatcher). + DoAndReturn(func(b []byte) (n int, err error) { + copy(b, readCall.read) + return readCall.n, readCall.err + }) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + nodeVariant, partialKeyLength, err := decodeHeader(reader) + + assert.Equal(t, testCase.nodeVariant, nodeVariant) + assert.Equal(t, int(testCase.partialKeyLength), int(partialKeyLength)) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_decodeHeaderByte(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + header byte + nodeVariant variant + partialKeyLengthHeader byte + errWrapped error + errMessage string + }{ + "empty_variant_header": { + header: 0b0000_0000, + nodeVariant: emptyVariant, + partialKeyLengthHeader: 0b0000_0000, + }, + "branch_with_value_header": { + header: 0b1110_1001, + nodeVariant: branchWithValueVariant, + partialKeyLengthHeader: 0b0010_1001, + }, + "branch_header": { + header: 0b1010_1001, + nodeVariant: branchVariant, + partialKeyLengthHeader: 0b0010_1001, + }, + "leaf_header": { + header: 0b0110_1001, + nodeVariant: leafVariant, + partialKeyLengthHeader: 0b0010_1001, + }, + "leaf_containing_hashes_header": { + header: 0b0011_1001, + nodeVariant: leafWithHashedValueVariant, + partialKeyLengthHeader: 0b0001_1001, + }, + "branch_containing_hashes_header": { + header: 0b0001_1001, + nodeVariant: branchWithHashedValueVariant, + partialKeyLengthHeader: 0b0000_1001, + }, + "compact_encoding_header": { + header: 0b0000_0001, + nodeVariant: compactEncodingVariant, + partialKeyLengthHeader: 0b0000_0000, + }, + "unknown_variant_header": { + header: 0b0000_1000, + errWrapped: ErrVariantUnknown, + errMessage: "node variant is unknown: for header byte 00001000", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodeVariant, partialKeyLengthHeader, + err := decodeHeaderByte(testCase.header) + + assert.Equal(t, testCase.nodeVariant, nodeVariant) + assert.Equal(t, testCase.partialKeyLengthHeader, partialKeyLengthHeader) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_variantsOrderedByBitMask(t *testing.T) { + t.Parallel() + + slice := make([]variant, len(variantsOrderedByBitMask)) + sortedSlice := make([]variant, len(variantsOrderedByBitMask)) + copy(slice, variantsOrderedByBitMask[:]) + copy(sortedSlice, variantsOrderedByBitMask[:]) + + sort.Slice(slice, func(i, j int) bool { + return slice[i].mask < slice[j].mask + }) + + assert.Equal(t, sortedSlice, slice) +} + +func Benchmark_decodeHeaderByte(b *testing.B) { + // For 7 variants defined in the variants array: + // With global scoped variants slice: + // 2.987 ns/op 0 B/op 0 allocs/op + // With locally scoped variants slice: + // 3.873 ns/op 0 B/op 0 allocs/op + header := leafVariant.bits | 0b0000_0001 + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = decodeHeaderByte(header) + } +} diff --git a/pkg/trie/triedb/codec/key.go b/pkg/trie/triedb/codec/key.go new file mode 100644 index 0000000000..5b0eccb7af --- /dev/null +++ b/pkg/trie/triedb/codec/key.go @@ -0,0 +1,37 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/pkg/trie/codec" +) + +const maxPartialKeyLength = ^uint16(0) + +var ErrReaderMismatchCount = errors.New("read unexpected number of bytes from reader") + +// decodeKey decodes a key from a reader. +func decodeKey(reader io.Reader, partialKeyLength uint16) (b []byte, err error) { + if partialKeyLength == 0 { + return []byte{}, nil + } + + key := make([]byte, partialKeyLength/2+partialKeyLength%2) + n, err := reader.Read(key) + if err != nil { + return nil, fmt.Errorf("reading from reader: %w", err) + } else if n != len(key) { + return nil, fmt.Errorf("%w: read %d bytes instead of expected %d bytes", + ErrReaderMismatchCount, n, len(key)) + } + + // if the partialKeyLength is an odd number means that when parsing the key + // to nibbles it will contains a useless 0 in the first index, otherwise + // we can use the entire nibbles + return codec.KeyLEToNibbles(key)[partialKeyLength%2:], nil +} diff --git a/pkg/trie/triedb/codec/key_test.go b/pkg/trie/triedb/codec/key_test.go new file mode 100644 index 0000000000..e24502f3c3 --- /dev/null +++ b/pkg/trie/triedb/codec/key_test.go @@ -0,0 +1,139 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +type readCall struct { + buffArgCap int + read []byte + n int // number of bytes read + err error +} + +func repeatReadCall(base readCall, n int) (calls []readCall) { + calls = make([]readCall, n) + for i := range calls { + calls[i] = base + } + return calls +} + +var _ gomock.Matcher = (*byteSliceCapMatcher)(nil) + +type byteSliceCapMatcher struct { + capacity int +} + +func (b *byteSliceCapMatcher) Matches(x interface{}) bool { + slice, ok := x.([]byte) + if !ok { + return false + } + return cap(slice) == b.capacity +} + +func (b *byteSliceCapMatcher) String() string { + return fmt.Sprintf("slice with capacity %d", b.capacity) +} + +func newByteSliceCapMatcher(capacity int) *byteSliceCapMatcher { + return &byteSliceCapMatcher{ + capacity: capacity, + } +} + +func Test_decodeKey(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reads []readCall + partialKeyLength uint16 + b []byte + errWrapped error + errMessage string + }{ + "zero_key_length": { + partialKeyLength: 0, + b: []byte{}, + }, + "short_key_length": { + reads: []readCall{ + {buffArgCap: 3, read: []byte{1, 2, 3}, n: 3}, + }, + partialKeyLength: 5, + b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, + }, + "key_read_error": { + reads: []readCall{ + {buffArgCap: 3, err: errTest}, + }, + partialKeyLength: 5, + errWrapped: errTest, + errMessage: "reading from reader: test error", + }, + + "key_read_bytes_count_mismatch": { + reads: []readCall{ + {buffArgCap: 3, n: 2}, + }, + partialKeyLength: 5, + errWrapped: ErrReaderMismatchCount, + errMessage: "read unexpected number of bytes from reader: read 2 bytes instead of expected 3 bytes", + }, + "long_key_length": { + reads: []readCall{ + {buffArgCap: 35, read: bytes.Repeat([]byte{7}, 35), n: 35}, // key data + }, + partialKeyLength: 70, + b: []byte{ + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + reader := NewMockReader(ctrl) + var previousCall *gomock.Call + for _, readCall := range testCase.reads { + readCall := readCall // required variable pinning + byteSliceCapMatcher := newByteSliceCapMatcher(readCall.buffArgCap) + call := reader.EXPECT().Read(byteSliceCapMatcher). + DoAndReturn(func(b []byte) (n int, err error) { + copy(b, readCall.read) + return readCall.n, readCall.err + }) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + b, err := decodeKey(reader, testCase.partialKeyLength) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.b, b) + }) + } +} diff --git a/pkg/trie/triedb/codec/mocks_generate_test.go b/pkg/trie/triedb/codec/mocks_generate_test.go new file mode 100644 index 0000000000..5deba89d3d --- /dev/null +++ b/pkg/trie/triedb/codec/mocks_generate_test.go @@ -0,0 +1,6 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +//go:generate mockgen -destination=reader_mock_test.go -package $GOPACKAGE io Reader diff --git a/pkg/trie/triedb/codec/node.go b/pkg/trie/triedb/codec/node.go new file mode 100644 index 0000000000..6f0154bac8 --- /dev/null +++ b/pkg/trie/triedb/codec/node.go @@ -0,0 +1,89 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +const ChildrenCapacity = 16 + +// MerkleValue is a helper enum to differentiate between inline and hashed nodes +// https://spec.polkadot.network/chap-state#defn-merkle-value +type MerkleValue interface { + isMerkleValue() + IsHashed() bool +} + +type ( + // Value bytes as stored in a trie node + InlineNode struct { + Data []byte + } + // Containing a hash used to lookup in db for real value + HashedNode struct { + Data []byte + } +) + +func (InlineNode) isMerkleValue() {} +func (InlineNode) IsHashed() bool { return false } +func (HashedNode) isMerkleValue() {} +func (HashedNode) IsHashed() bool { return true } + +func NewInlineNode(data []byte) MerkleValue { + return InlineNode{Data: data} +} + +func NewHashedNode(data []byte) MerkleValue { + return HashedNode{Data: data} +} + +// NodeValue is a helper enum to differentiate between inline and hashed values +type NodeValue interface { + isNodeValue() +} + +type ( + // Value bytes as stored in a trie node + InlineValue struct { + Data []byte + } + // Containing a hash used to lookup in db for real value + HashedValue struct { + Data []byte + } +) + +func (InlineValue) isNodeValue() {} +func (HashedValue) isNodeValue() {} + +func NewInlineValue(data []byte) NodeValue { + return InlineValue{Data: data} +} + +func NewHashedValue(data []byte) NodeValue { + return HashedValue{Data: data} +} + +// Node is the representation of a decoded node +type Node interface { + isNode() +} + +type ( + // Empty node + Empty struct{} + // Leaf always contains values + Leaf struct { + PartialKey []byte + Value NodeValue + } + // Branch could has or not has values + Branch struct { + PartialKey []byte + Children [16]MerkleValue + Value NodeValue + } +) + +func (Empty) isNode() {} +func (Leaf) isNode() {} +func (Branch) isNode() {} diff --git a/pkg/trie/triedb/codec/reader_mock_test.go b/pkg/trie/triedb/codec/reader_mock_test.go new file mode 100644 index 0000000000..5fe0007aae --- /dev/null +++ b/pkg/trie/triedb/codec/reader_mock_test.go @@ -0,0 +1,54 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: io (interfaces: Reader) +// +// Generated by this command: +// +// mockgen -destination=reader_mock_test.go -package codec io Reader +// + +// Package codec is a generated GoMock package. +package codec + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockReader is a mock of Reader interface. +type MockReader struct { + ctrl *gomock.Controller + recorder *MockReaderMockRecorder +} + +// MockReaderMockRecorder is the mock recorder for MockReader. +type MockReaderMockRecorder struct { + mock *MockReader +} + +// NewMockReader creates a new mock instance. +func NewMockReader(ctrl *gomock.Controller) *MockReader { + mock := &MockReader{ctrl: ctrl} + mock.recorder = &MockReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReader) EXPECT() *MockReaderMockRecorder { + return m.recorder +} + +// Read mocks base method. +func (m *MockReader) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockReaderMockRecorder) Read(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReader)(nil).Read), arg0) +} diff --git a/pkg/trie/triedb/codec/variants.go b/pkg/trie/triedb/codec/variants.go new file mode 100644 index 0000000000..e612a93814 --- /dev/null +++ b/pkg/trie/triedb/codec/variants.go @@ -0,0 +1,78 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +type variant struct { + bits byte + mask byte +} + +// Node variants +// See https://spec.polkadot.network/#defn-node-header +var ( + leafVariant = variant{ // leaf 01 + bits: 0b0100_0000, + mask: 0b1100_0000, + } + branchVariant = variant{ // branch 10 + bits: 0b1000_0000, + mask: 0b1100_0000, + } + branchWithValueVariant = variant{ // branch 11 + bits: 0b1100_0000, + mask: 0b1100_0000, + } + leafWithHashedValueVariant = variant{ // leaf containing hashes 001 + bits: 0b0010_0000, + mask: 0b1110_0000, + } + branchWithHashedValueVariant = variant{ // branch containing hashes 0001 + bits: 0b0001_0000, + mask: 0b1111_0000, + } + emptyVariant = variant{ // empty 0000 0000 + bits: 0b0000_0000, + mask: 0b1111_1111, + } + compactEncodingVariant = variant{ // compact encoding 0001 0000 + bits: 0b0000_0001, + mask: 0b1111_1111, + } + invalidVariant = variant{ + bits: 0b0000_0000, + mask: 0b0000_0000, + } +) + +// partialKeyLengthHeaderMask returns the partial key length +// header bit mask corresponding to the variant header bit mask. +// For example for the leaf variant with variant mask 1100_0000, +// the partial key length header mask returned is 0011_1111. +func (v variant) partialKeyLengthHeaderMask() byte { + return ^v.mask +} + +func (v variant) String() string { + switch v { + case leafVariant: + return "Leaf" + case leafWithHashedValueVariant: + return "LeafWithHashedValue" + case branchVariant: + return "Branch" + case branchWithValueVariant: + return "BranchWithValue" + case branchWithHashedValueVariant: + return "BranchWithHashedValue" + case emptyVariant: + return "Empty" + case compactEncodingVariant: + return "Compact" + case invalidVariant: + return "Invalid" + default: + return "Not reachable" + } + +} diff --git a/pkg/trie/triedb/in_memory_to_triedb_migration_test.go b/pkg/trie/triedb/in_memory_to_triedb_migration_test.go new file mode 100644 index 0000000000..c7f57640f9 --- /dev/null +++ b/pkg/trie/triedb/in_memory_to_triedb_migration_test.go @@ -0,0 +1,66 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "testing" + + "github.com/ChainSafe/gossamer/internal/database" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/ChainSafe/gossamer/pkg/trie/inmemory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestDB(t *testing.T) database.Table { + db, err := database.NewPebble("", true) + require.NoError(t, err) + return database.NewTable(db, "trie") +} + +func TestTrieDB_Get(t *testing.T) { + t.Run("read_successful_from_db_created_using_v1_trie", func(t *testing.T) { + db := newTestDB(t) + inMemoryTrie := inmemory.NewEmptyTrie() + inMemoryTrie.SetVersion(trie.V1) + + entries := map[string][]byte{ + "no": make([]byte, 20), + "not": make([]byte, 40), + "nothing": make([]byte, 20), + "notification": make([]byte, 40), + "test": make([]byte, 40), + } + + for k, v := range entries { + inMemoryTrie.Put([]byte(k), v) + } + + err := inMemoryTrie.WriteDirty(db) + assert.NoError(t, err) + + root, err := inMemoryTrie.Hash() + assert.NoError(t, err) + + trieDB := NewTrieDB(root, db) + + for k, v := range entries { + value := trieDB.Get([]byte(k)) + assert.Equal(t, v, value) + } + + assert.Equal(t, root, trieDB.MustHash()) + }) +} + +func TestTrieDB_Lookup(t *testing.T) { + t.Run("root_not_exists_in_db", func(t *testing.T) { + db := newTestDB(t) + trieDB := NewTrieDB(trie.EmptyHash, db) + + value, err := trieDB.lookup([]byte("test")) + assert.Nil(t, value) + assert.ErrorIs(t, err, ErrIncompleteDB) + }) +} diff --git a/pkg/trie/triedb/iterator.go b/pkg/trie/triedb/iterator.go new file mode 100644 index 0000000000..a069b49f09 --- /dev/null +++ b/pkg/trie/triedb/iterator.go @@ -0,0 +1,16 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +// Entries returns all the key-value pairs in the trie as a map of keys to values +// where the keys are encoded in Little Endian. +func (t *TrieDB) Entries() (keyValueMap map[string][]byte) { + panic("not implemented yet") +} + +// NextKey returns the next key in the trie in lexicographic order. +// It returns nil if no next key is found. +func (t *TrieDB) NextKey(key []byte) []byte { + panic("not implemented yet") +} diff --git a/pkg/trie/triedb/print.go b/pkg/trie/triedb/print.go new file mode 100644 index 0000000000..cba5d31506 --- /dev/null +++ b/pkg/trie/triedb/print.go @@ -0,0 +1,18 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "fmt" + + "github.com/ChainSafe/gossamer/lib/common" +) + +func (t *TrieDB) String() string { + if t.rootHash == common.EmptyHash { + return "empty" + } + + return fmt.Sprintf("TrieDB: %v", t.rootHash) +} diff --git a/pkg/trie/triedb/triedb.go b/pkg/trie/triedb/triedb.go new file mode 100644 index 0000000000..40d7ace18c --- /dev/null +++ b/pkg/trie/triedb/triedb.go @@ -0,0 +1,173 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "bytes" + "errors" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/trie" + nibbles "github.com/ChainSafe/gossamer/pkg/trie/codec" + "github.com/ChainSafe/gossamer/pkg/trie/db" + + "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" +) + +var ErrIncompleteDB = errors.New("incomplete database") + +// TrieDB is a DB-backed patricia merkle trie implementation +// using lazy loading to fetch nodes +type TrieDB struct { + rootHash common.Hash + db db.DBGetter +} + +// NewTrieDB creates a new TrieDB using the given root and db +func NewTrieDB(rootHash common.Hash, db db.DBGetter) *TrieDB { + return &TrieDB{ + rootHash: rootHash, + db: db, + } +} + +// Hash returns the hashed root of the trie. +func (t *TrieDB) Hash() (common.Hash, error) { + // This is trivial since it is a read only trie, but will change when we + // support writes + return t.rootHash, nil +} + +// MustHash returns the hashed root of the trie. +// It panics if it fails to hash the root node. +func (t *TrieDB) MustHash() common.Hash { + h, err := t.Hash() + if err != nil { + panic(err) + } + + return h +} + +// Get returns the value in the node of the trie +// which matches its key with the key given. +// Note the key argument is given in little Endian format. +func (t *TrieDB) Get(key []byte) []byte { + val, err := t.lookup(key) + if err != nil { + return nil + } + return val +} + +// GetKeysWithPrefix returns all keys in little Endian +// format from nodes in the trie that have the given little +// Endian formatted prefix in their key. +func (t *TrieDB) GetKeysWithPrefix(prefix []byte) (keysLE [][]byte) { + panic("not implemented yet") +} + +// Internal methods + +func (t *TrieDB) lookup(key []byte) ([]byte, error) { + keyNibbles := nibbles.KeyLEToNibbles(key) + return t.lookupWithoutCache(keyNibbles) +} + +// lookupWithoutCache traverse nodes loading then from DB until reach the one +// we are looking for. +func (t *TrieDB) lookupWithoutCache(nibbleKey []byte) ([]byte, error) { + // Start from root node and going downwards + partialKey := nibbleKey + hash := t.rootHash[:] + + // Iterates through non inlined nodes + for { + // Get node from DB + nodeData, err := t.db.Get(hash) + + if err != nil { + return nil, ErrIncompleteDB + } + + InlinedChildrenIterator: + for { + // Decode node + reader := bytes.NewReader(nodeData) + decodedNode, err := codec.Decode(reader) + if err != nil { + return nil, err + } + + var nextNode codec.MerkleValue + + switch n := decodedNode.(type) { + case codec.Empty: + return nil, nil + case codec.Leaf: + // We are in the node we were looking for + if bytes.Equal(partialKey, n.PartialKey) { + return t.loadValue(partialKey, n.Value) + } + return nil, nil + case codec.Branch: + nodePartialKey := n.PartialKey + + // This is unusual but could happen if for some reason one + // branch has a hashed child node that points to a node that + // doesn't share the prefix we are expecting + if !bytes.HasPrefix(partialKey, nodePartialKey) { + return nil, nil + } + + // We are in the node we were looking for + if bytes.Equal(partialKey, nodePartialKey) { + if n.Value != nil { + return t.loadValue(partialKey, n.Value) + } + return nil, nil + } + + // This is not the node we were looking for but it might be in + // one of its children + childIdx := int(partialKey[len(nodePartialKey)]) + nextNode = n.Children[childIdx] + if nextNode == nil { + return nil, nil + } + + // Advance the partial key consuming the part we already checked + partialKey = partialKey[len(nodePartialKey)+1:] + } + + // Next node could be inlined or hashed (pointer to a node) + // https://spec.polkadot.network/chap-state#defn-merkle-value + switch merkleValue := nextNode.(type) { + case codec.HashedNode: + // If it's hashed we set the hash to look for it in next loop + hash = merkleValue.Data + break InlinedChildrenIterator + case codec.InlineNode: + // If it is inlined we just need to decode it in the next loop + nodeData = merkleValue.Data + } + } + } +} + +// loadValue gets the value from the node, if it is inlined we can return it +// directly. But if it is hashed (V1) we have to look up for its value in the DB +func (t *TrieDB) loadValue(prefix []byte, value codec.NodeValue) ([]byte, error) { + switch v := value.(type) { + case codec.InlineValue: + return v.Data, nil + case codec.HashedValue: + prefixedKey := bytes.Join([][]byte{prefix, v.Data}, nil) + return t.db.Get(prefixedKey) + default: + panic("unreachable") + } +} + +var _ trie.TrieRead = (*TrieDB)(nil)