From 5f5f9ee31bc4a89f8eeb8d2e1005a57304fdc36c Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 26 Jun 2024 11:34:11 +0100 Subject: [PATCH 1/2] Update MixedCheckpointSchedule --- checkpoint_schedules/mixed.py | 193 ++++++++++------------------------ 1 file changed, 53 insertions(+), 140 deletions(-) diff --git a/checkpoint_schedules/mixed.py b/checkpoint_schedules/mixed.py index 045f530..aed723b 100644 --- a/checkpoint_schedules/mixed.py +++ b/checkpoint_schedules/mixed.py @@ -42,8 +42,8 @@ class MixedCheckpointSchedule(CheckpointSchedule): Notes ----- Assumes that the data required to restart the forward has the same size as - the data required to advance the adjoint over a step. Additionall details - about the mixed checkpointing schedule is avaiable in [1]. + the data required to advance the adjoint over a step. An updated version of + the algorithm described in [1]. This is a offline checkpointing strategy, one adjoint calculation permitted. @@ -73,35 +73,29 @@ def _iterator(self): if numba is None: warnings.warn("Numba not available -- using memoization", RuntimeWarning) + schedule = None else: schedule = mixed_steps_tabulation(self._max_n, self._snapshots) - schedule_0 = mixed_steps_tabulation_0(self._max_n, self._snapshots, schedule) # noqa: E501 - step_type = StepType.NONE while True: + step_type = StepType.NONE while self._n < self._max_n - self._r: n0 = self._n - if n0 in snapshot_n: - # n0 checkpoint exists - if numba is None: - step_type, n1, _ = mixed_step_memoization_0( - self._max_n - self._r - n0, - self._snapshots - len(snapshots)) - else: - step_type, n1, _ = schedule_0[ - self._max_n - self._r - n0, - self._snapshots - len(snapshots)] + reuse_snapshot = bool(n0 in snapshot_n) + + if schedule is None: + step_type, n1, _ = mixed_step_memoization( + self._max_n - self._r - n0, + self._snapshots - len(snapshots) + int(reuse_snapshot)) else: - # n0 checkpoint does not exist - if numba is None: - step_type, n1, _ = mixed_step_memoization( - self._max_n - self._r - n0, - self._snapshots - len(snapshots)) - else: - step_type, n1, _ = schedule[ - self._max_n - self._r - n0, - self._snapshots - len(snapshots)] + step_type, n1, _ = schedule[ + self._max_n - self._r - n0, + self._snapshots - len(snapshots) + int(reuse_snapshot)] n1 += n0 + if reuse_snapshot and \ + (snapshots[-1][:2] != (step_type, n0) + or snapshots[-1][2] < n1): + raise RuntimeError("Invalid checkpointing state") if step_type == StepType.FORWARD_REVERSE: if n1 > n0 + 1: @@ -119,31 +113,31 @@ def _iterator(self): elif step_type == StepType.WRITE_ADJ_DEPS: if n1 != n0 + 1: raise InvalidForwardStep - self._n = n1 - yield Forward(n0, n1, False, True, self._storage) - if n0 in snapshot_n: + if reuse_snapshot: raise RuntimeError("Invalid checkpointing state") elif len(snapshots) > self._snapshots - 1: raise RuntimeError("Invalid checkpointing state") + self._n = n1 + yield Forward(n0, n1, False, True, self._storage) snapshot_n.add(n0) - snapshots.append((StepType.READ_ADJ_DEPS, n0)) + snapshots.append((StepType.WRITE_ADJ_DEPS, n0, n1)) elif step_type == StepType.WRITE_ICS: if n1 <= n0 + 1: raise InvalidActionIndex self._n = n1 - yield Forward(n0, n1, True, False, self._storage) - if n0 in snapshot_n: - raise RuntimeError("Invalid checkpointing state") - elif len(snapshots) > self._snapshots - 1: - raise RuntimeError("Invalid checkpointing state") - snapshot_n.add(n0) - snapshots.append((StepType.READ_ICS, n0)) + if reuse_snapshot: + yield Forward(n0, n1, False, False, StorageType.WORK) + else: + yield Forward(n0, n1, True, False, self._storage) + if len(snapshots) > self._snapshots - 1: + raise RuntimeError("Invalid checkpointing state") + snapshot_n.add(n0) + snapshots.append((StepType.WRITE_ICS, n0, n1)) else: raise RuntimeError("Unexpected step type") if self._n != self._max_n - self._r: raise InvalidForwardStep - if step_type not in (StepType.FORWARD_REVERSE, - StepType.READ_ADJ_DEPS): + if step_type not in {StepType.NONE, StepType.FORWARD_REVERSE}: raise RuntimeError("Invalid checkpointing state") if self._r == 0: @@ -155,35 +149,39 @@ def _iterator(self): if self._r == self._max_n: break - step_type, cp_n = snapshots[-1] + cp_step_type, cp_n, _ = snapshots[-1] + if cp_step_type not in {StepType.WRITE_ICS, StepType.WRITE_ADJ_DEPS}: # noqa: E501 + raise RuntimeError("Invalid checkpointing state") - # Delete if we have (possibly after deleting this checkpoint) - # enough storage left to store all non-linear dependency data - cp_delete = (cp_n >= (self._max_n - self._r - 1 - - (self._snapshots - len(snapshots) + 1))) + if schedule is None: + next_step_type, _, _ = mixed_step_memoization( + self._max_n - self._r - cp_n, + self._snapshots - len(snapshots) + 1) + else: + next_step_type, _, _ = schedule[ + self._max_n - self._r - cp_n, + self._snapshots - len(snapshots) + 1] + cp_delete = (cp_step_type != next_step_type) if cp_delete: snapshot_n.remove(cp_n) snapshots.pop() - self._n = cp_n - if step_type == StepType.READ_ADJ_DEPS: + if cp_step_type == StepType.WRITE_ICS: + if cp_n + 1 >= self._max_n - self._r: + raise RuntimeError("Invalid checkpointing state") + self._n = cp_n + elif cp_step_type == StepType.WRITE_ADJ_DEPS: # Non-linear dependency data checkpoint - if not cp_delete: + if not cp_delete or cp_n + 1 != self._max_n - self._r: # We cannot advance from a loaded non-linear dependency # checkpoint, and so we expect to use it immediately raise RuntimeError("Invalid checkpointing state") # Note that we cannot in general restart the forward here - self._n += 1 - elif step_type != StepType.READ_ICS: - raise RuntimeError("Invalid checkpointing state") - if step_type == StepType.READ_ADJ_DEPS: - storage_type = StorageType.WORK - elif step_type == StepType.READ_ICS: - storage_type = StorageType.WORK + self._n = cp_n + 1 if cp_delete: - yield Move(cp_n, self._storage, storage_type) + yield Move(cp_n, self._storage, StorageType.WORK) else: - yield Copy(cp_n, self._storage, storage_type) + yield Copy(cp_n, self._storage, StorageType.WORK) if len(snapshot_n) > 0 or len(snapshots) > 0: raise RuntimeError("Invalid checkpointing state") @@ -273,7 +271,7 @@ def mixed_step_memoization(n, s): if m is None: raise RuntimeError("Failed to determine total number of steps") m1 = 1 + mixed_step_memoization(n - 1, s - 1)[2] - if m1 <= m[2]: + if m1 < m[2]: m = (StepType.WRITE_ADJ_DEPS, 1, m1) return m @@ -335,96 +333,11 @@ def mixed_steps_tabulation(n, s): "steps") assert schedule[n_i - 1, s_i - 1, 2] > 0 m1 = 1 + schedule[n_i - 1, s_i - 1, 2] - if m1 <= schedule[n_i, s_i, 2]: + if m1 < schedule[n_i, s_i, 2]: schedule[n_i, s_i, :] = (_WRITE_ADJ_DEPS, 1, m1) return schedule -def cache_step_0(fn): - _cache = {} - - @functools.wraps(fn) - def wrapped_fn(n, s): - # Avoid some cache misses - s = min(s, n - 2) - if (n, s) not in _cache: - _cache[(n, s)] = fn(n, s) - return _cache[(n, s)] - - return wrapped_fn - - -@cache_step_0 -def mixed_step_memoization_0(n, s): - if s < 0: - raise ValueError("Invalid number of snapshots") - if n < s + 2: - raise ValueError("Invalid number of steps") - - if s == 0: - return (StepType.FORWARD_REVERSE, n, n * (n + 1) // 2 - 1) - else: - m = None - for i in range(1, n): - m1 = ( - i - + mixed_step_memoization(i, s + 1)[2] - + mixed_step_memoization(n - i, s)[2]) - if m is None or m1 <= m[2]: - m = (StepType.FORWARD, i, m1) - if m is None: - raise RuntimeError("Failed to determine total number of steps") - return m - - -@njit -def mixed_steps_tabulation_0(n, s, schedule): - """Tabulate actions for a 'mixed' schedule, for the case where a forward - restart checkpoint is stored at the start of the first step. - - Parameters - ---------- - n : int - The number of forward steps. - s : int - The number of checkpointing units. - schedule: ndarray - As returned by `mixed_steps_tabulation`. - - Returns - ------- - ndarray - Defines the schedule. `schedule[n_i, s_i, :]` indicates the action for - the case of `n_i` steps and `s_i` checkpointing units. `schedule[n_i, - s_i, 0]` defines the actions, `schedule[n_i, s_i, 1]` defines the - number of forward steps to advance, and `schedule[n_i, s_i, 2]` defines - the cost. - """ - - schedule_0 = np.zeros((n + 1, s + 1, 3), dtype=np.int64) - schedule_0[:, :, 0] = _NONE - schedule_0[:, :, 1] = 0 - schedule_0[:, :, 2] = -1 - - for n_i in range(2, n + 1): - schedule_0[n_i, 0, :] = (_FORWARD_REVERSE, n_i, n_i * (n_i + 1) // 2 - 1) # noqa: E501 - for s_i in range(1, s): - for n_i in range(s_i + 2, n + 1): - for i in range(1, n_i): - assert schedule[i, s_i + 1, 2] > 0 - assert schedule[n_i - i, s_i, 2] > 0 - m1 = ( - i - + schedule[i, s_i + 1, 2] - + schedule[n_i - i, s_i, 2]) - if schedule_0[n_i, s_i, 2] < 0 or m1 <= schedule_0[n_i, s_i, 2]: # noqa: E501 - schedule_0[n_i, s_i, :] = (_FORWARD, i, m1) - if schedule_0[n_i, s_i, 2] < 0: - raise RuntimeError("Failed to determine total number of " - "steps") - return schedule_0 - - class InvalidForwardStep(IndexError): "The forward step is not correct." From 630e73cf98aeddca2949d175ae439f7e71455b6f Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 26 Jun 2024 11:40:06 +0100 Subject: [PATCH 2/2] Fix Exception type --- checkpoint_schedules/mixed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checkpoint_schedules/mixed.py b/checkpoint_schedules/mixed.py index aed723b..54f2706 100644 --- a/checkpoint_schedules/mixed.py +++ b/checkpoint_schedules/mixed.py @@ -136,7 +136,7 @@ def _iterator(self): else: raise RuntimeError("Unexpected step type") if self._n != self._max_n - self._r: - raise InvalidForwardStep + raise RuntimeError("Invalid checkpointing state") if step_type not in {StepType.NONE, StepType.FORWARD_REVERSE}: raise RuntimeError("Invalid checkpointing state")