Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Add refund support to efficient revert so state returns to correct re…
Browse files Browse the repository at this point in the history
…fund value on discard
  • Loading branch information
Wazzymandias committed Aug 7, 2023
1 parent a16791d commit d6d5cae
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 18 deletions.
16 changes: 12 additions & 4 deletions core/state/multi_tx_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@ type MultiTxSnapshot struct {

accountNotPending map[common.Address]struct{}
accountNotDirty map[common.Address]struct{}

previousRefund uint64
// TODO: snapdestructs, snapaccount storage
}

// NewMultiTxSnapshot creates a new MultiTxSnapshot
func NewMultiTxSnapshot() *MultiTxSnapshot {
multiTxSnapshot := newMultiTxSnapshot()
func NewMultiTxSnapshot(previousRefund uint64) *MultiTxSnapshot {
multiTxSnapshot := newMultiTxSnapshot(previousRefund)
return &multiTxSnapshot
}

func newMultiTxSnapshot() MultiTxSnapshot {
func newMultiTxSnapshot(previousRefund uint64) MultiTxSnapshot {
return MultiTxSnapshot{
numLogsAdded: make(map[common.Hash]int),
prevObjects: make(map[common.Address]*stateObject),
Expand All @@ -50,6 +52,7 @@ func newMultiTxSnapshot() MultiTxSnapshot {
accountDeleted: make(map[common.Address]bool),
accountNotPending: make(map[common.Address]struct{}),
accountNotDirty: make(map[common.Address]struct{}),
previousRefund: previousRefund,
}
}

Expand Down Expand Up @@ -361,6 +364,11 @@ func (s *MultiTxSnapshot) Merge(other *MultiTxSnapshot) error {

// revertState reverts the state to the snapshot.
func (s *MultiTxSnapshot) revertState(st *StateDB) {
// restore previous refund
if st.refund != s.previousRefund {
st.refund = s.previousRefund
}

// remove all the logs added
for txhash, numLogs := range s.numLogsAdded {
lens := len(st.logs[txhash])
Expand Down Expand Up @@ -455,7 +463,7 @@ func (stack *MultiTxSnapshotStack) NewSnapshot() (*MultiTxSnapshot, error) {
return nil, errors.New("failed to create new multi-transaction snapshot - invalid snapshot found at head")
}

snap := newMultiTxSnapshot()
snap := newMultiTxSnapshot(stack.state.refund)
stack.snapshots = append(stack.snapshots, snap)
return &snap, nil
}
Expand Down
75 changes: 61 additions & 14 deletions core/state/multi_tx_snapshot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,21 @@ func prepareInitialState(s *StateDB) {
afterHook(addrs[i], s)
}
}

s.Finalise(true)

// NOTE(wazzymandias):
// We want to test refund is properly reverted for snapshots - state.StateDB clears refund on Finalise
// so refund is set here to emulate state with non-zero value.
s.AddRefund(rng.Uint64())
}

func testMultiTxSnapshot(t *testing.T, actions func(s *StateDB)) {
s := newStateTest()
prepareInitialState(s.state)

previousRefund := s.state.GetRefund()

var obsStates []*observableAccountState
for _, account := range addrs {
obsStates = append(obsStates, getObservableAccountState(s.state, account, keys))
Expand Down Expand Up @@ -300,6 +308,10 @@ func testMultiTxSnapshot(t *testing.T, actions func(s *StateDB)) {
}
}

if s.state.GetRefund() != previousRefund {
t.Error("refund mismatch", "got", s.state.GetRefund(), "expected", previousRefund)
}

if len(s.state.stateObjectsPending) != len(pendingAddressesBefore) {
t.Error("pending state objects count mismatch", "got", len(s.state.stateObjectsPending), "expected", len(pendingAddressesBefore))
}
Expand Down Expand Up @@ -339,6 +351,18 @@ func TestMultiTxSnapshotAccountChangesSimple(t *testing.T) {
})
}

func TestMultiTxSnapshotRefund(t *testing.T) {
testMultiTxSnapshot(t, func(s *StateDB) {
for _, addr := range addrs {
s.SetNonce(addr, 78)
s.SetBalance(addr, big.NewInt(79))
s.SetCode(addr, []byte{0x80})
}
s.Finalise(true)
s.AddRefund(1000)
})
}

func TestMultiTxSnapshotAccountChangesMultiTx(t *testing.T) {
testMultiTxSnapshot(t, func(s *StateDB) {
for _, addr := range addrs {
Expand Down Expand Up @@ -419,7 +443,7 @@ func TestMultiTxSnapshotStateChanges(t *testing.T) {
})
}

func testStackBasic(t *testing.T) {
func TestStackBasic(t *testing.T) {
for i := 0; i < 10; i++ {
testMultiTxSnapshot(t, func(s *StateDB) {
// when test starts, actions are performed after new snapshot is created
Expand Down Expand Up @@ -475,7 +499,41 @@ func testStackBasic(t *testing.T) {
}
}

func testStackSelfDestruct(t *testing.T) {
func TestStackRefund(t *testing.T) {
testMultiTxSnapshot(t, func(s *StateDB) {
const counter = 10

s.AddRefund(500)
previousRefunds := make([]uint64, 0, counter)
previousRefunds = append(previousRefunds, s.GetRefund())

for i := 0; i < counter; i++ {
previousRefunds = append(previousRefunds, s.GetRefund())
if err := s.NewMultiTxSnapshot(); err != nil {
t.Errorf("NewMultiTxSnapshot failed: %v", err)
t.FailNow()
}
s.Finalise(true)
s.AddRefund(1000)
}

for i := 0; i < counter; i++ {
if err := s.MultiTxSnapshotRevert(); err != nil {
t.Errorf("MultiTxSnapshotRevert failed: %v", err)
t.FailNow()
}
actualRefund := s.GetRefund()
expectedRefund := previousRefunds[len(previousRefunds)-1]
if actualRefund != expectedRefund {
t.Errorf("expected refund to be %d, got %d", expectedRefund, actualRefund)
t.FailNow()
}
previousRefunds = previousRefunds[:len(previousRefunds)-1]
}
})
}

func TestStackSelfDestruct(t *testing.T) {
testMultiTxSnapshot(t, func(s *StateDB) {
if err := s.NewMultiTxSnapshot(); err != nil {
t.Errorf("NewMultiTxSnapshot failed: %v", err)
Expand Down Expand Up @@ -515,7 +573,7 @@ func testStackSelfDestruct(t *testing.T) {
})
}

func testStackAgainstSingleSnap(t *testing.T) {
func TestStackAgainstSingleSnap(t *testing.T) {
// we generate a random seed ten times to fuzz test multiple stack snapshots against single layer snapshot
for i := 0; i < 10; i++ {
testMultiTxSnapshot(t, func(s *StateDB) {
Expand Down Expand Up @@ -614,17 +672,6 @@ func testStackAgainstSingleSnap(t *testing.T) {
}
}

func TestMultiTxSnapshotStack(t *testing.T) {
// test state changes are valid after merging snapshots
testStackBasic(t)

// test self-destruct
testStackSelfDestruct(t)

// test against baseline single snapshot
testStackAgainstSingleSnap(t)
}

func CompareAndPrintSnapshotMismatches(t *testing.T, target, other *MultiTxSnapshot) {
var out bytes.Buffer
if target.Equal(other) {
Expand Down

0 comments on commit d6d5cae

Please sign in to comment.