-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stateful testing of TWA module (#11)
* test: hypothesis structure for TWA module * test: hypothesis parallel testing * test: hypothesis parallel testing * dev: TWA contract min_dt enforced; test: rm parallel hypothesis, added TWA invariant --------- Co-authored-by: Alberto <[email protected]>
- Loading branch information
1 parent
bc976fc
commit de3bc42
Showing
7 changed files
with
263 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from hypothesis import Phase, Verbosity, settings | ||
|
||
settings.register_profile("debug", settings(verbosity=Verbosity.verbose, phases=list(Phase)[:4])) | ||
settings.load_profile("debug") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import boa | ||
from hypothesis import note | ||
from hypothesis import strategies as st | ||
from hypothesis.stateful import RuleBasedStateMachine, initialize # , invariant, rule | ||
|
||
|
||
class TWAStatefulBase(RuleBasedStateMachine): | ||
twa_deployer = boa.load_partial("contracts/TWA.vy") | ||
|
||
def __init__(self): | ||
super().__init__() | ||
note("INIT") | ||
self.twa_contract = None | ||
self.twa_window = None | ||
self.min_snapshot_dt_seconds = None | ||
self.snapshots = [] | ||
self.last_snapshot_timestamp = 0 | ||
|
||
@initialize( | ||
twa_window=st.integers(min_value=1, max_value=86400 * 7), # 1 second to 1 week | ||
min_snapshot_dt_seconds=st.integers(min_value=1, max_value=86400), # 1 second to 1 day | ||
) | ||
def setup(self, twa_window, min_snapshot_dt_seconds): | ||
"""Initialize the TWA contract and set up initial parameters.""" | ||
note("SETUP") | ||
self.twa_contract = TWAStatefulBase.twa_deployer(twa_window, min_snapshot_dt_seconds) | ||
|
||
self.twa_window = twa_window | ||
self.min_snapshot_dt_seconds = min_snapshot_dt_seconds | ||
self.snapshots = [] | ||
self.last_snapshot_timestamp = 0 | ||
|
||
def python_take_snapshot(self, value): | ||
""" | ||
Python model of the contract's `_take_snapshot` function. | ||
Mirrors the contract logic and updates the internal state. | ||
""" | ||
# Contract logic: only take a snapshot if the time condition is met | ||
block_timestamp = boa.env.evm.patch.timestamp | ||
if self.last_snapshot_timestamp + self.min_snapshot_dt_seconds <= block_timestamp: | ||
self.last_snapshot_timestamp = block_timestamp | ||
self.snapshots.append({"tracked_value": value, "timestamp": block_timestamp}) | ||
note( | ||
f"python_take_snapshot: Python snapshot added: value={value}, timestamp={block_timestamp}" # noqa: E501 | ||
) | ||
else: | ||
note("python_take_snapshot: Python snapshot skipped (time condition not met)") | ||
|
||
def python_compute_twa(self): | ||
""" | ||
Python version of the contract's _compute function. | ||
Computes the TWA (Time-Weighted Average) based on the snapshots in self.snapshots. | ||
""" | ||
block_timestamp = boa.env.evm.patch.timestamp | ||
|
||
num_snapshots = len(self.snapshots) | ||
if num_snapshots == 0: | ||
note("python_compute_twa: No snapshots, no TWA") | ||
return 0 | ||
|
||
time_window_start = block_timestamp - self.twa_window | ||
|
||
total_weighted_tracked_value = 0 | ||
total_time = 0 | ||
|
||
# Iterate backwards over all snapshots | ||
index_array_end = num_snapshots - 1 | ||
for i in range(0, num_snapshots): | ||
i_backwards = index_array_end - i | ||
current_snapshot = self.snapshots[i_backwards] | ||
next_snapshot = current_snapshot | ||
|
||
if i != 0: # If not the first iteration, get the next snapshot | ||
next_snapshot = self.snapshots[i_backwards + 1] | ||
|
||
interval_start = current_snapshot["timestamp"] | ||
|
||
# Adjust interval start if it is before the time window start | ||
if interval_start < time_window_start: | ||
interval_start = time_window_start | ||
|
||
if i == 0: | ||
# For the last snapshot, interval end is the block_timestamp | ||
interval_end = block_timestamp | ||
else: | ||
# For other snapshots, interval end is the timestamp of the next snapshot | ||
interval_end = next_snapshot["timestamp"] | ||
|
||
if interval_end <= time_window_start: | ||
break | ||
|
||
time_delta = interval_end - interval_start | ||
|
||
# Interpolation using the trapezoidal rule | ||
averaged_tracked_value = ( | ||
current_snapshot["tracked_value"] + next_snapshot["tracked_value"] | ||
) // 2 | ||
|
||
# Accumulate weighted rate and time | ||
total_weighted_tracked_value += averaged_tracked_value * time_delta | ||
total_time += time_delta | ||
|
||
if total_time == 0 and len(self.snapshots) == 1: | ||
# case when only snapshot is taken in the block where computation is called | ||
return self.snapshots[0]["tracked_value"] | ||
|
||
# Ensure there is non-zero time for division | ||
if total_time == 0: | ||
raise ValueError("TWA: Zero total time!") | ||
|
||
# Calculate TWA | ||
twa = total_weighted_tracked_value // total_time | ||
note(f"python_compute_twa: Computed TWA: {twa}") | ||
return twa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import boa | ||
from hypothesis import HealthCheck, Verbosity, settings | ||
from hypothesis import strategies as st | ||
from hypothesis.stateful import invariant, rule | ||
|
||
from tests.hypothesis.twa.stateful_base import TWAStatefulBase | ||
|
||
|
||
def test_state_machine(): | ||
# Explicitly run the state machine | ||
TestTWAStateful = TWAStateful.TestCase() | ||
TestTWAStateful.run() | ||
|
||
|
||
@settings( | ||
max_examples=10, | ||
stateful_step_count=1000, | ||
suppress_health_check=[ | ||
HealthCheck.large_base_example | ||
], # skips issue when trying to add 1000 examples with 0 dt | ||
verbosity=Verbosity.verbose, | ||
) | ||
class TWAStateful(TWAStatefulBase): | ||
@invariant() | ||
def check_initialization(self): | ||
assert self.twa_window > 0, "TWA window must be set" | ||
assert self.min_snapshot_dt_seconds > 0, "Minimum snapshot interval must be set" | ||
|
||
@invariant() | ||
def check_crude_twa_invariant(self): | ||
""" | ||
Crude invariant to ensure that the computed TWA is reasonable. | ||
It checks that the TWA is non-negative and is between the minimum and maximum | ||
values of the snapshots within the TWA window. | ||
""" | ||
# Get current block timestamp | ||
current_time = boa.env.evm.patch.timestamp | ||
|
||
# Calculate the time window start | ||
time_window_start = current_time - self.twa_window | ||
|
||
# Collect snapshots within the TWA window | ||
snapshots_in_window = [ | ||
snapshot for snapshot in self.snapshots if snapshot["timestamp"] >= time_window_start | ||
] | ||
|
||
# Also consider the last snapshot just outside TWA window (needed for trapezoidal rule) | ||
previous_snapshot = None | ||
for snapshot in self.snapshots: | ||
if snapshot["timestamp"] < time_window_start: | ||
previous_snapshot = snapshot | ||
else: | ||
break # We passed the start of the window | ||
|
||
# If a previous snapshot exists, we add it to the window (on the boundary) | ||
# not changing timestamp as we only assert values here | ||
if previous_snapshot: | ||
snapshots_in_window.append(previous_snapshot) | ||
|
||
# If there are still no snapshots (even outside the window), TWA should be zero | ||
if not snapshots_in_window: | ||
contract_twa = self.twa_contract.compute_twa() | ||
python_twa = self.python_compute_twa() | ||
|
||
# Assert both TWAs are zero | ||
assert contract_twa == 0, f"Contract TWA should be zero but is {contract_twa}" | ||
assert python_twa == 0, f"Python TWA should be zero but is {python_twa}" | ||
return | ||
|
||
# Extract tracked values from snapshots in the window | ||
tracked_values = [snapshot["tracked_value"] for snapshot in snapshots_in_window] | ||
|
||
# Compute the min and max values of the tracked values | ||
min_value = min(tracked_values) | ||
max_value = max(tracked_values) | ||
# Compute the TWA from the contract and Python model | ||
contract_twa = self.twa_contract.compute_twa() | ||
python_twa = self.python_compute_twa() | ||
|
||
# Ensure that the TWA is non-negative | ||
assert contract_twa >= 0, f"Contract TWA is negative: {contract_twa}" | ||
assert python_twa >= 0, f"Python TWA is negative: {python_twa}" | ||
|
||
# Ensure that the TWA is between the min and max values of the snapshots | ||
assert ( | ||
min_value <= contract_twa <= max_value | ||
), f"Contract TWA {contract_twa} is not between min {min_value} and max {max_value}" | ||
assert ( | ||
min_value <= python_twa <= max_value | ||
), f"Python TWA {python_twa} is not between min {min_value} and max {max_value}" | ||
|
||
@rule( | ||
value=st.integers(min_value=0, max_value=100_000_000 * 10**18), # 0 to 100 million crvUSD | ||
timestamp_delta=st.integers( | ||
min_value=0, max_value=10 * 86400 | ||
), # 0s to 10 days between snapshots | ||
) | ||
def take_snapshot_rule(self, value, timestamp_delta): | ||
""" | ||
Rule to test taking snapshots in both the Python model and the contract. | ||
""" | ||
boa.env.time_travel(seconds=timestamp_delta) | ||
# Call snapshot-taking functions in both the Python model and the contract | ||
self.twa_contract.eval(f"self._take_snapshot({value})") | ||
self.python_take_snapshot(value) | ||
|
||
# Assert equal numbe of the snapshots | ||
contract_snapshot_len = self.twa_contract.get_len_snapshots() | ||
python_snapshot_len = len(self.snapshots) | ||
|
||
assert contract_snapshot_len == python_snapshot_len, ( | ||
"Mismatch in snapshot length: " | ||
+ f"contract={contract_snapshot_len}, python={python_snapshot_len}" | ||
) | ||
|
||
@rule( | ||
timestamp_delta=st.integers( | ||
min_value=0, max_value=10 * 86400 | ||
), # 0s to 10days between compute calls | ||
) | ||
def compute_twa_rule(self, timestamp_delta): | ||
boa.env.time_travel(seconds=timestamp_delta) | ||
# TWA computation for contract/python model | ||
contract_twa = self.twa_contract.compute_twa() | ||
python_twa = self.python_compute_twa() | ||
|
||
# Assert that both values are the same | ||
assert ( | ||
contract_twa == python_twa | ||
), f"Mismatch in TWA: contract={contract_twa}, python={python_twa}" | ||
|
||
|
||
# TestTWAStateful = TWAStateful.TestCase |