Skip to content

Commit

Permalink
Dynamic schedules (#137)
Browse files Browse the repository at this point in the history
* The checkpointing adapted for dynamic schedules and the addition of a test for linear multistep employment.
  • Loading branch information
Ig-dolci committed Jun 26, 2024
1 parent 43b4032 commit 7eeb49a
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 110 deletions.
163 changes: 90 additions & 73 deletions pyadjoint/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from enum import Enum
import sys
from functools import singledispatchmethod
from checkpoint_schedules import (
Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType)
from checkpoint_schedules import Revolve, MultistageCheckpointSchedule
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType


class CheckpointError(RuntimeError):
Expand Down Expand Up @@ -42,23 +41,14 @@ class CheckpointManager:
forward model.
reverse_schedule (list): A list of `checkpoint_schedules` actions used to manage the execution of the
reverse model.
timesteps (int): The initial number of timesteps.
adjoint_evaluated (bool): A boolean indicating whether the adjoint model has been evaluated.
total_timesteps (int): The total number of timesteps to execute the forward model.
mode (Mode): The mode of the checkpoint manager. The possible modes are `RECORD`, `FINISHED_RECORDING`,
`EVALUATED`, `EXHAUSTED`, `RECOMPUTE`, and `EVALUATE_ADJOINT`. Additional information about the modes
can be found class:`Mode`.
_current_action (checkpoint_schedules.CheckpointAction): The current `checkpoint_schedules` action.
"""
def __init__(self, schedule, tape):
if (
not isinstance(schedule, Revolve)
and not isinstance(schedule, MultistageCheckpointSchedule)
):
raise CheckpointError(
"Only Revolve and MultistageCheckpointSchedule schedules are supported."
)

if (
schedule.uses_storage_type(StorageType.DISK)
and not tape._package_data
Expand All @@ -70,9 +60,15 @@ def __init__(self, schedule, tape):
self._schedule = schedule
self.forward_schedule = []
self.reverse_schedule = []
self.timesteps = schedule.max_n
# This variable is used to indicate whether the adjoint model has been evaluated at the checkpoint.
self.adjoint_evaluated = False
if self._schedule.max_n:
self.total_timesteps = schedule.max_n
else:
# We have schedules in `checkpoint_schedules` offering the flexibility to determine
# the desired steps during the forward execution. For this type of schedule, we do not
# have the total number of timesteps `self._schedule.max_n`. Therefore, we set the
# `self.total_timesteps` to the maximum value of the `sys.maxsize` to indicate that
# the total number of timesteps is not known.
self.total_timesteps = sys.maxsize
self.mode = Mode.RECORD
self._current_action = next(self._schedule)
self.forward_schedule.append(self._current_action)
Expand All @@ -97,6 +93,11 @@ def end_timestep(self, timestep):
def end_taping(self):
"""Process the end of the forward execution."""
current_timestep = self.tape.latest_timestep
if not self._schedule.max_n:
# Inform the schedule that the forward model has finished.
self._schedule.finalize(len(self.tape.timesteps))
# `self._schedule.finalize` updates `self._schedule.max_n`.
self.total_timesteps = self._schedule.max_n
while self.mode != Mode.EVALUATED:
self.end_timestep(current_timestep)
current_timestep += 1
Expand Down Expand Up @@ -143,21 +144,27 @@ def _(self, cp_action, timestep):

if timestep > cp_action.n0:
if cp_action.write_ics and timestep == (cp_action.n0 + 1):
# Stores the checkpoint data in RAM
# This data will be used to restart the forward model
# from the step `n0` in the reverse computations.
self.tape.timesteps[cp_action.n0].checkpoint()

if not cp_action.write_adj_deps:
# Remove unnecessary variables from previous steps.
# Store the checkpoint data. This is the required data for restarting the forward model
# from the step `n0`.
self.tape.timesteps[timestep - 1].checkpoint()
if (cp_action.write_adj_deps and timestep == (cp_action.n1 - 1) and cp_action.storage != StorageType.WORK):
# Store the checkpoint data. This is the required data for computing the adjoint model
# from the step `n1`.
self.tape.timesteps[timestep].checkpoint()

if (
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
or not cp_action.write_adj_deps
):
# Remove unnecessary variables in working memory from previous steps.
for var in self.tape.timesteps[timestep - 1].checkpointable_state:
var._checkpoint = None
for block in self.tape.timesteps[timestep - 1]:
# Remove unnecessary variables from previous steps.
for output in block.get_outputs():
output._checkpoint = None

if timestep in cp_action:
if timestep in cp_action and timestep < self.total_timesteps:
self.tape.get_blocks().append_step()
if cp_action.write_ics:
self.tape.latest_checkpoint = cp_action.n0
Expand All @@ -167,7 +174,7 @@ def _(self, cp_action, timestep):

@process_taping.register(EndForward)
def _(self, cp_action, timestep):
if timestep != self.timesteps:
if timestep != self.total_timesteps:
raise CheckpointError(
"The correct number of forward steps has notbeen taken."
)
Expand All @@ -180,12 +187,12 @@ def recompute(self, functional=None):
Args:
functional (BlockVariable): The functional to be evaluated.
"""
if self.mode == Mode.RECORD:
# Finalise the taping process.
self.end_taping()
self.mode = Mode.RECOMPUTE

with self.tape.progress_bar("Evaluating Functional",
max=self.timesteps) as bar:
# Restore the initial condition to advance the forward model
# from the step 0.
with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as bar:
# Restore the initial condition to advance the forward model from the step 0.
current_step = self.tape.timesteps[self.forward_schedule[0].n0]
current_step.restore_from_checkpoint()
for cp_action in self.forward_schedule:
Expand All @@ -206,30 +213,27 @@ def evaluate_adj(self, last_block, markings):
raise NotImplementedError(
"Only the first block can be evaluated at present."
)

if self.mode == Mode.RECORD:
# The declared timesteps were not exhausted while taping.
# Finalise the taping process.
self.end_taping()

if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING):
raise CheckpointError("Evaluate Functional before calling gradient.")

with self.tape.progress_bar("Evaluating Adjoint", max=self.timesteps) as bar:
if self.adjoint_evaluated:
reverse_iterator = iter(self.reverse_schedule)
while not isinstance(self._current_action, EndReverse):
if not self.adjoint_evaluated:
self._current_action = next(self._schedule)
self.reverse_schedule.append(self._current_action)
else:
self._current_action = next(reverse_iterator)
self.process_operation(self._current_action, bar, markings=markings)
# Only set the mode after the first backward in order to handle
# that step correctly.
self.mode = Mode.EVALUATE_ADJOINT

# Inform that the adjoint model has been evaluated.
self.adjoint_evaluated = True
with self.tape.progress_bar("Evaluating Adjoint", max=self.total_timesteps) as bar:
if self.reverse_schedule:
for cp_action in self.reverse_schedule:
self.process_operation(cp_action, bar, markings=markings)
else:
while not isinstance(self._current_action, EndReverse):
cp_action = next(self._schedule)
self._current_action = cp_action
self.reverse_schedule.append(cp_action)
self.process_operation(cp_action, bar, markings=markings)

# Only set the mode after the first backward in order to handle
# that step correctly.
self.mode = Mode.EVALUATE_ADJOINT

@singledispatchmethod
def process_operation(self, cp_action, bar, **kwargs):
Expand All @@ -254,44 +258,57 @@ def process_operation(self, cp_action, bar, **kwargs):

@process_operation.register(Forward)
def _(self, cp_action, bar, functional=None, **kwargs):
for step in cp_action:
step = cp_action.n0
# In a dynamic schedule `cp_action` can be unbounded so we also need to check `self.total_timesteps`.
while step in cp_action and step < self.total_timesteps:
if self.mode == Mode.RECOMPUTE:
if bar:
bar.next()
bar.next()
# Get the blocks of the current step.
current_step = self.tape.timesteps[step]
for block in current_step:
block.recompute()

if cp_action.write_ics:
if step == cp_action.n0:
for var in current_step.checkpointable_state:
if var.checkpoint:
current_step._checkpoint.update(
{var: var.checkpoint}
)
if not cp_action.write_adj_deps:
if (
(cp_action.write_ics and step == cp_action.n0)
or (cp_action.write_adj_deps and step == cp_action.n1 - 1
and cp_action.storage != StorageType.WORK)
):
# Store the checkpoint data required for restarting the
# forward model or computing the adjoint model.
# If `cp_action.write_ics` is `True`, the checkpointed data
# will restart the forward model from the step `n0`.
# If `cp_action.write_adj_deps` is `True`, the checkpointed
# data will be used for computing the adjoint model from the
# step `n1`.
for var in current_step.checkpointable_state:
if var.checkpoint:
current_step._checkpoint.update(
{var: var.checkpoint}
)
if (
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
or not cp_action.write_adj_deps
):
to_keep = set()
if step < (self.total_timesteps - 1):
next_step = self.tape.timesteps[step + 1]
# The checkpointable state set of the current step.
to_keep = next_step.checkpointable_state
if functional:
# `to_keep` holds informations of the blocks required
# for restarting the forward model from a step `n`.
to_keep = to_keep.union([functional.block_variable])
for block in current_step:
# Remove unnecessary variables from previous steps.
for bv in block.get_outputs():
if bv not in to_keep:
bv._checkpoint = None
if functional:
to_keep = to_keep.union([functional.block_variable])
for block in current_step:
# Remove unnecessary variables from previous steps.
for var in (current_step.checkpointable_state - to_keep):
var._checkpoint = None
for bv in block.get_outputs():
if bv not in to_keep:
bv._checkpoint = None
# Remove unnecessary variables from previous steps.
for var in (current_step.checkpointable_state - to_keep):
var._checkpoint = None
step += 1

@process_operation.register(Reverse)
def _(self, cp_action, bar, markings, functional=None, **kwargs):
for step in cp_action:
if bar:
bar.next()
bar.next()
# Get the blocks of the current step.
current_step = self.tape.timesteps[step]
for block in reversed(current_step):
Expand Down
3 changes: 2 additions & 1 deletion pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,11 +721,12 @@ def __iter__(self):
return self

def __next__(self):
step = next(self.iterator)
if self._first:
self._first = False
else:
self.tape.end_timestep()
return next(self.iterator)
return step


class TimeStep(list):
Expand Down
Loading

0 comments on commit 7eeb49a

Please sign in to comment.