Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disk checkpointing #3812

Merged
merged 20 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion firedrake/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
pause_annotation, continue_annotation, \
stop_annotating, annotate_tape # noqa F401
from pyadjoint.reduced_functional import ReducedFunctional # noqa F401
from pyadjoint.checkpointing import disk_checkpointing_callback # noqa F401
from firedrake.adjoint_utils.checkpointing import \
enable_disk_checkpointing, pause_disk_checkpointing, \
continue_disk_checkpointing, stop_disk_checkpointing, \
checkpointable_mesh # noqa F401
checkpointable_mesh # noqa F401
from firedrake.adjoint_utils import get_solve_blocks # noqa F401

from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401
Expand Down
15 changes: 13 additions & 2 deletions firedrake/adjoint_utils/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""A module providing support for disk checkpointing of the adjoint tape."""
from pyadjoint import get_working_tape, OverloadedType
from pyadjoint import get_working_tape, OverloadedType, disk_checkpointing_callback
from pyadjoint.tape import TapePackageData
from pyop2.mpi import COMM_WORLD
import tempfile
Expand All @@ -10,6 +10,8 @@
from numbers import Number
_enable_disk_checkpoint = False
_checkpoint_init_data = False
disk_checkpointing_callback["firedrake"] = "Please call enable_disk_checkpointing() "\
"before checkpointing on the disk."

__all__ = ["enable_disk_checkpointing", "disk_checkpointing",
"pause_disk_checkpointing", "continue_disk_checkpointing",
Expand Down Expand Up @@ -204,6 +206,12 @@ def restore_from_checkpoint(self, state):
self.init_checkpoint_file = state["init"]
self.current_checkpoint_file = state["current"]

def continue_checkpointing(self):
continue_disk_checkpointing()

def pause_checkpointing(self):
pause_disk_checkpointing()


def checkpointable_mesh(mesh):
"""Write a mesh to disk and read it back.
Expand Down Expand Up @@ -251,7 +259,7 @@ def restore(self):
pass


class CheckpointFunction(CheckpointBase):
class CheckpointFunction(CheckpointBase, OverloadedType):
"""Metadata for a Function checkpointed to disk.

An object of this class replaces the :class:`~firedrake.Function` stored as
Expand Down Expand Up @@ -304,6 +312,9 @@ def restore(self):
return type(function)(function.function_space(),
function.dat, name=self.name(), count=self.count)

def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint.restore()


def maybe_disk_checkpoint(function):
"""Checkpoint a Function to disk if disk checkpointing is active."""
Expand Down
Loading