From 67ab312c87a1fb3bdb4c4dfd6bb9b6d95a245b17 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:50:58 +0800 Subject: [PATCH] bugfix: fix state prefetcher concurrent bugs; bugfix: fix trie epoch update bugs; --- consensus/parlia/parlia.go | 18 +++++++++--------- core/state/state_object.go | 10 ++++++---- core/state/trie_prefetcher.go | 2 -- core/state_prefetcher.go | 9 +++++++-- trie/proof.go | 18 ++++++++++++++++++ trie/trie.go | 33 ++++++++++++++++++++++++--------- 6 files changed, 64 insertions(+), 26 deletions(-) diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index acfa1816f7..4e969d1e06 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -1103,7 +1103,7 @@ func (p *Parlia) Finalize(chain consensus.ChainHeaderReader, header *types.Heade err = p.slash(spoiledVal, state, header, cx, txs, receipts, systemTxs, usedGas, false) if err != nil { // it is possible that slash validator failed because of the slash channel is disabled. - log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal) + log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal, "err", err) } } } @@ -1164,7 +1164,7 @@ func (p *Parlia) FinalizeAndAssemble(chain consensus.ChainHeaderReader, header * err = p.slash(spoiledVal, state, header, cx, &txs, &receipts, nil, &header.GasUsed, true) if err != nil { // it is possible that slash validator failed because of the slash channel is disabled. - log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal) + log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal, "err", err) } } } @@ -1673,13 +1673,13 @@ func (p *Parlia) applyTransaction( } actualTx := (*receivedTxs)[0] if !bytes.Equal(p.signer.Hash(actualTx).Bytes(), expectedHash.Bytes()) { - return fmt.Errorf("expected tx hash %v, get %v, nonce %d, to %s, value %s, gas %d, gasPrice %s, data %s", expectedHash.String(), actualTx.Hash().String(), - expectedTx.Nonce(), - expectedTx.To().String(), - expectedTx.Value().String(), - expectedTx.Gas(), - expectedTx.GasPrice().String(), - hex.EncodeToString(expectedTx.Data()), + return fmt.Errorf("expected tx hash %v, get %v, nonce %d:%d, to %s:%s, value %s:%s, gas %d:%d, gasPrice %s:%s, data %s:%s", expectedHash.String(), actualTx.Hash().String(), + expectedTx.Nonce(), actualTx.Nonce(), + expectedTx.To().String(), actualTx.To().String(), + expectedTx.Value().String(), actualTx.Value().String(), + expectedTx.Gas(), actualTx.Gas(), + expectedTx.GasPrice().String(), actualTx.GasPrice().String(), + hex.EncodeToString(expectedTx.Data()), hex.EncodeToString(actualTx.Data()), ) } expectedTx = actualTx diff --git a/core/state/state_object.go b/core/state/state_object.go index e80839f621..2eb184677b 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -469,7 +469,6 @@ func (s *stateObject) updateTrie() (Trie, error) { s.db.setError(fmt.Errorf("state object update trie UpdateStorage err, contract: %v, key: %v, err: %v", s.address, key, err)) } s.db.StorageUpdated += 1 - log.Debug("updateTrie UpdateStorage", "contract", s.address, "key", key, "epoch", s.db.epoch, "value", value, "tr", tr.Epoch()) } // Cache the items for preloading usedStorage = append(usedStorage, common.CopyBytes(key[:])) @@ -505,7 +504,6 @@ func (s *stateObject) updateTrie() (Trie, error) { snapshotVal, _ = rlp.EncodeToBytes(value) } storage[khash] = snapshotVal // snapshotVal will be nil if it's deleted - log.Debug("updateTrie snapshot", "contract", s.address, "key", key, "epoch", s.db.epoch, "value", snapshotVal) // Track the original value of slot only if it's mutated first time prev := s.originStorage[key] @@ -810,7 +808,6 @@ func (s *stateObject) fetchExpiredFromRemote(prefixKey []byte, key common.Hash, prefixKey = enErr.Path } - log.Info("fetchExpiredStorageFromRemote in stateDB", "addr", s.address, "prefixKey", prefixKey, "key", key, "tr", fmt.Sprintf("%p", tr)) kvs, err := fetchExpiredStorageFromRemote(s.db.fullStateDB, s.db.originalRoot, s.address, s.data.Root, tr, prefixKey, key) if err != nil { @@ -850,11 +847,16 @@ func (s *stateObject) getExpirySnapStorage(key common.Hash) (snapshot.SnapValue, return val, nil, nil } + // TODO(0xbundler): if found value not been pruned, just return + //if len(val.GetVal()) > 0 { + // return val, nil, nil + //} + // handle from remoteDB, if got err just setError, just return to revert in consensus version. valRaw, err := s.fetchExpiredFromRemote(nil, key, true) if err != nil { return nil, nil, err } - return snapshot.NewValueWithEpoch(s.db.epoch, valRaw), nil, nil + return snapshot.NewValueWithEpoch(val.GetEpoch(), valRaw), nil, nil } diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index eedd597992..56103e0c79 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -17,7 +17,6 @@ package state import ( - "fmt" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "sync" @@ -589,7 +588,6 @@ func (sf *subfetcher) loop() { if sf.enableStateExpiry { if exErr, match := err.(*trie2.ExpiredNodeError); match { key := common.BytesToHash(task) - log.Debug("fetchExpiredStorageFromRemote in trie prefetcher", "addr", sf.addr, "prefixKey", exErr.Path, "key", key, "tr", fmt.Sprintf("%p", sf.trie)) _, err = fetchExpiredStorageFromRemote(sf.fullStateDB, sf.state, sf.addr, sf.root, sf.trie, exErr.Path, key) if err != nil { log.Error("subfetcher fetchExpiredStorageFromRemote err", "addr", sf.addr, "path", exErr.Path, "err", err) diff --git a/core/state_prefetcher.go b/core/state_prefetcher.go index f1bb60febd..f8b7fb5fd5 100644 --- a/core/state_prefetcher.go +++ b/core/state_prefetcher.go @@ -59,7 +59,9 @@ func (p *statePrefetcher) Prefetch(block *types.Block, statedb *state.StateDB, c for i := 0; i < prefetchThread; i++ { go func() { newStatedb := statedb.CopyDoPrefetch() - newStatedb.EnableWriteOnSharedStorage() + if !statedb.EnableExpire() { + newStatedb.EnableWriteOnSharedStorage() + } gaspool := new(GasPool).AddGas(block.GasLimit()) blockContext := NewEVMBlockContext(header, p.bc, nil) evm := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, *cfg) @@ -106,7 +108,10 @@ func (p *statePrefetcher) PrefetchMining(txs TransactionsByPriceAndNonce, header go func(startCh <-chan *types.Transaction, stopCh <-chan struct{}) { idx := 0 newStatedb := statedb.CopyDoPrefetch() - newStatedb.EnableWriteOnSharedStorage() + // TODO(0xbundler): access empty in trie cause shared concurrent bug? opt later + if !statedb.EnableExpire() { + newStatedb.EnableWriteOnSharedStorage() + } gaspool := new(GasPool).AddGas(gasLimit) blockContext := NewEVMBlockContext(header, p.bc, nil) evm := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, cfg) diff --git a/trie/proof.go b/trie/proof.go index edfde61b4f..f7fc1cd4a1 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -18,6 +18,7 @@ package trie import ( "bytes" + "encoding/hex" "errors" "fmt" "github.com/ethereum/go-ethereum/rlp" @@ -906,6 +907,23 @@ func (m *MPTProofNub) GetValue() []byte { return nil } +func (m *MPTProofNub) String() string { + buf := bytes.NewBuffer(nil) + buf.WriteString("n1: ") + buf.WriteString(hex.EncodeToString(m.n1PrefixKey)) + buf.WriteString(", n1proof: ") + if m.n1 != nil { + buf.WriteString(m.n1.fstring("")) + } + buf.WriteString(", n2: ") + buf.WriteString(hex.EncodeToString(m.n2PrefixKey)) + buf.WriteString(", n2proof: ") + if m.n2 != nil { + buf.WriteString(m.n2.fstring("")) + } + return buf.String() +} + func getNubValue(origin node, prefixKey []byte) []byte { switch n := origin.(type) { case nil, hashNode: diff --git a/trie/trie.go b/trie/trie.go index 03df2366d9..c7ad8d0892 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -55,7 +55,7 @@ type Trie struct { unhashed int // reader is the handler trie can retrieve nodes from. - reader *trieReader // TODO (asyukii): create a reader for state expiry metadata + reader *trieReader // tracer is the tool to track the trie changes. // It will be reset after each commit operation. @@ -211,6 +211,7 @@ func (t *Trie) GetAndUpdateEpoch(key []byte) (value []byte, err error) { if err == nil && didResolve { t.root = newroot + t.rootEpoch = t.currentEpoch } return value, err } @@ -324,8 +325,8 @@ func (t *Trie) updateChildNodeEpoch(origNode node, key []byte, pos int, epoch ty n = n.copy() n.Val = newnode n.setEpoch(t.currentEpoch) + n.flags = t.newFlag() } - return n, true, err case *fullNode: newnode, updateEpoch, err = t.updateChildNodeEpoch(n.Children[key[pos]], key, pos+1, epoch) @@ -334,6 +335,7 @@ func (t *Trie) updateChildNodeEpoch(origNode node, key []byte, pos int, epoch ty n.Children[key[pos]] = newnode n.setEpoch(t.currentEpoch) n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + n.flags = t.newFlag() } return n, true, err case hashNode: @@ -627,12 +629,14 @@ func (t *Trie) insertWithEpoch(n node, prefix, key []byte, value node, epoch typ // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { + t.tracer.onExpandToBranchNode(prefix) return true, branch, nil } // New branch node is created as a child of the original short node. // Track the newly inserted node in the tracer. The node identifier // passed is the path from the root node. t.tracer.onInsert(append(prefix, key[:matchlen]...)) + t.tracer.onExpandToBranchNode(append(prefix, key[:matchlen]...)) // Replace it with a short node leading up to the branch. return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil @@ -908,6 +912,8 @@ func (t *Trie) deleteWithEpoch(n node, prefix, key []byte, epoch types.StateEpoc n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) // Because n is a full node, it must've contained at least two children // before the delete operation. If the new child value is non-nil, n still @@ -990,7 +996,7 @@ func (t *Trie) deleteWithEpoch(n node, prefix, key []byte, epoch types.StateEpoc } dirty, nn, err := t.deleteWithEpoch(rn, prefix, key, epoch) - if !dirty || err != nil { + if !t.renewNode(epoch, dirty, true) || err != nil { return false, rn, err } return true, nn, nil @@ -1226,8 +1232,15 @@ func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProo if err != nil { return nil, false, fmt.Errorf("update child node epoch while reviving failed, err: %v", err) } + n1 = n1.copy() n1.Val = newnode - return n1, true, nil + n1.flags = t.newFlag() + tryUpdateNodeEpoch(nub.n1, t.currentEpoch) + renew, _, err := t.updateChildNodeEpoch(nub.n1, key, pos, t.currentEpoch) + if err != nil { + return nil, false, fmt.Errorf("update child node epoch while reviving failed, err: %v", err) + } + return renew, true, nil } tryUpdateNodeEpoch(nub.n1, t.currentEpoch) @@ -1253,6 +1266,7 @@ func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProo n = n.copy() n.Val = newNode n.setEpoch(t.currentEpoch) + n.flags = t.newFlag() } return n, didRevive, err case *fullNode: @@ -1264,6 +1278,7 @@ func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProo n.Children[childIndex] = newNode n.setEpoch(t.currentEpoch) n.UpdateChildEpoch(childIndex, t.currentEpoch) + n.flags = t.newFlag() } if e, ok := err.(*ExpiredNodeError); ok { @@ -1384,13 +1399,13 @@ func (t *Trie) renewNode(epoch types.StateEpoch, childDirty bool, updateEpoch bo return childDirty } - // when no epoch update, same as before - if epoch == t.currentEpoch { - return childDirty + // node need update epoch, just renew + if t.currentEpoch > epoch { + return true } - // node need update epoch, just renew - return true + // when no epoch update, same as before + return childDirty } func (t *Trie) epochExpired(n node, epoch types.StateEpoch) bool {