diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py index 01809101..80da71fe 100644 --- a/pyadjoint/checkpointing.py +++ b/pyadjoint/checkpointing.py @@ -1,8 +1,7 @@ from enum import Enum +import sys from functools import singledispatchmethod -from checkpoint_schedules import ( - Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType) -from checkpoint_schedules import Revolve, MultistageCheckpointSchedule +from checkpoint_schedules import Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType class CheckpointError(RuntimeError): @@ -42,8 +41,7 @@ class CheckpointManager: forward model. reverse_schedule (list): A list of `checkpoint_schedules` actions used to manage the execution of the reverse model. - timesteps (int): The initial number of timesteps. - adjoint_evaluated (bool): A boolean indicating whether the adjoint model has been evaluated. + total_timesteps (int): The total number of timesteps to execute the forward model. mode (Mode): The mode of the checkpoint manager. The possible modes are `RECORD`, `FINISHED_RECORDING`, `EVALUATED`, `EXHAUSTED`, `RECOMPUTE`, and `EVALUATE_ADJOINT`. Additional information about the modes can be found class:`Mode`. @@ -51,14 +49,6 @@ class CheckpointManager: """ def __init__(self, schedule, tape): - if ( - not isinstance(schedule, Revolve) - and not isinstance(schedule, MultistageCheckpointSchedule) - ): - raise CheckpointError( - "Only Revolve and MultistageCheckpointSchedule schedules are supported." - ) - if ( schedule.uses_storage_type(StorageType.DISK) and not tape._package_data @@ -70,9 +60,15 @@ def __init__(self, schedule, tape): self._schedule = schedule self.forward_schedule = [] self.reverse_schedule = [] - self.timesteps = schedule.max_n - # This variable is used to indicate whether the adjoint model has been evaluated at the checkpoint. - self.adjoint_evaluated = False + if self._schedule.max_n: + self.total_timesteps = schedule.max_n + else: + # We have schedules in `checkpoint_schedules` offering the flexibility to determine + # the desired steps during the forward execution. For this type of schedule, we do not + # have the total number of timesteps `self._schedule.max_n`. Therefore, we set the + # `self.total_timesteps` to the maximum value of the `sys.maxsize` to indicate that + # the total number of timesteps is not known. + self.total_timesteps = sys.maxsize self.mode = Mode.RECORD self._current_action = next(self._schedule) self.forward_schedule.append(self._current_action) @@ -97,6 +93,11 @@ def end_timestep(self, timestep): def end_taping(self): """Process the end of the forward execution.""" current_timestep = self.tape.latest_timestep + if not self._schedule.max_n: + # Inform the schedule that the forward model has finished. + self._schedule.finalize(len(self.tape.timesteps)) + # `self._schedule.finalize` updates `self._schedule.max_n`. + self.total_timesteps = self._schedule.max_n while self.mode != Mode.EVALUATED: self.end_timestep(current_timestep) current_timestep += 1 @@ -143,13 +144,19 @@ def _(self, cp_action, timestep): if timestep > cp_action.n0: if cp_action.write_ics and timestep == (cp_action.n0 + 1): - # Stores the checkpoint data in RAM - # This data will be used to restart the forward model - # from the step `n0` in the reverse computations. - self.tape.timesteps[cp_action.n0].checkpoint() - - if not cp_action.write_adj_deps: - # Remove unnecessary variables from previous steps. + # Store the checkpoint data. This is the required data for restarting the forward model + # from the step `n0`. + self.tape.timesteps[timestep - 1].checkpoint() + if (cp_action.write_adj_deps and timestep == (cp_action.n1 - 1) and cp_action.storage != StorageType.WORK): + # Store the checkpoint data. This is the required data for computing the adjoint model + # from the step `n1`. + self.tape.timesteps[timestep].checkpoint() + + if ( + (cp_action.write_adj_deps and cp_action.storage != StorageType.WORK) + or not cp_action.write_adj_deps + ): + # Remove unnecessary variables in working memory from previous steps. for var in self.tape.timesteps[timestep - 1].checkpointable_state: var._checkpoint = None for block in self.tape.timesteps[timestep - 1]: @@ -157,7 +164,7 @@ def _(self, cp_action, timestep): for output in block.get_outputs(): output._checkpoint = None - if timestep in cp_action: + if timestep in cp_action and timestep < self.total_timesteps: self.tape.get_blocks().append_step() if cp_action.write_ics: self.tape.latest_checkpoint = cp_action.n0 @@ -167,7 +174,7 @@ def _(self, cp_action, timestep): @process_taping.register(EndForward) def _(self, cp_action, timestep): - if timestep != self.timesteps: + if timestep != self.total_timesteps: raise CheckpointError( "The correct number of forward steps has notbeen taken." ) @@ -180,12 +187,12 @@ def recompute(self, functional=None): Args: functional (BlockVariable): The functional to be evaluated. """ + if self.mode == Mode.RECORD: + # Finalise the taping process. + self.end_taping() self.mode = Mode.RECOMPUTE - - with self.tape.progress_bar("Evaluating Functional", - max=self.timesteps) as bar: - # Restore the initial condition to advance the forward model - # from the step 0. + with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as bar: + # Restore the initial condition to advance the forward model from the step 0. current_step = self.tape.timesteps[self.forward_schedule[0].n0] current_step.restore_from_checkpoint() for cp_action in self.forward_schedule: @@ -206,30 +213,27 @@ def evaluate_adj(self, last_block, markings): raise NotImplementedError( "Only the first block can be evaluated at present." ) - if self.mode == Mode.RECORD: - # The declared timesteps were not exhausted while taping. + # Finalise the taping process. self.end_taping() if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING): raise CheckpointError("Evaluate Functional before calling gradient.") - with self.tape.progress_bar("Evaluating Adjoint", max=self.timesteps) as bar: - if self.adjoint_evaluated: - reverse_iterator = iter(self.reverse_schedule) - while not isinstance(self._current_action, EndReverse): - if not self.adjoint_evaluated: - self._current_action = next(self._schedule) - self.reverse_schedule.append(self._current_action) - else: - self._current_action = next(reverse_iterator) - self.process_operation(self._current_action, bar, markings=markings) - # Only set the mode after the first backward in order to handle - # that step correctly. - self.mode = Mode.EVALUATE_ADJOINT - - # Inform that the adjoint model has been evaluated. - self.adjoint_evaluated = True + with self.tape.progress_bar("Evaluating Adjoint", max=self.total_timesteps) as bar: + if self.reverse_schedule: + for cp_action in self.reverse_schedule: + self.process_operation(cp_action, bar, markings=markings) + else: + while not isinstance(self._current_action, EndReverse): + cp_action = next(self._schedule) + self._current_action = cp_action + self.reverse_schedule.append(cp_action) + self.process_operation(cp_action, bar, markings=markings) + + # Only set the mode after the first backward in order to handle + # that step correctly. + self.mode = Mode.EVALUATE_ADJOINT @singledispatchmethod def process_operation(self, cp_action, bar, **kwargs): @@ -254,44 +258,57 @@ def process_operation(self, cp_action, bar, **kwargs): @process_operation.register(Forward) def _(self, cp_action, bar, functional=None, **kwargs): - for step in cp_action: + step = cp_action.n0 + # In a dynamic schedule `cp_action` can be unbounded so we also need to check `self.total_timesteps`. + while step in cp_action and step < self.total_timesteps: if self.mode == Mode.RECOMPUTE: - if bar: - bar.next() + bar.next() # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in current_step: block.recompute() - - if cp_action.write_ics: - if step == cp_action.n0: - for var in current_step.checkpointable_state: - if var.checkpoint: - current_step._checkpoint.update( - {var: var.checkpoint} - ) - if not cp_action.write_adj_deps: + if ( + (cp_action.write_ics and step == cp_action.n0) + or (cp_action.write_adj_deps and step == cp_action.n1 - 1 + and cp_action.storage != StorageType.WORK) + ): + # Store the checkpoint data required for restarting the + # forward model or computing the adjoint model. + # If `cp_action.write_ics` is `True`, the checkpointed data + # will restart the forward model from the step `n0`. + # If `cp_action.write_adj_deps` is `True`, the checkpointed + # data will be used for computing the adjoint model from the + # step `n1`. + for var in current_step.checkpointable_state: + if var.checkpoint: + current_step._checkpoint.update( + {var: var.checkpoint} + ) + if ( + (cp_action.write_adj_deps and cp_action.storage != StorageType.WORK) + or not cp_action.write_adj_deps + ): + to_keep = set() + if step < (self.total_timesteps - 1): next_step = self.tape.timesteps[step + 1] # The checkpointable state set of the current step. to_keep = next_step.checkpointable_state - if functional: - # `to_keep` holds informations of the blocks required - # for restarting the forward model from a step `n`. - to_keep = to_keep.union([functional.block_variable]) - for block in current_step: - # Remove unnecessary variables from previous steps. - for bv in block.get_outputs(): - if bv not in to_keep: - bv._checkpoint = None + if functional: + to_keep = to_keep.union([functional.block_variable]) + for block in current_step: # Remove unnecessary variables from previous steps. - for var in (current_step.checkpointable_state - to_keep): - var._checkpoint = None + for bv in block.get_outputs(): + if bv not in to_keep: + bv._checkpoint = None + # Remove unnecessary variables from previous steps. + for var in (current_step.checkpointable_state - to_keep): + var._checkpoint = None + step += 1 @process_operation.register(Reverse) def _(self, cp_action, bar, markings, functional=None, **kwargs): for step in cp_action: - if bar: - bar.next() + bar.next() # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in reversed(current_step): diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 98dfe994..63853d44 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -721,11 +721,12 @@ def __iter__(self): return self def __next__(self): + step = next(self.iterator) if self._first: self._first = False else: self.tape.end_timestep() - return next(self.iterator) + return step class TimeStep(list): diff --git a/tests/firedrake_adjoint/test_burgers_newton.py b/tests/firedrake_adjoint/test_burgers_newton.py index e063296e..052bed56 100644 --- a/tests/firedrake_adjoint/test_burgers_newton.py +++ b/tests/firedrake_adjoint/test_burgers_newton.py @@ -7,7 +7,8 @@ from firedrake import * from firedrake.adjoint import * -from checkpoint_schedules import Revolve, MultistageCheckpointSchedule +from checkpoint_schedules import Revolve, SingleMemoryStorageSchedule, MixedCheckpointSchedule,\ + NoneCheckpointSchedule, StorageType import numpy as np set_log_level(CRITICAL) continue_annotation() @@ -33,28 +34,31 @@ def J(ic, solve_type, checkpointing): + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx bc = DirichletBC(V, 0.0, "on_boundary") - t = 0.0 if solve_type == "NLVS": problem = NonlinearVariationalProblem(F, u, bcs=bc) solver = NonlinearVariationalSolver(problem) tape = get_working_tape() - t += float(timestep) - for t in tape.timestepper(np.arange(t, end + t, float(timestep))): + for _ in tape.timestepper(range(steps)): if solve_type == "NLVS": solver.solve() else: solve(F == 0, u, bc) u_.assign(u) - return assemble(u_*u_*dx + ic*ic*dx), u_ + return assemble(u_*u_*dx + ic*ic*dx) @pytest.mark.parametrize("solve_type, checkpointing", - [("solve", "Revolve"), + [ + ("solve", "Revolve"), ("NLVS", "Revolve"), - ("solve", "Multistage"), - ("NLVS", "Multistage"), + ("solve", "SingleMemory"), + ("NLVS", "SingleMemory"), + ("solve", "NoneAdjoint"), + ("NLVS", "NoneAdjoint"), + ("solve", "Mixed"), + ("NLVS", "Mixed"), ("solve", None), ("NLVS", None), ]) @@ -63,18 +67,24 @@ def test_burgers_newton(solve_type, checkpointing): """ tape = get_working_tape() tape.progress_bar = ProgressBar - if checkpointing == "Revolve": - tape.enable_checkpointing(Revolve(steps, steps//3)) - if checkpointing == "Multistage": - tape.enable_checkpointing(MultistageCheckpointSchedule(steps, steps//3, 0)) + if checkpointing: + if checkpointing == "Revolve": + schedule = Revolve(steps, steps//3) + if checkpointing == "SingleMemory": + schedule = SingleMemoryStorageSchedule() + if checkpointing == "Mixed": + schedule = MixedCheckpointSchedule(steps, steps//3, storage=StorageType.RAM) + if checkpointing == "NoneAdjoint": + schedule = NoneCheckpointSchedule() + tape.enable_checkpointing(schedule) x, = SpatialCoordinate(mesh) - ic = project(sin(2.*pi*x), V) - val, _ = J(ic, solve_type, checkpointing) + ic = project(sin(2. * pi * x), V) + val = J(ic, solve_type, checkpointing) if checkpointing: assert len(tape.timesteps) == steps - Jhat = ReducedFunctional(val, Control(ic)) - dJ = Jhat.derivative() + if checkpointing != "NoneAdjoint": + dJ = Jhat.derivative() # Recomputing the functional with a modified control variable # before the recompute test. @@ -82,22 +92,19 @@ def test_burgers_newton(solve_type, checkpointing): # Recompute test assert(np.allclose(Jhat(ic), val)) - - dJbar = Jhat.derivative() - # Test recompute adjoint-based gradient - assert np.allclose(dJ.dat.data_ro[:], dJbar.dat.data_ro[:]) - - # Taylor test - h = Function(V) - h.assign(1, annotate=False) - assert taylor_test(Jhat, ic, h) > 1.9 + if checkpointing != "NoneAdjoint": + dJbar = Jhat.derivative() + # Test recompute adjoint-based gradient + assert np.allclose(dJ.dat.data_ro[:], dJbar.dat.data_ro[:]) + # Taylor test + assert taylor_test(Jhat, ic, Function(V).assign(1, annotate=False)) > 1.9 @pytest.mark.parametrize("solve_type, checkpointing", [("solve", "Revolve"), ("NLVS", "Revolve"), - ("solve", "Multistage"), - ("NLVS", "Multistage") + ("solve", "Mixed"), + ("NLVS", "Mixed"), ]) def test_checkpointing_validity(solve_type, checkpointing): """Compare forward and backward results with and without checkpointing. @@ -108,7 +115,7 @@ def test_checkpointing_validity(solve_type, checkpointing): x, = SpatialCoordinate(mesh) ic = project(sin(2.*pi*x), V) - val0, u0 = J(ic, solve_type, False) + val0 = J(ic, solve_type, False) Jhat = ReducedFunctional(val0, Control(ic)) dJ0 = Jhat.derivative() tape.clear_tape() @@ -117,14 +124,10 @@ def test_checkpointing_validity(solve_type, checkpointing): tape.progress_bar = ProgressBar if checkpointing == "Revolve": tape.enable_checkpointing(Revolve(steps, steps//3)) - if checkpointing == "Multistage": - tape.enable_checkpointing(MultistageCheckpointSchedule(steps, steps//3, 0)) - x, = SpatialCoordinate(mesh) - ic = project(sin(2.*pi*x), V) - val1, u1 = J(ic, solve_type, True) + if checkpointing == "Mixed": + tape.enable_checkpointing(MixedCheckpointSchedule(steps, steps//3, storage=StorageType.RAM)) + val1 = J(ic, solve_type, True) Jhat = ReducedFunctional(val1, Control(ic)) - dJ1 = Jhat.derivative() assert len(tape.timesteps) == steps assert np.allclose(val0, val1) - assert np.allclose(u0.dat.data_ro[:], u1.dat.data_ro[:]) - assert np.allclose(dJ0.dat.data_ro[:], dJ1.dat.data_ro[:]) + assert np.allclose(dJ0.dat.data_ro[:], Jhat.derivative().dat.data_ro[:]) diff --git a/tests/firedrake_adjoint/test_checkpointing_multistep.py b/tests/firedrake_adjoint/test_checkpointing_multistep.py new file mode 100644 index 00000000..c0239a60 --- /dev/null +++ b/tests/firedrake_adjoint/test_checkpointing_multistep.py @@ -0,0 +1,79 @@ +import pytest +pytest.importorskip("firedrake") + +from firedrake import * +from firedrake.adjoint import * +from checkpoint_schedules import Revolve +import numpy as np +from collections import deque +continue_annotation() +total_steps = 20 +dt = 0.01 +mesh = UnitIntervalMesh(1) +V = FunctionSpace(mesh, "DG", 0) + + +def J(displacement_0): + stiff = Constant(2.5) + damping = Constant(0.3) + rho = Constant(1.0) + # Adams-Bashforth coefficients. + adams_bashforth_coeffs = [55.0/24.0, -59.0/24.0, 37.0/24.0, -3.0/8.0] + # Adams-Moulton coefficients. + adams_moulton_coeffs = [9.0/24.0, 19.0/24.0, -5.0/24.0, 1.0/24.0] + displacement = Function(V) + velocity = deque([Function(V) for _ in adams_bashforth_coeffs]) + forcing = deque([Function(V) for _ in adams_bashforth_coeffs]) + displacement.assign(displacement_0) + tape = get_working_tape() + for _ in tape.timestepper(range(total_steps)): + for _ in range(len(adams_bashforth_coeffs) - 1): + forcing.append(forcing.popleft()) + forcing[0].assign(-(stiff * displacement + damping * velocity[0])/rho) + for _ in range(len(adams_bashforth_coeffs) - 1): + velocity.append(velocity.popleft()) + for m, coef in enumerate(adams_bashforth_coeffs): + velocity[0].assign(velocity[0] + dt * coef * forcing[m]) + for m, coef in enumerate(adams_moulton_coeffs): + displacement.assign(displacement + dt * coef * velocity[m]) + return assemble(displacement * displacement * dx) + + +def test_multisteps(): + tape = get_working_tape() + tape.progress_bar = ProgressBar + tape.enable_checkpointing(Revolve(total_steps, 2)) + displacement_0 = Function(V).assign(1.0) + val = J(displacement_0) + c = Control(displacement_0) + J_hat = ReducedFunctional(val, c) + dJ = J_hat.derivative() + # Recomputing the functional with a modified control variable + # before the recompute test. + J_hat(Function(V).assign(0.5)) + # Recompute test + assert(np.allclose(J_hat(displacement_0), val)) + # Test recompute adjoint-based gradient + assert np.allclose(dJ.dat.data_ro[:], J_hat.derivative().dat.data_ro[:]) + assert taylor_test(J_hat, displacement_0, Function(V).assign(1, annotate=False)) > 1.9 + + +def test_validity(): + tape = get_working_tape() + tape.progress_bar = ProgressBar + displacement_0 = Function(V).assign(1.0) + # Without checkpointing. + val0 = J(displacement_0) + J_hat0 = ReducedFunctional(val0, Control(displacement_0)) + dJ0 = J_hat0.derivative() + val_recomputed0 = J(displacement_0) + tape.clear_tape() + + # With checkpointing. + tape.enable_checkpointing(Revolve(total_steps, 2)) + val = J(displacement_0) + J_hat = ReducedFunctional(val, Control(displacement_0)) + dJ = J_hat.derivative() + val_recomputed = J_hat(displacement_0) + assert np.allclose(val_recomputed, val_recomputed0) + assert np.allclose(dJ.dat.data_ro[:], dJ0.dat.data_ro[:])