From 2e8ca6f06bb288d1669f2e8e61dac5a728158b95 Mon Sep 17 00:00:00 2001 From: Timothy Wu Date: Wed, 13 Nov 2024 14:43:01 +0700 Subject: [PATCH] refactor(pkg/trie/triedb): introduce `HashDB` interface and integrate into `TrieDB` with added iterators (#4315) Co-authored-by: Haiko Schol --- internal/client/state-db/noncanonical.go | 20 +- internal/client/state-db/pruning.go | 22 +- internal/hash-db/hash_db.go | 54 +++ internal/memory-db/memory_db.go | 242 ++++++++++ internal/memory-db/memory_db_test.go | 84 ++++ internal/primitives/core/hashing/hashing.go | 15 + pkg/trie/triedb/README.md | 4 +- pkg/trie/triedb/cache.go | 59 ++- pkg/trie/triedb/cache_bench_test.go | 12 +- pkg/trie/triedb/cache_test.go | 46 ++ pkg/trie/triedb/iterator.go | 214 ++++++++- pkg/trie/triedb/iterator_test.go | 242 +++++++++- pkg/trie/triedb/lookup.go | 198 ++++++-- pkg/trie/triedb/lookup_test.go | 11 +- pkg/trie/triedb/nibbles/nibbles.go | 2 +- pkg/trie/triedb/nibbles/nibbleslice.go | 19 +- pkg/trie/triedb/node.go | 74 +-- pkg/trie/triedb/node_storage.go | 4 +- pkg/trie/triedb/proof/generate_test.go | 4 +- pkg/trie/triedb/proof/proof.go | 9 +- pkg/trie/triedb/proof/proof_test.go | 9 +- pkg/trie/triedb/proof/util_test.go | 88 +--- pkg/trie/triedb/recorder.go | 2 +- pkg/trie/triedb/recorder_test.go | 13 +- pkg/trie/triedb/triedb.go | 209 +++++---- pkg/trie/triedb/triedb_iterator_test.go | 15 +- pkg/trie/triedb/triedb_test.go | 476 ++++++++++++++++---- pkg/trie/triedb/util_test.go | 100 +--- 28 files changed, 1730 insertions(+), 517 deletions(-) create mode 100644 internal/hash-db/hash_db.go create mode 100644 internal/memory-db/memory_db.go create mode 100644 internal/memory-db/memory_db_test.go create mode 100644 pkg/trie/triedb/cache_test.go diff --git a/internal/client/state-db/noncanonical.go b/internal/client/state-db/noncanonical.go index 237e0d44da..9d039d526a 100644 --- a/internal/client/state-db/noncanonical.go +++ b/internal/client/state-db/noncanonical.go @@ -146,10 +146,10 @@ func (nco *nonCanonicalOverlay[BlockHash, Key]) Insert( }) nco.lastCanonicalized = &lastCanonicalized } else if nco.lastCanonicalized != nil { - if number < frontBlockNumber || number > frontBlockNumber+uint64(nco.levels.Len()) { + if number < frontBlockNumber || number > frontBlockNumber+uint64(nco.levels.Len()) { //nolint:gosec log.Printf( "TRACE: Failed to insert block %v, current is %v .. %v)\n", - number, frontBlockNumber, frontBlockNumber+uint64(nco.levels.Len())) + number, frontBlockNumber, frontBlockNumber+uint64(nco.levels.Len())) //nolint:gosec return CommitSet[Key]{}, ErrInvalidBlockNumber } // check for valid parent if inserting on second level or higher @@ -163,13 +163,13 @@ func (nco *nonCanonicalOverlay[BlockHash, Key]) Insert( } var level overlayLevel[BlockHash, Key] = newOverlayLevel[BlockHash, Key]() var levelIndex int - if nco.levels.Len() == 0 || number == frontBlockNumber+uint64(nco.levels.Len()) { + if nco.levels.Len() == 0 || number == frontBlockNumber+uint64(nco.levels.Len()) { //nolint:gosec nco.levels.PushBack(newOverlayLevel[BlockHash, Key]()) level = nco.levels.Back() levelIndex = nco.levels.Len() - 1 } else { - level = nco.levels.At(int(number - frontBlockNumber)) - levelIndex = int(number - frontBlockNumber) + level = nco.levels.At(int(number - frontBlockNumber)) //nolint:gosec + levelIndex = int(number - frontBlockNumber) //nolint:gosec } if len(level.blocks) >= int(maxBlocksPerLevel) { @@ -221,10 +221,10 @@ func (nco *nonCanonicalOverlay[BlockHash, Key]) Insert( func (nco *nonCanonicalOverlay[BlockHash, Key]) discardJournals( levelIndex uint, discardedJournals *[][]byte, hash BlockHash) { - if levelIndex >= uint(nco.levels.Len()) { + if levelIndex >= uint(nco.levels.Len()) { //nolint:gosec return } - level := nco.levels.At(int(levelIndex)) + level := nco.levels.At(int(levelIndex)) //nolint:gosec for _, overlay := range level.blocks { parent, ok := nco.parents[overlay.hash] if !ok { @@ -418,7 +418,7 @@ func (nco *nonCanonicalOverlay[BlockHash, Key]) Remove(hash BlockHash) *CommitSe } } } - overlay := level.remove(uint(index)) + overlay := level.remove(uint(index)) //nolint:gosec nco.levels.Set(levelIndex, level) commit.Meta.Deleted = append(commit.Meta.Deleted, overlay.journalKey) delete(nco.parents, overlay.hash) @@ -496,7 +496,7 @@ func (ol *overlayLevel[BlockHash, Key]) push(overlay blockOverlay[BlockHash, Key } func (ol *overlayLevel[BlockHash, Key]) availableIndex() uint64 { - return uint64(bits.TrailingZeros64(^ol.usedIndices)) + return uint64(bits.TrailingZeros64(^ol.usedIndices)) //nolint:gosec } func (ol *overlayLevel[BlockHash, Key]) remove(index uint) blockOverlay[BlockHash, Key] { @@ -639,7 +639,7 @@ func discardDescendants[BlockHash Hash, Key Hash]( panic("there is a parent entry for each entry in levels; qed") } if h == hash { - index = uint(i) + index = uint(i) //nolint:gosec overlay := level.remove(index) numPinned := discardDescendants(remainder, values, parents, pinned, pinnedInsertions, overlay.hash) if _, ok := pinned[overlay.hash]; ok { diff --git a/internal/client/state-db/pruning.go b/internal/client/state-db/pruning.go index f33f0b53c4..96cbaeb5ff 100644 --- a/internal/client/state-db/pruning.go +++ b/internal/client/state-db/pruning.go @@ -24,10 +24,10 @@ const defaultMaxBlockConstraint uint32 = 256 // the death list. // The changes are journaled in the DB. type pruningWindow[BlockHash Hash, Key Hash] struct { - /// A queue of blocks keep tracking keys that should be deleted for each block in the - /// pruning window. + // A queue of blocks keep tracking keys that should be deleted for each block in the + // pruning window. queue deathRowQueue[BlockHash, Key] - /// Block number that is next to be pruned. + // Block number that is next to be pruned. base uint64 } @@ -156,9 +156,9 @@ type deathRowQueue[BlockHash Hash, Key Hash] interface { } type inMemDeathRowQueue[BlockHash Hash, Key Hash] struct { - /// A queue of keys that should be deleted for each block in the pruning window. + // A queue of keys that should be deleted for each block in the pruning window. deathRows deque.Deque[deathRow[BlockHash, Key]] - /// An index that maps each key from `death_rows` to block number. + // An index that maps each key from `death_rows` to block number. deathIndex map[Key]uint64 } @@ -207,11 +207,11 @@ func (drqim *inMemDeathRowQueue[BlockHash, Key]) Import( block, ok := drqim.deathIndex[k] if ok { delete(drqim.deathIndex, k) - delete(drqim.deathRows.At(int(block-base)).deleted, k) + delete(drqim.deathRows.At(int(block-base)).deleted, k) //nolint:gosec } } // add new keys - importedBlock := base + uint64(drqim.deathRows.Len()) + importedBlock := base + uint64(drqim.deathRows.Len()) //nolint:gosec deletedMap := make(map[Key]any) for _, k := range deleted { drqim.deathIndex[k] = importedBlock @@ -236,7 +236,7 @@ func (drqim *inMemDeathRowQueue[BlockHash, Key]) PopFront(base uint64) (*deathRo // Check if the block at the given `index` of the queue exist // it is the caller's responsibility to ensure `index` won't be out of bounds func (drqim *inMemDeathRowQueue[BlockHash, Key]) HaveBlock(hash BlockHash, index uint) haveBlock { - if drqim.deathRows.At(int(index)).hash == hash { + if drqim.deathRows.At(int(index)).hash == hash { //nolint:gosec return haveBlockYes } return haveBlockNo @@ -244,7 +244,7 @@ func (drqim *inMemDeathRowQueue[BlockHash, Key]) HaveBlock(hash BlockHash, index // Return the number of block in the pruning window func (drqim *inMemDeathRowQueue[BlockHash, Key]) Len(base uint64) uint64 { - return uint64(drqim.deathRows.Len()) + return uint64(drqim.deathRows.Len()) //nolint:gosec } // Get the hash of the next pruning block @@ -276,8 +276,8 @@ func toPruningJournalKey(block uint64) []byte { type haveBlock uint const ( - /// Definitely don't have this block. + // Definitely don't have this block. haveBlockNo haveBlock = iota - /// Definitely has this block + // Definitely has this block haveBlockYes ) diff --git a/internal/hash-db/hash_db.go b/internal/hash-db/hash_db.go new file mode 100644 index 0000000000..4b2d891a6d --- /dev/null +++ b/internal/hash-db/hash_db.go @@ -0,0 +1,54 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package hashdb + +import "golang.org/x/exp/constraints" + +// A trie node prefix, it is the nibble path from the trie root +// to the trie node. +// For a node containing no partial key value it is the full key. +// For a value node or node containing a partial key, it is the full key minus its node partial +// nibbles (the node key can be split into prefix and node partial). +// Therefore it is always the leftmost portion of the node key, so its internal representation +// is a non expanded byte slice followed by a last padded byte representation. +// The padded byte is an optional padded value. +type Prefix struct { + Key []byte + Padded *byte +} + +// An empty prefix constant. +// Can be use when the prefix is not used internally or for root nodes. +var EmptyPrefix = Prefix{} + +// Hasher is an interface describing an object that can hash a slice of bytes. Used to abstract +// other types over the hashing algorithm. Defines a single hash method and an +// Out associated type with the necessary bounds. +type Hasher[Out constraints.Ordered] interface { + // Compute the hash of the provided slice of bytes returning the Out type of the Hasher. + Hash(x []byte) Out +} + +// HashDB is an interface modelling datastore keyed by a hash defined by the Hasher. +type HashDB[Hash comparable] interface { + // Look up a given hash into the bytes that hash to it, returning None if the + // hash is not known. + Get(key Hash, prefix Prefix) []byte + + // Check for the existence of a hash-key. + Contains(key Hash, prefix Prefix) bool + + // Insert a datum item into the DB and return the datum's hash for a later lookup. Insertions + // are counted and the equivalent number of remove()s must be performed before the data + // is considered dead. + Insert(prefix Prefix, value []byte) Hash + + // Like Insert(), except you provide the key and the data is all moved. + Emplace(key Hash, prefix Prefix, value []byte) + + // Remove a datum previously inserted. Insertions can be "owed" such that the same number of + // inserts may happen without the data eventually being inserted into the DB. + // It can be "owed" more than once. + Remove(key Hash, prefix Prefix) +} diff --git a/internal/memory-db/memory_db.go b/internal/memory-db/memory_db.go new file mode 100644 index 0000000000..803c1563ae --- /dev/null +++ b/internal/memory-db/memory_db.go @@ -0,0 +1,242 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package memorydb + +import ( + "maps" + + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" + "golang.org/x/exp/constraints" +) + +type dataRC struct { + Data []byte + RC int32 +} + +type Hash interface { + constraints.Ordered + Bytes() []byte +} + +type Value interface { + ~[]byte +} + +// Reference-counted memory-based [hashdb.HashDB] implementation. +type MemoryDB[H Hash, Hasher hashdb.Hasher[H], Key constraints.Ordered, KF KeyFunction[H, Key]] struct { + data map[Key]dataRC + hashedNullNode H + nullNodeData []byte +} + +func NewMemoryDB[H Hash, Hasher hashdb.Hasher[H], Key constraints.Ordered, KF KeyFunction[H, Key]]( + data []byte, +) MemoryDB[H, Hasher, Key, KF] { + return newMemoryDBFromNullNode[H, Hasher, Key, KF](data, data) +} + +func newMemoryDBFromNullNode[H Hash, Hasher hashdb.Hasher[H], Key constraints.Ordered, KF KeyFunction[H, Key], T Value]( + nullKey []byte, + nullNodeData T, +) MemoryDB[H, Hasher, Key, KF] { + return MemoryDB[H, Hasher, Key, KF]{ + data: make(map[Key]dataRC), + hashedNullNode: (*new(Hasher)).Hash(nullKey), + nullNodeData: nullNodeData, + } +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Clone() MemoryDB[H, Hasher, Key, KF] { + return MemoryDB[H, Hasher, Key, KF]{ + data: maps.Clone(mdb.data), + hashedNullNode: mdb.hashedNullNode, + nullNodeData: mdb.nullNodeData, + } +} + +// Purge all zero-referenced data from the database. +func (mdb *MemoryDB[H, Hasher, Key, KF]) Purge() { + for k, val := range mdb.data { + if val.RC == 0 { + delete(mdb.data, k) + } + } +} + +// Return the internal key-value Map, clearing the current state. +func (mdb *MemoryDB[H, Hasher, Key, KF]) Drain() map[Key]dataRC { + data := mdb.data + mdb.data = make(map[Key]dataRC) + return data +} + +// Grab the raw information associated with a key. Returns None if the key +// doesn't exist. +// +// Even when Some is returned, the data is only guaranteed to be useful +// when the refs > 0. +func (mdb *MemoryDB[H, Hasher, Key, KF]) raw(key H, prefix hashdb.Prefix) *dataRC { + if key == mdb.hashedNullNode { + return &dataRC{mdb.nullNodeData, 1} + } + kfKey := (*new(KF)).Key(key, prefix) + data, ok := mdb.data[kfKey] + if ok { + return &data + } + return nil +} + +// Consolidate all the entries of other into self. +func (mdb *MemoryDB[H, Hasher, Key, KF]) Consolidate(other *MemoryDB[H, Hasher, Key, KF]) { + for key, value := range other.Drain() { + entry, ok := mdb.data[key] + if ok { + if entry.RC < 0 { + entry.Data = value.Data + } + + entry.RC += value.RC + mdb.data[key] = entry + } else { + mdb.data[key] = dataRC{ + Data: value.Data, + RC: value.RC, + } + } + } +} + +// Remove an element and delete it from storage if reference count reaches zero. +// If the value was purged, return the old value. +func (mdb *MemoryDB[H, Hasher, Key, KF]) removeAndPurge(key H, prefix hashdb.Prefix) []byte { + if key == mdb.hashedNullNode { + return nil + } + kfKey := (*new(KF)).Key(key, prefix) + data, ok := mdb.data[kfKey] + if ok { + if data.RC == 1 { + delete(mdb.data, kfKey) + return data.Data + } + data.RC -= 1 + mdb.data[kfKey] = data + return nil + } + mdb.data[kfKey] = dataRC{RC: -1} + return nil +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Get(key H, prefix hashdb.Prefix) []byte { + if key == mdb.hashedNullNode { + return mdb.nullNodeData + } + + kfKey := (*new(KF)).Key(key, prefix) + data, ok := mdb.data[kfKey] + if ok { + if data.RC > 0 { + return data.Data + } + } + return nil +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Contains(key H, prefix hashdb.Prefix) bool { + if key == mdb.hashedNullNode { + return true + } + + kfKey := (*new(KF)).Key(key, prefix) + data, ok := mdb.data[kfKey] + if ok { + if data.RC > 0 { + return true + } + } + return false +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Emplace(key H, prefix hashdb.Prefix, value []byte) { + if string(mdb.nullNodeData) == string(value) { + return + } + + kfKey := (*new(KF)).Key(key, prefix) + data, ok := mdb.data[kfKey] + if ok { + if data.RC <= 0 { + data.Data = value + } + data.RC += 1 + mdb.data[kfKey] = data + } else { + mdb.data[kfKey] = dataRC{value, 1} + } +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Insert(prefix hashdb.Prefix, value []byte) H { + if string(mdb.nullNodeData) == string(value) { + return mdb.hashedNullNode + } + + key := (*new(Hasher)).Hash(value) + mdb.Emplace(key, prefix, value) + return key +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Remove(key H, prefix hashdb.Prefix) { + if key == mdb.hashedNullNode { + return + } + + kfKey := (*new(KF)).Key(key, prefix) + data, ok := mdb.data[kfKey] + if ok { + data.RC -= 1 + mdb.data[kfKey] = data + } else { + mdb.data[kfKey] = dataRC{RC: -1} + } +} + +func (mdb *MemoryDB[H, Hasher, Key, KF]) Keys() map[Key]int32 { + keyCounts := make(map[Key]int32) + for key, drc := range mdb.data { + if drc.RC != 0 { + keyCounts[key] = drc.RC + } + } + return keyCounts +} + +type KeyFunction[Hash constraints.Ordered, Key any] interface { + Key(hash Hash, prefix hashdb.Prefix) Key +} + +// Key function that only uses the hash +type HashKey[H Hash] struct{} + +func (HashKey[Hash]) Key(hash Hash, prefix hashdb.Prefix) Hash { + return hash +} + +// Key function that concatenates prefix and hash. +type PrefixedKey[H Hash] struct{} + +func (PrefixedKey[H]) Key(key H, prefix hashdb.Prefix) string { + return string(NewPrefixedKey(key, prefix)) +} + +// Derive a database key from hash value of the node (key) and the node prefix. +func NewPrefixedKey[H Hash](key H, prefix hashdb.Prefix) []byte { + prefixedKey := prefix.Key + if prefix.Padded != nil { + prefixedKey = append(prefixedKey, *prefix.Padded) + } + prefixedKey = append(prefixedKey, key.Bytes()...) + return prefixedKey +} diff --git a/internal/memory-db/memory_db_test.go b/internal/memory-db/memory_db_test.go new file mode 100644 index 0000000000..b0100cc4d4 --- /dev/null +++ b/internal/memory-db/memory_db_test.go @@ -0,0 +1,84 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package memorydb + +import ( + "testing" + + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/core/hashing" + "github.com/stretchr/testify/assert" +) + +var ( + _ KeyFunction[hash.H256, hash.H256] = HashKey[hash.H256]{} + _ KeyFunction[hash.H256, string] = PrefixedKey[hash.H256]{} +) + +// Blake2-256 Hash implementation. +type Keccak256 struct{} + +// Produce the hash of some byte-slice. +func (k256 Keccak256) Hash(s []byte) hash.H256 { + h := hashing.Keccak256(s) + return hash.H256(h[:]) +} + +func TestMemoryDB_RemoveAndPurge(t *testing.T) { + helloBytes := []byte("Hello world!") + helloKey := Keccak256{}.Hash(helloBytes) + + m := NewMemoryDB[hash.H256, Keccak256, hash.H256, HashKey[hash.H256]]([]byte{0}) + m.Remove(helloKey, hashdb.EmptyPrefix) + assert.Equal(t, int32(-1), m.raw(helloKey, hashdb.EmptyPrefix).RC) + m.Purge() + assert.Equal(t, int32(-1), m.raw(helloKey, hashdb.EmptyPrefix).RC) + m.Insert(hashdb.EmptyPrefix, helloBytes) + assert.Equal(t, int32(0), m.raw(helloKey, hashdb.EmptyPrefix).RC) + m.Purge() + assert.Nil(t, m.raw(helloKey, hashdb.EmptyPrefix)) + + m = NewMemoryDB[hash.H256, Keccak256, hash.H256, HashKey[hash.H256]]([]byte{0}) + assert.Nil(t, m.removeAndPurge(helloKey, hashdb.EmptyPrefix)) + assert.Equal(t, int32(-1), m.raw(helloKey, hashdb.EmptyPrefix).RC) + m.Insert(hashdb.EmptyPrefix, helloBytes) + m.Insert(hashdb.EmptyPrefix, helloBytes) + assert.Equal(t, int32(1), m.raw(helloKey, hashdb.EmptyPrefix).RC) + assert.Equal(t, helloBytes, m.removeAndPurge(helloKey, hashdb.EmptyPrefix)) + assert.Nil(t, m.raw(helloKey, hashdb.EmptyPrefix)) + assert.Nil(t, m.removeAndPurge(helloKey, hashdb.EmptyPrefix)) +} + +func TestMemoryDB_Consolidate(t *testing.T) { + main := NewMemoryDB[hash.H256, Keccak256, hash.H256, HashKey[hash.H256]]([]byte{0}) + other := NewMemoryDB[hash.H256, Keccak256, hash.H256, HashKey[hash.H256]]([]byte{0}) + removeKey := other.Insert(hashdb.EmptyPrefix, []byte("doggo")) + main.Remove(removeKey, hashdb.EmptyPrefix) + + insertKey := other.Insert(hashdb.EmptyPrefix, []byte("arf")) + main.Emplace(insertKey, hashdb.EmptyPrefix, []byte("arf")) + + negativeRemoveKey := other.Insert(hashdb.EmptyPrefix, []byte("negative")) + other.Remove(negativeRemoveKey, hashdb.EmptyPrefix) + other.Remove(negativeRemoveKey, hashdb.EmptyPrefix) + main.Remove(negativeRemoveKey, hashdb.EmptyPrefix) + + main.Consolidate(&other) + + assert.Equal(t, &dataRC{[]byte("doggo"), 0}, main.raw(removeKey, hashdb.EmptyPrefix)) + assert.Equal(t, &dataRC{[]byte("arf"), 2}, main.raw(insertKey, hashdb.EmptyPrefix)) + assert.Equal(t, &dataRC{[]byte("negative"), -2}, main.raw(negativeRemoveKey, hashdb.EmptyPrefix)) +} + +func TestMemoryDB_DefaultWorks(t *testing.T) { + db := NewMemoryDB[hash.H256, Keccak256, hash.H256, HashKey[hash.H256]]([]byte{0}) + hashedNullNode := Keccak256{}.Hash([]byte{0}) + assert.Equal(t, hashedNullNode, db.Insert(hashdb.EmptyPrefix, []byte{0})) + + db2 := NewMemoryDB[hash.H256, Keccak256, hash.H256, HashKey[hash.H256]]([]byte{0}) + root := db2.hashedNullNode + assert.True(t, db2.Contains(root, hashdb.EmptyPrefix)) + assert.True(t, db.Contains(root, hashdb.EmptyPrefix)) +} diff --git a/internal/primitives/core/hashing/hashing.go b/internal/primitives/core/hashing/hashing.go index 6cb31194e9..086edef4ab 100644 --- a/internal/primitives/core/hashing/hashing.go +++ b/internal/primitives/core/hashing/hashing.go @@ -5,6 +5,7 @@ package hashing import ( "golang.org/x/crypto/blake2b" + "golang.org/x/crypto/sha3" ) // BlakeTwo256 returns a Blake2 256-bit hash of the input data @@ -22,3 +23,17 @@ func BlakeTwo256(data []byte) [32]byte { copy(arr[:], encoded) return arr } + +// Keccak256 returns the keccak256 hash of the input data +func Keccak256(data []byte) [32]byte { + h := sha3.NewLegacyKeccak256() + _, err := h.Write(data) + if err != nil { + panic(err) + } + + hash := h.Sum(nil) + var buf = [32]byte{} + copy(buf[:], hash) + return buf +} diff --git a/pkg/trie/triedb/README.md b/pkg/trie/triedb/README.md index 60d7f95380..083a568063 100644 --- a/pkg/trie/triedb/README.md +++ b/pkg/trie/triedb/README.md @@ -10,7 +10,7 @@ It offers functionalities for writing and reading operations and uses lazy loadi - **Reads**: Basic functions to get data from the trie. - **Lazy Loading**: Load data on demand. - **Caching**: Enhances search performance. -- **Compatibility**: Works with any database implementing the `db.RWDatabase` interface and any cache implementing the `Cache` interface. +- **Compatibility**: Works with any database implementing the `hashdb.HashDB` interface and any cache implementing the `TrieCache` interface. - **Merkle proofs**: Create and verify merkle proofs. - **Iterator**: Traverse the trie keys in order. @@ -29,7 +29,7 @@ trie := triedb.NewEmptyTrieDB(db) To insert a key and its associated value: ```go -err := trie.Put([]byte("key"), []byte("value")) +err := trie.Set([]byte("key"), []byte("value")) ``` ### Get Data diff --git a/pkg/trie/triedb/cache.go b/pkg/trie/triedb/cache.go index b5517ccfa4..ee4e24f208 100644 --- a/pkg/trie/triedb/cache.go +++ b/pkg/trie/triedb/cache.go @@ -5,6 +5,7 @@ package triedb import ( "bytes" + "unsafe" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" @@ -86,7 +87,7 @@ func (nho HashCachedNodeHandle[H]) ChildReference() ChildReference { return HashChildReference[H](nho) } func (nho InlineCachedNodeHandle[H]) ChildReference() ChildReference { - encoded := nho.CachedNode.encoded() + encoded := nho.CachedNode.Encoded() store := (*new(H)) if len(encoded) > store.Length() { panic("Invalid inline node handle") @@ -106,7 +107,7 @@ func newCachedNodeHandleFromMerkleValue[H hash.Hash, Hasher hash.Hasher[H]]( if err != nil { return nil, err } - cachedNode, err := newCachedNodeFromNode[H, Hasher](node) + cachedNode, err := NewCachedNodeFromNode[H, Hasher](node) if err != nil { return nil, err } @@ -129,11 +130,12 @@ type CachedNodeTypes[H hash.Hash] interface { // Cached nodes. type CachedNode[H hash.Hash] interface { - data() []byte // nil means there is no data + Data() []byte // nil means there is no data dataHash() *H children() []child[H] partialKey() *nibbles.NibbleSlice - encoded() []byte + Encoded() []byte + ByteSize() uint } type ( @@ -161,15 +163,15 @@ type ( } ) -func (EmptyCachedNode[H]) data() []byte { return nil } //nolint:unused -func (no LeafCachedNode[H]) data() []byte { return no.Value.data() } //nolint:unused -func (no BranchCachedNode[H]) data() []byte { //nolint:unused +func (EmptyCachedNode[H]) Data() []byte { return nil } +func (no LeafCachedNode[H]) Data() []byte { return no.Value.data() } +func (no BranchCachedNode[H]) Data() []byte { if no.Value != nil { return no.Value.data() } return nil } -func (no ValueCachedNode[H]) data() []byte { return no.Value } //nolint:unused +func (no ValueCachedNode[H]) Data() []byte { return no.Value } func (EmptyCachedNode[H]) dataHash() *H { return nil } //nolint:unused func (no LeafCachedNode[H]) dataHash() *H { return no.Value.dataHash() } //nolint:unused @@ -203,8 +205,8 @@ func (no LeafCachedNode[H]) partialKey() *nibbles.NibbleSlice { return &no.Par func (no BranchCachedNode[H]) partialKey() *nibbles.NibbleSlice { return &no.PartialKey } //nolint:unused func (no ValueCachedNode[H]) partialKey() *nibbles.NibbleSlice { return nil } //nolint:unused -func (EmptyCachedNode[H]) encoded() []byte { return []byte{EmptyTrieBytes} } //nolint:unused -func (no LeafCachedNode[H]) encoded() []byte { //nolint:unused +func (EmptyCachedNode[H]) Encoded() []byte { return []byte{EmptyTrieBytes} } +func (no LeafCachedNode[H]) Encoded() []byte { encodingBuffer := bytes.NewBuffer(nil) err := NewEncodedLeaf(no.PartialKey.Right(), no.PartialKey.Len(), no.Value.EncodedValue(), encodingBuffer) if err != nil { @@ -212,7 +214,7 @@ func (no LeafCachedNode[H]) encoded() []byte { //nolint:unused } return encodingBuffer.Bytes() } -func (no BranchCachedNode[H]) encoded() []byte { //nolint:unused +func (no BranchCachedNode[H]) Encoded() []byte { encodingBuffer := bytes.NewBuffer(nil) children := [16]ChildReference{} for i, ch := range no.Children { @@ -236,9 +238,38 @@ func (no BranchCachedNode[H]) encoded() []byte { //nolint:unused } return encodingBuffer.Bytes() } -func (no ValueCachedNode[H]) encoded() []byte { return no.Value } //nolint:unused +func (no ValueCachedNode[H]) Encoded() []byte { return no.Value } + +func (no EmptyCachedNode[H]) ByteSize() uint { return (uint)(unsafe.Sizeof(no)) } +func (no LeafCachedNode[H]) ByteSize() uint { + return (uint)(unsafe.Sizeof(no)) + uint(len(no.PartialKey.Inner())+len(no.Value.data())) //nolint:gosec +} +func (no BranchCachedNode[H]) ByteSize() uint { + selfSize := (uint)(unsafe.Sizeof(no)) + var childSize = func(children [16]CachedNodeHandle) (size uint) { + for _, child := range children { + if child == nil { + continue + } + switch child := child.(type) { + case HashCachedNodeHandle[H]: + case InlineCachedNodeHandle[H]: + size = size + child.CachedNode.ByteSize() + default: + panic("unreachable") + } + } + return + } + size := selfSize + uint(len(no.PartialKey.Inner())) + childSize(no.Children) + if no.Value != nil { + size = size + uint(len(no.Value.data())) + } + return size +} +func (no ValueCachedNode[H]) ByteSize() uint { return (uint)(unsafe.Sizeof(no)) + uint(len(no.Value)) } -func newCachedNodeFromNode[H hash.Hash, Hasher hash.Hasher[H]](n codec.EncodedNode) (CachedNode[H], error) { +func NewCachedNodeFromNode[H hash.Hash, Hasher hash.Hasher[H]](n codec.EncodedNode) (CachedNode[H], error) { switch n := n.(type) { case codec.Empty: return EmptyCachedNode[H]{}, nil @@ -321,7 +352,7 @@ func (ecv ExistingCachedValue[H]) hash() *H { return &ecv.Hash } //nolint:un // // The interface consists of two cache levels, first the trie nodes cache and then the value cache. // The trie nodes cache, as the name indicates, is for caching trie nodes as [CachedNode]. These -// trie nodes are referenced by their hash. The value cache is caching [CachedValue]s and these +// trie nodes are referenced by their hash. The value cache is caching [CachedValue]'s and these // are referenced by the key to look them up in the trie. As multiple different tries can have // different values under the same key, it is up to the cache implementation to ensure that the // correct value is returned. As each trie has a different root, this root can be used to diff --git a/pkg/trie/triedb/cache_bench_test.go b/pkg/trie/triedb/cache_bench_test.go index 39104e912b..6e9b96f302 100644 --- a/pkg/trie/triedb/cache_bench_test.go +++ b/pkg/trie/triedb/cache_bench_test.go @@ -24,12 +24,12 @@ func Benchmark_ValueCache(b *testing.B) { } version := trie.V1 - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trie.SetVersion(version) for k, v := range entries { - require.NoError(b, trie.Put([]byte(k), v)) + require.NoError(b, trie.Set([]byte(k), v)) } err := trie.commit() require.NoError(b, err) @@ -73,12 +73,12 @@ func Benchmark_NodesCache(b *testing.B) { } version := trie.V1 - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trie.SetVersion(version) for k, v := range entries { - require.NoError(b, trie.Put([]byte(k), v)) + require.NoError(b, trie.Set([]byte(k), v)) } err := trie.commit() require.NoError(b, err) @@ -90,7 +90,7 @@ func Benchmark_NodesCache(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // Iterate through all keys - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) require.NoError(b, err) for entry, err := iter.NextItem(); entry != nil && err == nil; entry, err = iter.NextItem() { } @@ -104,7 +104,7 @@ func Benchmark_NodesCache(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // Iterate through all keys - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) require.NoError(b, err) for entry, err := iter.NextItem(); entry != nil && err == nil; entry, err = iter.NextItem() { } diff --git a/pkg/trie/triedb/cache_test.go b/pkg/trie/triedb/cache_test.go new file mode 100644 index 0000000000..49e8c6ef5b --- /dev/null +++ b/pkg/trie/triedb/cache_test.go @@ -0,0 +1,46 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "bytes" + "testing" + + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ByteSize(t *testing.T) { + var childHash hash.H256 = runtime.BlakeTwo256{}.Hash([]byte{0}) + encodedBranch := codec.Branch{ + PartialKey: nibbles.NewNibbles([]byte{1}), + Value: codec.InlineValue([]byte{7, 8, 9}), + Children: [codec.ChildrenCapacity]codec.MerkleValue{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + codec.HashedNode[hash.H256]{Hash: childHash}, + }, + } + + cachedNode, err := NewCachedNodeFromNode[hash.H256, runtime.BlakeTwo256](encodedBranch) + require.NoError(t, err) + assert.Equal(t, 308, int(cachedNode.ByteSize())) + + encodedBranch = codec.Branch{ + PartialKey: nibbles.NewNibbles([]byte{1}), + Value: codec.InlineValue(bytes.Repeat([]byte{1}, 33)), + Children: [codec.ChildrenCapacity]codec.MerkleValue{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + codec.HashedNode[hash.H256]{Hash: childHash}, + }, + } + cachedNode, err = NewCachedNodeFromNode[hash.H256, runtime.BlakeTwo256](encodedBranch) + require.NoError(t, err) + assert.Equal(t, 308+30, int(cachedNode.ByteSize())) +} diff --git a/pkg/trie/triedb/iterator.go b/pkg/trie/triedb/iterator.go index 83a7cb1ff3..93089cf963 100644 --- a/pkg/trie/triedb/iterator.go +++ b/pkg/trie/triedb/iterator.go @@ -6,6 +6,7 @@ package triedb import ( "bytes" "fmt" + "iter" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" @@ -116,7 +117,7 @@ func (ri rawItem[H]) extractKey() *extractedKey { } } -type rawIterator[H hash.Hash, Hasher hash.Hasher[H]] struct { +type TrieDBRawIterator[H hash.Hash, Hasher hash.Hasher[H]] struct { // Forward trail of nodes to visit. trail []crumb[H] // Forward iteration key nibbles of the current node. @@ -125,9 +126,9 @@ type rawIterator[H hash.Hash, Hasher hash.Hasher[H]] struct { } // Create a new iterator. -func newRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( +func NewTrieDBRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( db *TrieDB[H, Hasher], -) (*rawIterator[H, Hasher], error) { +) (*TrieDBRawIterator[H, Hasher], error) { rootNode, rootHash, err := db.getNodeOrLookup( codec.HashedNode[H]{Hash: db.rootHash}, nibbles.Prefix{}, @@ -137,7 +138,7 @@ func newRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( return nil, err } - r := rawIterator[H, Hasher]{ + r := TrieDBRawIterator[H, Hasher]{ db: db, } r.descend(rootNode, rootHash) @@ -145,10 +146,10 @@ func newRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( } // Create a new iterator, but limited to a given prefix. -func newPrefixedRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( +func NewPrefixedTrieDBRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( db *TrieDB[H, Hasher], prefix []byte, -) (*rawIterator[H, Hasher], error) { - iter, err := newRawIterator(db) +) (*TrieDBRawIterator[H, Hasher], error) { + iter, err := NewTrieDBRawIterator(db) if err != nil { return nil, err } @@ -162,10 +163,10 @@ func newPrefixedRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( // Create a new iterator, but limited to a given prefix. // It then do a seek operation from prefixed context (using seek lose // prefix context by default). -func newPrefixedRawIteratorThenSeek[H hash.Hash, Hasher hash.Hasher[H]]( +func NewPrefixedTrieDBRawIteratorThenSeek[H hash.Hash, Hasher hash.Hasher[H]]( db *TrieDB[H, Hasher], prefix []byte, seek []byte, -) (*rawIterator[H, Hasher], error) { - iter, err := newRawIterator(db) +) (*TrieDBRawIterator[H, Hasher], error) { + iter, err := NewTrieDBRawIterator(db) if err != nil { return nil, err } @@ -177,7 +178,7 @@ func newPrefixedRawIteratorThenSeek[H hash.Hash, Hasher hash.Hasher[H]]( } // Descend into a node. -func (ri *rawIterator[H, Hasher]) descend(node codec.EncodedNode, nodeHash *H) { +func (ri *TrieDBRawIterator[H, Hasher]) descend(node codec.EncodedNode, nodeHash *H) { ri.trail = append(ri.trail, crumb[H]{ hash: nodeHash, status: statusEntering{}, @@ -191,7 +192,7 @@ func (ri *rawIterator[H, Hasher]) descend(node codec.EncodedNode, nodeHash *H) { // share its prefix with the node. // This indicates if there is still nodes to iterate over in the case // where we limit iteration to key as a prefix. -func (ri *rawIterator[H, Hasher]) seek(keyBytes []byte, fwd bool) (bool, error) { +func (ri *TrieDBRawIterator[H, Hasher]) seek(keyBytes []byte, fwd bool) (bool, error) { ri.trail = nil ri.keyNibbles.Clear() key := nibbles.NewNibbles(keyBytes) @@ -272,7 +273,7 @@ func (ri *rawIterator[H, Hasher]) seek(keyBytes []byte, fwd bool) (bool, error) // Advance the iterator into a prefix, no value out of the prefix will be accessed // or returned after this operation. -func (ri *rawIterator[H, Hasher]) prefix(prefix []byte, fwd bool) error { +func (ri *TrieDBRawIterator[H, Hasher]) prefix(prefix []byte, fwd bool) error { found, err := ri.seek(prefix, fwd) if err != nil { return err @@ -291,7 +292,7 @@ func (ri *rawIterator[H, Hasher]) prefix(prefix []byte, fwd bool) error { // Advance the iterator into a prefix, no value out of the prefix will be accessed // or returned after this operation. -func (ri *rawIterator[H, Hasher]) prefixThenSeek(prefix []byte, seek []byte) error { +func (ri *TrieDBRawIterator[H, Hasher]) prefixThenSeek(prefix []byte, seek []byte) error { if len(prefix) == 0 { // Theres no prefix, so just seek. _, err := ri.seek(seek, true) @@ -357,7 +358,7 @@ func (ri *rawIterator[H, Hasher]) prefixThenSeek(prefix []byte, seek []byte) err // Must be called with the same db as when the iterator was created. // // Specify fwd to indicate the direction of the iteration (true for forward). -func (ri *rawIterator[H, Hasher]) nextRawItem(fwd bool) (*rawItem[H], error) { +func (ri *TrieDBRawIterator[H, Hasher]) nextRawItem(fwd bool) (*rawItem[H], error) { for { if len(ri.trail) == 0 { return nil, nil @@ -431,7 +432,7 @@ func (ri *rawIterator[H, Hasher]) nextRawItem(fwd bool) (*rawItem[H], error) { // Fetches the next trie item. // // Must be called with the same db as when the iterator was created. -func (ri *rawIterator[H, Hasher]) NextItem() (*TrieItem, error) { +func (ri *TrieDBRawIterator[H, Hasher]) NextItem() (*TrieItem, error) { for { rawItem, err := ri.nextRawItem(true) if err != nil { @@ -467,7 +468,188 @@ func (ri *rawIterator[H, Hasher]) NextItem() (*TrieItem, error) { } } +// Fetches the next key. +// +// Must be called with the same `db` as when the iterator was created. +func (ri *TrieDBRawIterator[H, Hasher]) NextKey() ([]byte, error) { + for { + rawItem, err := ri.nextRawItem(true) + if err != nil { + return nil, err + } + if rawItem == nil { + return nil, nil + } + extracted := rawItem.extractKey() + if extracted == nil { + continue + } + key := extracted.Key + maybeExtraNibble := extracted.Padding + + if maybeExtraNibble != nil { + return nil, fmt.Errorf("ValueAtIncompleteKey: %v %v", key, *maybeExtraNibble) + } + return key, nil + } +} + type TrieItem struct { Key []byte Value []byte } + +// A trie iterator that also supports random access (`seek()`). +type TrieIterator[H hash.Hash, Item any] interface { + Seek(key []byte) error + Next() (Item, error) + // Items() iter.Seq2[Item, error] +} + +// Iterator for going through all values in the trie in pre-order traversal order. +type TrieDBIterator[H hash.Hash, Hasher hash.Hasher[H]] struct { + db *TrieDB[H, Hasher] + rawIter TrieDBRawIterator[H, Hasher] +} + +// Create a new iterator. +func NewTrieDBIterator[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], +) (*TrieDBIterator[H, Hasher], error) { + rawIter, err := NewTrieDBRawIterator(db) + if err != nil { + return nil, err + } + return &TrieDBIterator[H, Hasher]{ + db: db, + rawIter: *rawIter, + }, nil +} + +// Create a new iterator, but limited to a given prefix. +func NewPrefixedTrieDBIterator[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], prefix []byte, +) (*TrieDBIterator[H, Hasher], error) { + rawIter, err := NewPrefixedTrieDBRawIterator(db, prefix) + if err != nil { + return nil, err + } + return &TrieDBIterator[H, Hasher]{ + db: db, + rawIter: *rawIter, + }, nil +} + +// Create a new iterator, but limited to a given prefix. +// It then do a seek operation from prefixed context (using `seek` lose +// prefix context by default). +func NewPrefixedTrieDBIteratorThenSeek[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], prefix []byte, startAt []byte, +) (*TrieDBIterator[H, Hasher], error) { + rawIter, err := NewPrefixedTrieDBRawIteratorThenSeek(db, prefix, startAt) + if err != nil { + return nil, err + } + return &TrieDBIterator[H, Hasher]{ + db: db, + rawIter: *rawIter, + }, nil +} + +func (tdbi *TrieDBIterator[H, Hasher]) Seek(key []byte) error { + _, err := tdbi.rawIter.seek(key, true) + return err +} + +func (tdbi *TrieDBIterator[H, Hasher]) Next() (*TrieItem, error) { + return tdbi.rawIter.NextItem() +} + +func (tdbi *TrieDBIterator[H, Hasher]) Items() iter.Seq2[TrieItem, error] { + return func(yield func(TrieItem, error) bool) { + for { + item, err := tdbi.Next() + if err != nil { + return + } + if item == nil { + return + } + if !yield(*item, err) { + return + } + } + } +} + +// Iterator for going through all of key with values in the trie in pre-order traversal order. +type TrieDBKeyIterator[H hash.Hash, Hasher hash.Hasher[H]] struct { + rawIter TrieDBRawIterator[H, Hasher] +} + +// Create a new iterator. +func NewTrieDBKeyIterator[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], +) (*TrieDBKeyIterator[H, Hasher], error) { + rawIter, err := NewTrieDBRawIterator(db) + if err != nil { + return nil, err + } + return &TrieDBKeyIterator[H, Hasher]{ + rawIter: *rawIter, + }, nil +} + +// Create a new iterator, but limited to a given prefix. +func NewPrefixedTrieDBKeyIterator[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], prefix []byte, +) (*TrieDBKeyIterator[H, Hasher], error) { + rawIter, err := NewPrefixedTrieDBRawIterator(db, prefix) + if err != nil { + return nil, err + } + return &TrieDBKeyIterator[H, Hasher]{ + rawIter: *rawIter, + }, nil +} + +// Create a new iterator, but limited to a given prefix. +// It then do a seek operation from prefixed context (using `seek` lose +// prefix context by default). +func NewPrefixedTrieDBKeyIteratorThenSeek[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], prefix []byte, startAt []byte, +) (*TrieDBKeyIterator[H, Hasher], error) { + rawIter, err := NewPrefixedTrieDBRawIteratorThenSeek(db, prefix, startAt) + if err != nil { + return nil, err + } + return &TrieDBKeyIterator[H, Hasher]{ + rawIter: *rawIter, + }, nil +} + +func (tdbki *TrieDBKeyIterator[H, Hasher]) Seek(key []byte) error { + _, err := tdbki.rawIter.seek(key, true) + return err +} + +func (tdbki *TrieDBKeyIterator[H, Hasher]) Next() ([]byte, error) { + return tdbki.rawIter.NextKey() +} + +func (tdbki *TrieDBKeyIterator[H, Hasher]) Items() iter.Seq2[[]byte, error] { + return func(yield func([]byte, error) bool) { + for { + key, err := tdbki.Next() + if err != nil { + return + } + if key == nil { + return + } + if !yield(key, err) { + return + } + } + } +} diff --git a/pkg/trie/triedb/iterator_test.go b/pkg/trie/triedb/iterator_test.go index c16c437d2e..31b342c950 100644 --- a/pkg/trie/triedb/iterator_test.go +++ b/pkg/trie/triedb/iterator_test.go @@ -1,6 +1,5 @@ // Copyright 2024 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only - package triedb import ( @@ -10,9 +9,10 @@ import ( "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func Test_rawIterator(t *testing.T) { +func Test_TrieDBRawIterator(t *testing.T) { entries := map[string][]byte{ "no": make([]byte, 1), "noot": make([]byte, 2), @@ -25,18 +25,18 @@ func Test_rawIterator(t *testing.T) { "bigbigvalue": make([]byte, 66), } - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trieDB.SetVersion(trie.V1) for k, v := range entries { - err := trieDB.Put([]byte(k), v) + err := trieDB.Set([]byte(k), v) assert.NoError(t, err) } assert.NoError(t, trieDB.commit()) t.Run("iterate_over_all_raw_items", func(t *testing.T) { - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) assert.NoError(t, err) i := 0 @@ -52,7 +52,7 @@ func Test_rawIterator(t *testing.T) { }) t.Run("iterate_over_all_entries", func(t *testing.T) { - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) assert.NoError(t, err) i := 0 @@ -70,7 +70,7 @@ func Test_rawIterator(t *testing.T) { }) t.Run("seek", func(t *testing.T) { - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) assert.NoError(t, err) found, err := iter.seek([]byte("no"), true) @@ -89,7 +89,7 @@ func Test_rawIterator(t *testing.T) { }) t.Run("seek_leaf", func(t *testing.T) { - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) assert.NoError(t, err) found, err := iter.seek([]byte("dimartiro"), true) @@ -102,8 +102,26 @@ func Test_rawIterator(t *testing.T) { assert.Equal(t, "dimartiro", string(item.Key)) }) + t.Run("seek_leaf_using_prefix", func(t *testing.T) { + iter, err := NewPrefixedTrieDBRawIterator(trieDB, []byte("dimar")) + assert.NoError(t, err) + + key, err := iter.NextKey() + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, "dimartiro", string(key)) + + key, err = iter.NextKey() + assert.NoError(t, err) + assert.Nil(t, key) + + key, err = iter.NextKey() + assert.NoError(t, err) + assert.Nil(t, key) + }) + t.Run("iterate_over_all_prefixed_entries", func(t *testing.T) { - iter, err := newPrefixedRawIterator(trieDB, []byte("no")) + iter, err := NewPrefixedTrieDBRawIterator(trieDB, []byte("no")) assert.NoError(t, err) i := 0 @@ -121,7 +139,7 @@ func Test_rawIterator(t *testing.T) { }) t.Run("prefixed_raw_iterator", func(t *testing.T) { - iter, err := newPrefixedRawIterator(trieDB, []byte("noot")) + iter, err := NewPrefixedTrieDBRawIterator(trieDB, []byte("noot")) assert.NoError(t, err) item, err := iter.NextItem() @@ -131,7 +149,7 @@ func Test_rawIterator(t *testing.T) { }) t.Run("iterate_over_all_prefixed_entries_then_seek", func(t *testing.T) { - iter, err := newPrefixedRawIteratorThenSeek(trieDB, []byte("no"), []byte("noot")) + iter, err := NewPrefixedTrieDBRawIteratorThenSeek(trieDB, []byte("no"), []byte("noot")) assert.NoError(t, err) i := 0 @@ -148,3 +166,205 @@ func Test_rawIterator(t *testing.T) { assert.Equal(t, 4, i) }) } + +func TestTrieDBIterator(t *testing.T) { + entries := map[string][]byte{ + "no": make([]byte, 1), + "noot": make([]byte, 2), + "not": make([]byte, 3), + "notable": make([]byte, 4), + "notification": make([]byte, 5), + "test": make([]byte, 6), + "dimartiro": make([]byte, 7), + "bigvalue": make([]byte, 33), + "bigbigvalue": make([]byte, 66), + } + + db := NewMemoryDB() + trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trieDB.SetVersion(trie.V1) + + for k, v := range entries { + err := trieDB.Set([]byte(k), v) + assert.NoError(t, err) + } + assert.NoError(t, trieDB.commit()) + + t.Run("iterate_over_all_keys", func(t *testing.T) { + iter, err := NewTrieDBIterator(trieDB) + require.NoError(t, err) + + i := 0 + for { + item, err := iter.Next() + require.NoError(t, err) + if item == nil { + break + } + i++ + } + require.Equal(t, len(entries), i) + }) + + t.Run("seek", func(t *testing.T) { + iter, err := NewTrieDBIterator(trieDB) + assert.NoError(t, err) + + err = iter.Seek([]byte("no")) + assert.NoError(t, err) + + item, err := iter.Next() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, []byte("no"), item.Key) + assert.Equal(t, make([]byte, 1), item.Value) + + item, err = iter.Next() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, []byte("noot"), item.Key) + assert.Equal(t, make([]byte, 2), item.Value) + }) + + t.Run("iterate_over_all_keys_using_Seq", func(t *testing.T) { + iter, err := NewTrieDBIterator(trieDB) + require.NoError(t, err) + + i := 0 + for _, err := range iter.Items() { + require.NoError(t, err) + i++ + } + require.Equal(t, len(entries), i) + }) + + t.Run("iterate_over_all_prefixed_entries", func(t *testing.T) { + iter, err := NewPrefixedTrieDBIterator(trieDB, []byte("no")) + require.NoError(t, err) + + i := 0 + for item, err := range iter.Items() { + require.NoError(t, err) + require.Equal(t, entries[string(item.Key)], item.Value) + i++ + } + require.Equal(t, 5, i) + }) + + t.Run("iterate_over_all_prefixed_entries_then_seek", func(t *testing.T) { + iter, err := NewPrefixedTrieDBIteratorThenSeek(trieDB, []byte("no"), []byte("noot")) + assert.NoError(t, err) + + i := 0 + for item, err := range iter.Items() { + assert.NoError(t, err) + require.Equal(t, entries[string(item.Key)], item.Value) + i++ + } + assert.Equal(t, 4, i) + }) +} + +func TestTrieDBKeyIterator(t *testing.T) { + entries := map[string][]byte{ + "no": make([]byte, 1), + "noot": make([]byte, 2), + "not": make([]byte, 3), + "notable": make([]byte, 4), + "notification": make([]byte, 5), + "test": make([]byte, 6), + "dimartiro": make([]byte, 7), + "bigvalue": make([]byte, 33), + "bigbigvalue": make([]byte, 66), + } + + db := NewMemoryDB() + trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trieDB.SetVersion(trie.V1) + + for k, v := range entries { + err := trieDB.Set([]byte(k), v) + assert.NoError(t, err) + } + assert.NoError(t, trieDB.commit()) + + t.Run("iterate_over_all_keys", func(t *testing.T) { + iter, err := NewTrieDBKeyIterator(trieDB) + require.NoError(t, err) + + i := 0 + for { + item, err := iter.Next() + require.NoError(t, err) + if item == nil { + break + } + i++ + } + require.Equal(t, len(entries), i) + }) + + t.Run("seek", func(t *testing.T) { + iter, err := NewTrieDBKeyIterator(trieDB) + assert.NoError(t, err) + + err = iter.Seek([]byte("no")) + assert.NoError(t, err) + + item, err := iter.Next() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, []byte("no"), item) + + item, err = iter.Next() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, []byte("noot"), item) + }) + + t.Run("iterate_over_all_keys_using_Seq", func(t *testing.T) { + iter, err := NewTrieDBKeyIterator(trieDB) + require.NoError(t, err) + + i := 0 + for key, err := range iter.Items() { + require.NoError(t, err) + if key != nil { + i++ + } + } + require.Equal(t, len(entries), i) + }) + + t.Run("iterate_over_all_prefixed_entries", func(t *testing.T) { + iter, err := NewPrefixedTrieDBKeyIterator(trieDB, []byte("no")) + require.NoError(t, err) + + i := 0 + for item, err := range iter.Items() { + require.NoError(t, err) + if item == nil { + continue + } + require.Contains(t, entries, string(item)) + i++ + } + require.Equal(t, 5, i) + }) + + t.Run("iterate_over_all_prefixed_entries_then_seek", func(t *testing.T) { + iter, err := NewPrefixedTrieDBKeyIteratorThenSeek(trieDB, []byte("no"), []byte("noot")) + assert.NoError(t, err) + + i := 0 + for item, err := range iter.Items() { + assert.NoError(t, err) + if item == nil { + continue + } + require.Contains(t, entries, string(item)) + i++ + } + assert.Equal(t, 4, i) + }) +} diff --git a/pkg/trie/triedb/lookup.go b/pkg/trie/triedb/lookup.go index b42ad26f06..f6577e2048 100644 --- a/pkg/trie/triedb/lookup.go +++ b/pkg/trie/triedb/lookup.go @@ -8,8 +8,8 @@ import ( "fmt" "slices" + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" "github.com/ChainSafe/gossamer/pkg/trie" - "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" @@ -21,7 +21,7 @@ type Query[Item any] func(data []byte) Item // Trie lookup helper object. type TrieLookup[H hash.Hash, Hasher hash.Hasher[H], QueryItem any] struct { // db to query from - db db.DBGetter + db hashdb.HashDB[H] // hash to start at hash H // optional cache to speed up the db lookups @@ -36,7 +36,7 @@ type TrieLookup[H hash.Hash, Hasher hash.Hasher[H], QueryItem any] struct { // NewTrieLookup is constructor for [TrieLookup] func NewTrieLookup[H hash.Hash, Hasher hash.Hasher[H], QueryItem any]( - db db.DBGetter, + db hashdb.HashDB[H], hash H, cache TrieCache[H], recorder TrieRecorder, @@ -57,6 +57,158 @@ func (l *TrieLookup[H, Hasher, QueryItem]) recordAccess(access TrieAccess) { } } +// Look up the merkle value (hash) of the node that is the closest descendant for the provided +// key. +// +// When the provided key leads to a node, then the merkle value (hash) of that node +// is returned. However, if the key does not lead to a node, then the merkle value +// of the closest descendant is returned. `None` if no such descendant exists. +func (l *TrieLookup[H, Hasher, QueryItem]) LookupFirstDescendant( + fullKey []byte, nibbleKey nibbles.Nibbles, +) (MerkleValue[H], error) { + partial := nibbleKey + hash := l.hash + var keyNibbles uint + + // this loop iterates through non-inline nodes. + var depth uint + for { + var nodeData []byte + + var getCachedNode = func() (CachedNode[H], error) { + data := l.db.Get(hash, hashdb.Prefix(nibbleKey.Mid(keyNibbles).Left())) + if data == nil { + if depth == 0 { + return nil, ErrInvalidStateRoot + } else { + return nil, ErrIncompleteDB + } + } + + reader := bytes.NewReader(data) + decoded, err := codec.Decode[H](reader) + if err != nil { + return nil, err + } + + owned, err := NewCachedNodeFromNode[H, Hasher](decoded) + if err != nil { + return nil, err + } + nodeData = data + return owned, nil + } + + var node CachedNode[H] + if l.cache != nil { + n, err := l.cache.GetOrInsertNode(hash, getCachedNode) + if err != nil { + return nil, err + } + + l.recordAccess(CachedNodeAccess[H]{Hash: hash, Node: node}) + node = n + } else { + n, err := getCachedNode() + if err != nil { + return nil, err + } + + l.recordAccess(EncodedNodeAccess[H]{Hash: hash, EncodedNode: nodeData}) + node = n + } + + // this loop iterates through all inline children (usually max 1) + // without incrementing the depth. + var isInline bool + inlineLoop: + for { + var nextNode CachedNodeHandle + switch node := node.(type) { + case LeafCachedNode[H]: + // The leaf slice can be longer than remainder of the provided key + // (descendent), but not the other way around. + if !node.PartialKey.StartsWithNibbles(partial) { + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + return nil, nil //nolint:nilnil + } + + if partial.Len() != node.PartialKey.Len() { + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + } + + if isInline { + return NodeMerkleValue(nodeData), nil + } + return HashMerkleValue[H]{Hash: hash}, nil + case BranchCachedNode[H]: + // Not enough remainder key to continue the search. + if partial.Len() < node.PartialKey.Len() { + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + + // Branch slice starts with the remainder key, there's nothing to + // advance. + if node.PartialKey.StartsWithNibbles(partial) { + if isInline { + return NodeMerkleValue(nodeData), nil + } + return HashMerkleValue[H]{Hash: hash}, nil + } + return nil, nil //nolint:nilnil + } + + // Partial key is longer or equal than the branch slice. + // Ensure partial key starts with the branch slice. + if !partial.StartsWithNibbleSlice(node.PartialKey) { + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + return nil, nil //nolint:nilnil + } + + // Partial key starts with the branch slice. + if partial.Len() == node.PartialKey.Len() { + if node.Value != nil { + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + } + + if isInline { + return NodeMerkleValue(nodeData), nil + } + return HashMerkleValue[H]{Hash: hash}, nil + } + + child := node.Children[partial.At(node.PartialKey.Len())] + if child != nil { + partial = partial.Mid(node.PartialKey.Len() + 1) + keyNibbles += node.PartialKey.Len() + 1 + nextNode = child + } else { + l.recordAccess(NonExistingNodeAccess{fullKey}) + return nil, nil //nolint:nilnil + } + case EmptyCachedNode[H]: + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + return nil, nil //nolint:nilnil + default: + panic("unreachable") + } + + // check if new node data is inline or hash. + switch nextNode := nextNode.(type) { + case HashCachedNodeHandle[H]: + hash = nextNode.Hash + break inlineLoop + case InlineCachedNodeHandle[H]: + node = nextNode.CachedNode + isInline = true + default: + panic("unreachable") + } + } + + depth++ + } +} + // Look up the given fullKey. // If the value is found, it will be passed to the [Query] associated to [TrieLookup]. // @@ -192,7 +344,7 @@ type loadCachedNodeValueFunc[H hash.Hash, R any] func( prefix nibbles.Prefix, fullKey []byte, cache TrieCache[H], - db db.DBGetter, + db hashdb.HashDB[H], recorder TrieRecorder, ) (R, error) @@ -212,9 +364,8 @@ func lookupWithCacheInternal[H hash.Hash, Hasher hash.Hasher[H], R, QueryItem an var depth uint for { node, err := cache.GetOrInsertNode(hash, func() (CachedNode[H], error) { - prefixedKey := append(nibbleKey.Mid(keyNibbles).Left().JoinedBytes(), hash.Bytes()...) - nodeData, err := l.db.Get(prefixedKey) - if err != nil { + nodeData := l.db.Get(hash, hashdb.Prefix(nibbleKey.Mid(keyNibbles).Left())) + if nodeData == nil { if depth == 0 { return nil, ErrInvalidStateRoot } else { @@ -227,7 +378,7 @@ func lookupWithCacheInternal[H hash.Hash, Hasher hash.Hasher[H], R, QueryItem an return nil, err } - return newCachedNodeFromNode[H, Hasher](decoded) + return NewCachedNodeFromNode[H, Hasher](decoded) }) if err != nil { return nil, err @@ -307,7 +458,7 @@ type loadValueFunc[H hash.Hash, QueryItem, R any] func( v codec.EncodedValue, prefix nibbles.Prefix, fullKey []byte, - db db.DBGetter, + db hashdb.HashDB[H], recorder TrieRecorder, query Query[QueryItem], ) (R, error) @@ -329,9 +480,8 @@ func lookupWithoutCache[H hash.Hash, Hasher hash.Hasher[H], QueryItem, R any]( var depth uint for { - prefixedKey := append(nibbleKey.Mid(keyNibbles).Left().JoinedBytes(), hash.Bytes()...) - nodeData, err := l.db.Get(prefixedKey) - if err != nil { + nodeData := l.db.Get(hash, hashdb.Prefix(nibbleKey.Mid(keyNibbles).Left())) + if nodeData == nil { if depth == 0 { return nil, ErrInvalidStateRoot } else { @@ -408,6 +558,7 @@ func lookupWithoutCache[H hash.Hash, Hasher hash.Hasher[H], QueryItem, R any]( } case codec.Empty: l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) + return nil, nil default: panic("unreachable") } @@ -454,7 +605,7 @@ func loadCachedNodeValue[H hash.Hash]( prefix nibbles.Prefix, fullKey []byte, cache TrieCache[H], - db db.DBGetter, + db hashdb.HashDB[H], recorder TrieRecorder, ) (valueHash[H], error) { switch v := v.(type) { @@ -465,10 +616,9 @@ func loadCachedNodeValue[H hash.Hash]( return valueHash[H](v), nil case NodeCachedNodeValue[H]: node, err := cache.GetOrInsertNode(v.Hash, func() (CachedNode[H], error) { - prefixedKey := append(prefix.JoinedBytes(), v.Hash.Bytes()...) - val, err := db.Get(prefixedKey) - if err != nil { - return nil, err + val := db.Get(v.Hash, hashdb.Prefix(prefix)) + if val == nil { + return nil, ErrIncompleteDB } return ValueCachedNode[H]{Value: val, Hash: v.Hash}, nil }) @@ -511,7 +661,7 @@ func loadValue[H hash.Hash, QueryItem any]( v codec.EncodedValue, prefix nibbles.Prefix, fullKey []byte, - db db.DBGetter, + db hashdb.HashDB[H], recorder TrieRecorder, query Query[QueryItem], ) (qi QueryItem, err error) { @@ -522,13 +672,9 @@ func loadValue[H hash.Hash, QueryItem any]( } return query(v), nil case codec.HashedValue[H]: - prefixedKey := append(prefix.JoinedBytes(), v.Hash.Bytes()...) - val, err := db.Get(prefixedKey) - if err != nil { - return qi, err - } + val := db.Get(v.Hash, hashdb.Prefix(prefix)) if val == nil { - return qi, fmt.Errorf("%w: %s", ErrIncompleteDB, prefixedKey) + return qi, fmt.Errorf("%w: %s", ErrIncompleteDB, fullKey) } if recorder != nil { @@ -559,7 +705,7 @@ func (l *TrieLookup[H, Hasher, QueryItem]) LookupHash(fullKey []byte) (*H, error v codec.EncodedValue, _ nibbles.Prefix, fullKey []byte, - _ db.DBGetter, + _ hashdb.HashDB[H], recorder TrieRecorder, _ Query[QueryItem], ) (H, error) { @@ -619,7 +765,7 @@ func (l *TrieLookup[H, Hasher, QueryItem]) lookupHashWithCache( _ nibbles.Prefix, fullKey []byte, _ TrieCache[H], - _ db.DBGetter, + _ hashdb.HashDB[H], recorder TrieRecorder, ) (valueHash[H], error) { switch value := value.(type) { diff --git a/pkg/trie/triedb/lookup_test.go b/pkg/trie/triedb/lookup_test.go index 76860062b6..8bfdb6040b 100644 --- a/pkg/trie/triedb/lookup_test.go +++ b/pkg/trie/triedb/lookup_test.go @@ -6,6 +6,7 @@ package triedb import ( "testing" + memorydb "github.com/ChainSafe/gossamer/internal/memory-db" "github.com/ChainSafe/gossamer/internal/primitives/core/hash" "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" @@ -16,9 +17,11 @@ import ( func TestTrieDB_Lookup(t *testing.T) { t.Run("root_not_exists_in_db", func(t *testing.T) { - db := newTestDB(t) + db := memorydb.NewMemoryDB[ + hash.H256, runtime.BlakeTwo256, hash.H256, memorydb.HashKey[hash.H256], + ]([]byte("not0")) empty := runtime.BlakeTwo256{}.Hash([]byte{0}) - lookup := NewTrieLookup[hash.H256, runtime.BlakeTwo256, []byte](db, empty, nil, nil, nil) + lookup := NewTrieLookup[hash.H256, runtime.BlakeTwo256, []byte](&db, empty, nil, nil, nil) value, err := lookup.Lookup([]byte("test")) assert.Nil(t, value) @@ -39,7 +42,7 @@ func (*trieCacheImpl) GetNode(hash hash.H256) CachedNode[hash.H256] { return nil func Test_TrieLookup_lookupValueWithCache(t *testing.T) { cache := &trieCacheImpl{} - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256]( inmemoryDB, WithCache[hash.H256, runtime.BlakeTwo256](cache), @@ -57,7 +60,7 @@ func Test_TrieLookup_lookupValueWithCache(t *testing.T) { } for k, v := range entries { - require.NoError(t, trieDB.Put([]byte(k), v)) + require.NoError(t, trieDB.Set([]byte(k), v)) } err := trieDB.commit() diff --git a/pkg/trie/triedb/nibbles/nibbles.go b/pkg/trie/triedb/nibbles/nibbles.go index 690cffb727..b63ae89ade 100644 --- a/pkg/trie/triedb/nibbles/nibbles.go +++ b/pkg/trie/triedb/nibbles/nibbles.go @@ -100,7 +100,7 @@ func (n Nibbles) Left() Prefix { split := n.offset / NibblesPerByte ix := uint8(n.offset % NibblesPerByte) if ix == 0 { - return Prefix{Key: n.data[:split]} + return Prefix{Key: slices.Clone(n.data[:split])} } padded := PadLeft(n.data[split]) return Prefix{Key: slices.Clone(n.data[:split]), Padded: &padded} diff --git a/pkg/trie/triedb/nibbles/nibbleslice.go b/pkg/trie/triedb/nibbles/nibbleslice.go index 21aa282a62..81986e2a66 100644 --- a/pkg/trie/triedb/nibbles/nibbleslice.go +++ b/pkg/trie/triedb/nibbles/nibbleslice.go @@ -200,7 +200,7 @@ func (n NibbleSlice) asNibbles() *Nibbles { return nil } -// Return an iterator over [NibbleSlice] bytes representation. +// Return an iterator over [Partial] bytes representation. func (n NibbleSlice) Right() []byte { requirePadding := n.Len()%NibblesPerByte != 0 var ix uint @@ -240,3 +240,20 @@ func (n NibbleSlice) NodeKey() NodeKey { func (n NibbleSlice) Inner() []byte { return n.inner } + +func (n NibbleSlice) StartsWithNibbles(other Nibbles) bool { + if n.Len() < other.Len() { + return false + } + + nAsNibbles := n.asNibbles() + if nAsNibbles != nil { + return nAsNibbles.StartsWith(other) + } + for i := uint(0); i < other.Len(); i++ { + if n.At(i) != other.At(i) { + return false + } + } + return true +} diff --git a/pkg/trie/triedb/node.go b/pkg/trie/triedb/node.go index ad07800a3c..5d8fa88b9b 100644 --- a/pkg/trie/triedb/node.go +++ b/pkg/trie/triedb/node.go @@ -8,8 +8,8 @@ import ( "fmt" "io" + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" "github.com/ChainSafe/gossamer/pkg/scale" - "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" @@ -30,7 +30,7 @@ type ( // newValueRef is a value that will be stored in the db newValueRef[H hash.Hash] struct { - hash H + hash *H data []byte } ) @@ -39,38 +39,39 @@ type ( func newEncodedValue[H hash.Hash]( value nodeValue, partial *nibbles.Nibbles, childF onChildStoreFn, ) (codec.EncodedValue, error) { - switch v := value.(type) { - case inline: - return codec.InlineValue(v), nil - case valueRef[H]: - return codec.HashedValue[H]{Hash: v.hash}, nil - case newValueRef[H]: - // Store value in db + v, ok := value.(newValueRef[H]) + if ok { childRef, err := childF(newNodeToEncode{value: v.data}, partial, nil) if err != nil { return nil, err } - - // Check and get new new value hash + var newHash H switch cr := childRef.(type) { case HashChildReference[H]: - empty := *new(H) - if cr.Hash == empty { - panic("new external value are always added before encoding a node") - } - - if v.hash != empty { - if v.hash != cr.Hash { - panic("hash mismatch") - } - } else { - v.hash = cr.Hash - } + newHash = cr.Hash default: panic("value node can never be inlined") } + if v.hash != nil { + if *v.hash != newHash { + panic("shouldn't happen") + } + } else { + v.hash = &newHash + value = v + } + } + switch v := value.(type) { + case inline: + return codec.InlineValue(v), nil + case valueRef[H]: return codec.HashedValue[H]{Hash: v.hash}, nil + case newValueRef[H]: + if v.hash == nil { + panic("New external value are always added before encoding a node") + } + return codec.HashedValue[H]{Hash: *v.hash}, nil default: panic("unreachable") } @@ -95,7 +96,7 @@ func (vr valueRef[H]) equal(other nodeValue) bool { } } -func (vr newValueRef[H]) getHash() H { +func (vr newValueRef[H]) getHash() *H { return vr.hash } func (vr newValueRef[H]) equal(other nodeValue) bool { @@ -108,9 +109,9 @@ func (vr newValueRef[H]) equal(other nodeValue) bool { } func NewValue[H hash.Hash](data []byte, threshold int) nodeValue { - if len(data) >= threshold { + if len(data) > threshold { return newValueRef[H]{ - hash: *new(H), + hash: nil, data: data, } } @@ -140,22 +141,27 @@ func newValueFromCachedNodeValue[H hash.Hash](val CachedNodeValue[H]) nodeValue } } -func inMemoryFetchedValue[H hash.Hash](value nodeValue, prefix []byte, db db.DBGetter) ([]byte, error) { +func inMemoryFetchedValue[H hash.Hash]( + value nodeValue, + prefix nibbles.Prefix, + db hashdb.HashDB[H], + recorder TrieRecorder, + fullKey []byte, +) ([]byte, error) { switch v := value.(type) { case inline: return v, nil case newValueRef[H]: return v.data, nil case valueRef[H]: - prefixedKey := bytes.Join([][]byte{prefix, v.hash.Bytes()}, nil) - value, err := db.Get(prefixedKey) - if err != nil { - return nil, err + value := db.Get(v.hash, hashdb.Prefix(prefix)) + if value == nil { + return nil, ErrIncompleteDB } - if value != nil { - return value, nil + if recorder != nil { + recorder.Record(ValueAccess[H]{Hash: v.hash, Value: value, FullKey: fullKey}) } - return value, ErrIncompleteDB + return value, nil default: panic("unreachable") } diff --git a/pkg/trie/triedb/node_storage.go b/pkg/trie/triedb/node_storage.go index a67fe75ab3..728175313e 100644 --- a/pkg/trie/triedb/node_storage.go +++ b/pkg/trie/triedb/node_storage.go @@ -11,7 +11,7 @@ import ( var EmptyNode = []byte{0} -// StorageHandle is a pointer to a node contained in `NodeStorage` +// StorageHandle is a pointer to a node contained in nodeStorage type storageHandle int // NodeHandle is an enum for the different types of nodes that can be stored in @@ -119,7 +119,7 @@ func (ns *nodeStorage[H]) destroy(handle storageHandle) StoredNode { idx := int(handle) ns.freeIndices.PushBack(idx) oldNode := ns.nodes[idx] - ns.nodes[idx] = nil + ns.nodes[idx] = NewStoredNode{Empty{}} return oldNode } diff --git a/pkg/trie/triedb/proof/generate_test.go b/pkg/trie/triedb/proof/generate_test.go index f70a934730..ce1b323ef4 100644 --- a/pkg/trie/triedb/proof/generate_test.go +++ b/pkg/trie/triedb/proof/generate_test.go @@ -140,11 +140,11 @@ func Test_NewProof(t *testing.T) { for name, testCase := range testCases { t.Run(name, func(t *testing.T) { // Build trie - inmemoryDB := NewMemoryDB(triedb.EmptyNode) + inmemoryDB := NewMemoryDB() triedb := triedb.NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.entries { - triedb.Put(entry.Key, entry.Value) + triedb.Set(entry.Key, entry.Value) } root := triedb.MustHash() diff --git a/pkg/trie/triedb/proof/proof.go b/pkg/trie/triedb/proof/proof.go index faaed8adf1..80fd7c8d2c 100644 --- a/pkg/trie/triedb/proof/proof.go +++ b/pkg/trie/triedb/proof/proof.go @@ -7,8 +7,8 @@ import ( "bytes" "errors" + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" "github.com/ChainSafe/gossamer/pkg/trie" - "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/pkg/trie/triedb" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" @@ -18,7 +18,7 @@ import ( type MerkleProof[H hash.Hash, Hasher hash.Hasher[H]] [][]byte func NewMerkleProof[H hash.Hash, Hasher hash.Hasher[H]]( - db db.RWDatabase, trieVersion trie.TrieLayout, rootHash H, keys []string) ( + db hashdb.HashDB[H], trieVersion trie.TrieLayout, rootHash H, keys []string) ( proof MerkleProof[H, Hasher], err error) { // Sort and deduplicate keys keys = sortAndDeduplicateKeys(keys) @@ -44,7 +44,10 @@ func NewMerkleProof[H hash.Hash, Hasher hash.Hasher[H]]( recorder := triedb.NewRecorder[H]() trie := triedb.NewTrieDB[H, Hasher](rootHash, db, triedb.WithRecorder[H, Hasher](recorder)) trie.SetVersion(trieVersion) - trie.Get(key) + _, err = trie.Get(key) + if err != nil { + return nil, err + } recordedNodes := NewIterator(recorder.Drain()) diff --git a/pkg/trie/triedb/proof/proof_test.go b/pkg/trie/triedb/proof/proof_test.go index eded85a722..7fc295bb89 100644 --- a/pkg/trie/triedb/proof/proof_test.go +++ b/pkg/trie/triedb/proof/proof_test.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/triedb" + "github.com/stretchr/testify/require" ) @@ -142,12 +143,12 @@ func Test_GenerateAndVerify(t *testing.T) { for _, trieVersion := range trieVersions { t.Run(fmt.Sprintf("%s_%s", name, trieVersion.String()), func(t *testing.T) { // Build trie - inmemoryDB := NewMemoryDB(triedb.EmptyNode) + inmemoryDB := NewMemoryDB() triedb := triedb.NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) triedb.SetVersion(trieVersion) for _, entry := range testCase.entries { - triedb.Put(entry.Key, entry.Value) + triedb.Set(entry.Key, entry.Value) } root := triedb.MustHash() @@ -160,9 +161,11 @@ func Test_GenerateAndVerify(t *testing.T) { // Verify proof items := make([]proofItem, len(testCase.keys)) for i, key := range testCase.keys { + val, err := triedb.Get([]byte(key)) + require.NoError(t, err) items[i] = proofItem{ key: []byte(key), - value: triedb.Get([]byte(key)), + value: val, } } err = proof.Verify(trieVersion, root.Bytes(), items) diff --git a/pkg/trie/triedb/proof/util_test.go b/pkg/trie/triedb/proof/util_test.go index 2cb0c8e94b..0e41363f1c 100644 --- a/pkg/trie/triedb/proof/util_test.go +++ b/pkg/trie/triedb/proof/util_test.go @@ -4,87 +4,15 @@ package proof import ( - "bytes" - - "github.com/ChainSafe/gossamer/internal/database" + memorydb "github.com/ChainSafe/gossamer/internal/memory-db" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" "github.com/ChainSafe/gossamer/internal/primitives/runtime" - "github.com/ChainSafe/gossamer/pkg/trie/db" ) -// MemoryDB is an in-memory implementation of the Database interface backed by a -// map. It uses blake2b as hashing algorithm -type MemoryDB struct { - data map[string][]byte - hashedNullNode []byte - nullNodeData []byte -} - -func memoryDBFromNullNode(nullKey, nullNodeData []byte) *MemoryDB { - return &MemoryDB{ - data: make(map[string][]byte), - hashedNullNode: runtime.BlakeTwo256{}.Hash(nullKey).Bytes(), - nullNodeData: nullNodeData, - } -} - -func NewMemoryDB(data []byte) *MemoryDB { - return memoryDBFromNullNode(data, data) -} - -func (db *MemoryDB) emplace(key []byte, value []byte) { - if bytes.Equal(value, db.nullNodeData) { - return - } - - db.data[string(key)] = value -} - -func (db *MemoryDB) Get(key []byte) ([]byte, error) { - dbKey := key - if bytes.Equal(dbKey, db.hashedNullNode) { - return db.nullNodeData, nil - } - if value, has := db.data[string(dbKey)]; has { - return value, nil - } - - return nil, nil -} - -func (db *MemoryDB) Put(key []byte, value []byte) error { - dbKey := key - db.emplace(dbKey, value) - return nil -} - -func (db *MemoryDB) Del(key []byte) error { - dbKey := key - delete(db.data, string(dbKey)) - return nil -} - -func (db *MemoryDB) Flush() error { - return nil -} - -func (db *MemoryDB) NewBatch() database.Batch { - return &MemoryBatch{db} -} - -var _ db.RWDatabase = &MemoryDB{} - -type MemoryBatch struct { - *MemoryDB -} - -func (b *MemoryBatch) Close() error { - return nil +func NewMemoryDB() *memorydb.MemoryDB[ + hash.H256, runtime.BlakeTwo256, hash.H256, memorydb.HashKey[hash.H256]] { + db := memorydb.NewMemoryDB[ + hash.H256, runtime.BlakeTwo256, hash.H256, memorydb.HashKey[hash.H256], + ]([]byte{0}) + return &db } - -func (*MemoryBatch) Reset() {} - -func (b *MemoryBatch) ValueSize() int { - return 1 -} - -var _ database.Batch = &MemoryBatch{} diff --git a/pkg/trie/triedb/recorder.go b/pkg/trie/triedb/recorder.go index 591958d25d..bfb0f9b7ad 100644 --- a/pkg/trie/triedb/recorder.go +++ b/pkg/trie/triedb/recorder.go @@ -146,7 +146,7 @@ func (r *Recorder[H]) Record(access TrieAccess) { case EncodedNodeAccess[H]: r.nodes = append(r.nodes, Record[H]{Hash: a.Hash, Data: a.EncodedNode}) case CachedNodeAccess[H]: - r.nodes = append(r.nodes, Record[H]{Hash: a.Hash, Data: a.Node.encoded()}) + r.nodes = append(r.nodes, Record[H]{Hash: a.Hash, Data: a.Node.Encoded()}) case ValueAccess[H]: r.nodes = append(r.nodes, Record[H]{Hash: a.Hash, Data: a.Value}) r.recordedKeys[string(a.FullKey)] = RecordedValue diff --git a/pkg/trie/triedb/recorder_test.go b/pkg/trie/triedb/recorder_test.go index 11cc375589..5ac62b8f2d 100644 --- a/pkg/trie/triedb/recorder_test.go +++ b/pkg/trie/triedb/recorder_test.go @@ -14,15 +14,14 @@ import ( // Tests results are based on // https://github.com/dimartiro/substrate-trie-test/blob/master/src/substrate_trie_test.rs func TestRecorder(t *testing.T) { - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) - + inmemoryDB := NewMemoryDB() triedb := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) - triedb.Put([]byte("pol"), []byte("polvalue")) - triedb.Put([]byte("polka"), []byte("polkavalue")) - triedb.Put([]byte("polkadot"), []byte("polkadotvalue")) - triedb.Put([]byte("go"), []byte("govalue")) - triedb.Put([]byte("gossamer"), []byte("gossamervalue")) + triedb.Set([]byte("pol"), []byte("polvalue")) + triedb.Set([]byte("polka"), []byte("polkavalue")) + triedb.Set([]byte("polkadot"), []byte("polkadotvalue")) + triedb.Set([]byte("go"), []byte("govalue")) + triedb.Set([]byte("gossamer"), []byte("gossamervalue")) // Commit and get root root := triedb.MustHash() diff --git a/pkg/trie/triedb/triedb.go b/pkg/trie/triedb/triedb.go index 0cacf23f25..7e7314ab8f 100644 --- a/pkg/trie/triedb/triedb.go +++ b/pkg/trie/triedb/triedb.go @@ -7,11 +7,11 @@ import ( "bytes" "errors" "fmt" + "slices" "github.com/ChainSafe/gossamer/pkg/trie" - "github.com/ChainSafe/gossamer/pkg/trie/db" - "github.com/ChainSafe/gossamer/internal/database" + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" @@ -40,11 +40,18 @@ func WithRecorder[H hash.Hash, Hasher hash.Hasher[H]](r TrieRecorder) TrieDBOpts } } +type TrieLayout = trie.TrieLayout + +var ( + V0 = trie.V0 + V1 = trie.V1 +) + // TrieDB is a DB-backed patricia merkle trie implementation // using lazy loading to fetch nodes type TrieDB[H hash.Hash, Hasher hash.Hasher[H]] struct { rootHash H - db db.RWDatabase + db hashdb.HashDB[H] version trie.TrieLayout // rootHandle is an in-memory-trie-like representation of the node // references and new inserted nodes in the trie @@ -54,7 +61,7 @@ type TrieDB[H hash.Hash, Hasher hash.Hasher[H]] struct { storage nodeStorage[H] // deathRow is a set of nodes that we want to delete from db // uses string since it's comparable []byte - deathRow map[string]interface{} + deathRow map[string]hashPrefix[H] // Optional cache to speed up the db lookups cache TrieCache[H] // Optional recorder for recording trie accesses @@ -62,15 +69,20 @@ type TrieDB[H hash.Hash, Hasher hash.Hasher[H]] struct { } func NewEmptyTrieDB[H hash.Hash, Hasher hash.Hasher[H]]( - db db.RWDatabase, opts ...TrieDBOpts[H, Hasher]) *TrieDB[H, Hasher] { + db hashdb.HashDB[H], opts ...TrieDBOpts[H, Hasher]) *TrieDB[H, Hasher] { hasher := *new(Hasher) root := hasher.Hash([]byte{0}) return NewTrieDB[H, Hasher](root, db, opts...) } +type hashPrefix[H hash.Hash] struct { + Hash H + nibbles.Prefix +} + // NewTrieDB creates a new TrieDB using the given root and db func NewTrieDB[H hash.Hash, Hasher hash.Hasher[H]]( - rootHash H, db db.RWDatabase, opts ...TrieDBOpts[H, Hasher]) *TrieDB[H, Hasher] { + rootHash H, db hashdb.HashDB[H], opts ...TrieDBOpts[H, Hasher]) *TrieDB[H, Hasher] { rootHandle := persisted[H]{rootHash} trieDB := &TrieDB[H, Hasher]{ @@ -79,7 +91,7 @@ func NewTrieDB[H hash.Hash, Hasher hash.Hasher[H]]( db: db, storage: newNodeStorage[H](), rootHandle: rootHandle, - deathRow: make(map[string]interface{}), + deathRow: make(map[string]hashPrefix[H]), } for _, opt := range opts { @@ -89,7 +101,7 @@ func NewTrieDB[H hash.Hash, Hasher hash.Hasher[H]]( return trieDB } -func (t *TrieDB[H, Hasher]) SetVersion(v trie.TrieLayout) { +func (t *TrieDB[H, Hasher]) SetVersion(v TrieLayout) { if v < t.version { panic("cannot regress trie version") } @@ -123,17 +135,13 @@ func (t *TrieDB[H, Hasher]) MustHash() H { // Get returns the value in the node of the trie // which matches its key with the key given. // Note the key argument is given in little Endian format. -func (t *TrieDB[H, Hasher]) Get(key []byte) []byte { - val, err := t.lookup(key, t.rootHandle) - if err != nil { - return nil - } - - return val +func (t *TrieDB[H, Hasher]) Get(key []byte) ([]byte, error) { + return t.lookup(key, t.rootHandle) } func (t *TrieDB[H, Hasher]) lookup(fullKey []byte, handle NodeHandle) ([]byte, error) { - prefix := fullKey + // prefix only use for value node access, so this is always correct. + prefix := nibbles.Prefix{Key: fullKey} partialKey := nibbles.NewNibbles(fullKey) for { var partialIdx uint @@ -162,14 +170,14 @@ func (t *TrieDB[H, Hasher]) lookup(fullKey []byte, handle NodeHandle) ([]byte, e return nil, nil case Leaf[H]: if nibbles.NewNibblesFromNodeKey(n.partialKey).Equal(partialKey) { - return inMemoryFetchedValue[H](n.value, prefix, t.db) + return inMemoryFetchedValue[H](n.value, prefix, t.db, t.recorder, fullKey) } else { return nil, nil } case Branch[H]: slice := nibbles.NewNibblesFromNodeKey(n.partialKey) if slice.Equal(partialKey) { - return inMemoryFetchedValue[H](n.value, prefix, t.db) + return inMemoryFetchedValue[H](n.value, prefix, t.db, t.recorder, fullKey) } else if partialKey.StartsWith(slice) { idx := partialKey.At(slice.Len()) child := n.children[idx] @@ -195,13 +203,8 @@ func (t *TrieDB[H, Hasher]) getNodeOrLookup( var nodeData []byte switch nodeHandle := nodeHandle.(type) { case codec.HashedNode[H]: - prefixedKey := append(partialKey.JoinedBytes(), nodeHandle.Hash.Bytes()...) - var err error - nodeData, err = t.db.Get(prefixedKey) - if err != nil { - return nil, nil, err - } - if len(nodeData) == 0 { + nodeData = t.db.Get(nodeHandle.Hash, hashdb.Prefix(partialKey)) + if nodeData == nil { if partialKey.Key == nil && partialKey.Padded == nil { return nil, nil, fmt.Errorf("%w: %v", ErrInvalidStateRoot, nodeHandle.Hash) } @@ -219,18 +222,14 @@ func (t *TrieDB[H, Hasher]) getNodeOrLookup( return nil, nil, err } - if recordAccess { - t.recordAccess(EncodedNodeAccess[H]{Hash: t.rootHash, EncodedNode: nodeData}) + if recordAccess && nodeHash != nil { + t.recordAccess(EncodedNodeAccess[H]{Hash: *nodeHash, EncodedNode: nodeData}) } return decoded, nodeHash, nil } func (t *TrieDB[H, Hasher]) fetchValue(hash H, prefix nibbles.Prefix) ([]byte, error) { - prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) - value, err := t.db.Get(prefixedKey) - if err != nil { - return nil, err - } + value := t.db.Get(hash, hashdb.Prefix(prefix)) if value == nil { return nil, fmt.Errorf("%w: %v", ErrIncompleteDB, hash) } @@ -276,13 +275,15 @@ func (t *TrieDB[H, Hasher]) insert(keyNibbles nibbles.Nibbles, value []byte) err return nil } -// Put inserts the given key / value pair into the trie -func (t *TrieDB[H, Hasher]) Put(key, value []byte) error { - return t.insert(nibbles.NewNibbles(key), value) +// Set inserts the given key / value pair into the trie +func (t *TrieDB[H, Hasher]) Set(key, value []byte) error { + copiedKey := append([]byte{}, key...) + copiedValue := append([]byte{}, value...) + return t.insert(nibbles.NewNibbles(copiedKey), copiedValue) } // insertAt inserts the given key / value pair into the node referenced by the -// node handle `handle` +// node handle func (t *TrieDB[H, Hasher]) insertAt( handle NodeHandle, keyNibbles *nibbles.Nibbles, @@ -358,7 +359,7 @@ type inspectResult struct { changed bool } -// inspect inspects the given node `stored` and calls the `inspector` function +// inspect inspects the given node stored and calls the inspector function // then returns the new node and a boolean indicating if the node has changed func (t *TrieDB[H, Hasher]) inspect( stored StoredNode, @@ -366,7 +367,7 @@ func (t *TrieDB[H, Hasher]) inspect( inspector func(Node, *nibbles.Nibbles) (action, error), ) (*inspectResult, error) { // shallow copy since key will change offset through inspector - currentKey := *key + currentKey := key.Clone() switch n := stored.(type) { case NewStoredNode: res, err := inspector(n.node, key) @@ -393,11 +394,11 @@ func (t *TrieDB[H, Hasher]) inspect( return &inspectResult{CachedStoredNode[H]{a.node, n.hash}, false}, nil case replaceNode: prefixedKey := append(currentKey.Left().JoinedBytes(), n.hash.Bytes()...) - t.deathRow[string(prefixedKey)] = nil + t.deathRow[string(prefixedKey)] = hashPrefix[H]{Hash: n.hash, Prefix: currentKey.Left()} return &inspectResult{NewStoredNode(a), true}, nil case deleteNode: prefixedKey := append(currentKey.Left().JoinedBytes(), n.hash.Bytes()...) - t.deathRow[string(prefixedKey)] = nil + t.deathRow[string(prefixedKey)] = hashPrefix[H]{Hash: n.hash, Prefix: currentKey.Left()} return nil, nil default: panic("unreachable") @@ -489,7 +490,7 @@ func (t *TrieDB[H, Hasher]) fix(branch Branch[H], key nibbles.Nibbles) (Node, er childNode = n.node case CachedStoredNode[H]: prefixedKey := append(childPrefix.JoinedBytes(), n.hash.Bytes()...) - t.deathRow[string(prefixedKey)] = nil + t.deathRow[string(prefixedKey)] = hashPrefix[H]{Hash: n.hash, Prefix: childPrefix} childNode = n.node } @@ -534,7 +535,7 @@ func combineKey(start nodeKey, end nodeKey) nodeKey { return start } -// removeInspector removes the key node from the given node `stored` +// removeInspector removes the key node from the given node stored func (t *TrieDB[H, Hasher]) removeInspector( stored Node, keyNibbles *nibbles.Nibbles, oldValue *nodeValue, ) (action, error) { @@ -622,7 +623,7 @@ func (t *TrieDB[H, Hasher]) removeInspector( } } -// insertInspector inserts the new key / value pair into the given node `stored` +// insertInspector inserts the new key / value pair into the given node stored func (t *TrieDB[H, Hasher]) insertInspector( stored Node, keyNibbles *nibbles.Nibbles, value []byte, oldValue *nodeValue, ) (action, error) { @@ -805,15 +806,15 @@ func (t *TrieDB[H, Hasher]) replaceOldValue( switch oldv := storedValue.(type) { case valueRef[H]: hash := oldv.getHash() - if hash != (*new(H)) { - prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) - t.deathRow[string(prefixedKey)] = nil - } + prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) + t.deathRow[string(prefixedKey)] = hashPrefix[H]{Hash: hash, Prefix: prefix} + case newValueRef[H]: hash := oldv.getHash() - if hash != (*new(H)) { + if hash != nil { + hash := *hash prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) - t.deathRow[string(prefixedKey)] = nil + t.deathRow[string(prefixedKey)] = hashPrefix[H]{Hash: hash, Prefix: prefix} } } *oldValue = storedValue @@ -822,9 +823,8 @@ func (t *TrieDB[H, Hasher]) replaceOldValue( // lookup node in DB and add it in storage, return storage handle func (t *TrieDB[H, Hasher]) lookupNode(hash H, key nibbles.Prefix) (storageHandle, error) { var newNode = func() (Node, error) { - prefixedKey := append(key.JoinedBytes(), hash.Bytes()...) - encodedNode, err := t.db.Get(prefixedKey) - if err != nil { + encodedNode := t.db.Get(hash, hashdb.Prefix(key)) + if encodedNode == nil { return nil, ErrIncompleteDB } @@ -832,7 +832,7 @@ func (t *TrieDB[H, Hasher]) lookupNode(hash H, key nibbles.Prefix) (storageHandl return newNodeFromEncoded[H](hash, encodedNode, &t.storage) } - // We only check the `cache` for a node with `get_node` and don't insert + // We only check the cache for a node with GetNode and don't insert // the node if it wasn't there, because in substrate we only access the node while computing // a new trie (aka some branch). We assume that this node isn't that important // to have it being cached. @@ -868,22 +868,12 @@ func (t *TrieDB[H, Hasher]) commit() error { logger.Debug("Committing trie changes to db") logger.Debugf("%d nodes to remove from db", len(t.deathRow)) - dbBatch := t.db.NewBatch() - defer func() { - if err := dbBatch.Close(); err != nil { - logger.Criticalf("cannot close triedb commit batcher: %w", err) - } - }() - - for hash := range t.deathRow { - err := dbBatch.Del([]byte(hash)) - if err != nil { - return err - } + for _, hp := range t.deathRow { + t.db.Remove(hp.Hash, hashdb.Prefix(hp.Prefix)) } // Reset deathRow - t.deathRow = make(map[string]interface{}) + t.deathRow = make(map[string]hashPrefix[H]) var handle storageHandle switch h := t.rootHandle.(type) { @@ -911,17 +901,12 @@ func (t *TrieDB[H, Hasher]) commit() error { mov := k.AppendOptionalSliceAndNibble(partialKey, childIndex) switch n := node.(type) { case newNodeToEncode: - hash := (*new(Hasher)).Hash(n.value) - prefixedKey := append(k.Prefix().JoinedBytes(), hash.Bytes()...) - err := dbBatch.Put(prefixedKey, n.value) - if err != nil { - return nil, err - } + hash := t.db.Insert(hashdb.Prefix(k.Prefix()), n.value) t.cacheValue(k.Inner(), n.value, hash) k.DropLasts(mov) return HashChildReference[H]{hash}, nil case trieNodeToEncode: - result, err := t.commitChild(dbBatch, n.child, &k) + result, err := t.commitChild(n.child, &k) if err != nil { return nil, err } @@ -938,18 +923,13 @@ func (t *TrieDB[H, Hasher]) commit() error { return err } - hash := (*new(Hasher)).Hash(encodedNode) - err = dbBatch.Put(hash.Bytes(), encodedNode) - if err != nil { - return err - } + hash := t.db.Insert(hashdb.EmptyPrefix, encodedNode) t.rootHash = hash t.cacheNode(hash, encodedNode, fullKey) t.rootHandle = persisted[H]{t.rootHash} - // Flush all db changes - return dbBatch.Flush() + return nil case CachedStoredNode[H]: t.rootHash = stored.hash t.rootHandle = inMemory( @@ -963,7 +943,6 @@ func (t *TrieDB[H, Hasher]) commit() error { // Commit a node by hashing it and writing it to the db. func (t *TrieDB[H, Hasher]) commitChild( - dbBatch database.Batch, child NodeHandle, prefixKey *nibbles.NibbleSlice, ) (ChildReference, error) { @@ -991,18 +970,12 @@ func (t *TrieDB[H, Hasher]) commitChild( mov := prefixKey.AppendOptionalSliceAndNibble(partialKey, childIndex) switch n := node.(type) { case newNodeToEncode: - hash := (*new(Hasher)).Hash(n.value) - prefixedKey := append(prefixKey.Prefix().JoinedBytes(), hash.Bytes()...) - err := dbBatch.Put(prefixedKey, n.value) - if err != nil { - panic("inserting in db") - } - + hash := t.db.Insert(hashdb.Prefix(prefix.Prefix()), n.value) t.cacheValue(prefixKey.Inner(), n.value, hash) prefixKey.DropLasts(mov) return HashChildReference[H]{hash}, nil case trieNodeToEncode: - result, err := t.commitChild(dbBatch, n.child, prefixKey) + result, err := t.commitChild(n.child, prefixKey) if err != nil { return nil, err } @@ -1021,12 +994,7 @@ func (t *TrieDB[H, Hasher]) commitChild( // Not inlined node if len(encoded) >= (*new(H)).Length() { - hash := (*new(Hasher)).Hash(encoded) - prefixedKey := append(prefixKey.Prefix().JoinedBytes(), hash.Bytes()...) - err := dbBatch.Put(prefixedKey, encoded) - if err != nil { - return nil, err - } + hash := t.db.Insert(hashdb.Prefix(prefixKey.Prefix()), encoded) t.cacheNode(hash, encoded, fullKey) @@ -1065,7 +1033,7 @@ func cacheChildValues[H hash.Hash]( key.Append(*pk) } - if d := c.data(); d != nil { + if d := c.Data(); d != nil { if h := c.dataHash(); h != nil { *valuesToCache = append(*valuesToCache, valueToCache[H]{ KeyBytes: key.Inner(), @@ -1093,7 +1061,7 @@ func (t *TrieDB[H, Hasher]) cacheNode(hash H, encoded []byte, fullKey *nibbles.N if err != nil { return nil, err } - return newCachedNodeFromNode[H, Hasher](decoded) + return NewCachedNodeFromNode[H, Hasher](decoded) }) if err != nil { panic("Just encoded the node, so it should decode without any errors; qed") @@ -1102,7 +1070,7 @@ func (t *TrieDB[H, Hasher]) cacheNode(hash H, encoded []byte, fullKey *nibbles.N valuesToCache := []valueToCache[H]{} // If the given node has data attached, the fullKey is the full key to this node. if fullKey != nil { - if v := node.data(); v != nil { + if v := node.Data(); v != nil { if h := node.dataHash(); h != nil { valuesToCache = append(valuesToCache, valueToCache[H]{ KeyBytes: fullKey.Inner(), @@ -1145,7 +1113,7 @@ func (t *TrieDB[H, Hasher]) cacheValue(fullKey []byte, value []byte, hash H) { panic("this should never happen") } if node != nil { - val = node.data() + val = node.Data() } if val != nil { @@ -1181,3 +1149,46 @@ func GetWith[H hash.Hash, Hasher hash.Hasher[H], QueryItem any]( ) return lookup.Lookup(key) } + +func (t *TrieDB[H, Hasher]) LookupFirstDescendant(key []byte) (MerkleValue[H], error) { + lookup := NewTrieLookup[H, Hasher]( + t.db, t.rootHash, t.cache, t.recorder, func([]byte) any { return nil }, + ) + return lookup.LookupFirstDescendant(key, nibbles.NewNibbles(slices.Clone(key))) +} + +func (t *TrieDB[H, Hasher]) Iterator() (TrieIterator[H, *TrieItem], error) { + return NewTrieDBIterator(t) +} + +func (t *TrieDB[H, Hasher]) KeyIterator() (TrieIterator[H, []byte], error) { + return NewTrieDBKeyIterator(t) +} + +type MerkleValues[H any] interface { + NodeMerkleValue | HashMerkleValue[H] + MerkleValue[H] +} + +// Either the hash or value of a node depending on its size. +// +// If the size of the node value is bigger or equal than 32 bytes the hash is +// returned. +type MerkleValue[H any] interface { + isMerkleValue() +} + +// The merkle value is the node data itself when the +// node data byte length is less than or equal to 32 bytes. +// +// Note: The case of inline nodes. +type NodeMerkleValue []byte + +func (NodeMerkleValue) isMerkleValue() {} + +// The merkle value is the hash of the node. +type HashMerkleValue[H any] struct { + Hash H +} + +func (HashMerkleValue[H]) isMerkleValue() {} diff --git a/pkg/trie/triedb/triedb_iterator_test.go b/pkg/trie/triedb/triedb_iterator_test.go index fc050bcc05..c39a0c3ea4 100644 --- a/pkg/trie/triedb/triedb_iterator_test.go +++ b/pkg/trie/triedb/triedb_iterator_test.go @@ -6,6 +6,7 @@ package triedb import ( "testing" + "github.com/ChainSafe/gossamer/internal/database" "github.com/ChainSafe/gossamer/internal/primitives/core/hash" "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" @@ -13,6 +14,12 @@ import ( "github.com/stretchr/testify/assert" ) +func newTestDB(t assert.TestingT) database.Table { + db, err := database.NewPebble("", true) + assert.NoError(t, err) + return database.NewTable(db, "trie") +} + func TestIterator(t *testing.T) { db := newTestDB(t) inMemoryTrie := inmemory.NewEmptyTrie() @@ -37,11 +44,11 @@ func TestIterator(t *testing.T) { root, err := inMemoryTrie.Hash() assert.NoError(t, err) - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for k, v := range entries { - err := trieDB.Put([]byte(k), v) + err := trieDB.Set([]byte(k), v) assert.NoError(t, err) } assert.NoError(t, trieDB.commit()) @@ -50,7 +57,7 @@ func TestIterator(t *testing.T) { assert.Equal(t, root.ToBytes(), trieDB.rootHash.Bytes()) t.Run("iterate_over_all_entries", func(t *testing.T) { - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) assert.NoError(t, err) expected := inMemoryTrie.NextKey([]byte{}) @@ -69,7 +76,7 @@ func TestIterator(t *testing.T) { }) t.Run("iterate_after_seeking", func(t *testing.T) { - iter, err := newRawIterator(trieDB) + iter, err := NewTrieDBRawIterator(trieDB) assert.NoError(t, err) found, err := iter.seek([]byte("not"), true) diff --git a/pkg/trie/triedb/triedb_test.go b/pkg/trie/triedb/triedb_test.go index 2642275415..ff7a2f9262 100644 --- a/pkg/trie/triedb/triedb_test.go +++ b/pkg/trie/triedb/triedb_test.go @@ -8,6 +8,7 @@ import ( "slices" "testing" + hashdb "github.com/ChainSafe/gossamer/internal/hash-db" "github.com/ChainSafe/gossamer/internal/primitives/core/hash" "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" @@ -352,23 +353,27 @@ func TestInsertions(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() // Setup trie - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.trieEntries { - require.NoError(t, trie.Put(entry.Key, entry.Value)) + require.NoError(t, trie.Set(entry.Key, entry.Value)) } // Add new key-value pair - err := trie.Put(testCase.key, testCase.value) + err := trie.Set(testCase.key, testCase.value) require.NoError(t, err) if !testCase.dontCheck { // Check values for keys for _, entry := range testCase.trieEntries { - require.Equal(t, entry.Value, trie.Get(entry.Key)) + val, err := trie.Get(entry.Key) + require.NoError(t, err) + require.Equal(t, entry.Value, val) } } - require.Equal(t, testCase.value, trie.Get(testCase.key)) + val, err := trie.Get(testCase.key) + require.NoError(t, err) + require.Equal(t, testCase.value, val) // Check we have what we expect assert.Equal(t, testCase.stored.nodes, trie.storage.nodes) @@ -405,7 +410,7 @@ func TestDeletes(t *testing.T) { "empty_trie": { key: []byte{1}, expected: nodeStorage[hash.H256]{ - nodes: []StoredNode{nil}, + nodes: []StoredNode{NewStoredNode{Empty{}}}, }, }, "delete_leaf": { @@ -417,7 +422,7 @@ func TestDeletes(t *testing.T) { }, key: []byte{1}, expected: nodeStorage[hash.H256]{ - nodes: []StoredNode{nil}, + nodes: []StoredNode{NewStoredNode{Empty{}}}, }, }, "delete_branch": { @@ -434,7 +439,7 @@ func TestDeletes(t *testing.T) { key: []byte{1}, expected: nodeStorage[hash.H256]{ nodes: []StoredNode{ - nil, + NewStoredNode{Empty{}}, NewStoredNode{ Leaf[hash.H256]{ partialKey: nodeKey{Data: []byte{1, 0}}, @@ -489,11 +494,11 @@ func TestDeletes(t *testing.T) { t.Parallel() // Setup trie - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.trieEntries { - assert.NoError(t, trie.Put(entry.Key, entry.Value)) + assert.NoError(t, trie.Set(entry.Key, entry.Value)) } // Remove key @@ -576,7 +581,7 @@ func TestInsertAfterDelete(t *testing.T) { t.Parallel() // Setup trie - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.trieEntries { @@ -603,56 +608,56 @@ func TestDBCommits(t *testing.T) { t.Run("commit_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) - err := trie.Put([]byte("leaf"), []byte("leafvalue")) + err := trie.Set([]byte("leaf"), []byte("leafvalue")) assert.NoError(t, err) err = trie.commit() assert.NoError(t, err) // 1 leaf - assert.Len(t, inmemoryDB.data, 1) + assert.Len(t, inmemoryDB.Keys(), 1) // Get values using lazy loading - value := trie.Get([]byte("leaf")) + value, _ := trie.Get([]byte("leaf")) assert.Equal(t, []byte("leafvalue"), value) }) t.Run("commit_branch_and_inlined_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) - err := trie.Put([]byte("branchleaf"), []byte("leafvalue")) + err := trie.Set([]byte("branchleaf"), []byte("leafvalue")) assert.NoError(t, err) - err = trie.Put([]byte("branch"), []byte("branchvalue")) + err = trie.Set([]byte("branch"), []byte("branchvalue")) assert.NoError(t, err) err = trie.commit() assert.NoError(t, err) // 1 branch with its inlined leaf - assert.Len(t, inmemoryDB.data, 1) + assert.Len(t, inmemoryDB.Keys(), 1) // Get values using lazy loading - value := trie.Get([]byte("branch")) + value, _ := trie.Get([]byte("branch")) assert.Equal(t, []byte("branchvalue"), value) - value = trie.Get([]byte("branchleaf")) + value, _ = trie.Get([]byte("branchleaf")) assert.Equal(t, []byte("leafvalue"), value) }) t.Run("commit_branch_and_hashed_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) - err := tr.Put([]byte("branchleaf"), make([]byte, 40)) + err := tr.Set([]byte("branchleaf"), make([]byte, 40)) assert.NoError(t, err) - err = tr.Put([]byte("branch"), []byte("branchvalue")) + err = tr.Set([]byte("branch"), []byte("branchvalue")) assert.NoError(t, err) err = tr.commit() @@ -660,23 +665,23 @@ func TestDBCommits(t *testing.T) { // 1 branch with 1 hashed leaf child // 1 hashed leaf - assert.Len(t, inmemoryDB.data, 2) + assert.Len(t, inmemoryDB.Keys(), 2) // Get values using lazy loading - value := tr.Get([]byte("branch")) + value, _ := tr.Get([]byte("branch")) assert.Equal(t, []byte("branchvalue"), value) - value = tr.Get([]byte("branchleaf")) + value, _ = tr.Get([]byte("branchleaf")) assert.Equal(t, make([]byte, 40), value) }) t.Run("commit_leaf_with_hashed_value", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) - err := tr.Put([]byte("leaf"), make([]byte, 40)) + err := tr.Set([]byte("leaf"), make([]byte, 40)) assert.NoError(t, err) err = tr.commit() @@ -684,21 +689,21 @@ func TestDBCommits(t *testing.T) { // 1 hashed leaf with hashed value // 1 hashed value - assert.Len(t, inmemoryDB.data, 2) + assert.Len(t, inmemoryDB.Keys(), 2) // Get values using lazy loading - value := tr.Get([]byte("leaf")) + value, _ := tr.Get([]byte("leaf")) assert.Equal(t, make([]byte, 40), value) }) t.Run("commit_leaf_with_hashed_value_then_remove_it", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) - err := tr.Put([]byte("leaf"), make([]byte, 40)) + err := tr.Set([]byte("leaf"), make([]byte, 40)) assert.NoError(t, err) err = tr.commit() @@ -706,25 +711,25 @@ func TestDBCommits(t *testing.T) { // 1 hashed leaf with hashed value // 1 hashed value - assert.Len(t, inmemoryDB.data, 2) + assert.Len(t, inmemoryDB.Keys(), 2) // Get values using lazy loading err = tr.Delete([]byte("leaf")) assert.NoError(t, err) tr.commit() - assert.Len(t, inmemoryDB.data, 0) + assert.Len(t, inmemoryDB.Keys(), 0) }) t.Run("commit_branch_and_hashed_leaf_with_hashed_value", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) - err := tr.Put([]byte("branchleaf"), make([]byte, 40)) + err := tr.Set([]byte("branchleaf"), make([]byte, 40)) assert.NoError(t, err) - err = tr.Put([]byte("branch"), []byte("branchvalue")) + err = tr.Set([]byte("branch"), []byte("branchvalue")) assert.NoError(t, err) err = tr.commit() @@ -733,25 +738,25 @@ func TestDBCommits(t *testing.T) { // 1 branch with 1 hashed leaf child // 1 hashed leaf with hashed value // 1 hashed value - assert.Len(t, inmemoryDB.data, 3) + assert.Len(t, inmemoryDB.Keys(), 3) // Get values using lazy loading - value := tr.Get([]byte("branch")) + value, _ := tr.Get([]byte("branch")) assert.Equal(t, []byte("branchvalue"), value) - value = tr.Get([]byte("branchleaf")) + value, _ = tr.Get([]byte("branchleaf")) assert.Equal(t, make([]byte, 40), value) }) t.Run("commit_branch_and_hashed_leaf_with_hashed_value_then_delete_it", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) - err := tr.Put([]byte("branchleaf"), make([]byte, 40)) + err := tr.Set([]byte("branchleaf"), make([]byte, 40)) assert.NoError(t, err) - err = tr.Put([]byte("branch"), []byte("branchvalue")) + err = tr.Set([]byte("branch"), []byte("branchvalue")) assert.NoError(t, err) err = tr.commit() @@ -760,7 +765,7 @@ func TestDBCommits(t *testing.T) { // 1 branch with 1 hashed leaf child // 1 hashed leaf with hashed value // 1 hashed value - assert.Len(t, inmemoryDB.data, 3) + assert.Len(t, inmemoryDB.Keys(), 3) err = tr.Delete([]byte("branchleaf")) assert.NoError(t, err) @@ -769,18 +774,18 @@ func TestDBCommits(t *testing.T) { // 1 branch transformed in a leaf // previous leaf was deleted // previous hashed (V1) value was deleted too - assert.Len(t, inmemoryDB.data, 1) + assert.Len(t, inmemoryDB.Keys(), 1) }) t.Run("commit_branch_with_leaf_then_delete_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + inmemoryDB := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) - err := trie.Put([]byte("branchleaf"), []byte("leafvalue")) + err := trie.Set([]byte("branchleaf"), []byte("leafvalue")) assert.NoError(t, err) - err = trie.Put([]byte("branch"), []byte("branchvalue")) + err = trie.Set([]byte("branch"), []byte("branchvalue")) assert.NoError(t, err) err = trie.commit() @@ -794,11 +799,11 @@ func TestDBCommits(t *testing.T) { // 1 branch transformed in a leaf // previous leaf was deleted - assert.Len(t, inmemoryDB.data, 1) + assert.Len(t, inmemoryDB.Keys(), 1) - v := trie.Get([]byte("branch")) + v, _ := trie.Get([]byte("branch")) assert.Equal(t, []byte("branchvalue"), v) - v = trie.Get([]byte("branchleaf")) + v, _ = trie.Get([]byte("branchleaf")) assert.Nil(t, v) }) } @@ -818,12 +823,12 @@ func Test_TrieDB(t *testing.T) { } // Add some initial data to the trie - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trie.SetVersion(version) for _, entry := range keyValues[:1] { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -836,12 +841,12 @@ func Test_TrieDB(t *testing.T) { overlay := db.Clone() newRoot := root { - trie := NewTrieDB(newRoot, overlay, + trie := NewTrieDB(newRoot, &overlay, WithRecorder[hash.H256, runtime.BlakeTwo256](recorder), ) trie.SetVersion(version) for _, entry := range keyValues[1:] { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -849,10 +854,11 @@ func Test_TrieDB(t *testing.T) { newRoot = trie.rootHash } - partialDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + partialDB := NewMemoryDB() for _, record := range recorder.Drain() { - key := runtime.BlakeTwo256{}.Hash(record.Data).Bytes() - require.NoError(t, partialDB.Put(key, record.Data)) + // key := runtime.BlakeTwo256{}.Hash(record.Data).Bytes() + // require.NoError(t, partialDB.Set(key, record.Data)) + partialDB.Insert(hashdb.EmptyPrefix, record.Data) } // Replay the it, but this time we use the proof. @@ -861,7 +867,7 @@ func Test_TrieDB(t *testing.T) { trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, partialDB) trie.SetVersion(version) for _, entry := range keyValues[1:] { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -887,12 +893,12 @@ func Test_TrieDB(t *testing.T) { } // Add some initial data to the trie - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trie.SetVersion(version) for _, entry := range keyValues[:1] { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -905,7 +911,7 @@ func Test_TrieDB(t *testing.T) { trie := NewTrieDB(trie.rootHash, db, WithCache[hash.H256, runtime.BlakeTwo256](cache)) trie.SetVersion(version) // Only read one entry, using GetWith which should cache the root node - _, err := GetWith(trie, keyValues[0].key, func([]byte) any { return nil }) + _, err := GetWith(trie, keyValues[0].key, func(v []byte) []byte { return v }) assert.NoError(t, err) } @@ -918,13 +924,13 @@ func Test_TrieDB(t *testing.T) { overlay := db.Clone() var newRoot hash.H256 { - trie := NewTrieDB(trie.rootHash, overlay, + trie := NewTrieDB(trie.rootHash, &overlay, WithCache[hash.H256, runtime.BlakeTwo256](cache), WithRecorder[hash.H256, runtime.BlakeTwo256](recorder), ) trie.SetVersion(version) for _, entry := range keyValues[1:] { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -940,10 +946,11 @@ func Test_TrieDB(t *testing.T) { }, cachedValue) } - partialDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + partialDB := NewMemoryDB() for _, record := range recorder.Drain() { - key := runtime.BlakeTwo256{}.Hash(record.Data).Bytes() - require.NoError(t, partialDB.Put(key, record.Data)) + // key := runtime.BlakeTwo256{}.Hash(record.Data).Bytes() + // require.NoError(t, partialDB.Set(key, record.Data)) + partialDB.Insert(hashdb.EmptyPrefix, record.Data) } // Replay the it, but this time we use the proof. @@ -952,7 +959,7 @@ func Test_TrieDB(t *testing.T) { trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, partialDB) trie.SetVersion(version) for _, entry := range keyValues[1:] { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -981,7 +988,7 @@ func Test_TrieDB(t *testing.T) { cache := NewTestTrieCache[hash.H256]() recorder := NewRecorder[hash.H256]() - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db, @@ -992,7 +999,7 @@ func Test_TrieDB(t *testing.T) { // Add all values for _, entry := range keyValues { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } // Remove only the last 2 elements @@ -1055,7 +1062,7 @@ func Test_TrieDB(t *testing.T) { cache := NewTestTrieCache[hash.H256]() recorder := NewRecorder[hash.H256]() - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db, @@ -1066,7 +1073,7 @@ func Test_TrieDB(t *testing.T) { // Add all values for _, entry := range keyValues { - require.NoError(t, trie.Put(slices.Clone(entry.key), entry.value)) + require.NoError(t, trie.Set(slices.Clone(entry.key), entry.value)) } err := trie.commit() @@ -1120,7 +1127,7 @@ func Test_TrieDB(t *testing.T) { cache := NewTestTrieCache[hash.H256]() recorder := NewRecorder[hash.H256]() - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db, @@ -1131,7 +1138,7 @@ func Test_TrieDB(t *testing.T) { // Add all values for _, entry := range keyValues { - require.NoError(t, trie.Put(slices.Clone(entry.key), entry.value)) + require.NoError(t, trie.Set(slices.Clone(entry.key), entry.value)) } err := trie.commit() @@ -1172,16 +1179,16 @@ func Test_TrieDB(t *testing.T) { ) trie.SetVersion(version) - require.NoError(t, trie.Put([]byte("AAB"), []byte{1, 1, 1, 1})) + require.NoError(t, trie.Set([]byte("AAB"), []byte{1, 1, 1, 1})) - val := trie.Get([]byte("AAB")) + val, _ := trie.Get([]byte("AAB")) require.NotNil(t, val) require.Equal(t, []byte{1, 1, 1, 1}, val) err := trie.commit() require.NoError(t, err) - val = trie.Get([]byte("AAB")) + val, _ = trie.Get([]byte("AAB")) require.NotNil(t, val) require.Equal(t, []byte{1, 1, 1, 1}, val) @@ -1211,7 +1218,7 @@ func Test_TrieDB(t *testing.T) { {[]byte("AC"), bytes.Repeat([]byte{8}, 8)}, } - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) @@ -1219,7 +1226,7 @@ func Test_TrieDB(t *testing.T) { // Add all values for _, entry := range keyValues { - require.NoError(t, trie.Put(slices.Clone(entry.key), entry.value)) + require.NoError(t, trie.Set(slices.Clone(entry.key), entry.value)) } err := trie.commit() @@ -1256,7 +1263,7 @@ func Test_TrieDB(t *testing.T) { // get all keys again from cache, by passing in brand new db { - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trie := NewTrieDB(root, db, WithCache[hash.H256, runtime.BlakeTwo256](cache), ) @@ -1290,7 +1297,7 @@ func Test_TrieDB(t *testing.T) { {[]byte("AC"), bytes.Repeat([]byte{8}, 8)}, } - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) @@ -1298,7 +1305,7 @@ func Test_TrieDB(t *testing.T) { // Add all values for _, entry := range keyValues { - require.NoError(t, trie.Put(slices.Clone(entry.key), entry.value)) + require.NoError(t, trie.Set(slices.Clone(entry.key), entry.value)) } err := trie.commit() @@ -1326,7 +1333,7 @@ func Test_TrieDB(t *testing.T) { // get all keys again from cache, by passing in brand new db { - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() trie := NewTrieDB(root, db, WithCache[hash.H256, runtime.BlakeTwo256](cache), ) @@ -1371,13 +1378,13 @@ func Test_TrieDB(t *testing.T) { {[]byte("BC"), bytes.Repeat([]byte{4}, 64)}, } - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trie.SetVersion(version) for _, entry := range keyValues { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -1467,13 +1474,13 @@ func Test_TrieDB(t *testing.T) { {[]byte("BC"), bytes.Repeat([]byte{4}, 64)}, } - db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + db := NewMemoryDB() var root hash.H256 { trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) trie.SetVersion(version) for _, entry := range keyValues { - require.NoError(t, trie.Put(entry.key, entry.value)) + require.NoError(t, trie.Set(entry.key, entry.value)) } err := trie.commit() require.NoError(t, err) @@ -1530,4 +1537,297 @@ func Test_TrieDB(t *testing.T) { } }) + t.Run("test_merkle_value_internal", func(t *testing.T) { + for _, version := range []trie.TrieLayout{ + trie.V0, + trie.V1, + } { + t.Run(version.String(), func(t *testing.T) { + keyValues := []struct { + key []byte + value []byte + }{ + {[]byte("A"), bytes.Repeat([]byte{1}, 64)}, + {[]byte("AA"), bytes.Repeat([]byte{2}, 64)}, + {[]byte("AAAA"), bytes.Repeat([]byte{3}, 64)}, + {[]byte("AAB"), bytes.Repeat([]byte{4}, 64)}, + {[]byte("AABBBB"), bytes.Repeat([]byte{4}, 1)}, + {[]byte("AB"), bytes.Repeat([]byte{5}, 1)}, + {[]byte("B"), bytes.Repeat([]byte{6}, 1)}, + } + + db := NewMemoryDB() + var root hash.H256 + { + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trie.SetVersion(version) + for _, entry := range keyValues { + require.NoError(t, trie.Set(entry.key, entry.value)) + } + err := trie.commit() + require.NoError(t, err) + require.NotEmpty(t, trie.rootHash) + root = trie.rootHash + } + + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, db) + trie.SetVersion(version) + for _, entry := range keyValues { + h, err := trie.LookupFirstDescendant(entry.key) + require.NoError(t, err) + require.NotNil(t, h) + } + + // Key is not present and has no descedant, but shares a prefix. + for _, key := range []string{ + "AAAAX", "AABX", "ABX", "AABBBX", "BX", "AC", "AAAAX", + "C", // Key shares the first nibble with b"A". + } { + mv, err := trie.LookupFirstDescendant([]byte(key)) + require.NoError(t, err) + require.Nil(t, mv) + } + + // Key not present, but has a descendent. + hash, err := trie.LookupFirstDescendant([]byte("AAA")) + require.NoError(t, err) + require.NotNil(t, hash) + expected, err := trie.LookupFirstDescendant([]byte("AAAA")) + require.NoError(t, err) + require.NotNil(t, expected) + require.Equal(t, expected, hash) + + hash, err = trie.LookupFirstDescendant([]byte("AABB")) + require.NoError(t, err) + require.NotNil(t, hash) + expected, err = trie.LookupFirstDescendant([]byte("AABBBB")) + require.NoError(t, err) + require.NotNil(t, expected) + require.Equal(t, expected, hash) + + hash, err = trie.LookupFirstDescendant([]byte("AABBB")) + require.NoError(t, err) + require.NotNil(t, hash) + expected, err = trie.LookupFirstDescendant([]byte("AABBBB")) + require.NoError(t, err) + require.NotNil(t, expected) + require.Equal(t, expected, hash) + + // Prefix AABB in between AAB and AABBBB, but has different ending char. + hash, err = trie.LookupFirstDescendant([]byte("AABBX")) + require.NoError(t, err) + require.Nil(t, hash) + }) + } + }) + + t.Run("test_merkle_value_branches_internal", func(t *testing.T) { + for _, version := range []trie.TrieLayout{ + trie.V0, + trie.V1, + } { + t.Run(version.String(), func(t *testing.T) { + keyValues := []struct { + key []byte + value []byte + }{ + {[]byte("AAAA"), bytes.Repeat([]byte{1}, 64)}, + {[]byte("AABA"), bytes.Repeat([]byte{2}, 64)}, + } + + db := NewMemoryDB() + var root hash.H256 + { + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trie.SetVersion(version) + for _, entry := range keyValues { + require.NoError(t, trie.Set(entry.key, entry.value)) + } + err := trie.commit() + require.NoError(t, err) + require.NotEmpty(t, trie.rootHash) + root = trie.rootHash + } + + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, db) + trie.SetVersion(version) + + // The hash is returned from the branch node. + hash, err := trie.LookupFirstDescendant([]byte("A")) + require.NoError(t, err) + require.NotNil(t, hash) + aaaa_hash, err := trie.LookupFirstDescendant([]byte("AAAA")) + require.NoError(t, err) + require.NotNil(t, aaaa_hash) + aaba_hash, err := trie.LookupFirstDescendant([]byte("AABA")) + require.NoError(t, err) + require.NotNil(t, aaba_hash) + + assert.NotEqual(t, hash, aaaa_hash) + assert.NotEqual(t, hash, aaba_hash) + }) + } + }) + + t.Run("test_merkle_value_empty_trie_internal", func(t *testing.T) { + for _, version := range []trie.TrieLayout{ + trie.V0, + trie.V1, + } { + t.Run(version.String(), func(t *testing.T) { + keyValues := []struct { + key []byte + value []byte + }{ + // test for both empty cases + {[]byte{}, []byte{}}, + {nil, nil}, + {[]byte{}, nil}, + {nil, []byte{}}, + } + + for _, entry := range keyValues { + db := NewMemoryDB() + var root hash.H256 + { + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trie.SetVersion(version) + + require.NoError(t, trie.Set(entry.key, entry.value)) + + err := trie.commit() + require.NoError(t, err) + require.NotEmpty(t, trie.rootHash) + // Valid state root. + root = trie.rootHash + } + + // Data set is empty. + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, db) + trie.SetVersion(version) + hash, err := trie.LookupFirstDescendant([]byte("A")) + require.NoError(t, err) + require.Nil(t, hash) + + hash, err = trie.LookupFirstDescendant([]byte("A")) + require.NoError(t, err) + require.Nil(t, hash) + + hash, err = trie.LookupFirstDescendant([]byte("AA")) + require.NoError(t, err) + require.Nil(t, hash) + + hash, err = trie.LookupFirstDescendant([]byte("AAA")) + require.NoError(t, err) + require.Nil(t, hash) + + hash, err = trie.LookupFirstDescendant([]byte("AAAA")) + require.NoError(t, err) + require.Nil(t, hash) + } + }) + } + }) + + t.Run("test_merkle_value_modification_internal", func(t *testing.T) { + for _, version := range []trie.TrieLayout{ + trie.V0, + trie.V1, + } { + t.Run(version.String(), func(t *testing.T) { + keyValues := []struct { + key []byte + value []byte + }{ + {[]byte("AAAA"), bytes.Repeat([]byte{1}, 64)}, + {[]byte("AABA"), bytes.Repeat([]byte{2}, 64)}, + } + + db := NewMemoryDB() + var root hash.H256 + { + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trie.SetVersion(version) + for _, entry := range keyValues { + require.NoError(t, trie.Set(entry.key, entry.value)) + } + err := trie.commit() + require.NoError(t, err) + require.NotEmpty(t, trie.rootHash) + root = trie.rootHash + } + + var ( + aHashLHS MerkleValue[hash.H256] + aaaaHashLHS MerkleValue[hash.H256] + aabaHashLHS MerkleValue[hash.H256] + ) + { + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, db) + trie.SetVersion(version) + + // The hash is returned from the branch node. + hash, err := trie.LookupFirstDescendant([]byte("A")) + require.NoError(t, err) + require.NotNil(t, hash) + aaaa_hash, err := trie.LookupFirstDescendant([]byte("AAAA")) + require.NoError(t, err) + require.NotNil(t, aaaa_hash) + aaba_hash, err := trie.LookupFirstDescendant([]byte("AABA")) + require.NoError(t, err) + require.NotNil(t, aaba_hash) + + // Ensure the hash is not from any leaf. + assert.NotEqual(t, hash, aaaa_hash) + assert.NotEqual(t, hash, aaba_hash) + + aHashLHS = hash + aaaaHashLHS = aaaa_hash + aabaHashLHS = aaba_hash + } + + var ( + aHashRHS MerkleValue[hash.H256] + aaaaHashRHS MerkleValue[hash.H256] + aabaHashRHS MerkleValue[hash.H256] + ) + // Modify AABA and expect AAAA to return the same merkle value + { + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256](root, db) + trie.SetVersion(version) + require.NoError(t, trie.Set([]byte("AABA"), bytes.Repeat([]byte{3}, 64))) + err := trie.commit() + require.NoError(t, err) + require.NotEmpty(t, trie.rootHash) + require.NotEqual(t, root, trie.rootHash) + + // The hash is returned from the branch node. + hash, err := trie.LookupFirstDescendant([]byte("A")) + require.NoError(t, err) + require.NotNil(t, hash) + aaaa_hash, err := trie.LookupFirstDescendant([]byte("AAAA")) + require.NoError(t, err) + require.NotNil(t, aaaa_hash) + aaba_hash, err := trie.LookupFirstDescendant([]byte("AABA")) + require.NoError(t, err) + require.NotNil(t, aaba_hash) + + // Ensure the hash is not from any leaf. + require.NotEqual(t, hash, aaaa_hash) + require.NotEqual(t, hash, aaba_hash) + + aHashRHS = hash + aaaaHashRHS = aaaa_hash + aabaHashRHS = aaba_hash + } + + // AAAA was not modified. + require.Equal(t, aaaaHashLHS, aaaaHashRHS) + // Changes to AABA must propagate to the root. + require.NotEqual(t, aabaHashLHS, aabaHashRHS) + require.NotEqual(t, aHashLHS, aHashRHS) + + }) + } + }) } diff --git a/pkg/trie/triedb/util_test.go b/pkg/trie/triedb/util_test.go index d847113875..fe190dbb4f 100644 --- a/pkg/trie/triedb/util_test.go +++ b/pkg/trie/triedb/util_test.go @@ -4,102 +4,18 @@ package triedb import ( - "bytes" - "strings" - - "github.com/ChainSafe/gossamer/internal/database" + memorydb "github.com/ChainSafe/gossamer/internal/memory-db" chash "github.com/ChainSafe/gossamer/internal/primitives/core/hash" - "github.com/ChainSafe/gossamer/pkg/trie/db" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" - "github.com/stretchr/testify/assert" - "golang.org/x/exp/maps" ) -// MemoryDB is an in-memory implementation of the Database interface backed by a -// map. It uses blake2b as hashing algorithm -type MemoryDB struct { - data map[string][]byte - hashedNullNode string - nullNodeData []byte -} - -func NewMemoryDB[H hash.Hash, Hasher hash.Hasher[H]](data []byte) *MemoryDB { - return &MemoryDB{ - data: make(map[string][]byte), - hashedNullNode: string((*new(Hasher)).Hash(data).Bytes()), - nullNodeData: data, - } -} - -func (db *MemoryDB) emplace(key []byte, value []byte) { - if bytes.Equal(value, db.nullNodeData) { - return - } - - db.data[string(key)] = value -} - -func (db *MemoryDB) Get(key []byte) ([]byte, error) { - dbKey := string(key) - if strings.Contains(dbKey, db.hashedNullNode) { - return db.nullNodeData, nil - } - if value, has := db.data[dbKey]; has { - return value, nil - } - - return nil, nil -} - -func (db *MemoryDB) Put(key []byte, value []byte) error { - db.emplace(key, value) - return nil -} - -func (db *MemoryDB) Del(key []byte) error { - dbKey := string(key) - delete(db.data, dbKey) - return nil -} - -func (db *MemoryDB) Flush() error { - return nil -} - -func (db *MemoryDB) NewBatch() database.Batch { - return &MemoryBatch{db} -} - -func (db *MemoryDB) Clone() *MemoryDB { - return &MemoryDB{ - data: maps.Clone(db.data), - hashedNullNode: db.hashedNullNode, - nullNodeData: db.nullNodeData, - } -} - -var _ db.RWDatabase = &MemoryDB{} - -type MemoryBatch struct { - *MemoryDB -} - -func (b *MemoryBatch) Close() error { - return nil -} - -func (*MemoryBatch) Reset() {} - -func (b *MemoryBatch) ValueSize() int { - return 1 -} - -var _ database.Batch = &MemoryBatch{} - -func newTestDB(t assert.TestingT) database.Table { - db, err := database.NewPebble("", true) - assert.NoError(t, err) - return database.NewTable(db, "trie") +func NewMemoryDB() *memorydb.MemoryDB[ + chash.H256, runtime.BlakeTwo256, chash.H256, memorydb.HashKey[chash.H256]] { + db := memorydb.NewMemoryDB[ + chash.H256, runtime.BlakeTwo256, chash.H256, memorydb.HashKey[chash.H256], + ]([]byte{0}) + return &db } type TestTrieCache[H hash.Hash] struct {