Skip to content

Commit

Permalink
Update validity test
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Oct 4, 2023
1 parent 40637c4 commit 04bcbfa
Showing 1 changed file with 122 additions and 78 deletions.
200 changes: 122 additions & 78 deletions tests/test_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,24 @@

import functools
import pytest
from checkpoint_schedules.schedule import \
Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType
from checkpoint_schedules import HRevolve, DiskRevolve, PeriodicDiskRevolve, \
Revolve, MultistageCheckpointSchedule, TwoLevelCheckpointSchedule, \
MixedCheckpointSchedule
from checkpoint_schedules.schedule import (
Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType)
from checkpoint_schedules import (
HRevolve, DiskRevolve, PeriodicDiskRevolve, Revolve,
MultistageCheckpointSchedule, TwoLevelCheckpointSchedule,
MixedCheckpointSchedule, SingleDiskStorageSchedule,
SingleMemoryStorageSchedule)


def h_revolve(n, s):
snap_ram = s//3
snap_disk = s - s//3
if s//3 < 1:
pytest.skip("H-Revolve accepts snapshots in RAM > 1")
snap_ram = s // 3
snap_disk = s - snap_ram
if snap_ram < 1 or snap_disk < 1:
return (None,
{StorageType.RAM: 0, StorageType.DISK: 0}, 0)
else:
revolver = HRevolve(n, snap_ram, snap_disk)
return (revolver,
cp_schedule = HRevolve(n, snap_ram, snap_disk)
return (cp_schedule,
{StorageType.RAM: snap_ram, StorageType.DISK: snap_disk}, 1)


Expand All @@ -44,8 +45,8 @@ def disk_revolve(n, s):
return (None,
{StorageType.RAM: 0, StorageType.DISK: 0}, 0)
else:
revolver = DiskRevolve(n, s, n - s)
return (revolver,
cp_schedule = DiskRevolve(n, s, n - s)
return (cp_schedule,
{StorageType.RAM: s, StorageType.DISK: n - s}, 1)


Expand All @@ -64,36 +65,57 @@ def periodic_disk(n, s):
return (None,
{StorageType.RAM: 0, StorageType.DISK: 0}, 0)
else:
revolver = PeriodicDiskRevolve(n, s)
return (revolver,
{StorageType.RAM: s, StorageType.DISK: n - s}, 1)
cp_schedule = PeriodicDiskRevolve(n, s)
return (cp_schedule,
{StorageType.RAM: s, StorageType.DISK: n - s}, 1)


def revolve(n, s):
if s < 1:
return (None,
{StorageType.RAM: 0, StorageType.DISK: 0}, 0)
else:
revolver = Revolve(n, s)
return (revolver,
{StorageType.RAM: s, StorageType.DISK: 0}, 1)
cp_schedule = Revolve(n, s)
return (cp_schedule,
{StorageType.RAM: s, StorageType.DISK: 0}, 1)


def mixed(n, s):
return (MixedCheckpointSchedule(n, s),
{StorageType.RAM: 0, StorageType.DISK: s}, 1)


def single_disk_copy(n, s):
cp_schedule = SingleDiskStorageSchedule(move_data=False)
return (cp_schedule,
{StorageType.RAM: 0, StorageType.DISK: n}, 1)


def single_disk_move(n, s):
cp_schedule = SingleDiskStorageSchedule(move_data=True)
return (cp_schedule,
{StorageType.RAM: 0, StorageType.DISK: n}, 1)


def single_memory(n, s):
cp_schedule = SingleMemoryStorageSchedule()
return (cp_schedule,
{StorageType.RAM: 0, StorageType.DISK: 0}, n)


@pytest.mark.parametrize(
"schedule",
[
revolve,
periodic_disk,
disk_revolve,
h_revolve,
disk_revolve,
multistage,
twolevel_binomial,
periodic_disk,
revolve,
mixed,
single_disk_copy,
single_disk_move,
single_memory
]
)
@pytest.mark.parametrize("n, S", [
Expand All @@ -104,53 +126,62 @@ def mixed(n, s):
(250, tuple(range(25, 250, 25)))
])
def test_validity(schedule, n, S):
"""Test the checkpoint revolvers.
"""Test validity of checkpoint schedules. Tests that an adjoint calculation
can be performed without exceeding storage limits.
Parameters
----------
schedule : object
Revolver schedule.
schedule : callable
Accepts the number of forward steps and checkpoint units, and returns a
schedule.
n : int
Total forward steps.
Number of forward steps.
S : int
Snapshots.
Number of checkpoint units.
"""

@functools.singledispatch
def action(cp_action):
raise TypeError("Unexpected action")

@action.register(Forward)
def action_forward(cp_action):
nonlocal model_n
# Start at the current location of the forward
assert model_n is not None and model_n == cp_action.n0

# Start at the current location of the forward
assert model_n is not None and cp_action.n0 == model_n
# If the schedule has been finalized, end at or before the end of the
# forward
assert cp_schedule.max_n is None or cp_action.n1 <= n

if cp_schedule.max_n is not None:
# Do not advance further than the current location of the adjoint
assert cp_action.n1 <= n - model_r

n1 = min(cp_action.n1, n)

model_n = n1
ics.clear()
data.clear()
if cp_action.write_ics:
# No forward restart data for these steps is stored
assert cp_action.n0 not in snapshots[cp_action.storage]
# No forward restart data for these steps is stored
assert len(ics.intersection(range(cp_action.n0, n1))) == 0
ics.update(range(cp_action.n0, n1))
snapshots[cp_action.storage][cp_action.n0] = (set(ics), set(data))

cp_ics = set(range(cp_action.n0, n1))
else:
cp_ics = set()
if cp_action.write_adj_deps:
# No non-linear dependency data for these steps is stored
assert len(data.intersection(range(cp_action.n0, n1))) == 0
data.update(range(cp_action.n0, n1))
if cp_action.storage == StorageType.DISK:
snapshots[cp_action.storage][cp_action.n0] = (set(ics), set(data)) # noqa: E501
cp_data = set(range(cp_action.n0, n1))
else:
cp_data = set()

if cp_action.storage in {StorageType.RAM, StorageType.DISK}:
assert cp_action.n0 not in snapshots[cp_action.storage]
snapshots[cp_action.storage][cp_action.n0] = (set(cp_ics), set(cp_data)) # noqa: E501
elif cp_action.storage == StorageType.WORK:
assert len(ics.intersection(cp_ics)) == 0
ics.update(cp_ics)
assert len(data.intersection(cp_data)) == 0
data.update(cp_data)
elif cp_action.storage == StorageType.NONE:
pass
else:
raise ValueError("Unexpected storage")

if len(ics) > 0:
if len(data) > 0:
Expand All @@ -176,78 +207,87 @@ def action_reverse(cp_action):

model_r += cp_action.n1 - cp_action.n0
if cp_action.clear_adj_deps:
data.clear()
data.difference_update(range(cp_action.n0, cp_action.n1))

@action.register(Copy)
def action_copy(cp_action):
nonlocal model_n

# The checkpoint exists
assert cp_action.n in snapshots[cp_action.from_storage]
cp = snapshots[cp_action.from_storage][cp_action.n]

# No data is currently stored for this step
assert cp_action.n not in ics
assert cp_action.n not in data
cp_ics, cp_data = snapshots[cp_action.from_storage][cp_action.n]

# The checkpoint contains forward restart or non-linear dependency data
assert len(cp[0]) > 0 or len(cp[1]) > 0
assert len(cp_ics) > 0 or len(cp_data) > 0

# The checkpoint data is before the current location of the adjoint
assert cp_action.n < n - model_r
model_n = None
if len(cp[0]) > 0:
ics.clear()
ics.update(cp[0])
model_n = cp_action.n

if len(cp[1]) > 0:
data.clear()
data.update(cp[1])
if cp_action.to_storage in {StorageType.RAM, StorageType.DISK}:
assert cp_action.n not in snapshots[cp_action.to_storage]
snapshots[cp_action.to_storage][cp_action.n] = (set(cp_ics), set(cp_data)) # noqa: E501
elif cp_action.to_storage == StorageType.WORK:
if cp_action.n in cp_ics:
model_n = cp_action.n
else:
model_n = None
assert len(ics.intersection(cp_ics)) == 0
ics.update(cp_ics)
assert len(data.intersection(cp_data)) == 0
data.update(cp_data)
elif cp_action.storage == StorageType.NONE:
pass
else:
raise ValueError("Unexpected storage")

@action.register(Move)
def action_move(cp_action):
nonlocal model_n
# The checkpoint exists
assert cp_action.n in snapshots[cp_action.from_storage]
cp = snapshots[cp_action.from_storage][cp_action.n]
cp_ics, cp_data = snapshots[cp_action.from_storage].pop(cp_action.n)

# The checkpoint contains forward restart or non-linear dependency data
assert len(cp[0]) > 0 or len(cp[1]) > 0
assert len(cp_ics) > 0 or len(cp_data) > 0

# The checkpoint data is before the current location of the adjoint
assert cp_action.n < n - model_r

assert cp_action.n < n - model_r

model_n = None
if len(cp[0]) > 0:
ics.clear()
ics.update(cp[0])
model_n = cp_action.n

if len(cp[1]) > 0:
data.clear()
data.update(cp[1])

del snapshots[cp_action.from_storage][cp_action.n]
if cp_action.to_storage in {StorageType.RAM, StorageType.DISK}:
assert cp_action.n not in snapshots[cp_action.to_storage]
snapshots[cp_action.to_storage][cp_action.n] = (set(cp_ics), set(cp_data)) # noqa: E501
elif cp_action.to_storage == StorageType.WORK:
if cp_action.n in cp_ics:
model_n = cp_action.n
else:
model_n = None
assert len(ics.intersection(cp_ics)) == 0
ics.update(cp_ics)
assert len(data.intersection(cp_data)) == 0
data.update(cp_data)
elif cp_action.storage == StorageType.NONE:
pass
else:
raise ValueError("Unexpected storage")

@action.register(EndForward)
def action_end_forward(cp_action):
ics.clear()
# The correct number of forward steps has been taken
assert model_n is not None and model_n == n

@action.register(EndReverse)
def action_end_reverse(cp_action):
nonlocal model_r, cp_schedule
nonlocal model_r

# The correct number of adjoint steps has been taken
assert model_r == n
is_exhausted = cp_schedule.is_exhausted
if is_exhausted is False:

if not cp_schedule.is_exhausted:
model_r = 0

for s in S:
print(f"{n=:d} {s=:d}")

model_n = 0
model_r = 0
ics = set()
Expand All @@ -256,15 +296,19 @@ def action_end_reverse(cp_action):
snapshots = {StorageType.RAM: {}, StorageType.DISK: {}}
cp_schedule, storage_limits, data_limit = schedule(n, s)
if cp_schedule is None:
raise TypeError("Incompatible with schedule type.")
pytest.skip("Incompatible with schedule type")
assert cp_schedule.n == 0
assert cp_schedule.r == 0
assert cp_schedule.max_n is None or cp_schedule.max_n == n

for _, cp_action in enumerate(cp_schedule):
action(cp_action)

# The schedule state is consistent with both the forward and
# adjoint
assert model_n is None or model_n == cp_schedule.n
assert model_r == cp_schedule.r
assert cp_schedule.max_n is None or cp_schedule.max_n == n

# Checkpoint storage limits are not exceeded
for storage_type, storage_limit in storage_limits.items():
Expand Down

0 comments on commit 04bcbfa

Please sign in to comment.