Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Clean up code, add comprehensive stack tests with fuzzing, fix edge c…
Browse files Browse the repository at this point in the history
…ases where merge operation for stack commit was not properly updated
  • Loading branch information
Wazzymandias committed Aug 4, 2023
1 parent 072890f commit 6111eaf
Show file tree
Hide file tree
Showing 5 changed files with 608 additions and 62 deletions.
106 changes: 100 additions & 6 deletions core/state/multi_tx_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package state

import (
"errors"
"fmt"
"math/big"

"github.com/ethereum/go-ethereum/core/types"
"reflect"

"github.com/ethereum/go-ethereum/common"
)
Expand Down Expand Up @@ -53,6 +53,59 @@ func newMultiTxSnapshot() MultiTxSnapshot {
}
}

// Equal returns true if the two MultiTxSnapshot are equal
func (s *MultiTxSnapshot) Equal(other *MultiTxSnapshot) bool {
if other == nil {
return false
}
if s.invalid != other.invalid {
return false
}

visited := make(map[common.Address]bool)
for address, obj := range other.prevObjects {
current, exist := s.prevObjects[address]
if !exist {
return false
}
if current == nil && obj != nil {
return false
}

if current != nil && obj == nil {
return false
}

visited[address] = true
}

for address, obj := range s.prevObjects {
otherObject, exist := other.prevObjects[address]
if !exist {
return false
}

if otherObject == nil && obj != nil {
return false
}

if otherObject != nil && obj == nil {
return false
}
}

return reflect.DeepEqual(s.numLogsAdded, other.numLogsAdded) &&
reflect.DeepEqual(s.accountStorage, other.accountStorage) &&
reflect.DeepEqual(s.accountBalance, other.accountBalance) &&
reflect.DeepEqual(s.accountNonce, other.accountNonce) &&
reflect.DeepEqual(s.accountCode, other.accountCode) &&
reflect.DeepEqual(s.accountCodeHash, other.accountCodeHash) &&
reflect.DeepEqual(s.accountSuicided, other.accountSuicided) &&
reflect.DeepEqual(s.accountDeleted, other.accountDeleted) &&
reflect.DeepEqual(s.accountNotPending, other.accountNotPending) &&
reflect.DeepEqual(s.accountNotDirty, other.accountNotDirty)
}

// updateFromJournal updates the snapshot with the changes from the journal.
func (s *MultiTxSnapshot) updateFromJournal(journal *journal) {
for _, journalEntry := range journal.entries {
Expand Down Expand Up @@ -208,31 +261,49 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error {
// update storage keys if they do not exist for a given account's storage,
// and update pending storage for accounts that don't already exist in current snapshot
for address, storage := range other.accountStorage {
if s.objectChanged(address) {
continue
}

if _, exist := s.accountStorage[address]; !exist {
s.accountStorage[address] = make(map[common.Hash]*common.Hash)
s.accountStorage[address] = storage
continue
}

for key, value := range storage {
if value == nil {
s.updatePendingStorage(address, key, types.EmptyCodeHash, false)
} else {
s.updatePendingStorage(address, key, common.BytesToHash(value.Bytes()), true)
if _, exists := s.accountStorage[address][key]; !exists {
s.accountStorage[address][key] = value
}
}
}

// add previous balance(s) for any addresses that don't exist in current snapshot
for address, balance := range other.accountBalance {
if s.objectChanged(address) {
continue
}

if _, exist := s.accountBalance[address]; !exist {
s.accountBalance[address] = balance
}
}

// add previous nonce for accounts that don't exist in current snapshot
for address, nonce := range other.accountNonce {
if s.objectChanged(address) {
continue
}
if _, exist := s.accountNonce[address]; !exist {
s.accountNonce[address] = nonce
}
}

// add previous code for accounts not found in current snapshot
for address, code := range other.accountCode {
if s.objectChanged(address) {
continue
}
if _, exist := s.accountCode[address]; !exist {
if _, found := other.accountCodeHash[address]; !found {
// every codeChange has code and code hash set -
Expand All @@ -247,6 +318,10 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error {

// add previous suicide for addresses not in current snapshot
for address, suicided := range other.accountSuicided {
if s.objectChanged(address) {
continue
}

if _, exist := s.accountSuicided[address]; !exist {
s.accountSuicided[address] = suicided
} else {
Expand All @@ -256,6 +331,9 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error {

// add previous account deletions if they don't exist
for address, deleted := range other.accountDeleted {
if s.objectChanged(address) {
continue
}
if _, exist := s.accountDeleted[address]; !exist {
s.accountDeleted[address] = deleted
}
Expand Down Expand Up @@ -303,8 +381,14 @@ func (s *MultiTxSnapshot) revertState(st *StateDB) {
for address, storage := range s.accountStorage {
for key, value := range storage {
if value == nil {
if _, ok := st.stateObjects[address].pendingStorage[key]; !ok {
panic(fmt.Sprintf("storage key %x not found in pending storage", key))
}
delete(st.stateObjects[address].pendingStorage, key)
} else {
if _, ok := st.stateObjects[address].pendingStorage[key]; !ok {
panic(fmt.Sprintf("storage key %x not found in pending storage", key))
}
st.stateObjects[address].pendingStorage[key] = *value
}
}
Expand Down Expand Up @@ -409,6 +493,16 @@ func (stack *MultiTxSnapshotStack) Revert() (*MultiTxSnapshot, error) {
return head, nil
}

// RevertAll reverts all snapshots in the stack.
func (stack *MultiTxSnapshotStack) RevertAll() (snapshot *MultiTxSnapshot, err error) {
for len(stack.snapshots) > 0 {
if snapshot, err = stack.Revert(); err != nil {
break
}
}
return
}

// Commit merges the changes from the head snapshot with the previous snapshot and removes it from the stack.
func (stack *MultiTxSnapshotStack) Commit() (*MultiTxSnapshot, error) {
if len(stack.snapshots) == 0 {
Expand Down
Loading

0 comments on commit 6111eaf

Please sign in to comment.