diff --git a/core/state/multi_tx_snapshot.go b/core/state/multi_tx_snapshot.go index b59735b960..6faa0dbb5c 100644 --- a/core/state/multi_tx_snapshot.go +++ b/core/state/multi_tx_snapshot.go @@ -2,9 +2,9 @@ package state import ( "errors" + "fmt" "math/big" - - "github.com/ethereum/go-ethereum/core/types" + "reflect" "github.com/ethereum/go-ethereum/common" ) @@ -53,6 +53,59 @@ func newMultiTxSnapshot() MultiTxSnapshot { } } +// Equal returns true if the two MultiTxSnapshot are equal +func (s *MultiTxSnapshot) Equal(other *MultiTxSnapshot) bool { + if other == nil { + return false + } + if s.invalid != other.invalid { + return false + } + + visited := make(map[common.Address]bool) + for address, obj := range other.prevObjects { + current, exist := s.prevObjects[address] + if !exist { + return false + } + if current == nil && obj != nil { + return false + } + + if current != nil && obj == nil { + return false + } + + visited[address] = true + } + + for address, obj := range s.prevObjects { + otherObject, exist := other.prevObjects[address] + if !exist { + return false + } + + if otherObject == nil && obj != nil { + return false + } + + if otherObject != nil && obj == nil { + return false + } + } + + return reflect.DeepEqual(s.numLogsAdded, other.numLogsAdded) && + reflect.DeepEqual(s.accountStorage, other.accountStorage) && + reflect.DeepEqual(s.accountBalance, other.accountBalance) && + reflect.DeepEqual(s.accountNonce, other.accountNonce) && + reflect.DeepEqual(s.accountCode, other.accountCode) && + reflect.DeepEqual(s.accountCodeHash, other.accountCodeHash) && + reflect.DeepEqual(s.accountSuicided, other.accountSuicided) && + reflect.DeepEqual(s.accountDeleted, other.accountDeleted) && + reflect.DeepEqual(s.accountNotPending, other.accountNotPending) && + reflect.DeepEqual(s.accountNotDirty, other.accountNotDirty) +} + // updateFromJournal updates the snapshot with the changes from the journal. func (s *MultiTxSnapshot) updateFromJournal(journal *journal) { for _, journalEntry := range journal.entries { @@ -208,17 +261,29 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error { // update storage keys if they do not exist for a given account's storage, // and update pending storage for accounts that don't already exist in current snapshot for address, storage := range other.accountStorage { + if s.objectChanged(address) { + continue + } + + if _, exist := s.accountStorage[address]; !exist { + s.accountStorage[address] = make(map[common.Hash]*common.Hash) + s.accountStorage[address] = storage + continue + } + for key, value := range storage { - if value == nil { - s.updatePendingStorage(address, key, types.EmptyCodeHash, false) - } else { - s.updatePendingStorage(address, key, common.BytesToHash(value.Bytes()), true) + if _, exists := s.accountStorage[address][key]; !exists { + s.accountStorage[address][key] = value } } } // add previous balance(s) for any addresses that don't exist in current snapshot for address, balance := range other.accountBalance { + if s.objectChanged(address) { + continue + } + if _, exist := s.accountBalance[address]; !exist { s.accountBalance[address] = balance } @@ -226,6 +291,9 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error { // add previous nonce for accounts that don't exist in current snapshot for address, nonce := range other.accountNonce { + if s.objectChanged(address) { + continue + } if _, exist := s.accountNonce[address]; !exist { s.accountNonce[address] = nonce } @@ -233,6 +301,9 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error { // add previous code for accounts not found in current snapshot for address, code := range other.accountCode { + if s.objectChanged(address) { + continue + } if _, exist := s.accountCode[address]; !exist { if _, found := other.accountCodeHash[address]; !found { // every codeChange has code and code hash set - @@ -247,6 +318,10 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error { // add previous suicide for addresses not in current snapshot for address, suicided := range other.accountSuicided { + if s.objectChanged(address) { + continue + } + if _, exist := s.accountSuicided[address]; !exist { s.accountSuicided[address] = suicided } else { @@ -256,6 +331,9 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error { // add previous account deletions if they don't exist for address, deleted := range other.accountDeleted { + if s.objectChanged(address) { + continue + } if _, exist := s.accountDeleted[address]; !exist { s.accountDeleted[address] = deleted } @@ -303,8 +381,14 @@ func (s *MultiTxSnapshot) revertState(st *StateDB) { for address, storage := range s.accountStorage { for key, value := range storage { if value == nil { + if _, ok := st.stateObjects[address].pendingStorage[key]; !ok { + panic(fmt.Sprintf("storage key %x not found in pending storage", key)) + } delete(st.stateObjects[address].pendingStorage, key) } else { + if _, ok := st.stateObjects[address].pendingStorage[key]; !ok { + panic(fmt.Sprintf("storage key %x not found in pending storage", key)) + } st.stateObjects[address].pendingStorage[key] = *value } } @@ -409,6 +493,16 @@ func (stack *MultiTxSnapshotStack) Revert() (*MultiTxSnapshot, error) { return head, nil } +// RevertAll reverts all snapshots in the stack. +func (stack *MultiTxSnapshotStack) RevertAll() (snapshot *MultiTxSnapshot, err error) { + for len(stack.snapshots) > 0 { + if snapshot, err = stack.Revert(); err != nil { + break + } + } + return +} + // Commit merges the changes from the head snapshot with the previous snapshot and removes it from the stack. func (stack *MultiTxSnapshotStack) Commit() (*MultiTxSnapshot, error) { if len(stack.snapshots) == 0 { diff --git a/core/state/multi_tx_snapshot_test.go b/core/state/multi_tx_snapshot_test.go index 4cf5ed2db2..50de20280d 100644 --- a/core/state/multi_tx_snapshot_test.go +++ b/core/state/multi_tx_snapshot_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/big" "math/rand" + "reflect" "testing" "github.com/ethereum/go-ethereum/common" @@ -21,7 +22,7 @@ func init() { for i := 0; i < 20; i++ { addrs = append(addrs, common.HexToAddress(fmt.Sprintf("0x%02x", i))) } - for i := 0; i < 10; i++ { + for i := 0; i < 100; i++ { keys = append(keys, common.HexToHash(fmt.Sprintf("0x%02x", i))) } } @@ -127,6 +128,25 @@ func randFillAccountState(addr common.Address, s *StateDB) { } } +func genRandomAccountState(seed int64) map[common.Address]map[common.Hash]common.Hash { + rng = rand.New(rand.NewSource(seed)) + + state := make(map[common.Address]map[common.Hash]common.Hash) + + for _, addr := range addrs { + state[addr] = make(map[common.Hash]common.Hash) + for i, key := range keys { + if i%5 == 0 { + state[addr][key] = common.BigToHash(common.Big0) + } else { + state[addr][key] = randomHash() + } + } + } + + return state +} + func randFillAccount(addr common.Address, s *StateDB) { s.SetNonce(addr, rng.Uint64()) s.SetBalance(addr, big.NewInt(rng.Int63())) @@ -259,7 +279,6 @@ func testMultiTxSnapshot(t *testing.T, actions func(s *StateDB)) { dirtyAddressesBefore[k] = v } - //err := s.state.MultiTxSnapshot() err := s.state.NewMultiTxSnapshot() if err != nil { t.Fatal("MultiTxSnapshot failed", err) @@ -399,3 +418,482 @@ func TestMultiTxSnapshotStateChanges(t *testing.T) { s.Finalise(true) }) } + +func testStackBasic(t *testing.T) { + for i := 0; i < 10; i++ { + testMultiTxSnapshot(t, func(s *StateDB) { + // when test starts, actions are performed after new snapshot is created + // we initialize additional snapshot on top of that + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("NewMultiTxSnapshot failed: %v", err) + t.FailNow() + } + + seed := rand.Int63() + stateMap := genRandomAccountState(seed) + for account, accountKeys := range stateMap { + for key, value := range accountKeys { + s.SetState(account, key, value) + } + } + s.Finalise(true) + + stack := s.multiTxSnapshotStack + + // the test starts with 1 snapshot, and we just created new one above + startSize := stack.Size() + if startSize != 2 { + t.Errorf("expected stack size to be 2, got %d", startSize) + t.FailNow() + } + + for _, addr := range addrs { + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("NewMultiTxSnapshot failed: %v", err) + t.FailNow() + } + randFillAccountState(addr, s) + s.Finalise(true) + } + afterAddrSize := stack.Size() + if afterAddrSize != startSize+len(addrs) { + t.Errorf("expected stack size to be %d, got %d", startSize+len(addrs), afterAddrSize) + t.FailNow() + } + + // the testMultiTxSnapshot subroutine calls MultiTxSnapshotRevert after applying actions + // we test here to make sure that the flattened commitments on the head of stack + // yield the same final root hash + // this ensures that we are properly flattening the stack on commit + for stack.Size() > 1 { + if _, err := stack.Commit(); err != nil { + t.Errorf("Commit failed: %v", err) + t.FailNow() + } + } + }) + } +} + +func testStackSelfDestruct(t *testing.T) { + testMultiTxSnapshot(t, func(s *StateDB) { + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("NewMultiTxSnapshot failed: %v", err) + t.FailNow() + } + for _, addr := range addrs { + s.SetNonce(addr, 78) + s.SetBalance(addr, big.NewInt(79)) + s.SetCode(addr, []byte{0x80}) + s.Finalise(true) + } + + for _, addr := range addrs { + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("NewMultiTxSnapshot failed: %v", err) + t.FailNow() + } + s.Suicide(addr) + } + stack := s.multiTxSnapshotStack + + // merge all the suicide operations + for stack.Size() > 1 { + if _, err := stack.Commit(); err != nil { + t.Errorf("Commit failed: %v", err) + t.FailNow() + } + } + s.Finalise(true) + + for _, addr := range addrs { + s.SetNonce(addr, 79) + s.SetBalance(addr, big.NewInt(80)) + s.SetCode(addr, []byte{0x81}) + } + s.Finalise(true) + }) +} + +func testStackAgainstSingleSnap(t *testing.T) { + // we generate a random seed ten times to fuzz test multiple stack snapshots against single layer snapshot + for i := 0; i < 10; i++ { + testMultiTxSnapshot(t, func(s *StateDB) { + original := s.Copy() + baselineStateDB := s.Copy() + + baselineRootHash, targetRootHash := baselineStateDB.originalRoot, s.originalRoot + + if !bytes.Equal(baselineRootHash.Bytes(), targetRootHash.Bytes()) { + t.Errorf("expected root hash to be %x, got %x", baselineRootHash, targetRootHash) + t.FailNow() + } + + // basic - add multiple snapshots and commit them, and compare them to single snapshot that has all + // state changes + + if err := baselineStateDB.NewMultiTxSnapshot(); err != nil { + t.Errorf("Error initializing snapshot: %v", err) + t.FailNow() + } + + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("Error initializing snapshot: %v", err) + t.FailNow() + } + + // we should be able to revert back to the same intermediate root hash + // for single snapshot and snapshot stack + seed := rand.Int63() + state := genRandomAccountState(seed) + for account, accountKeys := range state { + for key, value := range accountKeys { + baselineStateDB.SetState(account, key, value) + + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("Error initializing snapshot: %v", err) + t.FailNow() + } + s.SetState(account, key, value) + s.Finalise(true) + } + } + baselineStateDB.Finalise(true) + + // commit all but last snapshot + stack := s.multiTxSnapshotStack + for stack.Size() > 1 { + if _, err := stack.Commit(); err != nil { + t.Errorf("Commit failed: %v", err) + t.FailNow() + } + } + + var ( + baselineSnapshot = baselineStateDB.multiTxSnapshotStack.Peek() + targetSnapshot = s.multiTxSnapshotStack.Peek() + ) + if !targetSnapshot.Equal(baselineSnapshot) { + CompareAndPrintSnapshotMismatches(t, targetSnapshot, baselineSnapshot) + t.Errorf("expected snapshots to be equal") + t.FailNow() + } + + // revert back to previously calculated root hash + if err := baselineStateDB.MultiTxSnapshotRevert(); err != nil { + t.Errorf("MultiTxSnapshotRevert failed: %v", err) + t.FailNow() + } + + if err := s.MultiTxSnapshotRevert(); err != nil { + t.Errorf("MultiTxSnapshotRevert failed: %v", err) + t.FailNow() + } + + var err error + if targetRootHash, err = s.Commit(true); err != nil { + t.Errorf("Commit failed: %v", err) + t.FailNow() + } + + if baselineRootHash, err = baselineStateDB.Commit(true); err != nil { + t.Errorf("Commit failed: %v", err) + t.FailNow() + } + if !bytes.Equal(baselineRootHash.Bytes(), targetRootHash.Bytes()) { + t.Errorf("expected root hash to be %x, got %x", baselineRootHash, targetRootHash) + t.FailNow() + } + + *s = *original + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("Error initializing snapshot: %v", err) + t.FailNow() + } + }) + } +} + +func TestMultiTxSnapshotStack(t *testing.T) { + // test state changes are valid after merging snapshots + testStackBasic(t) + + // test self-destruct + testStackSelfDestruct(t) + + // test against baseline single snapshot + testStackAgainstSingleSnap(t) +} + +func CompareAndPrintSnapshotMismatches(t *testing.T, target, other *MultiTxSnapshot) { + var out bytes.Buffer + if target.Equal(other) { + t.Logf("Snapshots are equal") + return + } + + if target.invalid != other.invalid { + out.WriteString(fmt.Sprintf("invalid: %v != %v\n", target.invalid, other.invalid)) + return + } + + // check log mismatch + visited := make(map[common.Hash]bool) + for address, logCount := range other.numLogsAdded { + targetLogCount, exists := target.numLogsAdded[address] + if !exists { + out.WriteString(fmt.Sprintf("target<>other numLogsAdded[missing]: %v\n", address)) + continue + } + if targetLogCount != logCount { + out.WriteString(fmt.Sprintf("target<>other numLogsAdded[%x]: %v != %v\n", address, targetLogCount, logCount)) + } + } + + for address, logCount := range target.numLogsAdded { + if visited[address] { + continue + } + + otherLogCount, exists := other.numLogsAdded[address] + if !exists { + out.WriteString(fmt.Sprintf("other<>target numLogsAdded[missing]: %v\n", address)) + continue + } + + if otherLogCount != logCount { + out.WriteString(fmt.Sprintf("other<>target numLogsAdded[%x]: %v != %v\n", address, otherLogCount, logCount)) + } + } + + // check previous objects mismatch + for address := range other.prevObjects { + // TODO: we only check existence, need to add RLP comparison + _, exists := target.prevObjects[address] + if !exists { + out.WriteString(fmt.Sprintf("target<>other prevObjects[missing]: %v\n", address.String())) + continue + } + } + + for address, obj := range target.prevObjects { + otherObj, exists := other.prevObjects[address] + if !exists { + out.WriteString(fmt.Sprintf("other<>target prevObjects[missing]: %v\n", address)) + continue + } + if !reflect.DeepEqual(otherObj, obj) { + out.WriteString(fmt.Sprintf("other<>target prevObjects[%x]: %v != %v\n", address, otherObj, obj)) + } + } + + // check account storage mismatch + for account, storage := range other.accountStorage { + targetStorage, exists := target.accountStorage[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountStorage[missing]: %v\n", account)) + continue + } + + for key, value := range storage { + targetValue, exists := targetStorage[key] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountStorage[%s][missing]: %v\n", account.String(), key.String())) + continue + } + if !reflect.DeepEqual(targetValue, value) { + out.WriteString(fmt.Sprintf("target<>other accountStorage[%s][%s]: %v != %v\n", account.String(), key.String(), targetValue.String(), value.String())) + } + } + } + + for account, storage := range target.accountStorage { + otherStorage, exists := other.accountStorage[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountStorage[missing]: %v\n", account)) + continue + } + + for key, value := range storage { + otherValue, exists := otherStorage[key] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountStorage[%s][missing]: %v\n", account.String(), key.String())) + continue + } + if !reflect.DeepEqual(otherValue, value) { + out.WriteString(fmt.Sprintf("other<>target accountStorage[%s][%s]: %v != %v\n", account.String(), key.String(), otherValue.String(), value.String())) + } + } + } + + // check account balance mismatch + for account, balance := range other.accountBalance { + targetBalance, exists := target.accountBalance[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountBalance[missing]: %v\n", account)) + continue + } + if !reflect.DeepEqual(targetBalance, balance) { + out.WriteString(fmt.Sprintf("target<>other accountBalance[%x]: %v != %v\n", account, targetBalance, balance)) + } + } + + for account, balance := range target.accountBalance { + otherBalance, exists := other.accountBalance[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountBalance[missing]: %v\n", account)) + continue + } + if !bytes.Equal(otherBalance.Bytes(), balance.Bytes()) { + out.WriteString(fmt.Sprintf("other<>target accountBalance[%x]: %v != %v\n", account, otherBalance, balance)) + } + } + + // check account nonce mismatch + for account, nonce := range other.accountNonce { + targetNonce, exists := target.accountNonce[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountNonce[missing]: %v\n", account)) + continue + } + if targetNonce != nonce { + out.WriteString(fmt.Sprintf("target<>other accountNonce[%x]: %v != %v\n", account, targetNonce, nonce)) + } + } + + for account, nonce := range target.accountNonce { + otherNonce, exists := other.accountNonce[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountNonce[missing]: %v\n", account)) + continue + } + if otherNonce != nonce { + out.WriteString(fmt.Sprintf("other<>target accountNonce[%x]: %v != %v\n", account, otherNonce, nonce)) + } + } + + // check account code mismatch + for account, code := range other.accountCode { + targetCode, exists := target.accountCode[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountCode[missing]: %v\n", account)) + continue + } + if !bytes.Equal(targetCode, code) { + out.WriteString(fmt.Sprintf("target<>other accountCode[%x]: %v != %v\n", account, targetCode, code)) + } + } + + for account, code := range target.accountCode { + otherCode, exists := other.accountCode[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountCode[missing]: %v\n", account)) + continue + } + if !bytes.Equal(otherCode, code) { + out.WriteString(fmt.Sprintf("other<>target accountCode[%x]: %v != %v\n", account, otherCode, code)) + } + } + + // check account codeHash mismatch + for account, codeHash := range other.accountCodeHash { + targetCodeHash, exists := target.accountCodeHash[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountCodeHash[missing]: %v\n", account)) + continue + } + if !bytes.Equal(targetCodeHash, codeHash) { + out.WriteString(fmt.Sprintf("target<>other accountCodeHash[%x]: %v != %v\n", account, targetCodeHash, codeHash)) + } + } + + for account, codeHash := range target.accountCodeHash { + otherCodeHash, exists := other.accountCodeHash[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountCodeHash[missing]: %v\n", account)) + continue + } + if !bytes.Equal(otherCodeHash, codeHash) { + out.WriteString(fmt.Sprintf("other<>target accountCodeHash[%x]: %v != %v\n", account, otherCodeHash, codeHash)) + } + } + + // check account suicide mismatch + for account, suicide := range other.accountSuicided { + targetSuicide, exists := target.accountSuicided[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountSuicided[missing]: %v\n", account)) + continue + } + + if targetSuicide != suicide { + out.WriteString(fmt.Sprintf("target<>other accountSuicided[%x]: %t != %t\n", account, targetSuicide, suicide)) + } + } + + for account, suicide := range target.accountSuicided { + otherSuicide, exists := other.accountSuicided[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountSuicided[missing]: %v\n", account)) + continue + } + + if otherSuicide != suicide { + out.WriteString(fmt.Sprintf("other<>target accountSuicided[%x]: %t != %t\n", account, otherSuicide, suicide)) + } + } + + // check account deletion mismatch + for account, del := range other.accountDeleted { + targetDelete, exists := target.accountDeleted[account] + if !exists { + out.WriteString(fmt.Sprintf("target<>other accountDeleted[missing]: %v\n", account)) + continue + } + + if targetDelete != del { + out.WriteString(fmt.Sprintf("target<>other accountDeleted[%x]: %v != %v\n", account, targetDelete, del)) + } + } + + for account, del := range target.accountDeleted { + otherDelete, exists := other.accountDeleted[account] + if !exists { + out.WriteString(fmt.Sprintf("other<>target accountDeleted[missing]: %v\n", account)) + continue + } + + if otherDelete != del { + out.WriteString(fmt.Sprintf("other<>target accountDeleted[%x]: %v != %v\n", account, otherDelete, del)) + } + } + + // check account not pending mismatch + for account := range other.accountNotPending { + if _, exists := target.accountNotPending[account]; !exists { + out.WriteString(fmt.Sprintf("target<>other accountNotPending[missing]: %v\n", account)) + } + } + + for account := range target.accountNotPending { + if _, exists := other.accountNotPending[account]; !exists { + out.WriteString(fmt.Sprintf("other<>target accountNotPending[missing]: %v\n", account)) + } + } + + // check account not dirty mismatch + for account := range other.accountNotDirty { + if _, exists := target.accountNotDirty[account]; !exists { + out.WriteString(fmt.Sprintf("target<>other accountNotDirty[missing]: %v\n", account)) + } + } + + for account := range target.accountNotDirty { + if _, exists := other.accountNotDirty[account]; !exists { + out.WriteString(fmt.Sprintf("other<>target accountNotDirty[missing]: %v\n", account)) + } + } + + fmt.Println(out.String()) + out.Reset() +} diff --git a/core/state/state_object.go b/core/state/state_object.go index f47484c8d1..cd720019db 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -249,10 +249,7 @@ func (s *stateObject) finalise(prefetch bool) { for key, value := range s.dirtyStorage { prev, ok := s.pendingStorage[key] s.db.multiTxSnapshotStack.UpdatePendingStorage(s.address, key, prev, ok) - //if multiSnap := s.db.multiTxSnapshot; multiSnap != nil { - // prev, ok := s.pendingStorage[key] - // multiSnap.updatePendingStorage(s.address, key, prev, ok) - //} + s.pendingStorage[key] = value if value != s.originStorage[key] { slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure diff --git a/core/state/statedb.go b/core/state/statedb.go index e3befb8f42..a062bd3793 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -109,8 +109,6 @@ type StateDB struct { validRevisions []revision nextRevisionId int - // Multi-Transaction Snapshot - //multiTxSnapshot *MultiTxSnapshot // Multi-Transaction Snapshot Stack multiTxSnapshotStack *MultiTxSnapshotStack @@ -719,6 +717,7 @@ func (s *StateDB) Copy() *StateDB { journal: newJournal(), hasher: crypto.NewKeccakState(), } + // Initialize new multi-transaction snapshot stack for the copied state state.multiTxSnapshotStack = NewMultiTxSnapshotStack(state) // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { @@ -851,9 +850,7 @@ func (s *StateDB) GetRefund() uint64 { // into the tries just yet. Only IntermediateRoot or Commit will do that. func (s *StateDB) Finalise(deleteEmptyObjects bool) { s.multiTxSnapshotStack.UpdateFromJournal(s.journal) - //if multiSnap := s.multiTxSnapshot; multiSnap != nil { - // multiSnap.updateFromJournal(s.journal) - //} + addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties)) for addr := range s.journal.dirties { obj, exist := s.stateObjects[addr] @@ -868,9 +865,7 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { } if obj.suicided || (deleteEmptyObjects && obj.empty()) { s.multiTxSnapshotStack.UpdateObjectDeleted(obj.address, obj.deleted) - //if multiSnap := s.multiTxSnapshot; multiSnap != nil { - // multiSnap.updateObjectDeleted(obj.address, obj.deleted) - //} + //s.multiTxSnapshotStack.UpdateObjectDeleted(obj.address, obj.deleted) obj.deleted = true @@ -895,11 +890,6 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { _, wasDirty := s.stateObjectsDirty[addr] s.multiTxSnapshotStack.UpdatePendingStatus(addr, wasPending, wasDirty) } - //if multiSnap := s.multiTxSnapshot; multiSnap != nil { - // _, wasPending := s.stateObjectsPending[addr] - // _, wasDirty := s.stateObjectsDirty[addr] - //multiSnap.updatePendingStatus(addr, wasPending, wasDirty) - //} s.stateObjectsPending[addr] = struct{}{} s.stateObjectsDirty[addr] = struct{}{} @@ -922,10 +912,9 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { // Finalise all the dirty storage states and write them into the tries s.Finalise(deleteEmptyObjects) + // Intermediate root writes updates to the trie, which will cause + // in memory multi-transaction snapshot to be incompatible with the committed state, so we invalidate. s.multiTxSnapshotStack.Invalidate() - //if s.multiTxSnapshot != nil { - // s.multiTxSnapshot.invalid = true - //} // If there was a trie prefetcher operating, it gets aborted and irrevocably // modified after we start retrieving tries. Remove it from the statedb after @@ -1223,41 +1212,12 @@ func (s *StateDB) NewMultiTxSnapshot() error { return nil } -// MultiTxSnapshot creates new checkpoint for multi txs reverts -//func (s *StateDB) MultiTxSnapshot() error { -// if s.multiTxSnapshot != nil { -// return errors.New("multi tx snapshot already exists") -// } -// s.multiTxSnapshot = NewMultiTxSnapshot() -// return nil -//} - func (s *StateDB) MultiTxSnapshotRevert() error { _, err := s.multiTxSnapshotStack.Revert() return err } -//func (s *StateDB) MultiTxSnapshotRevert() error { -// if s.multiTxSnapshot == nil { -// return errors.New("multi tx snapshot does not exist") -// } -// if s.multiTxSnapshot.invalid { -// return errors.New("multi tx snapshot is invalid") -// } -// s.multiTxSnapshot.revertState(s) -// s.multiTxSnapshot = nil -// return nil -//} - -func (s *StateDB) MultiTxSnapshotDiscard() error { +func (s *StateDB) MultiTxSnapshotCommit() error { _, err := s.multiTxSnapshotStack.Commit() return err } - -//func (s *StateDB) MultiTxSnapshotDiscard() error { -// if s.multiTxSnapshot == nil { -// return errors.New("multi tx snapshot does not exist") -// } -// s.multiTxSnapshot = nil -// return nil -//} diff --git a/miner/env_changes.go b/miner/env_changes.go index 2ca2e3f376..6fc7cdc747 100644 --- a/miner/env_changes.go +++ b/miner/env_changes.go @@ -27,9 +27,6 @@ func newEnvChanges(env *environment) (*envChanges, error) { if err := env.state.NewMultiTxSnapshot(); err != nil { return nil, err } - //if err := env.state.MultiTxSnapshot(); err != nil { - // return nil, err - //} return &envChanges{ env: env, @@ -411,7 +408,7 @@ func (c *envChanges) rollback( } func (c *envChanges) apply() error { - if err := c.env.state.MultiTxSnapshotDiscard(); err != nil { + if err := c.env.state.MultiTxSnapshotCommit(); err != nil { return err }