Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into dolci/fix_sphinx_fail
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Feb 16, 2024
2 parents 6199ea8 + 9bfe70f commit e2123d3
Show file tree
Hide file tree
Showing 10 changed files with 729 additions and 81 deletions.
3 changes: 1 addition & 2 deletions docs/source/documentation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ a tape. The current working tape can be set and retrieved with the functions :py
:py:func:`get_working_tape`.

Annotation can be temporarily disabled using :py:func:`pause_annotation` and enabled again using :py:func:`continue_annotation`.
Note that if you call :py:func:`pause_annotation` twice, then :py:func:`continue_annotation` must be called twice
to enable annotation. Due to this, the recommended annotation control functions are :py:class:`stop_annotating` and :py:func:`no_annotations`.
It is recommended to use :py:class:`stop_annotating` and :py:func:`no_annotations` for annotation control.
:py:class:`stop_annotating` is a context manager and should be used as follows

.. code-block:: python
Expand Down
2 changes: 2 additions & 0 deletions docs/source/documentation/pyadjoint_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Core classes
.. automethod:: add_block
.. automethod:: visualise
.. autoproperty:: progress_bar
.. automethod:: end_timestep
.. automethod:: timestepper

.. autoclass:: Block

Expand Down
2 changes: 1 addition & 1 deletion pyadjoint/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def add_dependency(self, dep, no_duplicates=False):
"""
if not no_duplicates or dep.block_variable not in self._dependencies:
dep._ad_will_add_as_dependency()
dep.block_variable.will_add_as_dependency()
self._dependencies.append(dep.block_variable)

def get_dependencies(self):
Expand Down
24 changes: 19 additions & 5 deletions pyadjoint/block_variable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .tape import no_annotations
from .tape import no_annotations, get_working_tape


class BlockVariable(object):
Expand All @@ -16,6 +16,10 @@ def __init__(self, output):
self.floating_type = False
# Helper flag for use during tape traversals.
self.marked_in_path = False
# By default assume the variable is created externally to the tape.
self.creation_timestep = -1
# The timestep during which this variable was last used as an input.
self.last_use = -1

def add_adj_output(self, val):
if self.adj_value is None:
Expand Down Expand Up @@ -59,13 +63,23 @@ def saved_output(self):

def will_add_as_dependency(self):
overwrite = self.output._ad_will_add_as_dependency()
overwrite = False if overwrite is None else overwrite
self.save_output(overwrite=overwrite)
overwrite = bool(overwrite)
tape = get_working_tape()
if self.last_use < tape.latest_checkpoint:
self.save_output(overwrite=overwrite)
tape.add_to_checkpointable_state(self, self.last_use)
self.last_use = tape.latest_timestep

def will_add_as_output(self):
tape = get_working_tape()
self.creation_timestep = tape.latest_timestep
self.last_use = self.creation_timestep
overwrite = self.output._ad_will_add_as_output()
overwrite = True if overwrite is None else overwrite
self.save_output(overwrite=overwrite)
overwrite = bool(overwrite)
if not overwrite:
self._checkpoint = None
if tape._eagerly_checkpoint_outputs:
self.save_output()

def __str__(self):
return str(self.output)
Expand Down
Loading

0 comments on commit e2123d3

Please sign in to comment.