Skip to content

Commit

Permalink
Checkpointing enhancement (#160)
Browse files Browse the repository at this point in the history
* Optimizing checkpointing
  • Loading branch information
Ig-dolci authored Sep 11, 2024
1 parent 92121af commit 19f8718
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 63 deletions.
10 changes: 9 additions & 1 deletion pyadjoint/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ class Block(object):
:func:`evaluate_adj`
"""
__slots__ = ['_dependencies', '_outputs', 'block_helper']
__slots__ = ['_dependencies', '_outputs', 'block_helper', 'adj_state', 'tag']
pop_kwargs_keys = []

def __init__(self, ad_block_tag=None):
self._dependencies = []
self._outputs = []
self.block_helper = None
self.tag = ad_block_tag
# The adjoint state of the block. This is used to store the adjoint
# state, for instance to store the adjoint state of a solver.
self.adj_state = None

@classmethod
def pop_kwargs(cls, kwargs):
Expand Down Expand Up @@ -95,6 +98,11 @@ def reset_variables(self, types=None):
for output in self._outputs:
output.reset_variables(types)

def reset_adjoint_state(self):
"""Resets the adjoint state of the block.
"""
self.adj_state = None

@no_annotations
def evaluate_adj(self, markings=False):
"""Computes the adjoint action and stores the result in the `adj_value` attribute of the dependencies.
Expand Down
1 change: 1 addition & 0 deletions pyadjoint/block_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def will_add_as_output(self):
self._checkpoint = None
if tape._eagerly_checkpoint_outputs:
self.save_output()
tape.add_to_adjoint_dependencies(self, self.last_use - 1)

def __str__(self):
return str(self.output._ad_str)
Expand Down
104 changes: 48 additions & 56 deletions pyadjoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,29 +141,25 @@ def _(self, cp_action, timestep):
)

self.tape._eagerly_checkpoint_outputs = cp_action.write_adj_deps

if timestep > cp_action.n0:
_store_checkpointable_state = False
_store_adj_dependencies = False
if timestep > cp_action.n0 and cp_action.storage != StorageType.WORK:
if cp_action.write_ics and timestep == (cp_action.n0 + 1):
# Store the checkpoint data. This is the required data for restarting the forward model
# from the step `n0`.
self.tape.timesteps[timestep - 1].checkpoint()
if (cp_action.write_adj_deps and timestep == (cp_action.n1 - 1) and cp_action.storage != StorageType.WORK):
# Store the checkpoint data. This is the required data for computing the adjoint model
# from the step `n1`.
self.tape.timesteps[timestep].checkpoint()

if (
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
or not cp_action.write_adj_deps
):
# Remove unnecessary variables in working memory from previous steps.
for var in self.tape.timesteps[timestep - 1].checkpointable_state:
var._checkpoint = None
for block in self.tape.timesteps[timestep - 1]:
# Remove unnecessary variables from previous steps.
for output in block.get_outputs():
output._checkpoint = None

# Store the checkpoint data. This is the required data for
# restarting the forward model from the step `n0`.
_store_checkpointable_state = True
if cp_action.write_adj_deps:
# Store the checkpoint data. This is the required data for
# computing the adjoint model from the step `n1`.
_store_adj_dependencies = True
self.tape.timesteps[timestep - 1].checkpoint(
_store_checkpointable_state, _store_adj_dependencies)
# Remove unnecessary variables in working memory from previous steps.
for var in self.tape.timesteps[timestep - 1].checkpointable_state:
var._checkpoint = None
for block in self.tape.timesteps[timestep - 1]:
for out in block.get_outputs():
out._checkpoint = None
if timestep in cp_action and timestep < self.total_timesteps:
self.tape.get_blocks().append_step()
if cp_action.write_ics:
Expand Down Expand Up @@ -191,13 +187,13 @@ def recompute(self, functional=None):
# Finalise the taping process.
self.end_taping()
self.mode = Mode.RECOMPUTE
with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as bar:
with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as progress_bar:
# Restore the initial condition to advance the forward model from the step 0.
current_step = self.tape.timesteps[self.forward_schedule[0].n0]
current_step.restore_from_checkpoint()
for cp_action in self.forward_schedule:
self._current_action = cp_action
self.process_operation(cp_action, bar, functional=functional)
self.process_operation(cp_action, progress_bar, functional=functional)

def evaluate_adj(self, last_block, markings):
"""Evaluate the adjoint model.
Expand All @@ -220,23 +216,23 @@ def evaluate_adj(self, last_block, markings):
if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING):
raise CheckpointError("Evaluate Functional before calling gradient.")

with self.tape.progress_bar("Evaluating Adjoint", max=self.total_timesteps) as bar:
with self.tape.progress_bar("Evaluating Adjoint", max=self.total_timesteps) as progress_bar:
if self.reverse_schedule:
for cp_action in self.reverse_schedule:
self.process_operation(cp_action, bar, markings=markings)
self.process_operation(cp_action, progress_bar, markings=markings)
else:
while not isinstance(self._current_action, EndReverse):
cp_action = next(self._schedule)
self._current_action = cp_action
self.reverse_schedule.append(cp_action)
self.process_operation(cp_action, bar, markings=markings)
self.process_operation(cp_action, progress_bar, markings=markings)

# Only set the mode after the first backward in order to handle
# that step correctly.
self.mode = Mode.EVALUATE_ADJOINT

@singledispatchmethod
def process_operation(self, cp_action, bar, **kwargs):
def process_operation(self, cp_action, progress_bar, **kwargs):
"""A function used to process the forward and adjoint executions.
This single-dispatch generic function is used in the `Blocks`
recomputation and adjoint evaluation with checkpointing.
Expand All @@ -248,7 +244,7 @@ def process_operation(self, cp_action, bar, **kwargs):
Args:
cp_action (checkpoint_schedules.CheckpointAction): A checkpoint action obtained from the
`checkpoint_schedules`.
bar (progressbar.ProgressBar): A progress bar to display the progress of the reverse executions.
progress_bar (progressbar.ProgressBar): A progress bar to display the progress of the reverse executions.
kwargs: Additional keyword arguments.
Raises:
Expand All @@ -257,33 +253,27 @@ def process_operation(self, cp_action, bar, **kwargs):
raise CheckpointError(f"Unable to process {cp_action}.")

@process_operation.register(Forward)
def _(self, cp_action, bar, functional=None, **kwargs):
def _(self, cp_action, progress_bar, functional=None, **kwargs):
step = cp_action.n0
# In a dynamic schedule `cp_action` can be unbounded so we also need to check `self.total_timesteps`.
# In a dynamic schedule `cp_action` can be unbounded so we also need
# to check `self.total_timesteps`.
while step in cp_action and step < self.total_timesteps:
if self.mode == Mode.RECOMPUTE:
bar.next()
if self.mode == Mode.RECOMPUTE and progress_bar:
progress_bar.next()
# Get the blocks of the current step.
current_step = self.tape.timesteps[step]
for block in current_step:
block.recompute()
if (
(cp_action.write_ics and step == cp_action.n0)
or (cp_action.write_adj_deps and step == cp_action.n1 - 1
and cp_action.storage != StorageType.WORK)
):
# Store the checkpoint data required for restarting the
# forward model or computing the adjoint model.
# If `cp_action.write_ics` is `True`, the checkpointed data
# will restart the forward model from the step `n0`.
# If `cp_action.write_adj_deps` is `True`, the checkpointed
# data will be used for computing the adjoint model from the
# step `n1`.
for var in current_step.checkpointable_state:
if var.checkpoint:
current_step._checkpoint.update(
{var: var.checkpoint}
)
_store_checkpointable_state = False
_store_adj_dependencies = False
if cp_action.storage != StorageType.WORK:
if (cp_action.write_ics and step == cp_action.n0):
_store_checkpointable_state = True
if cp_action.write_adj_deps:
_store_adj_dependencies = True
current_step.checkpoint(
_store_checkpointable_state, _store_adj_dependencies)

if (
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
or not cp_action.write_adj_deps
Expand All @@ -306,16 +296,18 @@ def _(self, cp_action, bar, functional=None, **kwargs):
step += 1

@process_operation.register(Reverse)
def _(self, cp_action, bar, markings, functional=None, **kwargs):
def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
for step in cp_action:
bar.next()
if progress_bar:
progress_bar.next()
# Get the blocks of the current step.
current_step = self.tape.timesteps[step]
for block in reversed(current_step):
block.evaluate_adj(markings=markings)
# Output variables are used for the last time when running
# backwards.
for block in current_step:
block.reset_adjoint_state()
for var in block.get_outputs():
var.checkpoint = None
var.reset_variables(("tlm",))
Expand All @@ -330,22 +322,22 @@ def _(self, cp_action, bar, markings, functional=None, **kwargs):
output._checkpoint = None

@process_operation.register(Copy)
def _(self, cp_action, bar, **kwargs):
def _(self, cp_action, progress_bar, **kwargs):
current_step = self.tape.timesteps[cp_action.n]
current_step.restore_from_checkpoint()

@process_operation.register(Move)
def _(self, cp_action, bar, **kwargs):
def _(self, cp_action, progress_bar, **kwargs):
current_step = self.tape.timesteps[cp_action.n]
current_step.restore_from_checkpoint()
current_step.delete_checkpoint()

@process_operation.register(EndForward)
def _(self, cp_action, bar, **kwargs):
def _(self, cp_action, progress_bar, **kwargs):
self.mode = Mode.EVALUATED

@process_operation.register(EndReverse)
def _(self, cp_action, bar, **kwargs):
def _(self, cp_action, progress_bar, **kwargs):
if self._schedule.is_exhausted:
self.mode = Mode.EXHAUSTED
else:
Expand Down
38 changes: 32 additions & 6 deletions pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ def add_to_checkpointable_state(self, block_var, last_used):
for step in self.timesteps[last_used + 1:]:
step.checkpointable_state.add(block_var)

def add_to_adjoint_dependencies(self, block_var, last_used):
"""Add a block variable into the adjoint dependencies set.
Note:
`adjoint_dependencies` is a set of block variables which are needed
to compute the adjoint of a timestep.
Args:
block_var (BlockVariable): The block variable to add.
last_used (int): The last timestep in which the block variable was used.
"""
if not self.timesteps:
self._blocks.append_step()
for step in self.timesteps[last_used + 1:]:
step.adjoint_dependencies.add(block_var)

def enable_checkpointing(self, schedule):
"""Enable checkpointing on the adjoint evaluation.
Expand Down Expand Up @@ -736,6 +752,7 @@ def __init__(self, blocks=()):
# The set of block variables which are needed to restart from the start
# of this timestep.
self.checkpointable_state = set()
self.adjoint_dependencies = set()
# A dictionary mapping the block variables in the checkpointable state
# to their checkpoint values.
self._checkpoint = {}
Expand All @@ -745,14 +762,23 @@ def copy(self, blocks=None):
out.checkpointable_state = self.checkpointable_state
return out

def checkpoint(self):
"""Store a copy of the checkpoints in the checkpointable state."""
def checkpoint(self, checkpointable_state, adj_dependencies):
"""Store a copy of the checkpoints in the checkpointable state.
Args:
checkpointable_state (bool): If True, store the checkpointable state
required to restart from the start of a timestep.
adj_dependencies): (bool): If True, store the adjoint dependencies required
to compute the adjoint of a timestep.
"""
with stop_annotating():
self._checkpoint = {
var: var.saved_output._ad_create_checkpoint()
for var in self.checkpointable_state
}
if checkpointable_state:
for var in self.checkpointable_state:
self._checkpoint[var] = var.checkpoint

if adj_dependencies:
for var in self.adjoint_dependencies:
self._checkpoint[var] = var.checkpoint

def restore_from_checkpoint(self):
"""Restore the block var checkpoints from the timestep checkpoint."""
Expand Down

0 comments on commit 19f8718

Please sign in to comment.