diff --git a/docs/source/documentation/index.rst b/docs/source/documentation/index.rst
index 06f72727..759fc31e 100644
--- a/docs/source/documentation/index.rst
+++ b/docs/source/documentation/index.rst
@@ -361,8 +361,7 @@ a tape. The current working tape can be set and retrieved with the functions :py
:py:func:`get_working_tape`.
Annotation can be temporarily disabled using :py:func:`pause_annotation` and enabled again using :py:func:`continue_annotation`.
-Note that if you call :py:func:`pause_annotation` twice, then :py:func:`continue_annotation` must be called twice
-to enable annotation. Due to this, the recommended annotation control functions are :py:class:`stop_annotating` and :py:func:`no_annotations`.
+It is recommended to use :py:class:`stop_annotating` and :py:func:`no_annotations` for annotation control.
:py:class:`stop_annotating` is a context manager and should be used as follows
.. code-block:: python
diff --git a/docs/source/documentation/pyadjoint_api.rst b/docs/source/documentation/pyadjoint_api.rst
index 1bfb0ee0..36fed246 100644
--- a/docs/source/documentation/pyadjoint_api.rst
+++ b/docs/source/documentation/pyadjoint_api.rst
@@ -19,6 +19,8 @@ Core classes
.. automethod:: add_block
.. automethod:: visualise
.. autoproperty:: progress_bar
+ .. automethod:: end_timestep
+ .. automethod:: timestepper
.. autoclass:: Block
diff --git a/pyadjoint/block.py b/pyadjoint/block.py
index e9929861..29dbc70b 100644
--- a/pyadjoint/block.py
+++ b/pyadjoint/block.py
@@ -50,7 +50,7 @@ def add_dependency(self, dep, no_duplicates=False):
"""
if not no_duplicates or dep.block_variable not in self._dependencies:
- dep._ad_will_add_as_dependency()
+ dep.block_variable.will_add_as_dependency()
self._dependencies.append(dep.block_variable)
def get_dependencies(self):
diff --git a/pyadjoint/block_variable.py b/pyadjoint/block_variable.py
index 86002638..3b3fa25d 100644
--- a/pyadjoint/block_variable.py
+++ b/pyadjoint/block_variable.py
@@ -1,4 +1,4 @@
-from .tape import no_annotations
+from .tape import no_annotations, get_working_tape
class BlockVariable(object):
@@ -16,6 +16,10 @@ def __init__(self, output):
self.floating_type = False
# Helper flag for use during tape traversals.
self.marked_in_path = False
+ # By default assume the variable is created externally to the tape.
+ self.creation_timestep = -1
+ # The timestep during which this variable was last used as an input.
+ self.last_use = -1
def add_adj_output(self, val):
if self.adj_value is None:
@@ -59,13 +63,23 @@ def saved_output(self):
def will_add_as_dependency(self):
overwrite = self.output._ad_will_add_as_dependency()
- overwrite = False if overwrite is None else overwrite
- self.save_output(overwrite=overwrite)
+ overwrite = bool(overwrite)
+ tape = get_working_tape()
+ if self.last_use < tape.latest_checkpoint:
+ self.save_output(overwrite=overwrite)
+ tape.add_to_checkpointable_state(self, self.last_use)
+ self.last_use = tape.latest_timestep
def will_add_as_output(self):
+ tape = get_working_tape()
+ self.creation_timestep = tape.latest_timestep
+ self.last_use = self.creation_timestep
overwrite = self.output._ad_will_add_as_output()
- overwrite = True if overwrite is None else overwrite
- self.save_output(overwrite=overwrite)
+ overwrite = bool(overwrite)
+ if not overwrite:
+ self._checkpoint = None
+ if tape._eagerly_checkpoint_outputs:
+ self.save_output()
def __str__(self):
return str(self.output)
diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py
new file mode 100644
index 00000000..d20a0b2f
--- /dev/null
+++ b/pyadjoint/checkpointing.py
@@ -0,0 +1,334 @@
+from enum import Enum
+from functools import singledispatchmethod
+from checkpoint_schedules import (
+ Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType)
+from checkpoint_schedules import Revolve, MultistageCheckpointSchedule
+
+
+class CheckpointError(RuntimeError):
+ pass
+
+
+class Mode(Enum):
+ """The mode of the checkpoint manager.
+
+ RECORD: The forward model is being taped.
+ FINISHED_RECORDING: The forward model is finished being taped.
+ EVALUATED: The forward model was evaluated.
+ EXHAUSTED: The forward and the adjoint models were evaluated and the schedule has concluded.
+ RECOMPUTE: The forward model is being recomputed.
+ EVALUATE_ADJOINT: The adjoint model is being evaluated.
+
+ """
+ RECORD = 1
+ FINISHED_RECORDING = 2
+ EVALUATED = 3
+ EXHAUSTED = 4
+ RECOMPUTE = 5
+ EVALUATE_ADJOINT = 6
+
+
+class CheckpointManager:
+ """Manage the executions of the forward and adjoint solvers.
+
+ Args:
+ schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package.
+ tape (Tape): A list of blocks :class:`Block` instances.
+
+ Attributes:
+ tape (Tape): A list of blocks :class:`Block` instances.
+ _schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package.
+ forward_schedule (list): A list of `checkpoint_schedules` actions used to manage the execution of the
+ forward model.
+ reverse_schedule (list): A list of `checkpoint_schedules` actions used to manage the execution of the
+ reverse model.
+ timesteps (int): The initial number of timesteps.
+ adjoint_evaluated (bool): A boolean indicating whether the adjoint model has been evaluated.
+ mode (Mode): The mode of the checkpoint manager. The possible modes are `RECORD`, `FINISHED_RECORDING`,
+ `EVALUATED`, `EXHAUSTED`, `RECOMPUTE`, and `EVALUATE_ADJOINT`. Additional information about the modes
+ can be found class:`Mode`.
+ _current_action (checkpoint_schedules.CheckpointAction): The current `checkpoint_schedules` action.
+
+ """
+ def __init__(self, schedule, tape):
+ if (
+ not isinstance(schedule, Revolve)
+ and not isinstance(schedule, MultistageCheckpointSchedule)
+ ):
+ raise CheckpointError(
+ "Only Revolve and MultistageCheckpointSchedule schedules are supported."
+ )
+
+ if (
+ schedule.uses_storage_type(StorageType.DISK)
+ and not tape._package_data
+ ):
+ raise CheckpointError(
+ "The schedule employs disk checkpointing but it is not configured."
+ )
+ self.tape = tape
+ self._schedule = schedule
+ self.forward_schedule = []
+ self.reverse_schedule = []
+ self.timesteps = schedule.max_n
+ # This variable is used to indicate whether the adjoint model has been evaluated at the checkpoint.
+ self.adjoint_evaluated = False
+ self.mode = Mode.RECORD
+ self._current_action = next(self._schedule)
+ self.forward_schedule.append(self._current_action)
+ # Tell the tape to only checkpoint input data until told otherwise.
+ self.tape.latest_checkpoint = 0
+ self.end_timestep(-1)
+
+ def end_timestep(self, timestep):
+ """Mark the end of one timestep when taping the forward model.
+
+ Args:
+ timestep (int): The current timestep.
+ """
+ if self.mode == Mode.EVALUATED:
+ raise CheckpointError("Not enough timesteps in schedule.")
+ elif self.mode != Mode.RECORD:
+ raise CheckpointError(f"Cannot end timestep in {self.mode}")
+ while not self.process_taping(self._current_action, timestep + 1):
+ self._current_action = next(self._schedule)
+ self.forward_schedule.append(self._current_action)
+
+ def end_taping(self):
+ """Process the end of the forward execution."""
+ current_timestep = self.tape.latest_timestep
+ while self.mode != Mode.EVALUATED:
+ self.end_timestep(current_timestep)
+ current_timestep += 1
+
+ @singledispatchmethod
+ def process_taping(self, cp_action, timestep):
+ """Implement checkpointing schedule actions while taping.
+
+ A single-dispatch generic function.
+
+ Note:
+ To have more information about the `checkpoint_schedules`, please refer to the
+ `documentation `_.
+ Detailed descriptions of the actions used in the process taping can be found at the following links:
+ `Forward `_ and `End_Forward `_.
+
+ Args:
+ cp_action (checkpoint_schedules.CheckpointAction): A checkpoint action obtained from the
+ `checkpoint_schedules`.
+ timestep (int): The current timestep.
+
+ Returns:
+ bool: Returns `True` if the timestep is in the `checkpoint_schedules` action.
+ For example, if the `checkpoint_schedules` action is `Forward(0, 4, True, False, StorageType.DISK)`,
+ then timestep `0, 1, 2, 3` is considered within the action; timestep `4` is not considered within the
+ action and `False` is returned.
+
+ Raises:
+ CheckpointError: If the checkpoint action is not supported.
+ """
+
+ raise CheckpointError(f"Unable to process {cp_action} while taping.")
+
+ @process_taping.register(Forward)
+ def _(self, cp_action, timestep):
+ if timestep < (cp_action.n0):
+ raise CheckpointError(
+ "Timestep is before start of Forward action."
+ )
+
+ self.tape._eagerly_checkpoint_outputs = cp_action.write_adj_deps
+
+ if timestep > cp_action.n0:
+ if cp_action.write_ics and timestep == (cp_action.n0 + 1):
+ # Stores the checkpoint data in RAM
+ # This data will be used to restart the forward model
+ # from the step `n0` in the reverse computations.
+ self.tape.timesteps[cp_action.n0].checkpoint()
+
+ if not cp_action.write_adj_deps:
+ # Remove unnecessary variables from previous steps.
+ for var in self.tape.timesteps[timestep - 1].checkpointable_state:
+ var._checkpoint = None
+ for block in self.tape.timesteps[timestep - 1]:
+ # Remove unnecessary variables from previous steps.
+ for output in block.get_outputs():
+ output._checkpoint = None
+
+ if timestep in cp_action:
+ self.tape.get_blocks().append_step()
+ if cp_action.write_ics:
+ self.tape.latest_checkpoint = cp_action.n0
+ return True
+ else:
+ return False
+
+ @process_taping.register(EndForward)
+ def _(self, cp_action, timestep):
+ if timestep != self.timesteps:
+ raise CheckpointError(
+ "The correct number of forward steps has notbeen taken."
+ )
+ self.mode = Mode.EVALUATED
+ return True
+
+ def recompute(self, functional=None):
+ """Recompute the forward model.
+
+ Args:
+ functional (BlockVariable): The functional to be evaluated.
+ """
+ self.mode = Mode.RECOMPUTE
+
+ with self.tape.progress_bar("Evaluating Functional",
+ max=self.timesteps) as bar:
+ # Restore the initial condition to advance the forward model
+ # from the step 0.
+ current_step = self.tape.timesteps[self.forward_schedule[0].n0]
+ current_step.restore_from_checkpoint()
+ for cp_action in self.forward_schedule:
+ self._current_action = cp_action
+ self.process_operation(cp_action, bar, functional=functional)
+
+ def evaluate_adj(self, last_block, markings):
+ """Evaluate the adjoint model.
+
+ Args:
+ last_block (int): The last block to be evaluated.
+ markings (bool): If `True`, then each `BlockVariable` of the current block will have set
+ `marked_in_path` attribute indicating whether their adjoint components are relevant for
+ computing the final target adjoint values.
+ """
+ # Work out other cases when they arise.
+ if last_block != 0:
+ raise NotImplementedError(
+ "Only the first block can be evaluated at present."
+ )
+
+ if self.mode == Mode.RECORD:
+ # The declared timesteps were not exhausted while taping.
+ self.end_taping()
+
+ if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING):
+ raise CheckpointError("Evaluate Functional before calling gradient.")
+
+ with self.tape.progress_bar("Evaluating Adjoint",
+ max=self.timesteps) as bar:
+ if self.adjoint_evaluated:
+ reverse_iterator = iter(self.reverse_schedule)
+ while not isinstance(self._current_action, EndReverse):
+ if not self.adjoint_evaluated:
+ self._current_action = next(self._schedule)
+ self.reverse_schedule.append(self._current_action)
+ else:
+ self._current_action = next(reverse_iterator)
+ self.process_operation(self._current_action, bar, markings=markings)
+ # Only set the mode after the first backward in order to handle
+ # that step correctly.
+ self.mode = Mode.EVALUATE_ADJOINT
+
+ # Inform that the adjoint model has been evaluated.
+ self.adjoint_evaluated = True
+
+ @singledispatchmethod
+ def process_operation(self, cp_action, bar, **kwargs):
+ """A function used to process the forward and adjoint executions.
+ This single-dispatch generic function is used in the `Blocks`
+ recomputation and adjoint evaluation with checkpointing.
+
+ Note:
+ The documentation of the `checkpoint_schedules` actions is available
+ `here `_.
+
+ Args:
+ cp_action (checkpoint_schedules.CheckpointAction): A checkpoint action obtained from the
+ `checkpoint_schedules`.
+ bar (progressbar.ProgressBar): A progress bar to display the progress of the reverse executions.
+ kwargs: Additional keyword arguments.
+
+ Raises:
+ CheckpointError: If the checkpoint action is not supported.
+ """
+ raise CheckpointError(f"Unable to process {cp_action}.")
+
+ @process_operation.register(Forward)
+ def _(self, cp_action, bar, functional=None, **kwargs):
+ for step in cp_action:
+ if self.mode == Mode.RECOMPUTE:
+ bar.next()
+ # Get the blocks of the current step.
+ current_step = self.tape.timesteps[step]
+ for block in current_step:
+ block.recompute()
+
+ if cp_action.write_ics:
+ if step == cp_action.n0:
+ for var in current_step.checkpointable_state:
+ if var.checkpoint:
+ current_step._checkpoint.update(
+ {var: var.checkpoint}
+ )
+ if not cp_action.write_adj_deps:
+ next_step = self.tape.timesteps[step + 1]
+ # The checkpointable state set of the current step.
+ to_keep = next_step.checkpointable_state
+ if functional:
+ # `to_keep` holds informations of the blocks required
+ # for restarting the forward model from a step `n`.
+ to_keep = to_keep.union([functional.block_variable])
+ for block in current_step:
+ # Remove unnecessary variables from previous steps.
+ for bv in block.get_outputs():
+ if bv not in to_keep:
+ bv._checkpoint = None
+ # Remove unnecessary variables from previous steps.
+ for var in (current_step.checkpointable_state - to_keep):
+ var._checkpoint = None
+
+ @process_operation.register(Reverse)
+ def _(self, cp_action, bar, markings, functional=None, **kwargs):
+ for step in cp_action:
+ bar.next()
+ # Get the blocks of the current step.
+ current_step = self.tape.timesteps[step]
+ for block in reversed(current_step):
+ block.evaluate_adj(markings=markings)
+ # Output variables are used for the last time when running
+ # backwards.
+ for block in current_step:
+ for var in block.get_outputs():
+ var.checkpoint = None
+ var.reset_variables(("tlm",))
+ if not var.is_control:
+ var.reset_variables(("adjoint", "hessian"))
+ if cp_action.clear_adj_deps:
+ to_keep = current_step.checkpointable_state
+ if functional:
+ to_keep = to_keep.union([functional.block_variable])
+ for output in block.get_outputs():
+ if output not in to_keep:
+ output._checkpoint = None
+
+ @process_operation.register(Copy)
+ def _(self, cp_action, bar, **kwargs):
+ current_step = self.tape.timesteps[cp_action.n]
+ current_step.restore_from_checkpoint()
+
+ @process_operation.register(Move)
+ def _(self, cp_action, bar, **kwargs):
+ current_step = self.tape.timesteps[cp_action.n]
+ current_step.restore_from_checkpoint()
+ current_step.delete_checkpoint()
+
+ @process_operation.register(EndForward)
+ def _(self, cp_action, bar, **kwargs):
+ self.mode = Mode.EVALUATED
+
+ @process_operation.register(EndReverse)
+ def _(self, cp_action, bar, **kwargs):
+ if self._schedule.is_exhausted:
+ self.mode = Mode.EXHAUSTED
+ else:
+ self.mode = Mode.EVALUATED
diff --git a/pyadjoint/drivers.py b/pyadjoint/drivers.py
index fb631d17..9bf03b0c 100644
--- a/pyadjoint/drivers.py
+++ b/pyadjoint/drivers.py
@@ -26,7 +26,8 @@ def compute_gradient(J, m, options=None, tape=None, adj_value=1.0):
with stop_annotating():
with tape.marked_nodes(m):
- tape.evaluate_adj(markings=True)
+ with marked_controls(m):
+ tape.evaluate_adj(markings=True)
grads = [i.get_derivative(options=options) for i in m]
return m.delist(grads)
@@ -91,3 +92,26 @@ def solve_adjoint(J, tape=None, adj_value=1.0):
with stop_annotating():
tape.evaluate_adj(markings=False)
+
+
+class marked_controls:
+ """A context manager for marking controls.
+
+ Note:
+ This is a context manager for marking whether the class:'BlockVariable' is
+ a control. On exiting the context, the class:'BlockVariable' that were
+ marked as controls are automatically unmarked.
+
+ Args:
+ controls (list): A list of :class:`Control` to mark within the context manager.
+ """
+ def __init__(self, controls):
+ self.controls = controls
+
+ def __enter__(self):
+ for control in self.controls:
+ control.mark_as_control()
+
+ def __exit__(self, *args):
+ for control in self.controls:
+ control.unmark_as_control()
diff --git a/pyadjoint/reduced_functional.py b/pyadjoint/reduced_functional.py
index 993615df..8b2af45f 100644
--- a/pyadjoint/reduced_functional.py
+++ b/pyadjoint/reduced_functional.py
@@ -206,10 +206,13 @@ def __call__(self, values):
blocks = self.tape.get_blocks()
with self.marked_controls():
with stop_annotating():
- for i in self.tape._bar("Evaluating functional").iter(
- range(len(blocks))
- ):
- blocks[i].recompute()
+ if self.tape._checkpoint_manager:
+ self.tape._checkpoint_manager.recompute(self.functional)
+ else:
+ for i in self.tape._bar("Evaluating functional").iter(
+ range(len(blocks))
+ ):
+ blocks[i].recompute()
# ReducedFunctional can result in a scalar or an assembled 1-form
func_value = self.functional.block_variable.saved_output
diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py
index 9814beb8..996527c1 100644
--- a/pyadjoint/tape.py
+++ b/pyadjoint/tape.py
@@ -6,7 +6,9 @@
from functools import wraps
from itertools import chain
from abc import ABC, abstractmethod
-
+from typing import Optional
+from collections.abc import Iterable
+from .checkpointing import CheckpointManager, CheckpointError
_working_tape = None
_annotation_enabled = False
@@ -160,11 +162,13 @@ class Tape(object):
"""
__slots__ = ["_blocks", "_tf_tensors", "_tf_added_blocks", "_nodes",
- "_tf_registered_blocks", "_bar", "_package_data"]
+ "_tf_registered_blocks", "_bar", "_package_data",
+ "_checkpoint_manager", "latest_checkpoint",
+ "_eagerly_checkpoint_outputs"]
def __init__(self, blocks=None, package_data=None):
# Initialize the list of blocks on the tape.
- self._blocks = [] if blocks is None else blocks
+ self._blocks = TimeStepSequence(blocks=blocks)
# Dictionary of TensorFlow tensors. Key is id(block).
self._tf_tensors = {}
# Keep a list of blocks that has been added to the TensorFlow graph
@@ -174,12 +178,54 @@ def __init__(self, blocks=None, package_data=None):
# Hook location for packages which need to store additional data on the
# tape. Packages should store the data under a "packagename" key.
self._package_data = package_data or {}
+ # Default to checkpointing all block variables.
+ self.latest_checkpoint = float("inf")
+ self._checkpoint_manager = None
+ # Whether to store the adjoint dependencies.
+ self._eagerly_checkpoint_outputs = False
def clear_tape(self):
+ """Clear the tape."""
self.reset_variables()
- self._blocks = []
+ self._blocks = TimeStepSequence()
for data in self._package_data.values():
data.clear()
+ self._checkpoint_manager = None
+
+ @property
+ def latest_timestep(self):
+ """The current time step to which blocks will be added."""
+ return max(len(self._blocks.steps) - 1, 0)
+
+ def end_timestep(self):
+ """Mark the end of a timestep when taping the forward model."""
+ if self._checkpoint_manager:
+ self._checkpoint_manager.end_timestep(self.latest_timestep)
+ else:
+ self._blocks.append_step()
+
+ def timestepper(self, iterable):
+ """Return an iterator that advances the tape timestep.
+
+ Note:
+ This method facilitates taping timestepping simulations so that recompute
+ checkpointing can be used on the tape. For example, a simulation with
+ 10 timesteps might use a timestepping loop of this form::
+
+ tape = get_working_tape()
+
+ for timestep in tape.timestepper(range(10)):
+ ...
+
+ This has the effect of calling `tape.end_timestep()` after each iteration.
+
+ Args:
+ iterable (iterable): The iterable definining the sequence of timesteps.
+
+ Returns:
+ TapeTimeStepper: An iterator that advances the tape timestep.
+ """
+ return TapeTimeStepper(self, iterable)
def reset_blocks(self):
"""Calls the Block.reset method of all blocks on the tape.
@@ -200,6 +246,39 @@ def add_block(self, block):
# len() is computed in constant time, so this should be fine.
return len(self._blocks) - 1
+ def add_to_checkpointable_state(self, block_var, last_used):
+ """Add a block variable into the checkpointable state set.
+
+ Note:
+ `checkpointable_state` is a set of block variables which are needed
+ to restart from the start of a timestep.
+
+ Args:
+ block_var (BlockVariable): The block variable to add.
+ last_used (int): The last timestep in which the block variable was used.
+ """
+ if not self.timesteps:
+ self._blocks.append_step()
+ for step in self.timesteps[last_used + 1:]:
+ step.checkpointable_state.add(block_var)
+
+ def enable_checkpointing(self, schedule):
+ """Enable checkpointing on the adjoint evaluation.
+
+ A checkpoint manager able to execute the forward and adjoint computations
+ according to the schedule provided by checkpoint_schedules package.
+
+ Args:
+ schedule (checkpoint_schedules.schedule): A schedule provided by the
+ checkpoint_schedules package.
+ max_n (int, optional): The number of total steps.
+ """
+ if self._blocks:
+ raise CheckpointError(
+ "Checkpointing must be enabled before any blocks are added to the tape."
+ )
+ self._checkpoint_manager = CheckpointManager(schedule, self)
+
def get_blocks(self, tag=None):
"""Returns a list of the blocks on the tape.
@@ -226,10 +305,21 @@ def get_tags(self):
return tags
def evaluate_adj(self, last_block=0, markings=False):
- for i in self._bar("Evaluating adjoint").iter(
- range(len(self._blocks) - 1, last_block - 1, -1)
- ):
- self._blocks[i].evaluate_adj(markings=markings)
+ """Evaluate the adjoint of the tape.
+
+ Args:
+ last_block (int, optional): The index of the last block to evaluate.
+ markings (bool, optional): If True, then each `BlockVariable` of the current block
+ will have set `marked_in_path` attribute indicating whether their adjoint
+ components are relevant for computing the final target adjoint values.
+ """
+ if self._checkpoint_manager:
+ self._checkpoint_manager.evaluate_adj(last_block, markings)
+ else:
+ for i in self._bar("Evaluating adjoint").iter(
+ range(len(self._blocks) - 1, last_block - 1, -1)
+ ):
+ self._blocks[i].evaluate_adj(markings=markings)
def evaluate_tlm(self):
for i in self._bar("Evaluating TLM").iter(
@@ -264,7 +354,7 @@ def copy(self):
"""
# TODO: Offer deepcopying. But is it feasible memory wise to copy all checkpoints?
return Tape(
- blocks=self._blocks[:],
+ blocks=self._blocks,
package_data={k: v.copy() for k, v in self._package_data.items()}
)
@@ -319,40 +409,60 @@ def optimize(self, controls=None, functionals=None):
def optimize_for_controls(self, controls):
# TODO: Consider if we want Enlist wherever it is possible. Like in this case.
# TODO: Consider warning/message on empty tape.
- blocks = self.get_blocks()
nodes = set([control.block_variable for control in controls])
- valid_blocks = []
-
- for block in blocks:
- depends_on_control = False
- for dep in block.get_dependencies():
- if dep in nodes:
- depends_on_control = True
-
- if depends_on_control:
- for output in block.get_outputs():
- if output in nodes:
- raise RuntimeError("Control depends on another control.")
- nodes.add(output)
- valid_blocks.append(block)
- self._blocks = valid_blocks
+ discarded_variables = set()
+ optimized_timesteps = TimeStepSequence()
+
+ for step in self._blocks.steps:
+ optimized_timesteps.append_step()
+
+ for block in step:
+ depends_on_control = False
+ for dep in block.get_dependencies():
+ if dep in nodes:
+ depends_on_control = True
+ break
+
+ if depends_on_control:
+ for output in block.get_outputs():
+ if output in nodes:
+ raise RuntimeError("Control depends on another control.")
+ nodes.add(output)
+ optimized_timesteps.append(block)
+ else:
+ discarded_variables.union(block.get_outputs())
+ optimized_timesteps.steps[-1].checkpointable_state = \
+ step.checkpointable_state - discarded_variables
+
+ self._blocks = optimized_timesteps
def optimize_for_functionals(self, functionals):
- blocks = self.get_blocks()
- nodes = set([functional.block_variable for functional in functionals])
- valid_blocks = []
+ retained_nodes = set([functional.block_variable
+ for functional in functionals]
+ )
+ optimized_timesteps = []
+
+ for step in reversed(self._blocks.steps):
+ current_blocks = []
+ for block in reversed(step):
+ produces_functional = False
+ for dep in block.get_outputs():
+ if dep in retained_nodes:
+ produces_functional = True
+
+ if produces_functional:
+ for dep in block.get_dependencies():
+ retained_nodes.add(dep)
+ current_blocks.append(block)
+ optimized_timesteps.append(TimeStep(reversed(current_blocks)))
- for block in reversed(blocks):
- produces_functional = False
- for dep in block.get_outputs():
- if dep in nodes:
- produces_functional = True
+ optimized_timesteps.reverse()
- if produces_functional:
- for dep in block.get_dependencies():
- nodes.add(dep)
- valid_blocks.append(block)
- self._blocks = list(reversed(valid_blocks))
+ for step, new_step in zip(self._blocks.steps, optimized_timesteps):
+ new_step.checkpointable_state = \
+ step.checkpointable_state & retained_nodes
+
+ self._blocks = TimeStepSequence(steps=optimized_timesteps)
@contextmanager
def marked_nodes(self, controls):
@@ -363,6 +473,11 @@ def marked_nodes(self, controls):
for node in nodes:
node.marked_in_path = False
+ @property
+ def timesteps(self):
+ """Return the list of time steps on this tape."""
+ return self._blocks.steps
+
def _valid_tf_scope_name(self, name):
"""Return a valid TensorFlow scope name"""
valid_name = ""
@@ -596,6 +711,105 @@ def iter(self, iterator):
return iterator
+class TapeTimeStepper:
+ """Iterator wrapper which advances the timestep after each iteration."""
+ def __init__(self, tape, iterable):
+ self.tape = tape
+ self.iterator = tape.progress_bar("Taping forward").iter(iterable)
+ self._first = True
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self._first:
+ self._first = False
+ else:
+ self.tape.end_timestep()
+ return next(self.iterator)
+
+
+class TimeStep(list):
+ """A list of blocks in a single time step, plus associated metadata."""
+ def __init__(self, blocks=()):
+ super().__init__(blocks)
+ # The set of block variables which are needed to restart from the start
+ # of this timestep.
+ self.checkpointable_state = set()
+ # A dictionary mapping the block variables in the checkpointable state
+ # to their checkpoint values.
+ self._checkpoint = {}
+
+ def copy(self, blocks=None):
+ out = TimeStep(blocks or self)
+ out.checkpointable_state = self.checkpointable_state
+ return out
+
+ def checkpoint(self):
+ """Store a copy of the checkpoints in the checkpointable state."""
+
+ with stop_annotating():
+ self._checkpoint = {
+ var: var.saved_output._ad_create_checkpoint()
+ for var in self.checkpointable_state
+ }
+
+ def restore_from_checkpoint(self):
+ """Restore the block var checkpoints from the timestep checkpoint."""
+
+ for var in self._checkpoint:
+ var.checkpoint = self._checkpoint[var]
+
+ def delete_checkpoint(self):
+ """Delete the stored checkpoint references."""
+ self._checkpoint = {}
+
+
+class TimeStepSequence(list):
+ """A list of Blocks separated into timesteps to facilitate checkpointing.
+
+ This behaves like a list of blocks. To access a list of the timesteps, use
+ the :attr:`steps` property.
+ """
+
+ def __init__(self, blocks=None, steps: Optional[Iterable[Iterable[TimeStep]]] = None):
+ # Keep both per-timestep and unified block lists.
+ if steps and blocks:
+ raise ValueError("set blocks or steps but not both.")
+ elif isinstance(blocks, TimeStepSequence):
+ self._steps = [step.copy() for step in blocks._steps]
+ elif blocks:
+ self._steps = [TimeStep(blocks)]
+ else:
+ self._steps = list(step.copy() for step in steps) if steps else []
+ super().__init__(chain.from_iterable(self._steps))
+
+ @property
+ def steps(self):
+ return self._steps
+
+ def append(self, other):
+ """Add a new block to the sequence and to the current TimeStep."""
+ if not self.steps:
+ self.append_step()
+ self._steps[-1].append(other)
+ super().append(other)
+
+ def append_step(self, step=None):
+ """Add a new TimeStep."""
+ self._steps.append(step or TimeStep())
+
+ def __setitem__(self, key, value):
+ raise ValueError(
+ "Unable to set arbitrary blocks. Try appending instead."
+ )
+
+ def __delitem__(self, key, value):
+ raise ValueError(
+ "Unable to delete blocks from sequence."
+ )
+
+
class TapePackageData(ABC):
"""Abstract base class for additional data that packages store on the tape.
diff --git a/setup.py b/setup.py
index 73e521cd..c94b7e63 100644
--- a/setup.py
+++ b/setup.py
@@ -24,6 +24,6 @@
package_dir={'pyadjoint': 'pyadjoint',
'firedrake_adjoint': 'firedrake_adjoint',
'numpy_adjoint': 'numpy_adjoint'},
- install_requires=['scipy>=1.0'],
+ install_requires=['scipy>=1.0', 'checkpoint-schedules'],
extras_require=extras
)
diff --git a/tests/firedrake_adjoint/test_burgers_newton.py b/tests/firedrake_adjoint/test_burgers_newton.py
index 9197bf28..e063296e 100644
--- a/tests/firedrake_adjoint/test_burgers_newton.py
+++ b/tests/firedrake_adjoint/test_burgers_newton.py
@@ -7,27 +7,29 @@
from firedrake import *
from firedrake.adjoint import *
-
-
+from checkpoint_schedules import Revolve, MultistageCheckpointSchedule
+import numpy as np
set_log_level(CRITICAL)
-
+continue_annotation()
n = 30
mesh = UnitIntervalMesh(n)
V = FunctionSpace(mesh, "CG", 2)
+end = 0.3
+timestep = Constant(1.0/n)
+steps = int(end/float(timestep)) + 1
+
def Dt(u, u_, timestep):
return (u - u_)/timestep
-def J(ic, solve_type):
+
+def J(ic, solve_type, checkpointing):
u_ = Function(V)
u = Function(V)
v = TestFunction(V)
-
+ u_.assign(ic)
nu = Constant(0.0001)
-
- timestep = Constant(1.0/n)
-
- F = (Dt(u, ic, timestep)*v
+ F = (Dt(u, u_, timestep)*v
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx
bc = DirichletBC(V, 0.0, "on_boundary")
@@ -35,38 +37,94 @@ def J(ic, solve_type):
if solve_type == "NLVS":
problem = NonlinearVariationalProblem(F, u, bcs=bc)
solver = NonlinearVariationalSolver(problem)
- solver.solve()
- else:
- solve(F == 0, u, bc)
- u_.assign(u)
- t += float(timestep)
- F = (Dt(u, u_, timestep)*v
- + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx
-
- end = 0.2
- while (t <= end):
+ tape = get_working_tape()
+ t += float(timestep)
+ for t in tape.timestepper(np.arange(t, end + t, float(timestep))):
if solve_type == "NLVS":
solver.solve()
else:
solve(F == 0, u, bc)
u_.assign(u)
- t += float(timestep)
+ return assemble(u_*u_*dx + ic*ic*dx), u_
+
+
+@pytest.mark.parametrize("solve_type, checkpointing",
+ [("solve", "Revolve"),
+ ("NLVS", "Revolve"),
+ ("solve", "Multistage"),
+ ("NLVS", "Multistage"),
+ ("solve", None),
+ ("NLVS", None),
+ ])
+def test_burgers_newton(solve_type, checkpointing):
+ """Adjoint-based gradient tests with and without checkpointing.
+ """
+ tape = get_working_tape()
+ tape.progress_bar = ProgressBar
+ if checkpointing == "Revolve":
+ tape.enable_checkpointing(Revolve(steps, steps//3))
+ if checkpointing == "Multistage":
+ tape.enable_checkpointing(MultistageCheckpointSchedule(steps, steps//3, 0))
+ x, = SpatialCoordinate(mesh)
+ ic = project(sin(2.*pi*x), V)
+ val, _ = J(ic, solve_type, checkpointing)
+ if checkpointing:
+ assert len(tape.timesteps) == steps
- return assemble(u_*u_*dx + ic*ic*dx)
+ Jhat = ReducedFunctional(val, Control(ic))
+ dJ = Jhat.derivative()
+ # Recomputing the functional with a modified control variable
+ # before the recompute test.
+ Jhat(project(sin(pi*x), V))
-@pytest.mark.parametrize("solve_type",
- ["solve", "NLVS"])
-def test_burgers_newton(solve_type):
- x, = SpatialCoordinate(mesh)
- ic = project(sin(2*pi*x), V)
+ # Recompute test
+ assert(np.allclose(Jhat(ic), val))
- val = J(ic, solve_type)
-
- Jhat = ReducedFunctional(val, Control(ic))
+ dJbar = Jhat.derivative()
+ # Test recompute adjoint-based gradient
+ assert np.allclose(dJ.dat.data_ro[:], dJbar.dat.data_ro[:])
+ # Taylor test
h = Function(V)
h.assign(1, annotate=False)
assert taylor_test(Jhat, ic, h) > 1.9
+
+
+@pytest.mark.parametrize("solve_type, checkpointing",
+ [("solve", "Revolve"),
+ ("NLVS", "Revolve"),
+ ("solve", "Multistage"),
+ ("NLVS", "Multistage")
+ ])
+def test_checkpointing_validity(solve_type, checkpointing):
+ """Compare forward and backward results with and without checkpointing.
+ """
+ # Without checkpointing
+ tape = get_working_tape()
+ tape.progress_bar = ProgressBar
+ x, = SpatialCoordinate(mesh)
+ ic = project(sin(2.*pi*x), V)
+
+ val0, u0 = J(ic, solve_type, False)
+ Jhat = ReducedFunctional(val0, Control(ic))
+ dJ0 = Jhat.derivative()
+ tape.clear_tape()
+
+ # With checkpointing
+ tape.progress_bar = ProgressBar
+ if checkpointing == "Revolve":
+ tape.enable_checkpointing(Revolve(steps, steps//3))
+ if checkpointing == "Multistage":
+ tape.enable_checkpointing(MultistageCheckpointSchedule(steps, steps//3, 0))
+ x, = SpatialCoordinate(mesh)
+ ic = project(sin(2.*pi*x), V)
+ val1, u1 = J(ic, solve_type, True)
+ Jhat = ReducedFunctional(val1, Control(ic))
+ dJ1 = Jhat.derivative()
+ assert len(tape.timesteps) == steps
+ assert np.allclose(val0, val1)
+ assert np.allclose(u0.dat.data_ro[:], u1.dat.data_ro[:])
+ assert np.allclose(dJ0.dat.data_ro[:], dJ1.dat.data_ro[:])