diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 3c89e61429..0162671b5a 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,6 +103,7 @@ from firedrake.vector import * from firedrake.version import __version__ as ver, __version_info__, check # noqa: F401 from firedrake.ensemble import * +from firedrake.ensemblefunction import * from firedrake.randomfunctiongen import * from firedrake.external_operators import * from firedrake.progress_bar import ProgressBar # noqa: F401 diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index c48b990420..08f6a3c2ec 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,6 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.all_at_once_reduced_functional import AllAtOnceReducedFunctional # noqa F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 429000b2a2..0dd2ccafe8 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,12 +1,47 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ - stop_annotating, no_annotations, get_working_tape, set_working_tape + stop_annotating, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist from functools import wraps, cached_property from typing import Callable, Optional +from contextlib import contextmanager +from mpi4py import MPI __all__ = ['AllAtOnceReducedFunctional'] +# @set_working_tape() # ends up using old_tape = None because evaluates when imported - need separate decorator +def isolated_rf(operation, control, + functional_name=None, + control_name=None): + """ + Return a ReducedFunctional where the functional is `operation` applied + to a copy of `control`, and the tape contains only `operation`. + """ + with set_working_tape(): + controls = Enlist(control) + control_names = Enlist(control_name) + + with stop_annotating(): + control_copies = [control._ad_copy() for control in controls] + + if control_names: + for control, name in zip(control_copies, control_names): + _rename(control, name) + + if len(control_copies) == 1: + functional = operation(control_copies[0]) + control = Control(control_copies[0]) + else: + functional = operation(control_copies) + control = [Control(control) for control in control_copies] + + if functional_name: + _rename(functional, functional_name) + + return ReducedFunctional( + functional, control) + + def sc_passthrough(func): """ Wraps standard ReducedFunctional methods to differentiate strong or @@ -41,6 +76,20 @@ def _ad_sub(left, right): return result +def _intermediate_options(final_options): + """ + Options set for the intermediate stages of a chain of ReducedFunctionals + + Takes all elements of the final_options except riesz_representation, + which is set to prevent returning derivatives to the primal space. + """ + return { + 'riesz_representation': None, + **{k: v for k, v in final_options.items() + if (k != 'riesz_representation')} + } + + class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. @@ -53,39 +102,43 @@ class AllAtOnceReducedFunctional(ReducedFunctional): ---------- control - The initial condition :math:`x_{0}`. Starting value is used as the - background (prior) data :math:`x_{b}`. + The :class:`EnsembleFunction` for the control x_{i} at the initial + condition and at the end of each observation stage. background_iprod The inner product to calculate the background error functional from the background error :math:`x_{0} - x_{b}`. Can include the - error covariance matrix. + error covariance matrix. Only used on ensemble rank 0. + + background + The background (prior) data for the initial condition :math:`x_{b}`. + If not provided, the value of the first Function on the first ensemble + of the EnsembleFunction will be used. observation_err Given a state :math:`x`, returns the observations error :math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the observations at the initial time and :math:`\\mathcal{H}_{0}` is - the observation operator for the initial time. Optional. + the observation operator for the initial time. Only used on + ensemble rank 0. Optional. observation_iprod The inner product to calculate the observation error functional from the observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`. Can include the error covariance matrix. Must be provided if - observation_err is provided. + observation_err is provided. Only used on ensemble rank 0 weak_constraint Whether to use the weak or strong constraint 4DVar formulation. - tape - The tape to record on. - See Also -------- :class:`pyadjoint.ReducedFunctional`. """ def __init__(self, control: Control, - background_iprod: Callable[[OverloadedType], AdjFloat], + background_iprod: Optional[Callable[[OverloadedType], AdjFloat]], + background: Optional[OverloadedType] = None, observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, weak_constraint: bool = True, @@ -97,93 +150,86 @@ def __init__(self, control: Control, self.weak_constraint = weak_constraint self.initial_observations = observation_err is not None - # We need a copy for the prior, but this shouldn't be part of the tape with stop_annotating(): - self.background = control.copy_data() + if background: + self.background = background._ad_copy() + else: + self.background = control.control.subfunctions[0]._ad_copy() + _rename(self.background, "Background") if self.weak_constraint: self._annotate_accumulation = _annotate_accumulation - - # new tape for background error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes + self._accumulation_started = False + + ensemble = control.ensemble + self.ensemble = ensemble + self.trank = ensemble.ensemble_comm.rank if ensemble else 0 + self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 + + self._cbuf = control.copy() + _x = self._cbuf.subfunctions + self._x = _x + self._controls = tuple(Control(xi) for xi in _x) + + self.control = control + self.controls = [control] + + # first control on rank 0 is initial conditions, not end of observation stage + self.nlocal_stages = len(_x) - (1 if self.trank == 0 else 0) + + self.stages = [] # The record of each observation stage + + # first rank sets up functionals for background initial observations + if self.trank == 0: + + # RF to recalculate error vector (x_0 - x_b) + self.background_error = isolated_rf( + operation=lambda x0: _ad_sub(x0, self.background), + control=_x[0], + functional_name="bkg_err_vec", + control_name="Control_0_bkg_copy") + + # RF to recalculate inner product |x_0 - x_b|_B + self.background_norm = isolated_rf( + operation=background_iprod, + control=self.background_error.functional, + control_name="bkg_err_vec_copy") + + if self.initial_observations: + + # RF to recalculate error vector (H(x_0) - y_0) + self.initial_observation_error = isolated_rf( + operation=observation_err, + control=_x[0], + functional_name="obs_err_vec_0", + control_name="Control_0_obs_copy") + + # RF to recalculate inner product |H(x_0) - y_0|_R + self.initial_observation_norm = isolated_rf( + operation=observation_iprod, + control=self.initial_observation_error.functional, + functional_name="obs_err_vec_0_copy") + + # create halo for previous state + if self.ensemble and self.trank != 0: with stop_annotating(): - control_copy = control.copy_data() - _rename(control_copy, "Control_0_bkg_copy") - - # vector of x_0 - x_b - bkg_err_vec = _ad_sub(control_copy, self.background) - _rename(bkg_err_vec, "bkg_err_vec") + self.xprev = _x[0]._ad_copy() + self._control_prev = Control(self.xprev) - # RF to recover x_0 - x_b - self.background_error = ReducedFunctional( - bkg_err_vec, Control(control_copy), tape=tape) - - # new tape for background error reduction - with set_working_tape() as tape: - # start from a control independent of any other tapes + # halo for the derivative from the next chunk + if self.ensemble and self.trank != self.nchunks - 1: with stop_annotating(): - bkg_err_vec_copy = bkg_err_vec._ad_copy() - _rename(bkg_err_vec_copy, "bkg_err_vec_copy") - - # inner product |x_0 - x_b|_B - bkg_err = background_iprod(bkg_err_vec_copy) - - # RF to recover |x_0 - x_b|_B - self.background_rf = ReducedFunctional( - bkg_err, Control(bkg_err_vec_copy), tape=tape) - - self.controls = [control] # The solution at the beginning of each time-chunk - self.states = [] # The model propogation at the end of each time-chunk - self.forward_model_stages = [] # ReducedFunctional for each model propogation (returns state) - self.forward_model_errors = [] # Inner product for model errors (possibly including error covariance) - self.forward_model_rfs = [] # Inner product for model errors (possibly including error covariance) - self.observation_errors = [] # ReducedFunctional for each observation set (returns observation error) - self.observation_rfs = [] # Inner product for observation errors (possibly including error covariance) - - if self.initial_observations: - - # new tape for observation error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - control_copy = control.copy_data() - _rename(control_copy, "Control_0_obs_copy") - - # vector of H(x_0) - y_0 - obs_err_vec = observation_err(control_copy) - _rename(obs_err_vec, "obs_err_vec_0") - - # RF to recover H(x_0) - y_0 - self.observation_errors.append(ReducedFunctional( - obs_err_vec, Control(control_copy), tape=tape) - ) - - # new tape for observation error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - obs_err_vec_copy = obs_err_vec._ad_copy() - _rename(obs_err_vec_copy, "obs_err_vec_0_copy") - - # inner product |H(x_0) - y_0|_R - obs_err = observation_iprod(obs_err_vec_copy) - - # RF to recover |H(x_0) - y_0|_R - self.observation_rfs.append(ReducedFunctional( - obs_err, Control(obs_err_vec_copy), tape=tape) - ) - - # new tape for the next stage - set_working_tape() - self._stage_tape = get_working_tape() + self.xnext = _x[0]._ad_copy() else: self._annotate_accumulation = True + self._accumulation_started = False # initial conditions guess to be updated self.controls = Enlist(control) + self.tape = get_working_tape() if tape is None else tape + # Strong constraint functional to be converted to ReducedFunctional later # penalty for straying from prior @@ -195,155 +241,6 @@ def __init__(self, control: Control, self._accumulate_functional( observation_iprod(observation_err(control.control))) - def set_observation(self, state: OverloadedType, - observation_err: Callable[[OverloadedType], OverloadedType], - observation_iprod: Callable[[OverloadedType], AdjFloat], - forward_model_iprod: Optional[Callable[[OverloadedType], AdjFloat]]): - """ - Record an observation at the time of `state`. - - Parameters - ---------- - - state - The state at the current observation time. - - observation_err - Given a state :math:`x`, returns the observations error - :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are - the observations at the current observation time and - :math:`\\mathcal{H}_{i}` is the observation operator for the - current observation time. - - observation_iprod - The inner product to calculate the observation error functional - from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. - Can include the error covariance matrix. - - forward_model_iprod - The inner product to calculate the model error functional from - the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can - include the error covariance matrix. Ignored if using the strong - constraint formulation. - """ - if self.weak_constraint: - - stage_index = len(self.controls) - - # Cut the tape into seperate time-chunks. - # State is output from previous control i.e. forward model - # propogation over previous time-chunk. - - # get the tape used for this stage and make sure its the right one - prev_stage_tape = get_working_tape() - if prev_stage_tape is not self._stage_tape: - raise ValueError( - "Working tape at the end of the observation stage" - " differs from the tape at the stage beginning." - ) - - # # record forward propogation - with set_working_tape(prev_stage_tape.copy()) as tape: - prev_control = self.controls[-1] - self.forward_model_stages.append(ReducedFunctional( - state._ad_copy(), controls=prev_control, tape=tape) - ) - - # Beginning of next time-chunk is the control for this observation - # and the state at the end of the next time-chunk. - with stop_annotating(): - # smuggle initial guess at this time into the control without the tape seeing - next_control_state = state._ad_copy() - _rename(next_control_state, f"Control_{len(self.controls)}") - next_control = Control(next_control_state) - self.controls.append(next_control) - - # model error links time-chunks by depending on both the - # previous and current controls - - # new tape for model error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - state_copy = state._ad_copy() - _rename(state_copy, f"state_{stage_index}_copy") - next_control_copy = next_control.copy_data() - _rename(next_control_copy, f"Control_{stage_index}_model_copy") - - # vector of M_i - x_i - model_err_vec = _ad_sub(state_copy, next_control_copy) - _rename(model_err_vec, f"model_err_vec_{stage_index}") - - # RF to recover M_i - x_i - fmcontrols = [Control(state_copy), Control(next_control_copy)] - self.forward_model_errors.append(ReducedFunctional( - model_err_vec, fmcontrols, tape=tape) - ) - - # new tape for model error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - model_err_vec_copy = model_err_vec._ad_copy() - _rename(model_err_vec_copy, f"model_err_vec_{stage_index}_copy") - - # inner product |M_i - x_i|_Q - model_err = forward_model_iprod(model_err_vec_copy) - - # RF to recover |M_i - x_i|_Q - self.forward_model_rfs.append(ReducedFunctional( - model_err, Control(model_err_vec_copy), tape=tape) - ) - - # Observations after tape cut because this is now a control, not a state - - # new tape for observation error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - next_control_copy = next_control.copy_data() - _rename(next_control_copy, f"Control_{stage_index}_obs_copy") - - # vector of H(x_i) - y_i - obs_err_vec = observation_err(next_control_copy) - _rename(obs_err_vec, f"obs_err_vec_{stage_index}") - - # RF to recover H(x_i) - y_i - self.observation_errors.append(ReducedFunctional( - obs_err_vec, Control(next_control_copy), tape=tape) - ) - - # new tape for observation error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - obs_err_vec_copy = obs_err_vec._ad_copy() - _rename(obs_err_vec_copy, f"obs_err_vec_{stage_index}_copy") - - # inner product |H(x_i) - y_i|_R - obs_err = observation_iprod(obs_err_vec_copy) - - # RF to recover |H(x_i) - y_i|_R - self.observation_rfs.append(ReducedFunctional( - obs_err, Control(obs_err_vec_copy), tape=tape) - ) - - # new tape for the next stage - - set_working_tape() - self._stage_tape = get_working_tape() - - # Look we're starting this time-chunk from an "unrelated" value... really! - state.assign(next_control.control) - - else: - - if hasattr(self, "_strong_reduced_functional"): - msg = "Cannot add observations once strong constraint ReducedFunctional instantiated" - raise ValueError(msg) - self._accumulate_functional( - observation_iprod(observation_err(state))) - @cached_property def strong_reduced_functional(self): """A :class:`pyadjoint.ReducedFunctional` for the strong constraint 4DVar system. @@ -362,13 +259,13 @@ def __getattr__(self, attr): """ If using strong constraint then grab attributes from self.strong_reduced_functional. """ - if self.weak_constraint: + # hasattr calls getattr, so check self.__dir__ directly here to avoid recursion + if self.weak_constraint or "_strong_reduced_functional" not in dir(self): raise AttributeError(f"'{type(self)}' object has no attribute '{attr}'") - else: - return getattr(self.strong_reduced_functional, attr) + return getattr(self.strong_reduced_functional, attr) @sc_passthrough - @no_annotations + @stop_annotating() def __call__(self, values: OverloadedType): """Computes the reduced functional with supplied control value. @@ -387,53 +284,70 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - # controls are updated by the sub ReducedFunctionals - # so we don't need to do it ourselves + value = values[0] if isinstance(values, list) else values - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. - # Model i propogates from i-1 to i. - # Observation i is at i. + if not isinstance(value, type(self.control.control)): + raise ValueError(f"Value must be of type {type(self.control.control)} not type {type(value)}") - for c, v in zip(self.controls, values): - c.control.assign(v) + self.control.update(value) + self._cbuf.assign(value) - model_stages = [None, *self.forward_model_stages] - model_errors = [None, *self.forward_model_errors] - model_rfs = [None, *self.forward_model_rfs] + trank = self.trank - observation_errors = (self.observation_errors if self.initial_observations - else [None, *self.observation_errors]) + # first "control" for later ranks is the halo + if self.ensemble and trank != 0: + x = [self.xprev, *self._x] + else: + x = [*self._x] + + # post messages for control of forward model propogation on next chunk + if self.ensemble: + src = trank - 1 + dst = trank + 1 + + if trank != self.nchunks - 1: + self.ensemble.isend( + x[-1], dest=dst, tag=dst) - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) + if trank != 0: + recv_reqs = self.ensemble.irecv( + self.xprev, source=src, tag=trank) # Initial condition functionals - bkg_err_vec = self.background_error(values[0]) - J = self.background_rf(bkg_err_vec) + if trank == 0: + Jlocal = ( + self.background_norm( + self.background_error(x[0]))) - # observations at time 0 - if self.initial_observations: - obs_err_vec = observation_errors[0](values[0]) - J += observation_rfs[0](obs_err_vec) + # observations at time 0 + if self.initial_observations: + Jlocal += ( + self.initial_observation_norm( + self.initial_observation_error(x[0]))) + else: + Jlocal = 0. - for i in range(1, len(observation_rfs)): - prev_control = values[i-1] - this_control = values[i] + # evaluate all stages on chunk except first + for i in range(1, len(self.stages)): + Jlocal += self.stages[i](x[i:i+2]) - # observation error - do we match the 'real world'? - obs_err_vec = observation_errors[i](this_control) - J += observation_rfs[i](obs_err_vec) + # wait for halo swap to finish + if trank != 0: + MPI.Request.Waitall(recv_reqs) - # Model error - does propogation from last control match this control? - Mi = model_stages[i](prev_control) - model_err_vec = model_errors[i]([Mi, this_control]) - J += model_rfs[i](model_err_vec) + # evaluate first stage model on chunk now we have data + Jlocal += self.stages[0](x[0:2]) + + # sum all stages + if self.ensemble: + J = self.ensemble.ensemble_comm.allreduce(Jlocal) + else: + J = Jlocal return J @sc_passthrough - @no_annotations + @stop_annotating() def derivative(self, adj_input: float = 1.0, options: dict = {}): """Returns the derivative of the functional w.r.t. the control. Using the adjoint method, the derivative of the functional with @@ -454,64 +368,95 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): The derivative with respect to the control. Should be an instance of the same type as the control. """ - # create a list of overloaded types to put derivative into - derivatives = [] + trank = self.trank # chaining ReducedFunctionals means we need to pass Cofunctions not Functions - intermediate_options = { - 'riesz_representation': None, - **{k: v for k, v in options.items() - if (k != 'riesz_representation')} - } - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. Model i propogates from i-1 to i. - model_stages = [None, *self.forward_model_stages] - model_errors = [None, *self.forward_model_errors] - model_rfs = [None, *self.forward_model_rfs] + options = options or {} + intermediate_options = _intermediate_options(options) + + # evaluate first forward model, which contributes to previous chunk + sderiv0 = self.stages[0].derivative( + adj_input=adj_input, options=options) + + # create the derivative in the right primal or dual space + from ufl.duals import is_primal, is_dual + if is_primal(sderiv0[0]): + from firedrake.ensemblefunction import EnsembleFunction + derivatives = EnsembleFunction( + self.ensemble, self.control.local_function_spaces) + else: + if not is_dual(sderiv0[0]): + raise ValueError( + "Do not know how to handle stage derivative which is not primal or dual") + from firedrake.ensemblefunction import EnsembleCofunction + derivatives = EnsembleCofunction( + self.ensemble, [V.dual() for V in self.control.local_function_spaces]) - observation_errors = (self.observation_errors if self.initial_observations - else [None, *self.observation_errors]) + derivatives.zero() - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) + if self.ensemble: + with stop_annotating(): + xprev = derivatives.subfunctions[0]._ad_copy() + xnext = derivatives.subfunctions[0]._ad_copy() + xprev.zero() + xnext.zero() + if trank != 0: + derivs = [xprev, *derivatives.subfunctions] + else: + derivs = [*derivatives.subfunctions] + + # start accumulating the complete derivative + derivs[0] += sderiv0[0] + derivs[1] += sderiv0[1] + + # post the derivative halo exchange + if self.ensemble: + # halos sent backward in time + src = trank + 1 + dst = trank - 1 + + if trank != 0: + self.ensemble.isend( + derivs[0], dest=dst, tag=dst) + + if trank != self.nchunks - 1: + recv_reqs = self.ensemble.irecv( + xnext, source=src, tag=trank) # initial condition derivatives - bkg_deriv = self.background_rf.derivative(adj_input=adj_input, - options=intermediate_options) - derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, - options=options)) - - # observations at time 0 - if self.initial_observations: - obs_deriv = observation_rfs[0].derivative(adj_input=adj_input, - options=intermediate_options) - derivatives[0] += observation_errors[0].derivative(adj_input=obs_deriv, - options=options) - - for i in range(1, len(observation_rfs)): - obs_deriv = observation_rfs[i].derivative(adj_input=adj_input, - options=intermediate_options) - derivatives.append(observation_errors[i].derivative(adj_input=obs_deriv, - options=options)) - - # derivative of model error will split: - # wrt x_i through error vector - # wrt x_i-1 through stage propogation - model_deriv = model_rfs[i].derivative(adj_input=adj_input, - options=intermediate_options) - model_err_derivs = model_errors[i].derivative(adj_input=model_deriv, - options=intermediate_options) - model_stage_deriv = model_stages[i].derivative(adj_input=model_err_derivs[0], - options=options) - - derivatives[i-1] += model_stage_deriv - derivatives[i] += model_err_derivs[1].riesz_representation() + if trank == 0: + bkg_deriv = self.background_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.background_error.derivative( + adj_input=bkg_deriv, options=options) + + # observations at time 0 + if self.initial_observations: + obs_deriv = self.initial_observation_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.initial_observation_error.derivative( + adj_input=obs_deriv, options=options) + + # # evaluate all forward models on chunk except first while halo in flight + for i in range(1, len(self.stages)): + sderiv = self.stages[i].derivative( + adj_input=adj_input, options=options) + + derivs[i] += sderiv[0] + derivs[i+1] += sderiv[1] + + # finish the derivative halo exchange + if self.ensemble: + if trank != self.nchunks - 1: + MPI.Request.Waitall(recv_reqs) + derivs[-1] += xnext return derivatives @sc_passthrough - @no_annotations + @stop_annotating() def hessian(self, m_dot: OverloadedType, options: dict = {}): """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. @@ -529,64 +474,509 @@ def hessian(self, m_dot: OverloadedType, options: dict = {}): A dictionary of options. To find a list of available options have a look at the specific control type. + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + Returns ------- pyadjoint.OverloadedType The action of the Hessian in the direction m_dot. Should be an instance of the same type as the control. """ - # create a list of overloaded types to put hessian into - hessians = [] - - kwargs = {'options': options} - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. Model i propogates from i-1 to i. - model_rfs = [None, *self.forward_model_rfs] - - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) + raise ValueError("Not implemented yet") - # initial condition hessians - hessians.append( - self.background_rf.hessian(m_dot[0], **kwargs)) - - if self.initial_observations: - hessians[0] += observation_rfs[0].hessian(m_dot[0], **kwargs) - - for i in range(1, len(model_rfs)): - hessians.append(observation_rfs[i].hessian(m_dot[i], **kwargs)) - - mhess = model_rfs[i].hessian(m_dot[i-1:i+1], **kwargs) - - hessians[i-1] += mhess[0] - hessians[i] += mhess[1] - - return hessians - - @no_annotations + @stop_annotating() def hessian_matrix(self): # Other reduced functionals don't have this. if not self.weak_constraint: raise AttributeError("Strong constraint 4DVar does not form a Hessian matrix") raise NotImplementedError - @sc_passthrough - @no_annotations - def optimize_tape(self): - for rf in (self.background_error, - self.background_rf, - *self.observation_errors, - *self.observation_rfs, - *self.forward_model_stages, - *self.forward_model_errors, - *self.forward_model_rfs): - rf.optimize_tape() - def _accumulate_functional(self, val): if not self._annotate_accumulation: return - if hasattr(self, '_total_functional'): + if self._accumulation_started: self._total_functional += val else: self._total_functional = val + self._accumulation_started = True + + @contextmanager + def recording_stages(self, sequential=True, **stage_kwargs): + if not sequential: + raise ValueError("Recording stages concurrently not yet implemented") + + # record over ensemble + if self.weak_constraint: + + trank = self.trank + + # index of "previous" stage and observation in global series + global_index = -1 + observation_index = 0 if self.initial_observations else -1 + with stop_annotating(): + xhalo = self._x[0]._ad_copy() + + # add our data onto the user's context data + ekwargs = {k: v for k, v in stage_kwargs.items()} + ekwargs['global_index'] = global_index + ekwargs['observation_index'] = observation_index + + ekwargs['xhalo'] = xhalo + + # proceed one ensemble rank at a time + with self.ensemble.sequential(**ekwargs) as ectx: + + # later ranks start from halo + if trank == 0: + controls = self._controls + else: + controls = [self._control_prev, *self._controls] + with stop_annotating(): + controls[0].assign(ectx.xhalo) + + # grab the user's data from the ensemble context + local_stage_kwargs = { + k: getattr(ectx, k) for k in stage_kwargs.keys() + } + + # initialise iterator for local stages + stage_sequence = ObservationStageSequence( + controls, self, ectx.global_index, + ectx.observation_index, + local_stage_kwargs, sequential) + + # let the user record the local stages + yield stage_sequence + + # send the state forward + with stop_annotating(): + state = self.stages[-1].controls[1].control + ectx.xhalo.assign(state) + # grab the user's information to send forward + for k in local_stage_kwargs.keys(): + setattr(ectx, k, getattr(stage_sequence.ctx, k)) + # increment the global indices for the last local stage + ectx.global_index = self.stages[-1].global_index + ectx.observation_index = self.stages[-1].observation_index + + else: # strong constraint + + yield ObservationStageSequence( + self.controls, self, stage_kwargs, sequential=True) + + +class ObservationStageSequence: + def __init__(self, controls: Control, + aaorf: AllAtOnceReducedFunctional, + global_index: int, + observation_index: int, + stage_kwargs: dict = None, + sequential: bool = True): + self.controls = controls + self.nstages = len(controls) - 1 + self.aaorf = aaorf + self.ctx = StageContext(**(stage_kwargs or {})) + self.weak_constraint = aaorf.weak_constraint + self.global_index = global_index + self.observation_index = observation_index + self.local_index = -1 + + def __iter__(self): + return self + + def __next__(self): + + if self.weak_constraint: + stages = self.aaorf.stages + + # increment global indices + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + + # start of the next stage + next_control = self.controls[self.local_index] + + # smuggle state forward and increment stage indices + if self.local_index > 0: + state = stages[-1].controls[1].control + with stop_annotating(): + next_control.control.assign(state) + + # stop after we've recorded all stages + if self.local_index >= self.nstages: + raise StopIteration + + stage = WeakObservationStage(next_control, + local_index=self.local_index, + global_index=self.global_index, + observation_index=self.observation_index) + stages.append(stage) + + else: # strong constraint + + # increment stage indices + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + + # stop after we've recorded all stages + if self.index >= self.nstages: + raise StopIteration + self.index += 1 + + # dummy control to "start" stage from + control = (self.aaorf.controls[0].control if self.index == 0 + else self._prev_stage.state) + + stage = StrongObservationStage(control, self.aaorf) + self._prev_stage = stage + + return stage, self.ctx + + +class StageContext: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class StrongObservationStage: + """ + Record an observation for strong constraint 4DVar at the time of `state`. + + Parameters + ---------- + + aaorf + The strong constraint AllAtOnceReducedFunctional. + + """ + + def __init__(self, control: OverloadedType, + aaorf: AllAtOnceReducedFunctional): + self.aaorf = aaorf + self.control = control + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + """ + if hasattr(self.aaorf, "_strong_reduced_functional"): + raise ValueError("Cannot add observations once strong" + " constraint ReducedFunctional instantiated") + self.aaorf._accumulate_functional( + observation_iprod(observation_err(state))) + self.state = state + + +class WeakObservationStage: + """ + A single stage for weak constraint 4DVar at the time of `state`. + + Records the forward model propogation from the control at the beginning + of the stage, and the model and observation errors at the end of the stage. + + Parameters + ---------- + + control + The control x_{i-1} at the beginning of the stage + + local_index + The index of this stage in the timeseries on the + local ensemble member. + + global_index + The index of this stage in the global timeseries. + + observation_index + The index of the observation at the end of this stage in + the global timeseries. May be different from global_index if + an observation is taken at the initial time. + + """ + def __init__(self, control: Control, + local_index: Optional[int] = None, + global_index: Optional[int] = None, + observation_index: Optional[int] = None): + # "control" to use as initial condition. + # Not actual `Control` for consistency with strong constraint + self.control = control.control + + self.controls = Enlist(control) + self.local_index = local_index + self.global_index = global_index + self.observation_index = observation_index + set_working_tape() + self._stage_tape = get_working_tape() + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat], + forward_model_iprod: Callable[[OverloadedType], AdjFloat]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + + forward_model_iprod + The inner product to calculate the model error functional from + the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can + include the error covariance matrix. + """ + # get the tape used for this stage and make sure its the right one + stage_tape = get_working_tape() + if stage_tape is not self._stage_tape: + raise ValueError( + "Working tape at the end of the observation stage" + " differs from the tape at the stage beginning." + ) + + # record forward propogation + with set_working_tape(stage_tape.copy()) as tape: + self.forward_model = ReducedFunctional( + state._ad_copy(), controls=self.controls[0], tape=tape) + + # Beginning of next time-chunk is the control for this observation + # and the state at the end of the next time-chunk. + with stop_annotating(): + # smuggle initial guess at this time into the control without the tape seeing + self.controls.append(Control(state._ad_copy())) + if self.global_index: + _rename(self.controls[-1].control, f"Control_{self.global_index}") + + # model error links time-chunks by depending on both the + # previous and current controls + + # RF to recalculate error vector (M_i - x_i) + names = { + 'functional_name': f"model_err_vec_{self.global_index}", + 'control_name': [f"state_{self.global_index}_copy", + f"Control_{self.global_index}_model_copy"] + } if self.global_index else {} + + self.model_error = isolated_rf( + operation=lambda controls: _ad_sub(*controls), + control=[state, self.controls[-1].control], + **names) + + # RF to recalculate inner product |M_i - x_i|_Q + names = { + 'control_name': f"model_err_vec_{self.global_index}_copy" + } if self.global_index else {} + + self.model_norm = isolated_rf( + operation=forward_model_iprod, + control=self.model_error.functional, + **names) + + # Observations after tape cut because this is now a control, not a state + + # RF to recalculate error vector (H(x_i) - y_i) + names = { + 'functional_name': f"obs_err_vec_{self.global_index}", + 'control_name': f"Control_{self.global_index}_obs_copy" + } if self.global_index else {} + + self.observation_error = isolated_rf( + operation=observation_err, + control=self.controls[-1], + **names) + + # RF to recalculate inner product |H(x_i) - y_i|_R + names = { + 'functional_name': "obs_err_vec_{self.global_index}_copy" + } if self.global_index else {} + self.observation_norm = isolated_rf( + operation=observation_iprod, + control=self.observation_error.functional, + **names) + + # remove the stage initial condition "control" now we've finished recording + delattr(self, "control") + + # stop the stage tape recording anything else + set_working_tape() + + @stop_annotating() + def __call__(self, values: OverloadedType, + rftype: Optional[str] = None): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ + J = 0.0 + + # evaluate model error + if (rftype is None) or (rftype == 'model'): + Mi = self.forward_model(values[0]) + J += self.model_norm(self.model_error([Mi, values[1]])) + + # evaluate observation errors + if (rftype is None) or (rftype == 'obs'): + J += self.observation_norm(self.observation_error(values[1])) + + return J + + @stop_annotating() + def derivative(self, adj_input: float = 1.0, options: dict = {}, + rftype: Optional[str] = None): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + """ + # create a list of overloaded types to put derivative into + derivatives = [] + + # chaining ReducedFunctionals means we need to pass Cofunctions not Functions + options = options or {} + intermediate_options = _intermediate_options(options) + + if (rftype is None) or (rftype == 'model'): + # derivative of reduction + dm_norm = self.model_norm.derivative(adj_input=adj_input, + options=intermediate_options) + + # derivative of difference splits into (Mi, xi) + dm_errors = self.model_error.derivative(adj_input=dm_norm, + options=intermediate_options) + + # derivative through the forward model wrt to xprev + dm_forward = self.forward_model.derivative(adj_input=dm_errors[0], + options=options) + + derivatives.append(dm_forward) + + # dm_errors is still in the dual space, so we need to convert it to the + # type that the user has requested - this will be the type of dm_forward. + derivatives.append(dm_forward._ad_convert_type(dm_errors[1], options)) + + if (rftype is None) or (rftype == 'obs'): + # derivative of reduction + do_norm = self.observation_norm.derivative(adj_input=adj_input, + options=intermediate_options) + # derivative of error + do_error = self.observation_error.derivative(adj_input=do_norm, + options=options) + + if len(derivatives) == 0: + derivatives.append(None) + derivatives.append(do_error) + else: + derivatives[1] += do_error + + return derivatives + + @stop_annotating() + def hessian(self, m_dot: OverloadedType, options: dict = {}, + rftype: Optional[str] = None): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + """ + pass diff --git a/firedrake/adjoint_utils/__init__.py b/firedrake/adjoint_utils/__init__.py index 3b3426a850..c71da61ade 100644 --- a/firedrake/adjoint_utils/__init__.py +++ b/firedrake/adjoint_utils/__init__.py @@ -12,3 +12,4 @@ from firedrake.adjoint_utils.solving import * # noqa: F401 from firedrake.adjoint_utils.mesh import * # noqa: F401 from firedrake.adjoint_utils.checkpointing import * # noqa: F401 +from firedrake.adjoint_utils.ensemblefunction import * # noqa: F401 diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py new file mode 100644 index 0000000000..ba6882f054 --- /dev/null +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -0,0 +1,101 @@ +from pyadjoint.overloaded_type import OverloadedType +from firedrake.petsc import PETSc +from .checkpointing import disk_checkpointing + +from functools import wraps + + +class EnsembleFunctionMixin(OverloadedType): + + @staticmethod + def _ad_annotate_init(init): + @wraps(init) + def wrapper(self, *args, **kwargs): + OverloadedType.__init__(self) + init(self, *args, **kwargs) + return wrapper + + @staticmethod + def _ad_to_list(m): + with m.vec_ro() as gvec: + lcomm = PETSc.COMM_SELF + gsize = gvec.size + lvec = PETSc.Vec().createSeq(gsize, comm=lcomm) + is_ = PETSc.IS().createStride(gsize, 0, 1, comm=lcomm) + + mode = PETSc.InsertMode.INSERT_VALUES + scatter = PETSc.Scatter().create(gvec, is_, lvec, None) + scatter.scatterBegin(gvec, lvec, addv=mode) + scatter.scatterEnd(gvec, lvec, addv=mode) + + return lvec.array_r.tolist() + + @staticmethod + def _ad_assign_numpy(dst, src, offset): + with dst.vec_wo() as vec: + begin, end = vec.owner_range + src_array = src[offset + begin: offset + end] + vec.array[:] = src_array + offset += vec.size + return dst, offset + + def _ad_dot(self, other, options=None): + # local dot product + ldot = sum( + uself._ad_dot(uother, options=options) + for uself, uother in zip(self.subfunctions, + other.subfunctions)) + # global dot product + gdot = self.ensemble.ensemble_comm.allreduce(ldot) + return gdot + + def _ad_add(self, other): + new = self.copy() + new += other + return new + + def _ad_mul(self, other): + new = self.copy() + # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + other = other._fbuf if type(other) is type(self) else other + new._fbuf.assign(other*new._fbuf) + return new + + def _ad_iadd(self, other): + self += other + return self + + def _ad_imul(self, other): + self *= other + return self + + def _ad_copy(self): + return self.copy() + + def _ad_convert_riesz(self, value, options=None): + raise ValueError("NotImplementedYet") + + def _ad_create_checkpoint(self): + if disk_checkpointing(): + raise NotImplementedError( + "Disk checkpointing not implemented for EnsembleFunctions") + else: + return self.copy() + + def _ad_restore_at_checkpoint(self, checkpoint): + if isinstance(checkpoint, type(self)): + return checkpoint + raise NotImplementedError( + "Checkpointing not implemented for EnsembleFunctions") + + def _ad_from_petsc(self, vec): + with self.vec_wo as self_v: + vec.copy(result=self_v) + + def _ad_to_petsc(self, vec=None): + with self.vec_ro as self_v: + if vec: + self_v.copy(result=vec) + else: + vec = self_v.copy() + return vec diff --git a/firedrake/assign.py b/firedrake/assign.py index 8d1b30681a..892ed5b1b0 100644 --- a/firedrake/assign.py +++ b/firedrake/assign.py @@ -100,6 +100,9 @@ def component_tensor(self, o, a, _): def coefficient(self, o): return ((o, 1),) + def cofunction(self, o): + return ((o, 1),) + def constant_value(self, o): return ((o, 1),) diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index f847be51bf..5bb5519044 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -1,8 +1,11 @@ import weakref +from contextlib import contextmanager +from itertools import zip_longest from firedrake.petsc import PETSc +from firedrake.function import Function +from firedrake.cofunction import Cofunction from pyop2.mpi import MPI, internal_comm -from itertools import zip_longest __all__ = ("Ensemble", ) @@ -283,3 +286,60 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) for dat in frecv.dat]) return requests + + @contextmanager + def sequential(self, **kwargs): + """ + Context manager for executing code on each ensemble + member in turn. + + Any data in `kwargs` will be made available in the context + and will be communicated forward after each ensemble member + exits. Firedrake Functions/Cofunctions will be send with the + corresponding Ensemble methods. + + with ensemble.sequential(index=0) as ctx: + print(ensemble.ensemble_comm.rank, ctx.index) + ctx.index += 2 + + Would print: + 0 0 + 1 2 + 2 4 + 3 6 + ... etc ... + + """ + rank = self.ensemble_comm.rank + first_rank = (rank == 0) + last_rank = (rank == self.ensemble_comm.size - 1) + + if not first_rank: + src = rank - 1 + for i, (k, v) in enumerate(kwargs.items()): + recv_kwargs = {'source': src, 'tag': rank+i*100} + if isinstance(v, (Function, Cofunction)): + self.recv(kwargs[k], **recv_kwargs) + else: + kwargs[k] = self.ensemble_comm.recv( + **recv_kwargs) + + ctx = _EnsembleContext(**kwargs) + + yield ctx + + if not last_rank: + dst = rank + 1 + for i, v in enumerate((getattr(ctx, k) + for k in kwargs.keys())): + send_kwargs = {'dest': dst, 'tag': dst+i*100} + if isinstance(v, (Function, Cofunction)): + self.send(v, **send_kwargs) + else: + self.ensemble_comm.send(v, **send_kwargs) + + +class _EnsembleContext: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) diff --git a/firedrake/ensemblefunction.py b/firedrake/ensemblefunction.py new file mode 100644 index 0000000000..33ed7e7b2f --- /dev/null +++ b/firedrake/ensemblefunction.py @@ -0,0 +1,295 @@ +from firedrake.petsc import PETSc +from firedrake.adjoint_utils import EnsembleFunctionMixin +from firedrake.functionspace import MixedFunctionSpace +from firedrake.function import Function +from ufl.duals import is_primal, is_dual +from pyop2 import MixedDat + +from functools import cached_property +from contextlib import contextmanager + +__all__ = ("EnsembleFunction", "EnsembleCofunction") + + +class EnsembleFunctionBase(EnsembleFunctionMixin): + """ + A mixed finite element (co)function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The sub(co)functions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each (co)function on the + local ensemble member. + """ + + @PETSc.Log.EventDecorator() + @EnsembleFunctionMixin._ad_annotate_init + def __init__(self, ensemble, function_spaces): + self.ensemble = ensemble + self.local_function_spaces = function_spaces + self.local_size = len(function_spaces) + + # the local functions are stored as a big mixed space + self._function_space = MixedFunctionSpace(function_spaces) + self._fbuf = Function(self._function_space) + + # create a Vec containing the data for all functions on all + # ensemble members. Because we use the Vec of each local mixed + # function as the storage, if the data in the Function Vec + # is valid then the data in the EnsembleFunction Vec is valid. + + with self._fbuf.dat.vec as fvec: + local_size = self._function_space.node_set.size + sizes = (local_size, PETSc.DETERMINE) + self._vec = PETSc.Vec().createWithArray(fvec.array, + size=sizes, + comm=ensemble.global_comm) + self._vec.setFromOptions() + + @cached_property + def subfunctions(self): + """ + The (co)functions on the local ensemble member + """ + def local_function(i): + V = self.local_function_spaces[i] + usubs = self._subcomponents(i) + if len(usubs) == 1: + dat = usubs[0].dat + else: + dat = MixedDat((u.dat for u in usubs)) + return Function(V, val=dat) + + self._subfunctions = tuple(local_function(i) + for i in range(self.local_size)) + return self._subfunctions + + def _subcomponents(self, i): + """ + Return the subfunctions of the local mixed function storage + corresponding to the i-th local function. + """ + return tuple(self._fbuf.subfunctions[j] + for j in self._component_indices(i)) + + def _component_indices(self, i): + """ + Return the indices into the local mixed function storage + corresponding to the i-th local function. + """ + V = self.local_function_spaces[i] + offset = sum(len(V) for V in self.local_function_spaces[:i]) + return tuple(offset + i for i in range(len(V))) + + @PETSc.Log.EventDecorator() + def riesz_representation(self, riesz_map="L2", **kwargs): + """ + Return the Riesz representation of this :class:`EnsembleFunction` + with respect to the given Riesz map. + + Parameters + ---------- + + riesz_map + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + + kwargs + other arguments to be passed to the firedrake.riesz_map. + """ + DualType = { + EnsembleFunction: EnsembleCofunction, + EnsembleCofunction: EnsembleFunction, + }[type(self)] + Vdual = [V.dual() for V in self.local_function_spaces] + riesz = DualType(self.ensemble, Vdual) + for u in riesz.subfunctions: + u.assign(u.riesz_representation(riesz_map=riesz_map, **kwargs)) + return riesz + + @PETSc.Log.EventDecorator() + def assign(self, other, subsets=None): + r"""Set the :class:`EnsembleFunction` to the value of another + :class:`EnsembleFunction` other. + + Parameters + ---------- + + other + The :class:`EnsembleFunction` to assign from. + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if type(other) is not type(self): + raise ValueError( + f"Cannot assign {type(self)} from {type(other)}") + if subsets: + for i in range(self.local_size): + self.subfunctions[i].assign( + other.subfunctions[i], subset=subsets[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]) + return self + + @PETSc.Log.EventDecorator() + def copy(self): + """ + Return a deep copy of the :class:`EnsembleFunction`. + """ + new = type(self)(self.ensemble, self.local_function_spaces) + new.assign(self) + return new + + @PETSc.Log.EventDecorator() + def zero(self, subsets=None): + """ + Set values to zero. + + Parameters + ---------- + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if subsets: + for i in range(self.local_size): + self.subfunctions[i].zero(subsets[i]) + else: + for u in self.subfunctions: + u.zero() + return self + + @PETSc.Log.EventDecorator() + def __iadd__(self, other): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us + uo) + return self + + @PETSc.Log.EventDecorator() + def __imul__(self, other): + if type(other) is type(self): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us*uo) + else: + for us in self.subfunctions: + us *= other + return self + + @PETSc.Log.EventDecorator() + def __add__(self, other): + new = self.copy() + for i in range(self.local_size): + new.subfunctions[i] += other.subfunctions[i] + return new + + @PETSc.Log.EventDecorator() + def __mul__(self, other): + new = self.copy() + if type(other) is type(self): + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]*self.subfunctions[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other*self.subfunctions[i]) + return new + + @PETSc.Log.EventDecorator() + def __rmul__(self, other): + return self.__mul__(other) + + @contextmanager + def vec(self): + """ + Context manager for the global PETSc Vec with read/write access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager so that the data gets copied to/from + # the Function.dat storage and _vec. + # However, this copy is done without _vec knowing, so we have + # to manually increment the state. + with self._fbuf.dat.vec: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_ro(self): + """ + Context manager for the global PETSc Vec with read only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # to the Function.dat storage and _vec. + with self._fbuf.dat.vec_ro: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_wo(self): + """ + Context manager for the global PETSc Vec with write only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # from the Function.dat storage and _vec. + with self._fbuf.dat.vec_wo: + yield self._vec + + +class EnsembleFunction(EnsembleFunctionBase): + """ + A mixed finite element Function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subfunctions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each function on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_primal(V) for V in function_spaces): + raise TypeError( + "EnsembleFunction must be created using primal FunctionSpaces") + super().__init__(ensemble, function_spaces) + + +class EnsembleCofunction(EnsembleFunctionBase): + """ + A mixed finite element Cofunction distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subcofunctions are distributed + over the different ensemble members. + + function_spaces + A list of dual function spaces for each cofunction on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_dual(V) for V in function_spaces): + raise TypeError( + "EnsembleCofunction must be created using dual FunctionSpaces") + super().__init__(ensemble, function_spaces) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py new file mode 100644 index 0000000000..bf0530293f --- /dev/null +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -0,0 +1,325 @@ +import pytest +import firedrake as fd +from firedrake.__future__ import interpolate +from firedrake.adjoint import ( + continue_annotation, pause_annotation, stop_annotating, set_working_tape, + Control, taylor_test, ReducedFunctional, AllAtOnceReducedFunctional) + + +def function_space(comm): + """DG0 periodic advection""" + mesh = fd.PeriodicUnitIntervalMesh(nx, comm=comm) + return fd.FunctionSpace(mesh, "DG", 0) + + +def timestepper(V): + """Implicit midpoint timestepper for the advection equation""" + qn = fd.Function(V, name="qn") + qn1 = fd.Function(V, name="qn1") + + def mass(q, phi): + return fd.inner(q, phi)*fd.dx + + def tendency(q, phi): + u = fd.as_vector([vconst]) + n = fd.FacetNormal(V.mesh()) + un = fd.Constant(0.5)*(fd.dot(u, n) + abs(fd.dot(u, n))) + return (- q*fd.div(phi*u)*fd.dx + + fd.jump(phi)*fd.jump(un*q)*fd.dS) + + # midpoint rule + q = fd.TrialFunction(V) + phi = fd.TestFunction(V) + + qh = fd.Constant(0.5)*(q + qn) + eqn = mass(q - qn, phi) + fd.Constant(dt)*tendency(qh, phi) + + stepper = fd.LinearVariationalSolver( + fd.LinearVariationalProblem( + fd.lhs(eqn), fd.rhs(eqn), qn1, + constant_jacobian=True)) + + return qn, qn1, stepper + + +def prod2(w): + """generate weighted inner products to pass to FourDVarReducedFunctional""" + def n2(x): + return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx) + return n2 + + +prodB = prod2(0.1) # background error +prodR = prod2(10.) # observation error +prodQ = prod2(1.0) # model error + + +"""Advecting velocity""" +velocity = 1 +vconst = fd.Constant(velocity) + +"""Number of cells""" +nx = 16 + +"""Timestep size""" +cfl = 2.3523 +dx = 1.0/nx +dt = cfl*dx/velocity + +"""How many times / how often we take observations +(one extra at initial time)""" +observation_frequency = 5 +observation_n = 6 +observation_times = [i*observation_frequency*dt + for i in range(observation_n+1)] + + +def nlocal_observations(ensemble): + """How many observations on the current ensemble member""" + esize = ensemble.ensemble_comm.size + erank = ensemble.ensemble_comm.rank + if esize == 1: + return observation_n + 1 + assert (observation_n % esize == 0), "Must be able to split observations across ensemble" # noqa: E501 + return observation_n//esize + (1 if erank == 0 else 0) + + +def analytic_solution(V, t, mag=1.0, phase=0.0): + """Exact advection of sin wave after time t""" + x, = fd.SpatialCoordinate(V.mesh()) + return fd.Function(V).interpolate( + mag*fd.sin(2*fd.pi*((x + phase) - vconst*t))) + + +def analytic_series(V, tshift=0.0, mag=1.0, phase=0.0, ensemble=None): + """Timeseries of the analytic solution""" + series = [analytic_solution(V, t+tshift, mag=mag, phase=phase) + for t in observation_times] + + if ensemble is None: + return series + else: + nlocal_obs = nlocal_observations(ensemble) + rank = ensemble.ensemble_comm.rank + offset = (0 if rank == 0 else rank*nlocal_obs + 1) + + efunc = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + for e, s in zip(efunc.subfunctions, + series[offset:offset+nlocal_obs]): + e.assign(s) + return efunc + + +def observation_errors(V): + """List of functions to evaluate the observation error + at each observation time""" + + observation_locations = [ + [x] for x in [0.13, 0.18, 0.34, 0.36, 0.49, 0.61, 0.72, 0.99] + ] + + observation_mesh = fd.VertexOnlyMesh(V.mesh(), observation_locations) + Vobs = fd.FunctionSpace(observation_mesh, "DG", 0) + + # observation operator + def H(x): + return fd.assemble(interpolate(x, Vobs)) + + # ground truth + targets = analytic_series(V) + + # take observations + y = [H(x) for x in targets] + + # generate function to evaluate observation error at observation time i + def observation_error(i): + def obs_err(x): + return fd.Function(Vobs).assign(H(x) - y[i]) + return obs_err + + return observation_error + + +def background(V): + """Prior for initial condition""" + return analytic_solution(V, t=0, mag=0.9, phase=0.1) + + +def m(V, ensemble=None): + """The expansion points for the Taylor test""" + return analytic_series(V, tshift=0.1, mag=1.1, phase=-0.2, + ensemble=ensemble) + + +def h(V, ensemble=None): + """The perturbation direction for the Taylor test""" + return analytic_series(V, tshift=0.3, mag=0.1, phase=0.3, + ensemble=ensemble) + + +def fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + controls = [fd.Function(V) + for _ in range(len(observation_times))] + + # Prior + bkg = background(V) + + controls[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # background error + J = prodB(controls[0] - bkg) + + # initial observation error + J += prodR(obs_errors(0)(controls[0])) + + # record observation stages + for i in range(1, len(controls)): + qn.assign(controls[i-1]) + + # forward model propogation + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # we need to smuggle the state over to next + # control without the tape seeing so that we + # can continue the timeseries through the next + # stage but with the tape thinking that the + # forward model in each stage is independent. + with stop_annotating(): + controls[i].assign(qn) + + # model error for this stage + J += prodQ(qn - controls[i]) + + # observation error + J += prodR(obs_errors(i)(controls[i])) + + pause_annotation() + + Jhat = ReducedFunctional( + J, [Control(c) for c in controls]) + + return Jhat + + +def fdvar_firedrake(V, ensemble): + """Build an AllAtOnceReducedFunctional for the 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + + nlocal_obs = nlocal_observations(ensemble) + + control = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + # Prior + bkg = background(V) + + if ensemble.ensemble_comm.rank == 0: + control.subfunctions[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = AllAtOnceReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=True) + + # record observation stages + with Jhat.recording_stages() as stages: + + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # take observation + obs_err = obs_errors(stage.observation_index) + stage.set_observation(qn, obs_err, + observation_iprod=prodR, + forward_model_iprod=prodQ) + + pause_annotation() + + return Jhat + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_advection(): + main_test_advection() + + +def main_test_advection(): + global_comm = fd.COMM_WORLD + if global_comm.size in (1, 2): # time serial + nspace = global_comm.size + elif global_comm.size == 3: # time parallel + nspace = 1 + elif global_comm.size == 4: # space-time parallel + nspace = 2 + + ensemble = fd.Ensemble(global_comm, nspace) + V = function_space(ensemble.comm) + + erank = ensemble.ensemble_comm.rank + + # only setup the reference pyadjoint rf on the first ensemble member + if erank == 0: + Jhat_pyadj = fdvar_pyadjoint(V) + mp = m(V) + hp = h(V) + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 + + Jpm = ensemble.ensemble_comm.bcast(Jhat_pyadj(mp) if erank == 0 else None) + Jph = ensemble.ensemble_comm.bcast(Jhat_pyadj(hp) if erank == 0 else None) + + Jhat_aaorf = fdvar_firedrake(V, ensemble) + + ma = m(V, ensemble) + ha = h(V, ensemble) + + eps = 1e-12 + # Does evaluating the functional match the reference rf? + assert abs(Jpm - Jhat_aaorf(ma)) < eps + assert abs(Jph - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor test + # should mean we match the derivative too. + assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + + +if __name__ == '__main__': + main_test_advection()