diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go
index af2efdccace7..295402d41b32 100644
--- a/triedb/pathdb/database.go
+++ b/triedb/pathdb/database.go
@@ -450,7 +450,7 @@ func (db *Database) Enable(root common.Hash) error {
// Re-construct a new disk layer backed by persistent state
// and schedule the state snapshot generation if it's permitted.
- db.tree.reset(generateSnapshot(db, root))
+ db.tree.init(generateSnapshot(db, root))
log.Info("Rebuilt trie database", "root", root)
return nil
}
@@ -491,7 +491,7 @@ func (db *Database) Recover(root common.Hash) error {
// reset layer with newly created disk layer. It must be
// done after each revert operation, otherwise the new
// disk layer won't be accessible from outside.
- db.tree.reset(dl)
+ db.tree.init(dl)
}
rawdb.DeleteTrieJournal(db.diskdb)
_, err := truncateFromHead(db.diskdb, db.freezer, dl.stateID())
diff --git a/triedb/pathdb/difflayer.go b/triedb/pathdb/difflayer.go
index c06026b6cac8..870842472790 100644
--- a/triedb/pathdb/difflayer.go
+++ b/triedb/pathdb/difflayer.go
@@ -156,7 +156,7 @@ func (dl *diffLayer) update(root common.Hash, id uint64, block uint64, nodes *no
}
// persist flushes the diff layer and all its parent layers to disk layer.
-func (dl *diffLayer) persist(force bool) (layer, error) {
+func (dl *diffLayer) persist(force bool) (*diskLayer, error) {
if parent, ok := dl.parentLayer().(*diffLayer); ok {
// Hold the lock to prevent any read operation until the new
// parent is linked correctly.
@@ -183,7 +183,7 @@ func (dl *diffLayer) size() uint64 {
// diffToDisk merges a bottom-most diff into the persistent disk layer underneath
// it. The method will panic if called onto a non-bottom-most diff layer.
-func diffToDisk(layer *diffLayer, force bool) (layer, error) {
+func diffToDisk(layer *diffLayer, force bool) (*diskLayer, error) {
disk, ok := layer.parentLayer().(*diskLayer)
if !ok {
panic(fmt.Sprintf("unknown layer type: %T", layer.parentLayer()))
diff --git a/triedb/pathdb/disklayer.go b/triedb/pathdb/disklayer.go
index 26c1064209b6..ab2bb1a515e4 100644
--- a/triedb/pathdb/disklayer.go
+++ b/triedb/pathdb/disklayer.go
@@ -87,15 +87,6 @@ func (dl *diskLayer) setGenerator(generator *generator) {
dl.generator = generator
}
-// isStale return whether this layer has become stale (was flattened across) or if
-// it's still live.
-func (dl *diskLayer) isStale() bool {
- dl.lock.RLock()
- defer dl.lock.RUnlock()
-
- return dl.stale
-}
-
// markStale sets the stale flag as true.
func (dl *diskLayer) markStale() {
dl.lock.Lock()
diff --git a/triedb/pathdb/layertree.go b/triedb/pathdb/layertree.go
index cf6b14e744ef..2ca9135d9a6f 100644
--- a/triedb/pathdb/layertree.go
+++ b/triedb/pathdb/layertree.go
@@ -32,29 +32,41 @@ import (
// thread-safe to use. However, callers need to ensure the thread-safety
// of the referenced layer by themselves.
type layerTree struct {
- lock sync.RWMutex
- layers map[common.Hash]layer
+ base *diskLayer
+ layers map[common.Hash]layer
+ descendants map[common.Hash]map[common.Hash]struct{}
+ lookup *lookup
+ lock sync.RWMutex
}
// newLayerTree constructs the layerTree with the given head layer.
func newLayerTree(head layer) *layerTree {
tree := new(layerTree)
- tree.reset(head)
+ tree.init(head)
return tree
}
-// reset initializes the layerTree by the given head layer.
-// All the ancestors will be iterated out and linked in the tree.
-func (tree *layerTree) reset(head layer) {
+// init initializes the layerTree by the given head layer.
+func (tree *layerTree) init(head layer) {
tree.lock.Lock()
defer tree.lock.Unlock()
- var layers = make(map[common.Hash]layer)
- for head != nil {
- layers[head.rootHash()] = head
- head = head.parentLayer()
+ current := head
+ tree.layers = make(map[common.Hash]layer)
+ tree.descendants = make(map[common.Hash]map[common.Hash]struct{})
+
+ for {
+ tree.layers[current.rootHash()] = current
+ tree.fillAncestors(current)
+
+ parent := current.parentLayer()
+ if parent == nil {
+ break
+ }
+ current = parent
}
- tree.layers = layers
+ tree.base = current.(*diskLayer) // panic if it's not a disk layer
+ tree.lookup = newLookup(head, tree.isDescendant)
}
// get retrieves a layer belonging to the given state root.
@@ -65,6 +77,43 @@ func (tree *layerTree) get(root common.Hash) layer {
return tree.layers[types.TrieRootHash(root)]
}
+// isDescendant returns whether the specified layer with given root is a
+// descendant of a specific ancestor.
+//
+// This function assumes the read lock has been held.
+func (tree *layerTree) isDescendant(root common.Hash, ancestor common.Hash) bool {
+ subset := tree.descendants[ancestor]
+ if subset == nil {
+ return false
+ }
+ _, ok := subset[root]
+ return ok
+}
+
+// fillAncestors identifies the ancestors of the given layer and populates the
+// descendants set. The ancestors include the diff layers below the supplied
+// layer and also the disk layer.
+//
+// This function assumes the write lock has been held.
+func (tree *layerTree) fillAncestors(layer layer) {
+ hash := layer.rootHash()
+ for {
+ parent := layer.parentLayer()
+ if parent == nil {
+ break
+ }
+ layer = parent
+
+ phash := parent.rootHash()
+ subset := tree.descendants[phash]
+ if subset == nil {
+ subset = make(map[common.Hash]struct{})
+ tree.descendants[phash] = subset
+ }
+ subset[hash] = struct{}{}
+ }
+}
+
// forEach iterates the stored layers inside and applies the
// given callback on them.
func (tree *layerTree) forEach(onLayer func(layer)) {
@@ -103,8 +152,11 @@ func (tree *layerTree) add(root common.Hash, parentRoot common.Hash, block uint6
l := parent.update(root, parent.stateID()+1, block, newNodeSet(nodes.Flatten()), states)
tree.lock.Lock()
+ defer tree.lock.Unlock()
+
tree.layers[l.rootHash()] = l
- tree.lock.Unlock()
+ tree.fillAncestors(l)
+ tree.lookup.addLayer(l)
return nil
}
@@ -130,8 +182,14 @@ func (tree *layerTree) cap(root common.Hash, layers int) error {
if err != nil {
return err
}
- // Replace the entire layer tree with the flat base
- tree.layers = map[common.Hash]layer{base.rootHash(): base}
+ tree.base = base
+
+ // Reset the layer tree with the single new disk layer
+ tree.layers = map[common.Hash]layer{
+ base.rootHash(): base,
+ }
+ tree.descendants = make(map[common.Hash]map[common.Hash]struct{})
+ tree.lookup = newLookup(base, tree.isDescendant)
return nil
}
// Dive until we run out of layers or reach the persistent database
@@ -146,6 +204,11 @@ func (tree *layerTree) cap(root common.Hash, layers int) error {
}
// We're out of layers, flatten anything below, stopping if it's the disk or if
// the memory limit is not yet exceeded.
+ var (
+ err error
+ replaced layer
+ newBase *diskLayer
+ )
switch parent := diff.parentLayer().(type) {
case *diskLayer:
return nil
@@ -155,14 +218,33 @@ func (tree *layerTree) cap(root common.Hash, layers int) error {
// parent is linked correctly.
diff.lock.Lock()
- base, err := parent.persist(false)
+ // Hold the reference of the original layer being replaced
+ replaced = parent
+
+ // Replace the original parent layer with new disk layer. The procedure
+ // can be illustrated as below:
+ //
+ // Before change:
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C2'->C3'->C4'
+ //
+ // After change:
+ // Chain:
+ // (a) C3->C4 (HEAD)
+ // (b) C1->C2
+ // ->C2'->C3'->C4'
+ // The original C3 is replaced by the new base (with root C3)
+ // Dangling layers in (b) will be removed later
+ newBase, err = parent.persist(false)
if err != nil {
diff.lock.Unlock()
return err
}
- tree.layers[base.rootHash()] = base
- diff.parent = base
+ tree.layers[newBase.rootHash()] = newBase
+ // Link the new parent and release the lock
+ diff.parent = newBase
diff.lock.Unlock()
default:
@@ -176,19 +258,28 @@ func (tree *layerTree) cap(root common.Hash, layers int) error {
children[parent] = append(children[parent], root)
}
}
+ clearDiff := func(layer layer) {
+ diff, ok := layer.(*diffLayer)
+ if !ok {
+ return
+ }
+ tree.lookup.removeLayer(diff)
+ }
var remove func(root common.Hash)
remove = func(root common.Hash) {
+ clearDiff(tree.layers[root])
+
+ // Unlink the layer from the layer tree and cascade to its children
+ delete(tree.descendants, root)
delete(tree.layers, root)
for _, child := range children[root] {
remove(child)
}
delete(children, root)
}
- for root, layer := range tree.layers {
- if dl, ok := layer.(*diskLayer); ok && dl.isStale() {
- remove(root)
- }
- }
+ remove(tree.base.rootHash()) // remove the old/stale disk layer
+ clearDiff(replaced) // remove the lookup data of the stale parent being replaced
+ tree.base = newBase // update the base layer with newly constructed one
return nil
}
@@ -197,17 +288,39 @@ func (tree *layerTree) bottom() *diskLayer {
tree.lock.RLock()
defer tree.lock.RUnlock()
- if len(tree.layers) == 0 {
- return nil // Shouldn't happen, empty tree
+ return tree.base
+}
+
+// lookupAccount returns the layer that is confirmed to contain the account data
+// being searched for.
+func (tree *layerTree) lookupAccount(accountHash common.Hash, state common.Hash) (layer, error) {
+ tree.lock.RLock()
+ defer tree.lock.RUnlock()
+
+ tip := tree.lookup.accountTip(accountHash, state, tree.base.root)
+ if tip == (common.Hash{}) {
+ return nil, fmt.Errorf("[%#x] %w", state, errSnapshotStale)
}
- // pick a random one as the entry point
- var current layer
- for _, layer := range tree.layers {
- current = layer
- break
+ l := tree.layers[tip]
+ if l == nil {
+ return nil, fmt.Errorf("triedb layer [%#x] missing", tip)
}
- for current.parentLayer() != nil {
- current = current.parentLayer()
+ return l, nil
+}
+
+// lookupStorage returns the layer that is confirmed to contain the storage slot
+// data being searched for.
+func (tree *layerTree) lookupStorage(accountHash common.Hash, slotHash common.Hash, state common.Hash) (layer, error) {
+ tree.lock.RLock()
+ defer tree.lock.RUnlock()
+
+ tip := tree.lookup.storageTip(accountHash, slotHash, state, tree.base.root)
+ if tip == (common.Hash{}) {
+ return nil, fmt.Errorf("[%#x] %w", state, errSnapshotStale)
+ }
+ l := tree.layers[tip]
+ if l == nil {
+ return nil, fmt.Errorf("triedb layer [%#x] missing", tip)
}
- return current.(*diskLayer)
+ return l, nil
}
diff --git a/triedb/pathdb/layertree_test.go b/triedb/pathdb/layertree_test.go
new file mode 100644
index 000000000000..8ca090ef80ed
--- /dev/null
+++ b/triedb/pathdb/layertree_test.go
@@ -0,0 +1,885 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package pathdb
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/trie/trienode"
+)
+
+func newTestLayerTree() *layerTree {
+ db := New(rawdb.NewMemoryDatabase(), nil, false)
+ l := newDiskLayer(common.Hash{0x1}, 0, db, nil, nil, newBuffer(0, nil, nil, 0))
+ t := newLayerTree(l)
+ return t
+}
+
+func TestLayerCap(t *testing.T) {
+ var cases = []struct {
+ init func() *layerTree
+ head common.Hash
+ layers int
+ base common.Hash
+ snapshot map[common.Hash]struct{}
+ }{
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ // Chain:
+ // C2->C3->C4 (HEAD)
+ head: common.Hash{0x4},
+ layers: 2,
+ base: common.Hash{0x2},
+ snapshot: map[common.Hash]struct{}{
+ common.Hash{0x2}: {},
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ // Chain:
+ // C3->C4 (HEAD)
+ head: common.Hash{0x4},
+ layers: 1,
+ base: common.Hash{0x3},
+ snapshot: map[common.Hash]struct{}{
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ // Chain:
+ // C4 (HEAD)
+ head: common.Hash{0x4},
+ layers: 0,
+ base: common.Hash{0x4},
+ snapshot: map[common.Hash]struct{}{
+ common.Hash{0x4}: {},
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C2'->C3'->C4'
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2a}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3a}, common.Hash{0x2a}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4a}, common.Hash{0x3a}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x2b}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3b}, common.Hash{0x2b}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4b}, common.Hash{0x3b}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ // Chain:
+ // C2->C3->C4 (HEAD)
+ head: common.Hash{0x4a},
+ layers: 2,
+ base: common.Hash{0x2a},
+ snapshot: map[common.Hash]struct{}{
+ common.Hash{0x4a}: {},
+ common.Hash{0x3a}: {},
+ common.Hash{0x2a}: {},
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C2'->C3'->C4'
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2a}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3a}, common.Hash{0x2a}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4a}, common.Hash{0x3a}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x2b}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3b}, common.Hash{0x2b}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4b}, common.Hash{0x3b}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ // Chain:
+ // C3->C4 (HEAD)
+ head: common.Hash{0x4a},
+ layers: 1,
+ base: common.Hash{0x3a},
+ snapshot: map[common.Hash]struct{}{
+ common.Hash{0x4a}: {},
+ common.Hash{0x3a}: {},
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C3'->C4'
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3a}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4a}, common.Hash{0x3a}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3b}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4b}, common.Hash{0x3b}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ // Chain:
+ // C2->C3->C4 (HEAD)
+ // ->C3'->C4'
+ head: common.Hash{0x4a},
+ layers: 2,
+ base: common.Hash{0x2},
+ snapshot: map[common.Hash]struct{}{
+ common.Hash{0x4a}: {},
+ common.Hash{0x3a}: {},
+ common.Hash{0x4b}: {},
+ common.Hash{0x3b}: {},
+ common.Hash{0x2}: {},
+ },
+ },
+ }
+ for _, c := range cases {
+ tr := c.init()
+ if err := tr.cap(c.head, c.layers); err != nil {
+ t.Fatalf("Failed to cap the layer tree %v", err)
+ }
+ if tr.bottom().root != c.base {
+ t.Fatalf("Unexpected bottom layer tree root, want %v, got %v", c.base, tr.bottom().root)
+ }
+ if len(c.snapshot) != len(tr.layers) {
+ t.Fatalf("Unexpected layer tree size, want %v, got %v", len(c.snapshot), len(tr.layers))
+ }
+ for h := range tr.layers {
+ if _, ok := c.snapshot[h]; !ok {
+ t.Fatalf("Unexpected layer %v", h)
+ }
+ }
+ }
+}
+
+func TestBaseLayer(t *testing.T) {
+ tr := newTestLayerTree()
+
+ var cases = []struct {
+ op func()
+ base common.Hash
+ }{
+ // Chain:
+ // C1 (HEAD)
+ {
+ func() {},
+ common.Hash{0x1},
+ },
+ // Chain:
+ // C1->C2->C3 (HEAD)
+ {
+ func() {
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ },
+ common.Hash{0x1},
+ },
+ // Chain:
+ // C3 (HEAD)
+ {
+ func() {
+ tr.cap(common.Hash{0x3}, 0)
+ },
+ common.Hash{0x3},
+ },
+ // Chain:
+ // C4->C5->C6 (HEAD)
+ {
+ func() {
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x5}, common.Hash{0x4}, 4, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x6}, common.Hash{0x5}, 5, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.cap(common.Hash{0x6}, 2)
+ },
+ common.Hash{0x4},
+ },
+ }
+ for _, c := range cases {
+ c.op()
+ if tr.base.rootHash() != c.base {
+ t.Fatalf("Unexpected base root, want %v, got: %v", c.base, tr.base.rootHash())
+ }
+ }
+}
+
+func TestDescendant(t *testing.T) {
+ var cases = []struct {
+ init func() *layerTree
+ snapshotA map[common.Hash]map[common.Hash]struct{}
+ op func(tr *layerTree)
+ snapshotB map[common.Hash]map[common.Hash]struct{}
+ }{
+ {
+ // Chain:
+ // C1->C2 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2}: {},
+ },
+ },
+ // Chain:
+ // C1->C2->C3 (HEAD)
+ op: func(tr *layerTree) {
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2}: {},
+ common.Hash{0x3}: {},
+ },
+ common.Hash{0x2}: {
+ common.Hash{0x3}: {},
+ },
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2}: {},
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x2}: {
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x3}: {
+ common.Hash{0x4}: {},
+ },
+ },
+ // Chain:
+ // C2->C3->C4 (HEAD)
+ op: func(tr *layerTree) {
+ tr.cap(common.Hash{0x4}, 2)
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x2}: {
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x3}: {
+ common.Hash{0x4}: {},
+ },
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2}: {},
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x2}: {
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x3}: {
+ common.Hash{0x4}: {},
+ },
+ },
+ // Chain:
+ // C3->C4 (HEAD)
+ op: func(tr *layerTree) {
+ tr.cap(common.Hash{0x4}, 1)
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x3}: {
+ common.Hash{0x4}: {},
+ },
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2}: {},
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x2}: {
+ common.Hash{0x3}: {},
+ common.Hash{0x4}: {},
+ },
+ common.Hash{0x3}: {
+ common.Hash{0x4}: {},
+ },
+ },
+ // Chain:
+ // C4 (HEAD)
+ op: func(tr *layerTree) {
+ tr.cap(common.Hash{0x4}, 0)
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{},
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C2'->C3'->C4'
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2a}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3a}, common.Hash{0x2a}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4a}, common.Hash{0x3a}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x2b}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3b}, common.Hash{0x2b}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4b}, common.Hash{0x3b}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2a}: {},
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ common.Hash{0x2b}: {},
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x2a}: {
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x3a}: {
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x2b}: {
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x3b}: {
+ common.Hash{0x4b}: {},
+ },
+ },
+ // Chain:
+ // C2->C3->C4 (HEAD)
+ op: func(tr *layerTree) {
+ tr.cap(common.Hash{0x4a}, 2)
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x2a}: {
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x3a}: {
+ common.Hash{0x4a}: {},
+ },
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C2'->C3'->C4'
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2a}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3a}, common.Hash{0x2a}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4a}, common.Hash{0x3a}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x2b}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3b}, common.Hash{0x2b}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4b}, common.Hash{0x3b}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2a}: {},
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ common.Hash{0x2b}: {},
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x2a}: {
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x3a}: {
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x2b}: {
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x3b}: {
+ common.Hash{0x4b}: {},
+ },
+ },
+ // Chain:
+ // C3->C4 (HEAD)
+ op: func(tr *layerTree) {
+ tr.cap(common.Hash{0x4a}, 1)
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x3a}: {
+ common.Hash{0x4a}: {},
+ },
+ },
+ },
+ {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ // ->C3'->C4'
+ init: func() *layerTree {
+ tr := newTestLayerTree()
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3a}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4a}, common.Hash{0x3a}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x3b}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ tr.add(common.Hash{0x4b}, common.Hash{0x3b}, 3, trienode.NewMergedNodeSet(), NewStateSetWithOrigin(nil, nil, nil, nil))
+ return tr
+ },
+ snapshotA: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x1}: {
+ common.Hash{0x2}: {},
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x2}: {
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x3a}: {
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x3b}: {
+ common.Hash{0x4b}: {},
+ },
+ },
+ // Chain:
+ // C2->C3->C4 (HEAD)
+ // ->C3'->C4'
+ op: func(tr *layerTree) {
+ tr.cap(common.Hash{0x4a}, 2)
+ },
+ snapshotB: map[common.Hash]map[common.Hash]struct{}{
+ common.Hash{0x2}: {
+ common.Hash{0x3a}: {},
+ common.Hash{0x4a}: {},
+ common.Hash{0x3b}: {},
+ common.Hash{0x4b}: {},
+ },
+ common.Hash{0x3a}: {
+ common.Hash{0x4a}: {},
+ },
+ common.Hash{0x3b}: {
+ common.Hash{0x4b}: {},
+ },
+ },
+ },
+ }
+ check := func(setA, setB map[common.Hash]map[common.Hash]struct{}) bool {
+ if len(setA) != len(setB) {
+ return false
+ }
+ for h, subA := range setA {
+ subB, ok := setB[h]
+ if !ok {
+ return false
+ }
+ if len(subA) != len(subB) {
+ return false
+ }
+ for hh := range subA {
+ if _, ok := subB[hh]; !ok {
+ return false
+ }
+ }
+ }
+ return true
+ }
+ for _, c := range cases {
+ tr := c.init()
+ if !check(c.snapshotA, tr.descendants) {
+ t.Fatalf("Unexpected descendants")
+ }
+ c.op(tr)
+ if !check(c.snapshotB, tr.descendants) {
+ t.Fatalf("Unexpected descendants")
+ }
+ }
+}
+
+func TestAccountLookup(t *testing.T) {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ tr := newTestLayerTree() // base = 0x1
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(),
+ NewStateSetWithOrigin(randomAccountSet("0xa"), nil, nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(),
+ NewStateSetWithOrigin(randomAccountSet("0xb"), nil, nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(),
+ NewStateSetWithOrigin(randomAccountSet("0xa", "0xc"), nil, nil, nil))
+
+ var cases = []struct {
+ account common.Hash
+ state common.Hash
+ expect common.Hash
+ }{
+ {
+ // unknown account
+ common.HexToHash("0xd"), common.Hash{0x4}, common.Hash{0x1},
+ },
+ /*
+ lookup account from the top
+ */
+ {
+ common.HexToHash("0xa"), common.Hash{0x4}, common.Hash{0x4},
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x4}, common.Hash{0x3},
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x4}, common.Hash{0x4},
+ },
+ /*
+ lookup account from the middle
+ */
+ {
+ common.HexToHash("0xa"), common.Hash{0x3}, common.Hash{0x2},
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x3}, common.Hash{0x3},
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x3}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0xa"), common.Hash{0x2}, common.Hash{0x2},
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x2}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x2}, common.Hash{0x1}, // not found
+ },
+ /*
+ lookup account from the bottom
+ */
+ {
+ common.HexToHash("0xa"), common.Hash{0x1}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x1}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x1}, common.Hash{0x1}, // not found
+ },
+ }
+ for i, c := range cases {
+ l, err := tr.lookupAccount(c.account, c.state)
+ if err != nil {
+ t.Fatalf("%d: %v", i, err)
+ }
+ if l.rootHash() != c.expect {
+ t.Errorf("Unexpected tiphash, %d, want: %x, got: %x", i, c.expect, l.rootHash())
+ }
+ }
+
+ // Chain:
+ // C3->C4 (HEAD)
+ tr.cap(common.Hash{0x4}, 1)
+
+ cases2 := []struct {
+ account common.Hash
+ state common.Hash
+ expect common.Hash
+ expectErr error
+ }{
+ {
+ // unknown account
+ common.HexToHash("0xd"), common.Hash{0x4}, common.Hash{0x3}, nil,
+ },
+ /*
+ lookup account from the top
+ */
+ {
+ common.HexToHash("0xa"), common.Hash{0x4}, common.Hash{0x4}, nil,
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x4}, common.Hash{0x3}, nil,
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x4}, common.Hash{0x4}, nil,
+ },
+ /*
+ lookup account from the bottom
+ */
+ {
+ common.HexToHash("0xa"), common.Hash{0x3}, common.Hash{0x3}, nil,
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x3}, common.Hash{0x3}, nil,
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x3}, common.Hash{0x3}, nil, // not found
+ },
+ /*
+ stale states
+ */
+ {
+ common.HexToHash("0xa"), common.Hash{0x2}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x2}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x2}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0xa"), common.Hash{0x1}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0xb"), common.Hash{0x1}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0xc"), common.Hash{0x1}, common.Hash{}, errSnapshotStale,
+ },
+ }
+ for i, c := range cases2 {
+ l, err := tr.lookupAccount(c.account, c.state)
+ if c.expectErr != nil {
+ if !errors.Is(err, c.expectErr) {
+ t.Fatalf("%d: unexpected error, want %v, got %v", i, c.expectErr, err)
+ }
+ }
+ if c.expectErr == nil {
+ if err != nil {
+ t.Fatalf("%d: %v", i, err)
+ }
+ if l.rootHash() != c.expect {
+ t.Errorf("Unexpected tiphash, %d, want: %x, got: %x", i, c.expect, l.rootHash())
+ }
+ }
+ }
+}
+
+func TestStorageLookup(t *testing.T) {
+ // Chain:
+ // C1->C2->C3->C4 (HEAD)
+ tr := newTestLayerTree() // base = 0x1
+ tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, trienode.NewMergedNodeSet(),
+ NewStateSetWithOrigin(randomAccountSet("0xa"), randomStorageSet([]string{"0xa"}, [][]string{{"0x1"}}, nil), nil, nil))
+ tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, trienode.NewMergedNodeSet(),
+ NewStateSetWithOrigin(randomAccountSet("0xa"), randomStorageSet([]string{"0xa"}, [][]string{{"0x2"}}, nil), nil, nil))
+ tr.add(common.Hash{0x4}, common.Hash{0x3}, 3, trienode.NewMergedNodeSet(),
+ NewStateSetWithOrigin(randomAccountSet("0xa"), randomStorageSet([]string{"0xa"}, [][]string{{"0x1", "0x3"}}, nil), nil, nil))
+
+ var cases = []struct {
+ storage common.Hash
+ state common.Hash
+ expect common.Hash
+ }{
+ {
+ // unknown storage slot
+ common.HexToHash("0x4"), common.Hash{0x4}, common.Hash{0x1},
+ },
+ /*
+ lookup storage slot from the top
+ */
+ {
+ common.HexToHash("0x1"), common.Hash{0x4}, common.Hash{0x4},
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x4}, common.Hash{0x3},
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x4}, common.Hash{0x4},
+ },
+ /*
+ lookup storage slot from the middle
+ */
+ {
+ common.HexToHash("0x1"), common.Hash{0x3}, common.Hash{0x2},
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x3}, common.Hash{0x3},
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x3}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0x1"), common.Hash{0x2}, common.Hash{0x2},
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x2}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x2}, common.Hash{0x1}, // not found
+ },
+ /*
+ lookup storage slot from the bottom
+ */
+ {
+ common.HexToHash("0x1"), common.Hash{0x1}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x1}, common.Hash{0x1}, // not found
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x1}, common.Hash{0x1}, // not found
+ },
+ }
+ for i, c := range cases {
+ l, err := tr.lookupStorage(common.HexToHash("0xa"), c.storage, c.state)
+ if err != nil {
+ t.Fatalf("%d: %v", i, err)
+ }
+ if l.rootHash() != c.expect {
+ t.Errorf("Unexpected tiphash, %d, want: %x, got: %x", i, c.expect, l.rootHash())
+ }
+ }
+
+ // Chain:
+ // C3->C4 (HEAD)
+ tr.cap(common.Hash{0x4}, 1)
+
+ cases2 := []struct {
+ storage common.Hash
+ state common.Hash
+ expect common.Hash
+ expectErr error
+ }{
+ {
+ // unknown storage slot
+ common.HexToHash("0x4"), common.Hash{0x4}, common.Hash{0x3}, nil,
+ },
+ /*
+ lookup account from the top
+ */
+ {
+ common.HexToHash("0x1"), common.Hash{0x4}, common.Hash{0x4}, nil,
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x4}, common.Hash{0x3}, nil,
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x4}, common.Hash{0x4}, nil,
+ },
+ /*
+ lookup account from the bottom
+ */
+ {
+ common.HexToHash("0x1"), common.Hash{0x3}, common.Hash{0x3}, nil,
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x3}, common.Hash{0x3}, nil,
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x3}, common.Hash{0x3}, nil, // not found
+ },
+ /*
+ stale states
+ */
+ {
+ common.HexToHash("0x1"), common.Hash{0x2}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x2}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x2}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0x1"), common.Hash{0x1}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0x2"), common.Hash{0x1}, common.Hash{}, errSnapshotStale,
+ },
+ {
+ common.HexToHash("0x3"), common.Hash{0x1}, common.Hash{}, errSnapshotStale,
+ },
+ }
+ for i, c := range cases2 {
+ l, err := tr.lookupStorage(common.HexToHash("0xa"), c.storage, c.state)
+ if c.expectErr != nil {
+ if !errors.Is(err, c.expectErr) {
+ t.Fatalf("%d: unexpected error, want %v, got %v", i, c.expectErr, err)
+ }
+ }
+ if c.expectErr == nil {
+ if err != nil {
+ t.Fatalf("%d: %v", i, err)
+ }
+ if l.rootHash() != c.expect {
+ t.Errorf("Unexpected tiphash, %d, want: %x, got: %x", i, c.expect, l.rootHash())
+ }
+ }
+ }
+}
diff --git a/triedb/pathdb/lookup.go b/triedb/pathdb/lookup.go
new file mode 100644
index 000000000000..a43621db77dd
--- /dev/null
+++ b/triedb/pathdb/lookup.go
@@ -0,0 +1,281 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package pathdb
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "golang.org/x/sync/errgroup"
+)
+
+// slicePool is a shared pool of hash slice, for reducing the GC pressure.
+var slicePool = sync.Pool{
+ New: func() interface{} {
+ slice := make([]common.Hash, 0, 16) // Pre-allocate a slice with a reasonable capacity.
+ return &slice
+ },
+}
+
+// getSlice obtains the hash slice from the shared pool.
+func getSlice() []common.Hash {
+ slice := *slicePool.Get().(*[]common.Hash)
+ slice = slice[:0]
+ return slice
+}
+
+// returnSlice returns the hash slice back to the shared pool for following usage.
+func returnSlice(slice []common.Hash) {
+ slicePool.Put(&slice)
+}
+
+// lookup is an internal structure used to efficiently determine the layer in
+// which a state entry resides.
+type lookup struct {
+ accounts map[common.Hash][]common.Hash
+ storages map[common.Hash]map[common.Hash][]common.Hash
+ descendant func(state common.Hash, ancestor common.Hash) bool
+}
+
+// newLookup initializes the lookup structure.
+func newLookup(head layer, descendant func(state common.Hash, ancestor common.Hash) bool) *lookup {
+ var (
+ current = head
+ layers []layer
+ )
+ for current != nil {
+ layers = append(layers, current)
+ current = current.parentLayer()
+ }
+ l := &lookup{
+ accounts: make(map[common.Hash][]common.Hash),
+ storages: make(map[common.Hash]map[common.Hash][]common.Hash),
+ descendant: descendant,
+ }
+ // Apply the diff layers from bottom to top
+ for i := len(layers) - 1; i >= 0; i-- {
+ switch diff := layers[i].(type) {
+ case *diskLayer:
+ continue
+ case *diffLayer:
+ l.addLayer(diff)
+ }
+ }
+ return l
+}
+
+// accountTip traverses the layer list associated with the given account in
+// reverse order to locate the first entry that either matches the specified
+// stateID or is a descendant of it.
+//
+// If found, the account data corresponding to the supplied stateID resides
+// in that layer. Otherwise, two scenarios are possible:
+//
+// The account remains unmodified from the current disk layer up to the state
+// layer specified by the stateID: fallback to the disk layer for data retrieval.
+// Or the layer specified by the stateID is stale: reject the data retrieval.
+func (l *lookup) accountTip(accountHash common.Hash, stateID common.Hash, base common.Hash) common.Hash {
+ list := l.accounts[accountHash]
+ for i := len(list) - 1; i >= 0; i-- {
+ if list[i] == stateID || l.descendant(stateID, list[i]) {
+ return list[i]
+ }
+ }
+ // No layer matching the stateID or its descendants was found. Use the
+ // current disk layer as a fallback.
+ if base == stateID || l.descendant(stateID, base) {
+ return base
+ }
+ // The layer associated with 'stateID' is not the descendant of the current
+ // disk layer, it's already stale, return nothing.
+ return common.Hash{}
+}
+
+// storageTip traverses the layer list associated with the given account and
+// slot hash in reverse order to locate the first entry that either matches
+// the specified stateID or is a descendant of it.
+//
+// If found, the storage data corresponding to the supplied stateID resides
+// in that layer. Otherwise, two scenarios are possible:
+//
+// The storage slot remains unmodified from the current disk layer up to the
+// state layer specified by the stateID: fallback to the disk layer for data
+// retrieval. Or the layer specified by the stateID is stale: reject the data
+// retrieval.
+func (l *lookup) storageTip(accountHash common.Hash, slotHash common.Hash, stateID common.Hash, base common.Hash) common.Hash {
+ subset, exists := l.storages[accountHash]
+ if exists {
+ list := subset[slotHash]
+ for i := len(list) - 1; i >= 0; i-- {
+ if list[i] == stateID || l.descendant(stateID, list[i]) {
+ return list[i]
+ }
+ }
+ }
+ // No layer matching the stateID or its descendants was found. Use the
+ // current disk layer as a fallback.
+ if base == stateID || l.descendant(stateID, base) {
+ return base
+ }
+ // The layer associated with 'stateID' is not the descendant of the current
+ // disk layer, it's already stale, return nothing.
+ return common.Hash{}
+}
+
+// addLayer traverses the state data retained in the specified diff layer and
+// integrates it into the lookup set.
+//
+// This function assumes that all layers older than the provided one have already
+// been processed, ensuring that layers are processed strictly in a bottom-to-top
+// order.
+func (l *lookup) addLayer(diff *diffLayer) {
+ defer func(now time.Time) {
+ lookupAddLayerTimer.UpdateSince(now)
+ }(time.Now())
+
+ var (
+ wg sync.WaitGroup
+ state = diff.rootHash()
+ )
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for accountHash := range diff.states.accountData {
+ list, exists := l.accounts[accountHash]
+ if !exists {
+ list = getSlice()
+ }
+ list = append(list, state)
+ l.accounts[accountHash] = list
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for accountHash, slots := range diff.states.storageData {
+ subset := l.storages[accountHash]
+ if subset == nil {
+ subset = make(map[common.Hash][]common.Hash)
+ l.storages[accountHash] = subset
+ }
+ for slotHash := range slots {
+ list, exists := subset[slotHash]
+ if !exists {
+ list = getSlice()
+ }
+ list = append(list, state)
+ subset[slotHash] = list
+ }
+ }
+ }()
+ wg.Wait()
+}
+
+// removeLayer traverses the state data retained in the specified diff layer and
+// unlink them from the lookup set.
+func (l *lookup) removeLayer(diff *diffLayer) error {
+ defer func(now time.Time) {
+ lookupRemoveLayerTimer.UpdateSince(now)
+ }(time.Now())
+
+ var (
+ wg errgroup.Group
+ state = diff.rootHash()
+ )
+ wg.Go(func() error {
+ for accountHash := range diff.states.accountData {
+ var (
+ found bool
+ list = l.accounts[accountHash]
+ )
+ // Traverse the list from oldest to newest to quickly locate the ID
+ // of the stale layer.
+ for i := 0; i < len(list); i++ {
+ if list[i] == state {
+ if i == 0 {
+ list = list[1:]
+ if cap(list) > 1024 {
+ list = append(getSlice(), list...)
+ }
+ } else {
+ list = append(list[:i], list[i+1:]...)
+ }
+ found = true
+ break
+ }
+ }
+ if !found {
+ return fmt.Errorf("account lookup is not found, %x, state: %x", accountHash, state)
+ }
+ if len(list) != 0 {
+ l.accounts[accountHash] = list
+ } else {
+ returnSlice(list)
+ delete(l.accounts, accountHash)
+ }
+ }
+ return nil
+ })
+
+ wg.Go(func() error {
+ for accountHash, slots := range diff.states.storageData {
+ subset := l.storages[accountHash]
+ if subset == nil {
+ return fmt.Errorf("storage lookup is not found, %x", accountHash)
+ }
+ for slotHash := range slots {
+ var (
+ found bool
+ list = subset[slotHash]
+ )
+ // Traverse the list from oldest to newest to quickly locate the ID
+ // of the stale layer.
+ for i := 0; i < len(list); i++ {
+ if list[i] == state {
+ if i == 0 {
+ list = list[1:]
+ if cap(list) > 1024 {
+ list = append(getSlice(), list...)
+ }
+ } else {
+ list = append(list[:i], list[i+1:]...)
+ }
+ found = true
+ break
+ }
+ }
+ if !found {
+ return fmt.Errorf("storage lookup is not found, %x %x, state: %x", accountHash, slotHash, state)
+ }
+ if len(list) != 0 {
+ subset[slotHash] = list
+ } else {
+ returnSlice(subset[slotHash])
+ delete(subset, slotHash)
+ }
+ }
+ if len(subset) == 0 {
+ delete(l.storages, accountHash)
+ }
+ }
+ return nil
+ })
+ return wg.Wait()
+}
diff --git a/triedb/pathdb/metrics.go b/triedb/pathdb/metrics.go
index b2b849157cd6..a9c78ad82e50 100644
--- a/triedb/pathdb/metrics.go
+++ b/triedb/pathdb/metrics.go
@@ -69,9 +69,12 @@ var (
gcStorageMeter = metrics.NewRegisteredMeter("pathdb/gc/storage/count", nil)
gcStorageBytesMeter = metrics.NewRegisteredMeter("pathdb/gc/storage/bytes", nil)
- historyBuildTimeMeter = metrics.NewRegisteredTimer("pathdb/history/time", nil)
+ historyBuildTimeMeter = metrics.NewRegisteredResettingTimer("pathdb/history/time", nil)
historyDataBytesMeter = metrics.NewRegisteredMeter("pathdb/history/bytes/data", nil)
historyIndexBytesMeter = metrics.NewRegisteredMeter("pathdb/history/bytes/index", nil)
+
+ lookupAddLayerTimer = metrics.NewRegisteredResettingTimer("pathdb/lookup/add/time", nil)
+ lookupRemoveLayerTimer = metrics.NewRegisteredResettingTimer("pathdb/lookup/remove/time", nil)
)
// Metrics in generation
diff --git a/triedb/pathdb/reader.go b/triedb/pathdb/reader.go
index a404409035b6..deb2d2e97528 100644
--- a/triedb/pathdb/reader.go
+++ b/triedb/pathdb/reader.go
@@ -50,8 +50,10 @@ func (loc *nodeLoc) string() string {
// reader implements the database.NodeReader interface, providing the functionalities to
// retrieve trie nodes by wrapping the internal state layer.
type reader struct {
- layer layer
+ db *Database
+ state common.Hash
noHashCheck bool
+ layer layer
}
// Node implements database.NodeReader interface, retrieving the node with specified
@@ -94,7 +96,11 @@ func (r *reader) Node(owner common.Hash, path []byte, hash common.Hash) ([]byte,
// - the returned account data is not a copy, please don't modify it
// - no error will be returned if the requested account is not found in database
func (r *reader) AccountRLP(hash common.Hash) ([]byte, error) {
- return r.layer.account(hash, 0)
+ l, err := r.db.tree.lookupAccount(hash, r.state)
+ if err != nil {
+ return nil, err
+ }
+ return l.account(hash, 0)
}
// Account directly retrieves the account associated with a particular hash in
@@ -105,7 +111,7 @@ func (r *reader) AccountRLP(hash common.Hash) ([]byte, error) {
// - the returned account object is safe to modify
// - no error will be returned if the requested account is not found in database
func (r *reader) Account(hash common.Hash) (*types.SlimAccount, error) {
- blob, err := r.layer.account(hash, 0)
+ blob, err := r.AccountRLP(hash)
if err != nil {
return nil, err
}
@@ -127,7 +133,11 @@ func (r *reader) Account(hash common.Hash) (*types.SlimAccount, error) {
// - the returned storage data is not a copy, please don't modify it
// - no error will be returned if the requested slot is not found in database
func (r *reader) Storage(accountHash, storageHash common.Hash) ([]byte, error) {
- return r.layer.storage(accountHash, storageHash, 0)
+ l, err := r.db.tree.lookupStorage(accountHash, storageHash, r.state)
+ if err != nil {
+ return nil, err
+ }
+ return l.storage(accountHash, storageHash, 0)
}
// NodeReader retrieves a layer belonging to the given state root.
@@ -136,7 +146,12 @@ func (db *Database) NodeReader(root common.Hash) (database.NodeReader, error) {
if layer == nil {
return nil, fmt.Errorf("state %#x is not available", root)
}
- return &reader{layer: layer, noHashCheck: db.isVerkle}, nil
+ return &reader{
+ db: db,
+ state: root,
+ noHashCheck: db.isVerkle,
+ layer: layer,
+ }, nil
}
// StateReader returns a reader that allows access to the state data associated
@@ -146,5 +161,9 @@ func (db *Database) StateReader(root common.Hash) (database.StateReader, error)
if layer == nil {
return nil, fmt.Errorf("state %#x is not available", root)
}
- return &reader{layer: layer}, nil
+ return &reader{
+ db: db,
+ state: root,
+ layer: layer,
+ }, nil
}