From 40637c4bc5fc9917191695f789e41c16845976ed Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 4 Oct 2023 16:06:52 +0100 Subject: [PATCH] Bugfix to SingleDiskStorageSchedule, update associated unit tests --- checkpoint_schedules/basic_schedules.py | 2 +- tests/test_trivial.py | 83 +++++++++++-------------- 2 files changed, 39 insertions(+), 46 deletions(-) diff --git a/checkpoint_schedules/basic_schedules.py b/checkpoint_schedules/basic_schedules.py index 15c337e..ffb63c8 100644 --- a/checkpoint_schedules/basic_schedules.py +++ b/checkpoint_schedules/basic_schedules.py @@ -116,7 +116,7 @@ def _iterator(self): while self._max_n is None: n0 = self._n - n1 = n0 + sys.maxsize + n1 = n0 + 1 self._n = n1 yield Forward(n0, n1, False, True, StorageType.DISK) diff --git a/tests/test_trivial.py b/tests/test_trivial.py index fb0c83e..aa80b3b 100644 --- a/tests/test_trivial.py +++ b/tests/test_trivial.py @@ -1,28 +1,29 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + import functools import pytest -from checkpoint_schedules.schedule import \ - Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType -from checkpoint_schedules import SingleDiskStorageSchedule, \ - SingleMemoryStorageSchedule +from checkpoint_schedules.schedule import ( + Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType) +from checkpoint_schedules import ( + SingleDiskStorageSchedule, SingleMemoryStorageSchedule) def single_disk_copy(n): - checkpointing = SingleDiskStorageSchedule() - return (checkpointing, + cp_schedule = SingleDiskStorageSchedule(move_data=False) + return (cp_schedule, {StorageType.RAM: 0, StorageType.DISK: n}, 1) def single_disk_move(n): - checkpointing = SingleDiskStorageSchedule(move_data=True) - return (checkpointing, + cp_schedule = SingleDiskStorageSchedule(move_data=True) + return (cp_schedule, {StorageType.RAM: 0, StorageType.DISK: n}, 1) def single_memory(n): - checkpointing = SingleMemoryStorageSchedule() - return (checkpointing, + cp_schedule = SingleMemoryStorageSchedule() + return (cp_schedule, {StorageType.RAM: 0, StorageType.DISK: 0}, n) @@ -35,17 +36,16 @@ def single_memory(n): ] ) def test_validity(schedule, n=10): - """Test the checkpoint revolvers. + """Test basic checkpointing schedules. Parameters ---------- - schedule : object - Revolver schedule. + schedule : callable + Accepts the number of forward steps and returns a schedule. n : int - Total forward steps. - S : int - Snapshots. + Number of forward steps """ + @functools.singledispatch def action(cp_action): raise TypeError("Unexpected action") @@ -56,8 +56,6 @@ def action_forward(cp_action): # Start at the current location of the forward assert model_n is not None and model_n == cp_action.n0 - assert cp_action.storage == StorageType.WORK or \ - cp_action.storage == StorageType.DISK # If the schedule has been finalized, end at or before the end of the # forward assert cp_schedule.max_n is None or cp_action.n1 <= n @@ -68,21 +66,17 @@ def action_forward(cp_action): n1 = min(cp_action.n1, n) model_n = n1 - data.clear() + assert len(ics) == 0 assert not cp_action.write_ics - assert cp_action.write_adj_deps is True assert len(data.intersection(range(cp_action.n0, n1))) == 0 + assert cp_action.write_adj_deps if cp_action.storage == StorageType.DISK: - for step in range(cp_action.n0, n1): - data.update(range(step, step + 1)) - snapshots[cp_action.storage][step] = (set(ics), set(data)) - data.clear() - assert cp_action.n0 == min(snapshots[cp_action.storage]) - else: + snapshots[cp_action.storage][cp_action.n0] = \ + (set(ics), set(range(cp_action.n0, n1))) + elif cp_action.storage == StorageType.WORK: data.update(range(cp_action.n0, n1)) - assert cp_action.n0 == min(data) - - assert len(ics) == 0 + else: + raise ValueError("Unexpected storage") if n1 == n: cp_schedule.finalize(n1) @@ -100,7 +94,7 @@ def action_reverse(cp_action): model_r += cp_action.n1 - cp_action.n0 if cp_action.clear_adj_deps: - data.clear() + data.difference_update(range(cp_action.n0, cp_action.n1)) @action.register(Copy) def action_copy(cp_action): @@ -109,17 +103,18 @@ def action_copy(cp_action): assert cp_action.n in snapshots[cp_action.from_storage] cp = snapshots[cp_action.from_storage][cp_action.n] - # No data is currently stored for this step - assert cp_action.n not in ics - assert cp_action.n not in data + assert len(ics.intersection(cp[0])) == 0 + assert len(data.intersection(cp[1])) == 0 # The checkpoint contains forward data assert len(cp[0]) == 0 and len(cp[1]) > 0 # The checkpoint data is before the current location of the adjoint assert cp_action.n < n - model_r - model_n = None + model_n = None + assert len(ics) == 0 + assert cp_action.to_storage == StorageType.WORK data.clear() data.update(cp[1]) @@ -128,22 +123,22 @@ def action_move(cp_action): nonlocal model_n # The checkpoint exists assert cp_action.n in snapshots[cp_action.from_storage] - assert cp_action.n == max(snapshots[cp_action.from_storage]) - cp = snapshots[cp_action.from_storage][cp_action.n] + cp = snapshots[cp_action.from_storage].pop(cp_action.n) - # No data is currently stored for this step - assert cp_action.n not in ics - assert cp_action.n not in data + assert len(ics.intersection(cp[0])) == 0 + assert len(data.intersection(cp[1])) == 0 # The checkpoint contains forward data assert len(cp[0]) == 0 and len(cp[1]) > 0 # The checkpoint data is before the current location of the adjoint assert cp_action.n < n - model_r + model_n = None + assert len(ics) == 0 + assert cp_action.to_storage == StorageType.WORK data.clear() data.update(cp[1]) - del snapshots[cp_action.from_storage][cp_action.n] @action.register(EndForward) def action_end_forward(cp_action): @@ -152,23 +147,21 @@ def action_end_forward(cp_action): @action.register(EndReverse) def action_end_reverse(cp_action): - nonlocal model_r, cp_schedule + nonlocal model_r # The correct number of adjoint steps has been taken assert model_r == n is_exhausted = cp_schedule.is_exhausted - if is_exhausted is False: + if not is_exhausted: model_r = 0 model_n = 0 model_r = 0 ics = set() data = set() - assert isinstance(n, int) snapshots = {StorageType.RAM: {}, StorageType.DISK: {}} cp_schedule, storage_limits, data_limit = schedule(n) - if cp_schedule is None: - raise TypeError("Incompatible with schedule type.") + assert cp_schedule is not None assert cp_schedule.n == 0 assert cp_schedule.r == 0 assert cp_schedule.max_n is None or cp_schedule.max_n == n