From 6360250ca3a1ad97fce69439c67645a98fe26dad Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 4 Dec 2024 12:25:16 +0000 Subject: [PATCH] 4dvar rf test --- firedrake/adjoint/__init__.py | 1 + .../adjoint/all_at_once_reduced_functional.py | 200 +++++++---- firedrake/adjoint_utils/ensemblefunction.py | 6 + .../test_4dvar_reduced_functional.py | 325 ++++++++++++++++++ 4 files changed, 458 insertions(+), 74 deletions(-) create mode 100644 tests/firedrake/regression/test_4dvar_reduced_functional.py diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index c48b990420..08f6a3c2ec 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,6 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.all_at_once_reduced_functional import AllAtOnceReducedFunctional # noqa F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 0b78f5e236..42f3f90965 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -5,11 +5,12 @@ from typing import Callable, Optional from contextlib import contextmanager from mpi4py import MPI +from firedrake.petsc import PETSc __all__ = ['AllAtOnceReducedFunctional'] -@set_working_tape() +# @set_working_tape() # ends up using old_tape = None because evaluates when imported - need separate decorator def isolated_rf(operation, control, functional_name=None, control_name=None): @@ -17,28 +18,29 @@ def isolated_rf(operation, control, Return a ReducedFunctional where the functional is `operation` applied to a copy of `control`, and the tape contains only `operation`. """ - controls = Enlist(control) - control_names = Enlist(control_name) + with set_working_tape(): + controls = Enlist(control) + control_names = Enlist(control_name) - with stop_annotating(): - control_copies = [control._ad_copy() for control in controls] + with stop_annotating(): + control_copies = [control._ad_copy() for control in controls] - if control_names: - for control, name in zip(control_copies, control_names): - _rename(control, name) + if control_names: + for control, name in zip(control_copies, control_names): + _rename(control, name) - if len(control_copies) == 1: - functional = operation(control_copies[0]) - control = Control(control_copies[0]) - else: - functional = operation(control_copies) - control = [Control(control) for control in control_copies] + if len(control_copies) == 1: + functional = operation(control_copies[0]) + control = Control(control_copies[0]) + else: + functional = operation(control_copies) + control = [Control(control) for control in control_copies] - if functional_name: - _rename(functional, functional_name) + if functional_name: + _rename(functional, functional_name) - return ReducedFunctional( - functional, control) + return ReducedFunctional( + functional, control) def sc_passthrough(func): @@ -151,15 +153,16 @@ def __init__(self, control: Control, self.trank = ensemble.ensemble_comm.rank if ensemble else 0 self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 - x = control.control.subfunctions - self.x = x + self._cbuf = control.copy() + _x = self._cbuf.subfunctions + self._x = _x + self._controls = tuple(Control(xi) for xi in _x) self.control = control self.controls = [control] - self._controls = tuple(Control(xi) for xi in x) # first control on rank 0 is initial conditions, not end of observation stage - self.nlocal_stages = len(x) - (1 if self.trank == 0 else 0) + self.nlocal_stages = len(_x) - (1 if self.trank == 0 else 0) self.stages = [] # The record of each observation stage @@ -169,7 +172,7 @@ def __init__(self, control: Control, # RF to recalculate error vector (x_0 - x_b) self.background_error = isolated_rf( operation=lambda x0: _ad_sub(x0, self.background), - control=x[0], + control=_x[0], functional_name="bkg_err_vec", control_name="Control_0_bkg_copy") @@ -184,7 +187,7 @@ def __init__(self, control: Control, # RF to recalculate error vector (H(x_0) - y_0) self.initial_observation_error = isolated_rf( operation=observation_err, - control=x[0], + control=_x[0], functional_name="obs_err_vec_0", control_name="Control_0_obs_copy") @@ -197,13 +200,13 @@ def __init__(self, control: Control, # create halo for previous state if self.ensemble and self.trank != 0: with stop_annotating(): - self.xprev = x[0]._ad_copy() + self.xprev = _x[0]._ad_copy() self._control_prev = Control(self.xprev) # halo for the derivative from the next chunk if self.ensemble and self.trank != self.nchunks - 1: with stop_annotating(): - self.xnext = x[0]._ad_copy() + self.xnext = _x[0]._ad_copy() else: self._annotate_accumulation = True @@ -268,14 +271,21 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - self.control.assign(values[0] if isinstance(values, list) else values) + value = values[0] if isinstance(values, list) else values + + if not isinstance(value, type(self.control.control)): + raise ValueError(f"Value must be of type {type(self.control.control)} not type {type(value)}") + + self.control.update(value) + self._cbuf.assign(value) + trank = self.trank # first "control" for later ranks is the halo if self.ensemble and trank != 0: - x = [self.xprev, *self.x] + x = [self.xprev, *self._x] else: - x = [*self.x] + x = [*self._x] # post messages for control of forward model propogation on next chunk if self.ensemble: @@ -346,49 +356,53 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): Should be an instance of the same type as the control. """ trank = self.trank - # create a list of overloaded types to put derivative into - derivatives = self.control._ad_copy() - derivatives.zero() - - if self.ensemble and trank != 0: - self.xprev.zero() - derivs = [self.xprev, *derivatives.subfunctions] - else: - derivs = [*derivatives.subfunctions] # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} intermediate_options = { - 'riesz_representation': None, + 'riesz_representation': 'l2', **{k: v for k, v in options.items() if (k != 'riesz_representation')} } - # initial condition derivatives - if trank == 0: - bkg_deriv = self.background_norm.derivative( - adj_input=adj_input, options=intermediate_options) - - derivs[0] += self.background_error.derivative( - adj_input=bkg_deriv, options=options) - - # observations at time 0 - if self.initial_observations: - obs_deriv = self.initial_observation_norm.derivative( - adj_input=adj_input, options=intermediate_options) - - derivs[0] += self.initial_observation_error.derivative( - adj_input=obs_deriv, options=options) - # evaluate first forward model, which contributes to previous chunk sderiv0 = self.stages[0].derivative( adj_input=adj_input, options=options) + # create the derivative in the right primal or dual space + from ufl.duals import is_primal, is_dual + if is_primal(sderiv0[0]): + from firedrake.ensemblefunction import EnsembleFunction + derivatives = EnsembleFunction( + self.ensemble, self.control.local_function_spaces) + else: + if not is_dual(sderiv0[0]): + raise ValueError( + "Do not know how to handle stage derivative which is not primal or dual") + from firedrake.ensemblefunction import EnsembleCofunction + derivatives = EnsembleCofunction( + self.ensemble, [V.dual() for V in self.control.local_function_spaces]) + + derivatives.zero() + + if self.ensemble: + with stop_annotating(): + xprev = derivatives.subfunctions[0]._ad_copy() + xnext = derivatives.subfunctions[0]._ad_copy() + xprev.zero() + xnext.zero() + if trank != 0: + derivs = [xprev, *derivatives.subfunctions] + else: + derivs = [*derivatives.subfunctions] + + # start accumulating the complete derivative derivs[0] += sderiv0[0] derivs[1] += sderiv0[1] # post the derivative halo exchange if self.ensemble: + # halos sent backward in time src = trank + 1 dst = trank - 1 @@ -398,7 +412,23 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): if trank != self.nchunks - 1: recv_reqs = self.ensemble.irecv( - self.xnext, source=src, tag=trank) + xnext, source=src, tag=trank) + + # initial condition derivatives + if trank == 0: + bkg_deriv = self.background_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.background_error.derivative( + adj_input=bkg_deriv, options=options) + + # observations at time 0 + if self.initial_observations: + obs_deriv = self.initial_observation_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.initial_observation_error.derivative( + adj_input=obs_deriv, options=options) # # evaluate all forward models on chunk except first while halo in flight for i in range(1, len(self.stages)): @@ -412,7 +442,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): if self.ensemble: if trank != self.nchunks - 1: MPI.Request.Waitall(recv_reqs) - derivs[-1] += self.xnext + derivs[-1] += xnext return derivatives @@ -475,13 +505,17 @@ def recording_stages(self, sequential=True, **stage_kwargs): trank = self.trank - global_index = 0 + # index of "previous" stage and observation in global series + global_index = -1 + observation_index = 0 if self.initial_observations else -1 with stop_annotating(): - xhalo = self.x[0]._ad_copy() + xhalo = self._x[0]._ad_copy() - # add our data onto the users context data + # add our data onto the user's context data ekwargs = {k: v for k, v in stage_kwargs.items()} ekwargs['global_index'] = global_index + ekwargs['observation_index'] = observation_index + ekwargs['xhalo'] = xhalo # proceed one ensemble rank at a time @@ -503,6 +537,7 @@ def recording_stages(self, sequential=True, **stage_kwargs): # initialise iterator for local stages stage_sequence = ObservationStageSequence( controls, self, ectx.global_index, + ectx.observation_index, local_stage_kwargs, sequential) # let the user record the local stages @@ -512,6 +547,12 @@ def recording_stages(self, sequential=True, **stage_kwargs): with stop_annotating(): state = self.stages[-1].controls[1].control ectx.xhalo.assign(state) + # grab the user's information to send forward + for k in local_stage_kwargs.keys(): + setattr(ectx, k, getattr(stage_sequence.ctx, k)) + # increment the global indices for the last local stage + ectx.global_index = self.stages[-1].global_index + ectx.observation_index = self.stages[-1].observation_index else: # strong constraint @@ -523,16 +564,17 @@ class ObservationStageSequence: def __init__(self, controls: Control, aaorf: AllAtOnceReducedFunctional, global_index: int, + observation_index: int, stage_kwargs: dict = None, sequential: bool = True): self.controls = controls self.nstages = len(controls) - 1 self.aaorf = aaorf self.ctx = StageContext(**(stage_kwargs or {})) - self.index = 0 self.weak_constraint = aaorf.weak_constraint self.global_index = global_index - self.local_index = 0 + self.observation_index = observation_index + self.local_index = -1 def __iter__(self): return self @@ -542,34 +584,36 @@ def __next__(self): if self.weak_constraint: stages = self.aaorf.stages + # increment global indices + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + # start of the next stage - next_control = self.controls[self.index] + next_control = self.controls[self.local_index] # smuggle state forward and increment stage indices - if self.index > 0: - self.local_index += 1 - self.global_index += 1 - + if self.local_index > 0: state = stages[-1].controls[1].control with stop_annotating(): next_control.control.assign(state) # stop after we've recorded all stages - if self.index >= self.nstages: + if self.local_index >= self.nstages: raise StopIteration - self.index += 1 stage = WeakObservationStage(next_control, local_index=self.local_index, - global_index=self.global_index) + global_index=self.global_index, + observation_index=self.observation_index) stages.append(stage) else: # strong constraint # increment stage indices - if self.index > 0: - self.local_index += 1 - self.global_index += 1 + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 # stop after we've recorded all stages if self.index >= self.nstages: @@ -583,6 +627,7 @@ def __next__(self): stage = StrongObservationStage(control, self.aaorf) self._prev_stage = stage + return stage, self.ctx @@ -661,10 +706,16 @@ class WeakObservationStage: global_index The index of this stage in the global timeseries. + observation_index + The index of the observation at the end of this stage in + the global timeseries. May be different from global_index if + an observation is taken at the initial time. + """ def __init__(self, control: Control, local_index: Optional[int] = None, - global_index: Optional[int] = None): + global_index: Optional[int] = None, + observation_index: Optional[int] = None): # "control" to use as initial condition. # Not actual `Control` for consistency with strong constraint self.control = control.control @@ -672,6 +723,7 @@ def __init__(self, control: Control, self.controls = Enlist(control) self.local_index = local_index self.global_index = global_index + self.observation_index = observation_index set_working_tape() self._stage_tape = get_working_tape() diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py index 1c39e04517..ba6882f054 100644 --- a/firedrake/adjoint_utils/ensemblefunction.py +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -82,6 +82,12 @@ def _ad_create_checkpoint(self): else: return self.copy() + def _ad_restore_at_checkpoint(self, checkpoint): + if isinstance(checkpoint, type(self)): + return checkpoint + raise NotImplementedError( + "Checkpointing not implemented for EnsembleFunctions") + def _ad_from_petsc(self, vec): with self.vec_wo as self_v: vec.copy(result=self_v) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py new file mode 100644 index 0000000000..bf0530293f --- /dev/null +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -0,0 +1,325 @@ +import pytest +import firedrake as fd +from firedrake.__future__ import interpolate +from firedrake.adjoint import ( + continue_annotation, pause_annotation, stop_annotating, set_working_tape, + Control, taylor_test, ReducedFunctional, AllAtOnceReducedFunctional) + + +def function_space(comm): + """DG0 periodic advection""" + mesh = fd.PeriodicUnitIntervalMesh(nx, comm=comm) + return fd.FunctionSpace(mesh, "DG", 0) + + +def timestepper(V): + """Implicit midpoint timestepper for the advection equation""" + qn = fd.Function(V, name="qn") + qn1 = fd.Function(V, name="qn1") + + def mass(q, phi): + return fd.inner(q, phi)*fd.dx + + def tendency(q, phi): + u = fd.as_vector([vconst]) + n = fd.FacetNormal(V.mesh()) + un = fd.Constant(0.5)*(fd.dot(u, n) + abs(fd.dot(u, n))) + return (- q*fd.div(phi*u)*fd.dx + + fd.jump(phi)*fd.jump(un*q)*fd.dS) + + # midpoint rule + q = fd.TrialFunction(V) + phi = fd.TestFunction(V) + + qh = fd.Constant(0.5)*(q + qn) + eqn = mass(q - qn, phi) + fd.Constant(dt)*tendency(qh, phi) + + stepper = fd.LinearVariationalSolver( + fd.LinearVariationalProblem( + fd.lhs(eqn), fd.rhs(eqn), qn1, + constant_jacobian=True)) + + return qn, qn1, stepper + + +def prod2(w): + """generate weighted inner products to pass to FourDVarReducedFunctional""" + def n2(x): + return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx) + return n2 + + +prodB = prod2(0.1) # background error +prodR = prod2(10.) # observation error +prodQ = prod2(1.0) # model error + + +"""Advecting velocity""" +velocity = 1 +vconst = fd.Constant(velocity) + +"""Number of cells""" +nx = 16 + +"""Timestep size""" +cfl = 2.3523 +dx = 1.0/nx +dt = cfl*dx/velocity + +"""How many times / how often we take observations +(one extra at initial time)""" +observation_frequency = 5 +observation_n = 6 +observation_times = [i*observation_frequency*dt + for i in range(observation_n+1)] + + +def nlocal_observations(ensemble): + """How many observations on the current ensemble member""" + esize = ensemble.ensemble_comm.size + erank = ensemble.ensemble_comm.rank + if esize == 1: + return observation_n + 1 + assert (observation_n % esize == 0), "Must be able to split observations across ensemble" # noqa: E501 + return observation_n//esize + (1 if erank == 0 else 0) + + +def analytic_solution(V, t, mag=1.0, phase=0.0): + """Exact advection of sin wave after time t""" + x, = fd.SpatialCoordinate(V.mesh()) + return fd.Function(V).interpolate( + mag*fd.sin(2*fd.pi*((x + phase) - vconst*t))) + + +def analytic_series(V, tshift=0.0, mag=1.0, phase=0.0, ensemble=None): + """Timeseries of the analytic solution""" + series = [analytic_solution(V, t+tshift, mag=mag, phase=phase) + for t in observation_times] + + if ensemble is None: + return series + else: + nlocal_obs = nlocal_observations(ensemble) + rank = ensemble.ensemble_comm.rank + offset = (0 if rank == 0 else rank*nlocal_obs + 1) + + efunc = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + for e, s in zip(efunc.subfunctions, + series[offset:offset+nlocal_obs]): + e.assign(s) + return efunc + + +def observation_errors(V): + """List of functions to evaluate the observation error + at each observation time""" + + observation_locations = [ + [x] for x in [0.13, 0.18, 0.34, 0.36, 0.49, 0.61, 0.72, 0.99] + ] + + observation_mesh = fd.VertexOnlyMesh(V.mesh(), observation_locations) + Vobs = fd.FunctionSpace(observation_mesh, "DG", 0) + + # observation operator + def H(x): + return fd.assemble(interpolate(x, Vobs)) + + # ground truth + targets = analytic_series(V) + + # take observations + y = [H(x) for x in targets] + + # generate function to evaluate observation error at observation time i + def observation_error(i): + def obs_err(x): + return fd.Function(Vobs).assign(H(x) - y[i]) + return obs_err + + return observation_error + + +def background(V): + """Prior for initial condition""" + return analytic_solution(V, t=0, mag=0.9, phase=0.1) + + +def m(V, ensemble=None): + """The expansion points for the Taylor test""" + return analytic_series(V, tshift=0.1, mag=1.1, phase=-0.2, + ensemble=ensemble) + + +def h(V, ensemble=None): + """The perturbation direction for the Taylor test""" + return analytic_series(V, tshift=0.3, mag=0.1, phase=0.3, + ensemble=ensemble) + + +def fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + controls = [fd.Function(V) + for _ in range(len(observation_times))] + + # Prior + bkg = background(V) + + controls[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # background error + J = prodB(controls[0] - bkg) + + # initial observation error + J += prodR(obs_errors(0)(controls[0])) + + # record observation stages + for i in range(1, len(controls)): + qn.assign(controls[i-1]) + + # forward model propogation + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # we need to smuggle the state over to next + # control without the tape seeing so that we + # can continue the timeseries through the next + # stage but with the tape thinking that the + # forward model in each stage is independent. + with stop_annotating(): + controls[i].assign(qn) + + # model error for this stage + J += prodQ(qn - controls[i]) + + # observation error + J += prodR(obs_errors(i)(controls[i])) + + pause_annotation() + + Jhat = ReducedFunctional( + J, [Control(c) for c in controls]) + + return Jhat + + +def fdvar_firedrake(V, ensemble): + """Build an AllAtOnceReducedFunctional for the 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + + nlocal_obs = nlocal_observations(ensemble) + + control = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + # Prior + bkg = background(V) + + if ensemble.ensemble_comm.rank == 0: + control.subfunctions[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = AllAtOnceReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=True) + + # record observation stages + with Jhat.recording_stages() as stages: + + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # take observation + obs_err = obs_errors(stage.observation_index) + stage.set_observation(qn, obs_err, + observation_iprod=prodR, + forward_model_iprod=prodQ) + + pause_annotation() + + return Jhat + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_advection(): + main_test_advection() + + +def main_test_advection(): + global_comm = fd.COMM_WORLD + if global_comm.size in (1, 2): # time serial + nspace = global_comm.size + elif global_comm.size == 3: # time parallel + nspace = 1 + elif global_comm.size == 4: # space-time parallel + nspace = 2 + + ensemble = fd.Ensemble(global_comm, nspace) + V = function_space(ensemble.comm) + + erank = ensemble.ensemble_comm.rank + + # only setup the reference pyadjoint rf on the first ensemble member + if erank == 0: + Jhat_pyadj = fdvar_pyadjoint(V) + mp = m(V) + hp = h(V) + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 + + Jpm = ensemble.ensemble_comm.bcast(Jhat_pyadj(mp) if erank == 0 else None) + Jph = ensemble.ensemble_comm.bcast(Jhat_pyadj(hp) if erank == 0 else None) + + Jhat_aaorf = fdvar_firedrake(V, ensemble) + + ma = m(V, ensemble) + ha = h(V, ensemble) + + eps = 1e-12 + # Does evaluating the functional match the reference rf? + assert abs(Jpm - Jhat_aaorf(ma)) < eps + assert abs(Jph - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor test + # should mean we match the derivative too. + assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + + +if __name__ == '__main__': + main_test_advection()