Skip to content

Commit

Permalink
bugfix: fix state prefetcher concurrent bugs;
Browse files Browse the repository at this point in the history
bugfix: fix trie epoch update bugs;
  • Loading branch information
0xbundler committed Sep 22, 2023
1 parent 2e0ee20 commit 67ab312
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 26 deletions.
18 changes: 9 additions & 9 deletions consensus/parlia/parlia.go
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ func (p *Parlia) Finalize(chain consensus.ChainHeaderReader, header *types.Heade
err = p.slash(spoiledVal, state, header, cx, txs, receipts, systemTxs, usedGas, false)
if err != nil {
// it is possible that slash validator failed because of the slash channel is disabled.
log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal)
log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal, "err", err)
}
}
}
Expand Down Expand Up @@ -1164,7 +1164,7 @@ func (p *Parlia) FinalizeAndAssemble(chain consensus.ChainHeaderReader, header *
err = p.slash(spoiledVal, state, header, cx, &txs, &receipts, nil, &header.GasUsed, true)
if err != nil {
// it is possible that slash validator failed because of the slash channel is disabled.
log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal)
log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal, "err", err)
}
}
}
Expand Down Expand Up @@ -1673,13 +1673,13 @@ func (p *Parlia) applyTransaction(
}
actualTx := (*receivedTxs)[0]
if !bytes.Equal(p.signer.Hash(actualTx).Bytes(), expectedHash.Bytes()) {
return fmt.Errorf("expected tx hash %v, get %v, nonce %d, to %s, value %s, gas %d, gasPrice %s, data %s", expectedHash.String(), actualTx.Hash().String(),
expectedTx.Nonce(),
expectedTx.To().String(),
expectedTx.Value().String(),
expectedTx.Gas(),
expectedTx.GasPrice().String(),
hex.EncodeToString(expectedTx.Data()),
return fmt.Errorf("expected tx hash %v, get %v, nonce %d:%d, to %s:%s, value %s:%s, gas %d:%d, gasPrice %s:%s, data %s:%s", expectedHash.String(), actualTx.Hash().String(),
expectedTx.Nonce(), actualTx.Nonce(),
expectedTx.To().String(), actualTx.To().String(),
expectedTx.Value().String(), actualTx.Value().String(),
expectedTx.Gas(), actualTx.Gas(),
expectedTx.GasPrice().String(), actualTx.GasPrice().String(),
hex.EncodeToString(expectedTx.Data()), hex.EncodeToString(actualTx.Data()),
)
}
expectedTx = actualTx
Expand Down
10 changes: 6 additions & 4 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ func (s *stateObject) updateTrie() (Trie, error) {
s.db.setError(fmt.Errorf("state object update trie UpdateStorage err, contract: %v, key: %v, err: %v", s.address, key, err))
}
s.db.StorageUpdated += 1
log.Debug("updateTrie UpdateStorage", "contract", s.address, "key", key, "epoch", s.db.epoch, "value", value, "tr", tr.Epoch())
}
// Cache the items for preloading
usedStorage = append(usedStorage, common.CopyBytes(key[:]))
Expand Down Expand Up @@ -505,7 +504,6 @@ func (s *stateObject) updateTrie() (Trie, error) {
snapshotVal, _ = rlp.EncodeToBytes(value)
}
storage[khash] = snapshotVal // snapshotVal will be nil if it's deleted
log.Debug("updateTrie snapshot", "contract", s.address, "key", key, "epoch", s.db.epoch, "value", snapshotVal)

// Track the original value of slot only if it's mutated first time
prev := s.originStorage[key]
Expand Down Expand Up @@ -810,7 +808,6 @@ func (s *stateObject) fetchExpiredFromRemote(prefixKey []byte, key common.Hash,
prefixKey = enErr.Path
}

log.Info("fetchExpiredStorageFromRemote in stateDB", "addr", s.address, "prefixKey", prefixKey, "key", key, "tr", fmt.Sprintf("%p", tr))
kvs, err := fetchExpiredStorageFromRemote(s.db.fullStateDB, s.db.originalRoot, s.address, s.data.Root, tr, prefixKey, key)

if err != nil {
Expand Down Expand Up @@ -850,11 +847,16 @@ func (s *stateObject) getExpirySnapStorage(key common.Hash) (snapshot.SnapValue,
return val, nil, nil
}

// TODO(0xbundler): if found value not been pruned, just return
//if len(val.GetVal()) > 0 {
// return val, nil, nil
//}

// handle from remoteDB, if got err just setError, just return to revert in consensus version.
valRaw, err := s.fetchExpiredFromRemote(nil, key, true)
if err != nil {
return nil, nil, err
}

return snapshot.NewValueWithEpoch(s.db.epoch, valRaw), nil, nil
return snapshot.NewValueWithEpoch(val.GetEpoch(), valRaw), nil, nil
}
2 changes: 0 additions & 2 deletions core/state/trie_prefetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package state

import (
"fmt"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"sync"
Expand Down Expand Up @@ -589,7 +588,6 @@ func (sf *subfetcher) loop() {
if sf.enableStateExpiry {
if exErr, match := err.(*trie2.ExpiredNodeError); match {
key := common.BytesToHash(task)
log.Debug("fetchExpiredStorageFromRemote in trie prefetcher", "addr", sf.addr, "prefixKey", exErr.Path, "key", key, "tr", fmt.Sprintf("%p", sf.trie))
_, err = fetchExpiredStorageFromRemote(sf.fullStateDB, sf.state, sf.addr, sf.root, sf.trie, exErr.Path, key)
if err != nil {
log.Error("subfetcher fetchExpiredStorageFromRemote err", "addr", sf.addr, "path", exErr.Path, "err", err)
Expand Down
9 changes: 7 additions & 2 deletions core/state_prefetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ func (p *statePrefetcher) Prefetch(block *types.Block, statedb *state.StateDB, c
for i := 0; i < prefetchThread; i++ {
go func() {
newStatedb := statedb.CopyDoPrefetch()
newStatedb.EnableWriteOnSharedStorage()
if !statedb.EnableExpire() {
newStatedb.EnableWriteOnSharedStorage()
}
gaspool := new(GasPool).AddGas(block.GasLimit())
blockContext := NewEVMBlockContext(header, p.bc, nil)
evm := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, *cfg)
Expand Down Expand Up @@ -106,7 +108,10 @@ func (p *statePrefetcher) PrefetchMining(txs TransactionsByPriceAndNonce, header
go func(startCh <-chan *types.Transaction, stopCh <-chan struct{}) {
idx := 0
newStatedb := statedb.CopyDoPrefetch()
newStatedb.EnableWriteOnSharedStorage()
// TODO(0xbundler): access empty in trie cause shared concurrent bug? opt later
if !statedb.EnableExpire() {
newStatedb.EnableWriteOnSharedStorage()
}
gaspool := new(GasPool).AddGas(gasLimit)
blockContext := NewEVMBlockContext(header, p.bc, nil)
evm := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, cfg)
Expand Down
18 changes: 18 additions & 0 deletions trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package trie

import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/rlp"
Expand Down Expand Up @@ -906,6 +907,23 @@ func (m *MPTProofNub) GetValue() []byte {
return nil
}

func (m *MPTProofNub) String() string {
buf := bytes.NewBuffer(nil)
buf.WriteString("n1: ")
buf.WriteString(hex.EncodeToString(m.n1PrefixKey))
buf.WriteString(", n1proof: ")
if m.n1 != nil {
buf.WriteString(m.n1.fstring(""))
}
buf.WriteString(", n2: ")
buf.WriteString(hex.EncodeToString(m.n2PrefixKey))
buf.WriteString(", n2proof: ")
if m.n2 != nil {
buf.WriteString(m.n2.fstring(""))
}
return buf.String()
}

func getNubValue(origin node, prefixKey []byte) []byte {
switch n := origin.(type) {
case nil, hashNode:
Expand Down
33 changes: 24 additions & 9 deletions trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type Trie struct {
unhashed int

// reader is the handler trie can retrieve nodes from.
reader *trieReader // TODO (asyukii): create a reader for state expiry metadata
reader *trieReader

// tracer is the tool to track the trie changes.
// It will be reset after each commit operation.
Expand Down Expand Up @@ -211,6 +211,7 @@ func (t *Trie) GetAndUpdateEpoch(key []byte) (value []byte, err error) {

if err == nil && didResolve {
t.root = newroot
t.rootEpoch = t.currentEpoch
}
return value, err
}
Expand Down Expand Up @@ -324,8 +325,8 @@ func (t *Trie) updateChildNodeEpoch(origNode node, key []byte, pos int, epoch ty
n = n.copy()
n.Val = newnode
n.setEpoch(t.currentEpoch)
n.flags = t.newFlag()
}

return n, true, err
case *fullNode:
newnode, updateEpoch, err = t.updateChildNodeEpoch(n.Children[key[pos]], key, pos+1, epoch)
Expand All @@ -334,6 +335,7 @@ func (t *Trie) updateChildNodeEpoch(origNode node, key []byte, pos int, epoch ty
n.Children[key[pos]] = newnode
n.setEpoch(t.currentEpoch)
n.UpdateChildEpoch(int(key[pos]), t.currentEpoch)
n.flags = t.newFlag()
}
return n, true, err
case hashNode:
Expand Down Expand Up @@ -627,12 +629,14 @@ func (t *Trie) insertWithEpoch(n node, prefix, key []byte, value node, epoch typ

// Replace this shortNode with the branch if it occurs at index 0.
if matchlen == 0 {
t.tracer.onExpandToBranchNode(prefix)
return true, branch, nil
}
// New branch node is created as a child of the original short node.
// Track the newly inserted node in the tracer. The node identifier
// passed is the path from the root node.
t.tracer.onInsert(append(prefix, key[:matchlen]...))
t.tracer.onExpandToBranchNode(append(prefix, key[:matchlen]...))

// Replace it with a short node leading up to the branch.
return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil
Expand Down Expand Up @@ -908,6 +912,8 @@ func (t *Trie) deleteWithEpoch(n node, prefix, key []byte, epoch types.StateEpoc
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
n.setEpoch(t.currentEpoch)
n.UpdateChildEpoch(int(key[0]), t.currentEpoch)

// Because n is a full node, it must've contained at least two children
// before the delete operation. If the new child value is non-nil, n still
Expand Down Expand Up @@ -990,7 +996,7 @@ func (t *Trie) deleteWithEpoch(n node, prefix, key []byte, epoch types.StateEpoc
}

dirty, nn, err := t.deleteWithEpoch(rn, prefix, key, epoch)
if !dirty || err != nil {
if !t.renewNode(epoch, dirty, true) || err != nil {
return false, rn, err
}
return true, nn, nil
Expand Down Expand Up @@ -1226,8 +1232,15 @@ func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProo
if err != nil {
return nil, false, fmt.Errorf("update child node epoch while reviving failed, err: %v", err)
}
n1 = n1.copy()
n1.Val = newnode
return n1, true, nil
n1.flags = t.newFlag()
tryUpdateNodeEpoch(nub.n1, t.currentEpoch)
renew, _, err := t.updateChildNodeEpoch(nub.n1, key, pos, t.currentEpoch)
if err != nil {
return nil, false, fmt.Errorf("update child node epoch while reviving failed, err: %v", err)
}
return renew, true, nil
}

tryUpdateNodeEpoch(nub.n1, t.currentEpoch)
Expand All @@ -1253,6 +1266,7 @@ func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProo
n = n.copy()
n.Val = newNode
n.setEpoch(t.currentEpoch)
n.flags = t.newFlag()
}
return n, didRevive, err
case *fullNode:
Expand All @@ -1264,6 +1278,7 @@ func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProo
n.Children[childIndex] = newNode
n.setEpoch(t.currentEpoch)
n.UpdateChildEpoch(childIndex, t.currentEpoch)
n.flags = t.newFlag()
}

if e, ok := err.(*ExpiredNodeError); ok {
Expand Down Expand Up @@ -1384,13 +1399,13 @@ func (t *Trie) renewNode(epoch types.StateEpoch, childDirty bool, updateEpoch bo
return childDirty
}

// when no epoch update, same as before
if epoch == t.currentEpoch {
return childDirty
// node need update epoch, just renew
if t.currentEpoch > epoch {
return true
}

// node need update epoch, just renew
return true
// when no epoch update, same as before
return childDirty
}

func (t *Trie) epochExpired(n node, epoch types.StateEpoch) bool {
Expand Down

0 comments on commit 67ab312

Please sign in to comment.