diff --git a/core/state/multi_tx_snapshot.go b/core/state/multi_tx_snapshot.go index a3728b69b9..0900ecc33c 100644 --- a/core/state/multi_tx_snapshot.go +++ b/core/state/multi_tx_snapshot.go @@ -28,16 +28,18 @@ type MultiTxSnapshot struct { accountNotPending map[common.Address]struct{} accountNotDirty map[common.Address]struct{} + + previousRefund uint64 // TODO: snapdestructs, snapaccount storage } // NewMultiTxSnapshot creates a new MultiTxSnapshot -func NewMultiTxSnapshot() *MultiTxSnapshot { - multiTxSnapshot := newMultiTxSnapshot() +func NewMultiTxSnapshot(previousRefund uint64) *MultiTxSnapshot { + multiTxSnapshot := newMultiTxSnapshot(previousRefund) return &multiTxSnapshot } -func newMultiTxSnapshot() MultiTxSnapshot { +func newMultiTxSnapshot(previousRefund uint64) MultiTxSnapshot { return MultiTxSnapshot{ numLogsAdded: make(map[common.Hash]int), prevObjects: make(map[common.Address]*stateObject), @@ -50,6 +52,7 @@ func newMultiTxSnapshot() MultiTxSnapshot { accountDeleted: make(map[common.Address]bool), accountNotPending: make(map[common.Address]struct{}), accountNotDirty: make(map[common.Address]struct{}), + previousRefund: previousRefund, } } @@ -361,6 +364,11 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error { // revertState reverts the state to the snapshot. func (s *MultiTxSnapshot) revertState(st *StateDB) { + // restore previous refund + if st.refund != s.previousRefund { + st.refund = s.previousRefund + } + // remove all the logs added for txhash, numLogs := range s.numLogsAdded { lens := len(st.logs[txhash]) @@ -455,7 +463,7 @@ func (stack *MultiTxSnapshotStack) NewSnapshot() (*MultiTxSnapshot, error) { return nil, errors.New("failed to create new multi-transaction snapshot - invalid snapshot found at head") } - snap := newMultiTxSnapshot() + snap := newMultiTxSnapshot(stack.state.refund) stack.snapshots = append(stack.snapshots, snap) return &snap, nil } diff --git a/core/state/multi_tx_snapshot_test.go b/core/state/multi_tx_snapshot_test.go index 50de20280d..607b458924 100644 --- a/core/state/multi_tx_snapshot_test.go +++ b/core/state/multi_tx_snapshot_test.go @@ -258,13 +258,21 @@ func prepareInitialState(s *StateDB) { afterHook(addrs[i], s) } } + s.Finalise(true) + + // NOTE(wazzymandias): + // We want to test refund is properly reverted for snapshots - state.StateDB clears refund on Finalise + // so refund is set here to emulate state with non-zero value. + s.AddRefund(rng.Uint64()) } func testMultiTxSnapshot(t *testing.T, actions func(s *StateDB)) { s := newStateTest() prepareInitialState(s.state) + previousRefund := s.state.GetRefund() + var obsStates []*observableAccountState for _, account := range addrs { obsStates = append(obsStates, getObservableAccountState(s.state, account, keys)) @@ -300,6 +308,10 @@ func testMultiTxSnapshot(t *testing.T, actions func(s *StateDB)) { } } + if s.state.GetRefund() != previousRefund { + t.Error("refund mismatch", "got", s.state.GetRefund(), "expected", previousRefund) + } + if len(s.state.stateObjectsPending) != len(pendingAddressesBefore) { t.Error("pending state objects count mismatch", "got", len(s.state.stateObjectsPending), "expected", len(pendingAddressesBefore)) } @@ -339,6 +351,18 @@ func TestMultiTxSnapshotAccountChangesSimple(t *testing.T) { }) } +func TestMultiTxSnapshotRefund(t *testing.T) { + testMultiTxSnapshot(t, func(s *StateDB) { + for _, addr := range addrs { + s.SetNonce(addr, 78) + s.SetBalance(addr, big.NewInt(79)) + s.SetCode(addr, []byte{0x80}) + } + s.Finalise(true) + s.AddRefund(1000) + }) +} + func TestMultiTxSnapshotAccountChangesMultiTx(t *testing.T) { testMultiTxSnapshot(t, func(s *StateDB) { for _, addr := range addrs { @@ -419,7 +443,7 @@ func TestMultiTxSnapshotStateChanges(t *testing.T) { }) } -func testStackBasic(t *testing.T) { +func TestStackBasic(t *testing.T) { for i := 0; i < 10; i++ { testMultiTxSnapshot(t, func(s *StateDB) { // when test starts, actions are performed after new snapshot is created @@ -475,7 +499,41 @@ func testStackBasic(t *testing.T) { } } -func testStackSelfDestruct(t *testing.T) { +func TestStackRefund(t *testing.T) { + testMultiTxSnapshot(t, func(s *StateDB) { + const counter = 10 + + s.AddRefund(500) + previousRefunds := make([]uint64, 0, counter) + previousRefunds = append(previousRefunds, s.GetRefund()) + + for i := 0; i < counter; i++ { + previousRefunds = append(previousRefunds, s.GetRefund()) + if err := s.NewMultiTxSnapshot(); err != nil { + t.Errorf("NewMultiTxSnapshot failed: %v", err) + t.FailNow() + } + s.Finalise(true) + s.AddRefund(1000) + } + + for i := 0; i < counter; i++ { + if err := s.MultiTxSnapshotRevert(); err != nil { + t.Errorf("MultiTxSnapshotRevert failed: %v", err) + t.FailNow() + } + actualRefund := s.GetRefund() + expectedRefund := previousRefunds[len(previousRefunds)-1] + if actualRefund != expectedRefund { + t.Errorf("expected refund to be %d, got %d", expectedRefund, actualRefund) + t.FailNow() + } + previousRefunds = previousRefunds[:len(previousRefunds)-1] + } + }) +} + +func TestStackSelfDestruct(t *testing.T) { testMultiTxSnapshot(t, func(s *StateDB) { if err := s.NewMultiTxSnapshot(); err != nil { t.Errorf("NewMultiTxSnapshot failed: %v", err) @@ -515,7 +573,7 @@ func testStackSelfDestruct(t *testing.T) { }) } -func testStackAgainstSingleSnap(t *testing.T) { +func TestStackAgainstSingleSnap(t *testing.T) { // we generate a random seed ten times to fuzz test multiple stack snapshots against single layer snapshot for i := 0; i < 10; i++ { testMultiTxSnapshot(t, func(s *StateDB) { @@ -614,17 +672,6 @@ func testStackAgainstSingleSnap(t *testing.T) { } } -func TestMultiTxSnapshotStack(t *testing.T) { - // test state changes are valid after merging snapshots - testStackBasic(t) - - // test self-destruct - testStackSelfDestruct(t) - - // test against baseline single snapshot - testStackAgainstSingleSnap(t) -} - func CompareAndPrintSnapshotMismatches(t *testing.T, target, other *MultiTxSnapshot) { var out bytes.Buffer if target.Equal(other) {