diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py index d20a0b2f..01809101 100644 --- a/pyadjoint/checkpointing.py +++ b/pyadjoint/checkpointing.py @@ -214,8 +214,7 @@ def evaluate_adj(self, last_block, markings): if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING): raise CheckpointError("Evaluate Functional before calling gradient.") - with self.tape.progress_bar("Evaluating Adjoint", - max=self.timesteps) as bar: + with self.tape.progress_bar("Evaluating Adjoint", max=self.timesteps) as bar: if self.adjoint_evaluated: reverse_iterator = iter(self.reverse_schedule) while not isinstance(self._current_action, EndReverse): @@ -257,7 +256,8 @@ def process_operation(self, cp_action, bar, **kwargs): def _(self, cp_action, bar, functional=None, **kwargs): for step in cp_action: if self.mode == Mode.RECOMPUTE: - bar.next() + if bar: + bar.next() # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in current_step: @@ -290,7 +290,8 @@ def _(self, cp_action, bar, functional=None, **kwargs): @process_operation.register(Reverse) def _(self, cp_action, bar, markings, functional=None, **kwargs): for step in cp_action: - bar.next() + if bar: + bar.next() # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in reversed(current_step): diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 0774d260..98dfe994 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -703,7 +703,7 @@ def __init__(self, *args, **kwargs): def __enter__(self): pass - def __exit__(self): + def __exit__(self, *args, **kwargs): pass def iter(self, iterator):