From 19f87183db87e5184fb038c99a40f40686f43087 Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:49:32 +0100 Subject: [PATCH] Checkpointing enhancement (#160) * Optimizing checkpointing --- pyadjoint/block.py | 10 +++- pyadjoint/block_variable.py | 1 + pyadjoint/checkpointing.py | 104 +++++++++++++++++------------------- pyadjoint/tape.py | 38 ++++++++++--- 4 files changed, 90 insertions(+), 63 deletions(-) diff --git a/pyadjoint/block.py b/pyadjoint/block.py index 29dbc70b..4b779f00 100644 --- a/pyadjoint/block.py +++ b/pyadjoint/block.py @@ -12,7 +12,7 @@ class Block(object): :func:`evaluate_adj` """ - __slots__ = ['_dependencies', '_outputs', 'block_helper'] + __slots__ = ['_dependencies', '_outputs', 'block_helper', 'adj_state', 'tag'] pop_kwargs_keys = [] def __init__(self, ad_block_tag=None): @@ -20,6 +20,9 @@ def __init__(self, ad_block_tag=None): self._outputs = [] self.block_helper = None self.tag = ad_block_tag + # The adjoint state of the block. This is used to store the adjoint + # state, for instance to store the adjoint state of a solver. + self.adj_state = None @classmethod def pop_kwargs(cls, kwargs): @@ -95,6 +98,11 @@ def reset_variables(self, types=None): for output in self._outputs: output.reset_variables(types) + def reset_adjoint_state(self): + """Resets the adjoint state of the block. + """ + self.adj_state = None + @no_annotations def evaluate_adj(self, markings=False): """Computes the adjoint action and stores the result in the `adj_value` attribute of the dependencies. diff --git a/pyadjoint/block_variable.py b/pyadjoint/block_variable.py index e252b0ea..b723da61 100644 --- a/pyadjoint/block_variable.py +++ b/pyadjoint/block_variable.py @@ -80,6 +80,7 @@ def will_add_as_output(self): self._checkpoint = None if tape._eagerly_checkpoint_outputs: self.save_output() + tape.add_to_adjoint_dependencies(self, self.last_use - 1) def __str__(self): return str(self.output._ad_str) diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py index 80da71fe..18ce4842 100644 --- a/pyadjoint/checkpointing.py +++ b/pyadjoint/checkpointing.py @@ -141,29 +141,25 @@ def _(self, cp_action, timestep): ) self.tape._eagerly_checkpoint_outputs = cp_action.write_adj_deps - - if timestep > cp_action.n0: + _store_checkpointable_state = False + _store_adj_dependencies = False + if timestep > cp_action.n0 and cp_action.storage != StorageType.WORK: if cp_action.write_ics and timestep == (cp_action.n0 + 1): - # Store the checkpoint data. This is the required data for restarting the forward model - # from the step `n0`. - self.tape.timesteps[timestep - 1].checkpoint() - if (cp_action.write_adj_deps and timestep == (cp_action.n1 - 1) and cp_action.storage != StorageType.WORK): - # Store the checkpoint data. This is the required data for computing the adjoint model - # from the step `n1`. - self.tape.timesteps[timestep].checkpoint() - - if ( - (cp_action.write_adj_deps and cp_action.storage != StorageType.WORK) - or not cp_action.write_adj_deps - ): - # Remove unnecessary variables in working memory from previous steps. - for var in self.tape.timesteps[timestep - 1].checkpointable_state: - var._checkpoint = None - for block in self.tape.timesteps[timestep - 1]: - # Remove unnecessary variables from previous steps. - for output in block.get_outputs(): - output._checkpoint = None - + # Store the checkpoint data. This is the required data for + # restarting the forward model from the step `n0`. + _store_checkpointable_state = True + if cp_action.write_adj_deps: + # Store the checkpoint data. This is the required data for + # computing the adjoint model from the step `n1`. + _store_adj_dependencies = True + self.tape.timesteps[timestep - 1].checkpoint( + _store_checkpointable_state, _store_adj_dependencies) + # Remove unnecessary variables in working memory from previous steps. + for var in self.tape.timesteps[timestep - 1].checkpointable_state: + var._checkpoint = None + for block in self.tape.timesteps[timestep - 1]: + for out in block.get_outputs(): + out._checkpoint = None if timestep in cp_action and timestep < self.total_timesteps: self.tape.get_blocks().append_step() if cp_action.write_ics: @@ -191,13 +187,13 @@ def recompute(self, functional=None): # Finalise the taping process. self.end_taping() self.mode = Mode.RECOMPUTE - with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as bar: + with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as progress_bar: # Restore the initial condition to advance the forward model from the step 0. current_step = self.tape.timesteps[self.forward_schedule[0].n0] current_step.restore_from_checkpoint() for cp_action in self.forward_schedule: self._current_action = cp_action - self.process_operation(cp_action, bar, functional=functional) + self.process_operation(cp_action, progress_bar, functional=functional) def evaluate_adj(self, last_block, markings): """Evaluate the adjoint model. @@ -220,23 +216,23 @@ def evaluate_adj(self, last_block, markings): if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING): raise CheckpointError("Evaluate Functional before calling gradient.") - with self.tape.progress_bar("Evaluating Adjoint", max=self.total_timesteps) as bar: + with self.tape.progress_bar("Evaluating Adjoint", max=self.total_timesteps) as progress_bar: if self.reverse_schedule: for cp_action in self.reverse_schedule: - self.process_operation(cp_action, bar, markings=markings) + self.process_operation(cp_action, progress_bar, markings=markings) else: while not isinstance(self._current_action, EndReverse): cp_action = next(self._schedule) self._current_action = cp_action self.reverse_schedule.append(cp_action) - self.process_operation(cp_action, bar, markings=markings) + self.process_operation(cp_action, progress_bar, markings=markings) # Only set the mode after the first backward in order to handle # that step correctly. self.mode = Mode.EVALUATE_ADJOINT @singledispatchmethod - def process_operation(self, cp_action, bar, **kwargs): + def process_operation(self, cp_action, progress_bar, **kwargs): """A function used to process the forward and adjoint executions. This single-dispatch generic function is used in the `Blocks` recomputation and adjoint evaluation with checkpointing. @@ -248,7 +244,7 @@ def process_operation(self, cp_action, bar, **kwargs): Args: cp_action (checkpoint_schedules.CheckpointAction): A checkpoint action obtained from the `checkpoint_schedules`. - bar (progressbar.ProgressBar): A progress bar to display the progress of the reverse executions. + progress_bar (progressbar.ProgressBar): A progress bar to display the progress of the reverse executions. kwargs: Additional keyword arguments. Raises: @@ -257,33 +253,27 @@ def process_operation(self, cp_action, bar, **kwargs): raise CheckpointError(f"Unable to process {cp_action}.") @process_operation.register(Forward) - def _(self, cp_action, bar, functional=None, **kwargs): + def _(self, cp_action, progress_bar, functional=None, **kwargs): step = cp_action.n0 - # In a dynamic schedule `cp_action` can be unbounded so we also need to check `self.total_timesteps`. + # In a dynamic schedule `cp_action` can be unbounded so we also need + # to check `self.total_timesteps`. while step in cp_action and step < self.total_timesteps: - if self.mode == Mode.RECOMPUTE: - bar.next() + if self.mode == Mode.RECOMPUTE and progress_bar: + progress_bar.next() # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in current_step: block.recompute() - if ( - (cp_action.write_ics and step == cp_action.n0) - or (cp_action.write_adj_deps and step == cp_action.n1 - 1 - and cp_action.storage != StorageType.WORK) - ): - # Store the checkpoint data required for restarting the - # forward model or computing the adjoint model. - # If `cp_action.write_ics` is `True`, the checkpointed data - # will restart the forward model from the step `n0`. - # If `cp_action.write_adj_deps` is `True`, the checkpointed - # data will be used for computing the adjoint model from the - # step `n1`. - for var in current_step.checkpointable_state: - if var.checkpoint: - current_step._checkpoint.update( - {var: var.checkpoint} - ) + _store_checkpointable_state = False + _store_adj_dependencies = False + if cp_action.storage != StorageType.WORK: + if (cp_action.write_ics and step == cp_action.n0): + _store_checkpointable_state = True + if cp_action.write_adj_deps: + _store_adj_dependencies = True + current_step.checkpoint( + _store_checkpointable_state, _store_adj_dependencies) + if ( (cp_action.write_adj_deps and cp_action.storage != StorageType.WORK) or not cp_action.write_adj_deps @@ -306,9 +296,10 @@ def _(self, cp_action, bar, functional=None, **kwargs): step += 1 @process_operation.register(Reverse) - def _(self, cp_action, bar, markings, functional=None, **kwargs): + def _(self, cp_action, progress_bar, markings, functional=None, **kwargs): for step in cp_action: - bar.next() + if progress_bar: + progress_bar.next() # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in reversed(current_step): @@ -316,6 +307,7 @@ def _(self, cp_action, bar, markings, functional=None, **kwargs): # Output variables are used for the last time when running # backwards. for block in current_step: + block.reset_adjoint_state() for var in block.get_outputs(): var.checkpoint = None var.reset_variables(("tlm",)) @@ -330,22 +322,22 @@ def _(self, cp_action, bar, markings, functional=None, **kwargs): output._checkpoint = None @process_operation.register(Copy) - def _(self, cp_action, bar, **kwargs): + def _(self, cp_action, progress_bar, **kwargs): current_step = self.tape.timesteps[cp_action.n] current_step.restore_from_checkpoint() @process_operation.register(Move) - def _(self, cp_action, bar, **kwargs): + def _(self, cp_action, progress_bar, **kwargs): current_step = self.tape.timesteps[cp_action.n] current_step.restore_from_checkpoint() current_step.delete_checkpoint() @process_operation.register(EndForward) - def _(self, cp_action, bar, **kwargs): + def _(self, cp_action, progress_bar, **kwargs): self.mode = Mode.EVALUATED @process_operation.register(EndReverse) - def _(self, cp_action, bar, **kwargs): + def _(self, cp_action, progress_bar, **kwargs): if self._schedule.is_exhausted: self.mode = Mode.EXHAUSTED else: diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 63853d44..06a49782 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -261,6 +261,22 @@ def add_to_checkpointable_state(self, block_var, last_used): for step in self.timesteps[last_used + 1:]: step.checkpointable_state.add(block_var) + def add_to_adjoint_dependencies(self, block_var, last_used): + """Add a block variable into the adjoint dependencies set. + + Note: + `adjoint_dependencies` is a set of block variables which are needed + to compute the adjoint of a timestep. + + Args: + block_var (BlockVariable): The block variable to add. + last_used (int): The last timestep in which the block variable was used. + """ + if not self.timesteps: + self._blocks.append_step() + for step in self.timesteps[last_used + 1:]: + step.adjoint_dependencies.add(block_var) + def enable_checkpointing(self, schedule): """Enable checkpointing on the adjoint evaluation. @@ -736,6 +752,7 @@ def __init__(self, blocks=()): # The set of block variables which are needed to restart from the start # of this timestep. self.checkpointable_state = set() + self.adjoint_dependencies = set() # A dictionary mapping the block variables in the checkpointable state # to their checkpoint values. self._checkpoint = {} @@ -745,14 +762,23 @@ def copy(self, blocks=None): out.checkpointable_state = self.checkpointable_state return out - def checkpoint(self): - """Store a copy of the checkpoints in the checkpointable state.""" + def checkpoint(self, checkpointable_state, adj_dependencies): + """Store a copy of the checkpoints in the checkpointable state. + Args: + checkpointable_state (bool): If True, store the checkpointable state + required to restart from the start of a timestep. + adj_dependencies): (bool): If True, store the adjoint dependencies required + to compute the adjoint of a timestep. + """ with stop_annotating(): - self._checkpoint = { - var: var.saved_output._ad_create_checkpoint() - for var in self.checkpointable_state - } + if checkpointable_state: + for var in self.checkpointable_state: + self._checkpoint[var] = var.checkpoint + + if adj_dependencies: + for var in self.adjoint_dependencies: + self._checkpoint[var] = var.checkpoint def restore_from_checkpoint(self): """Restore the block var checkpoints from the timestep checkpoint."""