Skip to content

Commit

Permalink
Merge pull request #69 from jrmaddison/jrmaddison/mixed_update
Browse files Browse the repository at this point in the history
Update `MixedCheckpointSchedule`
  • Loading branch information
jrmaddison authored Jul 8, 2024
2 parents 8a5256b + 630e73c commit 4ab0544
Showing 1 changed file with 54 additions and 141 deletions.
195 changes: 54 additions & 141 deletions checkpoint_schedules/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class MixedCheckpointSchedule(CheckpointSchedule):
Notes
-----
Assumes that the data required to restart the forward has the same size as
the data required to advance the adjoint over a step. Additionall details
about the mixed checkpointing schedule is avaiable in [1].
the data required to advance the adjoint over a step. An updated version of
the algorithm described in [1].
This is a offline checkpointing strategy, one adjoint calculation
permitted.
Expand Down Expand Up @@ -73,35 +73,29 @@ def _iterator(self):
if numba is None:
warnings.warn("Numba not available -- using memoization",
RuntimeWarning)
schedule = None
else:
schedule = mixed_steps_tabulation(self._max_n, self._snapshots)
schedule_0 = mixed_steps_tabulation_0(self._max_n, self._snapshots, schedule) # noqa: E501

step_type = StepType.NONE
while True:
step_type = StepType.NONE
while self._n < self._max_n - self._r:
n0 = self._n
if n0 in snapshot_n:
# n0 checkpoint exists
if numba is None:
step_type, n1, _ = mixed_step_memoization_0(
self._max_n - self._r - n0,
self._snapshots - len(snapshots))
else:
step_type, n1, _ = schedule_0[
self._max_n - self._r - n0,
self._snapshots - len(snapshots)]
reuse_snapshot = bool(n0 in snapshot_n)

if schedule is None:
step_type, n1, _ = mixed_step_memoization(
self._max_n - self._r - n0,
self._snapshots - len(snapshots) + int(reuse_snapshot))
else:
# n0 checkpoint does not exist
if numba is None:
step_type, n1, _ = mixed_step_memoization(
self._max_n - self._r - n0,
self._snapshots - len(snapshots))
else:
step_type, n1, _ = schedule[
self._max_n - self._r - n0,
self._snapshots - len(snapshots)]
step_type, n1, _ = schedule[
self._max_n - self._r - n0,
self._snapshots - len(snapshots) + int(reuse_snapshot)]
n1 += n0
if reuse_snapshot and \
(snapshots[-1][:2] != (step_type, n0)
or snapshots[-1][2] < n1):
raise RuntimeError("Invalid checkpointing state")

if step_type == StepType.FORWARD_REVERSE:
if n1 > n0 + 1:
Expand All @@ -119,31 +113,31 @@ def _iterator(self):
elif step_type == StepType.WRITE_ADJ_DEPS:
if n1 != n0 + 1:
raise InvalidForwardStep
self._n = n1
yield Forward(n0, n1, False, True, self._storage)
if n0 in snapshot_n:
if reuse_snapshot:
raise RuntimeError("Invalid checkpointing state")
elif len(snapshots) > self._snapshots - 1:
raise RuntimeError("Invalid checkpointing state")
self._n = n1
yield Forward(n0, n1, False, True, self._storage)
snapshot_n.add(n0)
snapshots.append((StepType.READ_ADJ_DEPS, n0))
snapshots.append((StepType.WRITE_ADJ_DEPS, n0, n1))
elif step_type == StepType.WRITE_ICS:
if n1 <= n0 + 1:
raise InvalidActionIndex
self._n = n1
yield Forward(n0, n1, True, False, self._storage)
if n0 in snapshot_n:
raise RuntimeError("Invalid checkpointing state")
elif len(snapshots) > self._snapshots - 1:
raise RuntimeError("Invalid checkpointing state")
snapshot_n.add(n0)
snapshots.append((StepType.READ_ICS, n0))
if reuse_snapshot:
yield Forward(n0, n1, False, False, StorageType.WORK)
else:
yield Forward(n0, n1, True, False, self._storage)
if len(snapshots) > self._snapshots - 1:
raise RuntimeError("Invalid checkpointing state")
snapshot_n.add(n0)
snapshots.append((StepType.WRITE_ICS, n0, n1))
else:
raise RuntimeError("Unexpected step type")
if self._n != self._max_n - self._r:
raise InvalidForwardStep
if step_type not in (StepType.FORWARD_REVERSE,
StepType.READ_ADJ_DEPS):
raise RuntimeError("Invalid checkpointing state")
if step_type not in {StepType.NONE, StepType.FORWARD_REVERSE}:
raise RuntimeError("Invalid checkpointing state")

if self._r == 0:
Expand All @@ -155,35 +149,39 @@ def _iterator(self):
if self._r == self._max_n:
break

step_type, cp_n = snapshots[-1]
cp_step_type, cp_n, _ = snapshots[-1]
if cp_step_type not in {StepType.WRITE_ICS, StepType.WRITE_ADJ_DEPS}: # noqa: E501
raise RuntimeError("Invalid checkpointing state")

# Delete if we have (possibly after deleting this checkpoint)
# enough storage left to store all non-linear dependency data
cp_delete = (cp_n >= (self._max_n - self._r - 1
- (self._snapshots - len(snapshots) + 1)))
if schedule is None:
next_step_type, _, _ = mixed_step_memoization(
self._max_n - self._r - cp_n,
self._snapshots - len(snapshots) + 1)
else:
next_step_type, _, _ = schedule[
self._max_n - self._r - cp_n,
self._snapshots - len(snapshots) + 1]
cp_delete = (cp_step_type != next_step_type)
if cp_delete:
snapshot_n.remove(cp_n)
snapshots.pop()

self._n = cp_n
if step_type == StepType.READ_ADJ_DEPS:
if cp_step_type == StepType.WRITE_ICS:
if cp_n + 1 >= self._max_n - self._r:
raise RuntimeError("Invalid checkpointing state")
self._n = cp_n
elif cp_step_type == StepType.WRITE_ADJ_DEPS:
# Non-linear dependency data checkpoint
if not cp_delete:
if not cp_delete or cp_n + 1 != self._max_n - self._r:
# We cannot advance from a loaded non-linear dependency
# checkpoint, and so we expect to use it immediately
raise RuntimeError("Invalid checkpointing state")
# Note that we cannot in general restart the forward here
self._n += 1
elif step_type != StepType.READ_ICS:
raise RuntimeError("Invalid checkpointing state")
if step_type == StepType.READ_ADJ_DEPS:
storage_type = StorageType.WORK
elif step_type == StepType.READ_ICS:
storage_type = StorageType.WORK
self._n = cp_n + 1
if cp_delete:
yield Move(cp_n, self._storage, storage_type)
yield Move(cp_n, self._storage, StorageType.WORK)
else:
yield Copy(cp_n, self._storage, storage_type)
yield Copy(cp_n, self._storage, StorageType.WORK)

if len(snapshot_n) > 0 or len(snapshots) > 0:
raise RuntimeError("Invalid checkpointing state")
Expand Down Expand Up @@ -273,7 +271,7 @@ def mixed_step_memoization(n, s):
if m is None:
raise RuntimeError("Failed to determine total number of steps")
m1 = 1 + mixed_step_memoization(n - 1, s - 1)[2]
if m1 <= m[2]:
if m1 < m[2]:
m = (StepType.WRITE_ADJ_DEPS, 1, m1)
return m

Expand Down Expand Up @@ -335,96 +333,11 @@ def mixed_steps_tabulation(n, s):
"steps")
assert schedule[n_i - 1, s_i - 1, 2] > 0
m1 = 1 + schedule[n_i - 1, s_i - 1, 2]
if m1 <= schedule[n_i, s_i, 2]:
if m1 < schedule[n_i, s_i, 2]:
schedule[n_i, s_i, :] = (_WRITE_ADJ_DEPS, 1, m1)
return schedule


def cache_step_0(fn):
_cache = {}

@functools.wraps(fn)
def wrapped_fn(n, s):
# Avoid some cache misses
s = min(s, n - 2)
if (n, s) not in _cache:
_cache[(n, s)] = fn(n, s)
return _cache[(n, s)]

return wrapped_fn


@cache_step_0
def mixed_step_memoization_0(n, s):
if s < 0:
raise ValueError("Invalid number of snapshots")
if n < s + 2:
raise ValueError("Invalid number of steps")

if s == 0:
return (StepType.FORWARD_REVERSE, n, n * (n + 1) // 2 - 1)
else:
m = None
for i in range(1, n):
m1 = (
i
+ mixed_step_memoization(i, s + 1)[2]
+ mixed_step_memoization(n - i, s)[2])
if m is None or m1 <= m[2]:
m = (StepType.FORWARD, i, m1)
if m is None:
raise RuntimeError("Failed to determine total number of steps")
return m


@njit
def mixed_steps_tabulation_0(n, s, schedule):
"""Tabulate actions for a 'mixed' schedule, for the case where a forward
restart checkpoint is stored at the start of the first step.
Parameters
----------
n : int
The number of forward steps.
s : int
The number of checkpointing units.
schedule: ndarray
As returned by `mixed_steps_tabulation`.
Returns
-------
ndarray
Defines the schedule. `schedule[n_i, s_i, :]` indicates the action for
the case of `n_i` steps and `s_i` checkpointing units. `schedule[n_i,
s_i, 0]` defines the actions, `schedule[n_i, s_i, 1]` defines the
number of forward steps to advance, and `schedule[n_i, s_i, 2]` defines
the cost.
"""

schedule_0 = np.zeros((n + 1, s + 1, 3), dtype=np.int64)
schedule_0[:, :, 0] = _NONE
schedule_0[:, :, 1] = 0
schedule_0[:, :, 2] = -1

for n_i in range(2, n + 1):
schedule_0[n_i, 0, :] = (_FORWARD_REVERSE, n_i, n_i * (n_i + 1) // 2 - 1) # noqa: E501
for s_i in range(1, s):
for n_i in range(s_i + 2, n + 1):
for i in range(1, n_i):
assert schedule[i, s_i + 1, 2] > 0
assert schedule[n_i - i, s_i, 2] > 0
m1 = (
i
+ schedule[i, s_i + 1, 2]
+ schedule[n_i - i, s_i, 2])
if schedule_0[n_i, s_i, 2] < 0 or m1 <= schedule_0[n_i, s_i, 2]: # noqa: E501
schedule_0[n_i, s_i, :] = (_FORWARD, i, m1)
if schedule_0[n_i, s_i, 2] < 0:
raise RuntimeError("Failed to determine total number of "
"steps")
return schedule_0


class InvalidForwardStep(IndexError):
"The forward step is not correct."

Expand Down

0 comments on commit 4ab0544

Please sign in to comment.