Skip to content

Commit

Permalink
Stateful testing of TWA module (#11)
Browse files Browse the repository at this point in the history
* test: hypothesis structure for TWA module

* test: hypothesis parallel testing

* test: hypothesis parallel testing

* dev: TWA contract min_dt enforced; test: rm parallel hypothesis, added TWA invariant
---------

Co-authored-by: Alberto <[email protected]>
  • Loading branch information
heswithme and AlbertoCentonze authored Oct 14, 2024
1 parent bc976fc commit de3bc42
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 4 deletions.
16 changes: 12 additions & 4 deletions contracts/TWA.vy
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct Snapshot:
@deploy
def __init__(_twa_window: uint256, _min_snapshot_dt_seconds: uint256):
self._set_twa_window(_twa_window)
self._set_snapshot_dt(_min_snapshot_dt_seconds)
self._set_snapshot_dt(max(1, _min_snapshot_dt_seconds))


################################################################
Expand Down Expand Up @@ -158,15 +158,20 @@ def _compute() -> uint256:
i_backwards: uint256 = index_array_end - i
current_snapshot: Snapshot = self.snapshots[i_backwards]
next_snapshot: Snapshot = current_snapshot
if i != 0: # If not the first iteration, get the next snapshot
if i != 0: # If not the first iteration (last snapshot), get the next snapshot
next_snapshot = self.snapshots[i_backwards + 1]

# Time Axis (Increasing to the Right) --->
# SNAPSHOT
# |---------|---------|---------|------------------------|---------|---------|
# t0 time_window_start interval_start interval_end block.timestamp (Now)

interval_start: uint256 = current_snapshot.timestamp
# Adjust interval start if it is before the time window start
if interval_start < time_window_start:
interval_start = time_window_start

interval_end: uint256 = 0
interval_end: uint256 = interval_start
if i == 0: # First iteration - we are on the last snapshot (i_backwards = num_snapshots - 1)
# For the last snapshot, interval end is block.timestamp
interval_end = block.timestamp
Expand All @@ -186,7 +191,10 @@ def _compute() -> uint256:
total_weighted_tracked_value += averaged_tracked_value * time_delta
total_time += time_delta

if total_time == 0 and len(self.snapshots) == 1:
# case when only snapshot is taken in the block where computation is called
return self.snapshots[0].tracked_value

assert total_time > 0, "Zero total time!"
twa: uint256 = total_weighted_tracked_value // total_time

return twa
Empty file added tests/__init__.py
Empty file.
Empty file added tests/hypothesis/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions tests/hypothesis/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from hypothesis import Phase, Verbosity, settings

settings.register_profile("debug", settings(verbosity=Verbosity.verbose, phases=list(Phase)[:4]))
settings.load_profile("debug")
Empty file.
114 changes: 114 additions & 0 deletions tests/hypothesis/twa/stateful_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import boa
from hypothesis import note
from hypothesis import strategies as st
from hypothesis.stateful import RuleBasedStateMachine, initialize # , invariant, rule


class TWAStatefulBase(RuleBasedStateMachine):
twa_deployer = boa.load_partial("contracts/TWA.vy")

def __init__(self):
super().__init__()
note("INIT")
self.twa_contract = None
self.twa_window = None
self.min_snapshot_dt_seconds = None
self.snapshots = []
self.last_snapshot_timestamp = 0

@initialize(
twa_window=st.integers(min_value=1, max_value=86400 * 7), # 1 second to 1 week
min_snapshot_dt_seconds=st.integers(min_value=1, max_value=86400), # 1 second to 1 day
)
def setup(self, twa_window, min_snapshot_dt_seconds):
"""Initialize the TWA contract and set up initial parameters."""
note("SETUP")
self.twa_contract = TWAStatefulBase.twa_deployer(twa_window, min_snapshot_dt_seconds)

self.twa_window = twa_window
self.min_snapshot_dt_seconds = min_snapshot_dt_seconds
self.snapshots = []
self.last_snapshot_timestamp = 0

def python_take_snapshot(self, value):
"""
Python model of the contract's `_take_snapshot` function.
Mirrors the contract logic and updates the internal state.
"""
# Contract logic: only take a snapshot if the time condition is met
block_timestamp = boa.env.evm.patch.timestamp
if self.last_snapshot_timestamp + self.min_snapshot_dt_seconds <= block_timestamp:
self.last_snapshot_timestamp = block_timestamp
self.snapshots.append({"tracked_value": value, "timestamp": block_timestamp})
note(
f"python_take_snapshot: Python snapshot added: value={value}, timestamp={block_timestamp}" # noqa: E501
)
else:
note("python_take_snapshot: Python snapshot skipped (time condition not met)")

def python_compute_twa(self):
"""
Python version of the contract's _compute function.
Computes the TWA (Time-Weighted Average) based on the snapshots in self.snapshots.
"""
block_timestamp = boa.env.evm.patch.timestamp

num_snapshots = len(self.snapshots)
if num_snapshots == 0:
note("python_compute_twa: No snapshots, no TWA")
return 0

time_window_start = block_timestamp - self.twa_window

total_weighted_tracked_value = 0
total_time = 0

# Iterate backwards over all snapshots
index_array_end = num_snapshots - 1
for i in range(0, num_snapshots):
i_backwards = index_array_end - i
current_snapshot = self.snapshots[i_backwards]
next_snapshot = current_snapshot

if i != 0: # If not the first iteration, get the next snapshot
next_snapshot = self.snapshots[i_backwards + 1]

interval_start = current_snapshot["timestamp"]

# Adjust interval start if it is before the time window start
if interval_start < time_window_start:
interval_start = time_window_start

if i == 0:
# For the last snapshot, interval end is the block_timestamp
interval_end = block_timestamp
else:
# For other snapshots, interval end is the timestamp of the next snapshot
interval_end = next_snapshot["timestamp"]

if interval_end <= time_window_start:
break

time_delta = interval_end - interval_start

# Interpolation using the trapezoidal rule
averaged_tracked_value = (
current_snapshot["tracked_value"] + next_snapshot["tracked_value"]
) // 2

# Accumulate weighted rate and time
total_weighted_tracked_value += averaged_tracked_value * time_delta
total_time += time_delta

if total_time == 0 and len(self.snapshots) == 1:
# case when only snapshot is taken in the block where computation is called
return self.snapshots[0]["tracked_value"]

# Ensure there is non-zero time for division
if total_time == 0:
raise ValueError("TWA: Zero total time!")

# Calculate TWA
twa = total_weighted_tracked_value // total_time
note(f"python_compute_twa: Computed TWA: {twa}")
return twa
133 changes: 133 additions & 0 deletions tests/hypothesis/twa/test_twa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import boa
from hypothesis import HealthCheck, Verbosity, settings
from hypothesis import strategies as st
from hypothesis.stateful import invariant, rule

from tests.hypothesis.twa.stateful_base import TWAStatefulBase


def test_state_machine():
# Explicitly run the state machine
TestTWAStateful = TWAStateful.TestCase()
TestTWAStateful.run()


@settings(
max_examples=10,
stateful_step_count=1000,
suppress_health_check=[
HealthCheck.large_base_example
], # skips issue when trying to add 1000 examples with 0 dt
verbosity=Verbosity.verbose,
)
class TWAStateful(TWAStatefulBase):
@invariant()
def check_initialization(self):
assert self.twa_window > 0, "TWA window must be set"
assert self.min_snapshot_dt_seconds > 0, "Minimum snapshot interval must be set"

@invariant()
def check_crude_twa_invariant(self):
"""
Crude invariant to ensure that the computed TWA is reasonable.
It checks that the TWA is non-negative and is between the minimum and maximum
values of the snapshots within the TWA window.
"""
# Get current block timestamp
current_time = boa.env.evm.patch.timestamp

# Calculate the time window start
time_window_start = current_time - self.twa_window

# Collect snapshots within the TWA window
snapshots_in_window = [
snapshot for snapshot in self.snapshots if snapshot["timestamp"] >= time_window_start
]

# Also consider the last snapshot just outside TWA window (needed for trapezoidal rule)
previous_snapshot = None
for snapshot in self.snapshots:
if snapshot["timestamp"] < time_window_start:
previous_snapshot = snapshot
else:
break # We passed the start of the window

# If a previous snapshot exists, we add it to the window (on the boundary)
# not changing timestamp as we only assert values here
if previous_snapshot:
snapshots_in_window.append(previous_snapshot)

# If there are still no snapshots (even outside the window), TWA should be zero
if not snapshots_in_window:
contract_twa = self.twa_contract.compute_twa()
python_twa = self.python_compute_twa()

# Assert both TWAs are zero
assert contract_twa == 0, f"Contract TWA should be zero but is {contract_twa}"
assert python_twa == 0, f"Python TWA should be zero but is {python_twa}"
return

# Extract tracked values from snapshots in the window
tracked_values = [snapshot["tracked_value"] for snapshot in snapshots_in_window]

# Compute the min and max values of the tracked values
min_value = min(tracked_values)
max_value = max(tracked_values)
# Compute the TWA from the contract and Python model
contract_twa = self.twa_contract.compute_twa()
python_twa = self.python_compute_twa()

# Ensure that the TWA is non-negative
assert contract_twa >= 0, f"Contract TWA is negative: {contract_twa}"
assert python_twa >= 0, f"Python TWA is negative: {python_twa}"

# Ensure that the TWA is between the min and max values of the snapshots
assert (
min_value <= contract_twa <= max_value
), f"Contract TWA {contract_twa} is not between min {min_value} and max {max_value}"
assert (
min_value <= python_twa <= max_value
), f"Python TWA {python_twa} is not between min {min_value} and max {max_value}"

@rule(
value=st.integers(min_value=0, max_value=100_000_000 * 10**18), # 0 to 100 million crvUSD
timestamp_delta=st.integers(
min_value=0, max_value=10 * 86400
), # 0s to 10 days between snapshots
)
def take_snapshot_rule(self, value, timestamp_delta):
"""
Rule to test taking snapshots in both the Python model and the contract.
"""
boa.env.time_travel(seconds=timestamp_delta)
# Call snapshot-taking functions in both the Python model and the contract
self.twa_contract.eval(f"self._take_snapshot({value})")
self.python_take_snapshot(value)

# Assert equal numbe of the snapshots
contract_snapshot_len = self.twa_contract.get_len_snapshots()
python_snapshot_len = len(self.snapshots)

assert contract_snapshot_len == python_snapshot_len, (
"Mismatch in snapshot length: "
+ f"contract={contract_snapshot_len}, python={python_snapshot_len}"
)

@rule(
timestamp_delta=st.integers(
min_value=0, max_value=10 * 86400
), # 0s to 10days between compute calls
)
def compute_twa_rule(self, timestamp_delta):
boa.env.time_travel(seconds=timestamp_delta)
# TWA computation for contract/python model
contract_twa = self.twa_contract.compute_twa()
python_twa = self.python_compute_twa()

# Assert that both values are the same
assert (
contract_twa == python_twa
), f"Mismatch in TWA: contract={contract_twa}, python={python_twa}"


# TestTWAStateful = TWAStateful.TestCase

0 comments on commit de3bc42

Please sign in to comment.