Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix to SingleDiskStorageSchedule, update associated unit tests #48

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checkpoint_schedules/basic_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _iterator(self):

while self._max_n is None:
n0 = self._n
n1 = n0 + sys.maxsize
n1 = n0 + 1
self._n = n1
yield Forward(n0, n1, False, True, StorageType.DISK)

Expand Down
83 changes: 38 additions & 45 deletions tests/test_trivial.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import functools
import pytest
from checkpoint_schedules.schedule import \
Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType
from checkpoint_schedules import SingleDiskStorageSchedule, \
SingleMemoryStorageSchedule
from checkpoint_schedules.schedule import (
Forward, Reverse, Copy, Move, EndForward, EndReverse, StorageType)
from checkpoint_schedules import (
SingleDiskStorageSchedule, SingleMemoryStorageSchedule)


def single_disk_copy(n):
checkpointing = SingleDiskStorageSchedule()
return (checkpointing,
cp_schedule = SingleDiskStorageSchedule(move_data=False)
return (cp_schedule,
{StorageType.RAM: 0, StorageType.DISK: n}, 1)


def single_disk_move(n):
checkpointing = SingleDiskStorageSchedule(move_data=True)
return (checkpointing,
cp_schedule = SingleDiskStorageSchedule(move_data=True)
return (cp_schedule,
{StorageType.RAM: 0, StorageType.DISK: n}, 1)


def single_memory(n):
checkpointing = SingleMemoryStorageSchedule()
return (checkpointing,
cp_schedule = SingleMemoryStorageSchedule()
return (cp_schedule,
{StorageType.RAM: 0, StorageType.DISK: 0}, n)


Expand All @@ -35,17 +36,16 @@ def single_memory(n):
]
)
def test_validity(schedule, n=10):
"""Test the checkpoint revolvers.
"""Test basic checkpointing schedules.

Parameters
----------
schedule : object
Revolver schedule.
schedule : callable
Accepts the number of forward steps and returns a schedule.
n : int
Total forward steps.
S : int
Snapshots.
Number of forward steps
"""

@functools.singledispatch
def action(cp_action):
raise TypeError("Unexpected action")
Expand All @@ -56,8 +56,6 @@ def action_forward(cp_action):
# Start at the current location of the forward
assert model_n is not None and model_n == cp_action.n0

assert cp_action.storage == StorageType.WORK or \
cp_action.storage == StorageType.DISK
# If the schedule has been finalized, end at or before the end of the
# forward
assert cp_schedule.max_n is None or cp_action.n1 <= n
Expand All @@ -68,21 +66,17 @@ def action_forward(cp_action):
n1 = min(cp_action.n1, n)

model_n = n1
data.clear()
assert len(ics) == 0
assert not cp_action.write_ics
assert cp_action.write_adj_deps is True
assert len(data.intersection(range(cp_action.n0, n1))) == 0
assert cp_action.write_adj_deps
if cp_action.storage == StorageType.DISK:
for step in range(cp_action.n0, n1):
data.update(range(step, step + 1))
snapshots[cp_action.storage][step] = (set(ics), set(data))
data.clear()
assert cp_action.n0 == min(snapshots[cp_action.storage])
else:
snapshots[cp_action.storage][cp_action.n0] = \
(set(ics), set(range(cp_action.n0, n1)))
elif cp_action.storage == StorageType.WORK:
data.update(range(cp_action.n0, n1))
assert cp_action.n0 == min(data)

assert len(ics) == 0
else:
raise ValueError("Unexpected storage")

if n1 == n:
cp_schedule.finalize(n1)
Expand All @@ -100,7 +94,7 @@ def action_reverse(cp_action):

model_r += cp_action.n1 - cp_action.n0
if cp_action.clear_adj_deps:
data.clear()
data.difference_update(range(cp_action.n0, cp_action.n1))

@action.register(Copy)
def action_copy(cp_action):
Expand All @@ -109,17 +103,18 @@ def action_copy(cp_action):
assert cp_action.n in snapshots[cp_action.from_storage]
cp = snapshots[cp_action.from_storage][cp_action.n]

# No data is currently stored for this step
assert cp_action.n not in ics
assert cp_action.n not in data
assert len(ics.intersection(cp[0])) == 0
assert len(data.intersection(cp[1])) == 0

# The checkpoint contains forward data
assert len(cp[0]) == 0 and len(cp[1]) > 0

# The checkpoint data is before the current location of the adjoint
assert cp_action.n < n - model_r
model_n = None

model_n = None
assert len(ics) == 0
assert cp_action.to_storage == StorageType.WORK
data.clear()
data.update(cp[1])

Expand All @@ -128,22 +123,22 @@ def action_move(cp_action):
nonlocal model_n
# The checkpoint exists
assert cp_action.n in snapshots[cp_action.from_storage]
assert cp_action.n == max(snapshots[cp_action.from_storage])
cp = snapshots[cp_action.from_storage][cp_action.n]
cp = snapshots[cp_action.from_storage].pop(cp_action.n)

# No data is currently stored for this step
assert cp_action.n not in ics
assert cp_action.n not in data
assert len(ics.intersection(cp[0])) == 0
assert len(data.intersection(cp[1])) == 0

# The checkpoint contains forward data
assert len(cp[0]) == 0 and len(cp[1]) > 0

# The checkpoint data is before the current location of the adjoint
assert cp_action.n < n - model_r

model_n = None
assert len(ics) == 0
assert cp_action.to_storage == StorageType.WORK
data.clear()
data.update(cp[1])
del snapshots[cp_action.from_storage][cp_action.n]

@action.register(EndForward)
def action_end_forward(cp_action):
Expand All @@ -152,23 +147,21 @@ def action_end_forward(cp_action):

@action.register(EndReverse)
def action_end_reverse(cp_action):
nonlocal model_r, cp_schedule
nonlocal model_r

# The correct number of adjoint steps has been taken
assert model_r == n
is_exhausted = cp_schedule.is_exhausted
if is_exhausted is False:
if not is_exhausted:
model_r = 0

model_n = 0
model_r = 0
ics = set()
data = set()
assert isinstance(n, int)
snapshots = {StorageType.RAM: {}, StorageType.DISK: {}}
cp_schedule, storage_limits, data_limit = schedule(n)
if cp_schedule is None:
raise TypeError("Incompatible with schedule type.")
assert cp_schedule is not None
assert cp_schedule.n == 0
assert cp_schedule.r == 0
assert cp_schedule.max_n is None or cp_schedule.max_n == n
Expand Down
Loading