From 6406402f1367f4029f2b2a01c382a8fdb4113c74 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 15 Nov 2024 12:07:29 +0000 Subject: [PATCH 01/23] aaorf - initial parallel impl - observation stages distributed over ensemble - context manager for recording the forward model and passing data between stages - iterator to record each stage and set the observation --- .../adjoint/all_at_once_reduced_functional.py | 1016 +++++++++++------ 1 file changed, 680 insertions(+), 336 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 429000b2a2..6a4e84617c 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,12 +1,47 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ stop_annotating, no_annotations, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist +from firedrake import Ensemble from functools import wraps, cached_property from typing import Callable, Optional +from contextlib import contextmanager +from mpi4py import MPI __all__ = ['AllAtOnceReducedFunctional'] +@set_working_tape(decorator=True) +def isolated_rf(operation, control, + functional_name=None, + control_name=None): + """ + Return a ReducedFunctional where the functional is `operation` applied + to a copy of `control`, and the tape contains only `operation`. + """ + controls = Enlist(control) + control_names = Enlist(control_name) + + with stop_annotating(): + control_copies = [control._ad_copy() for control in controls] + + if control_names: + for control, name in zip(control_copies, control_names): + _rename(control, name) + + if len(control_copies) == 1: + functional = operation(control_copies[0]) + control = Control(control_copies[0]) + else: + functional = operation(control_copies) + control = [Control(control) for control in control_copies] + + if functional_name: + _rename(functional, functional_name) + + return ReducedFunctional( + functional, control) + + def sc_passthrough(func): """ Wraps standard ReducedFunctional methods to differentiate strong or @@ -41,6 +76,18 @@ def _ad_sub(left, right): return result +def _scalarSend(comm, x, **kwargs): + from numpy import ones + comm.Send(x*ones(1, dtype=type(x)), **kwargs) + + +def _scalarRecv(comm, dtype=float, **kwargs): + from numpy import zeros + xtmp = zeros(1, dtype=dtype) + comm.Recv(xtmp, **kwargs) + return xtmp[0] + + class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. @@ -56,6 +103,9 @@ class AllAtOnceReducedFunctional(ReducedFunctional): The initial condition :math:`x_{0}`. Starting value is used as the background (prior) data :math:`x_{b}`. + nlocal_stages + The number of observation stages on the local ensemble member. + background_iprod The inner product to calculate the background error functional from the background error :math:`x_{0} - x_{b}`. Can include the @@ -76,8 +126,10 @@ class AllAtOnceReducedFunctional(ReducedFunctional): weak_constraint Whether to use the weak or strong constraint 4DVar formulation. - tape - The tape to record on. + ensemble + The ensemble communicator to parallelise over. None for no time parallelism. + If `ensemble` is provided, then `background_iprod`, `observation_err` and + `observation_iprod` must only be provided on ensemble rank 0. See Also -------- @@ -85,12 +137,14 @@ class AllAtOnceReducedFunctional(ReducedFunctional): """ def __init__(self, control: Control, - background_iprod: Callable[[OverloadedType], AdjFloat], + nlocal_stages: int, + background_iprod: Optional[Callable[[OverloadedType], AdjFloat]], observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, weak_constraint: bool = True, tape: Optional[Tape] = None, - _annotate_accumulation: bool = False): + _annotate_accumulation: bool = False, + ensemble: Optional[Ensemble] = None): self.tape = get_working_tape() if tape is None else tape @@ -103,87 +157,87 @@ def __init__(self, control: Control, if self.weak_constraint: self._annotate_accumulation = _annotate_accumulation - - # new tape for background error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes + self._accumulation_started = False + + self.nlocal_stages = nlocal_stages + + self.ensemble = ensemble + self.trank = ensemble.ensemble_comm.rank if ensemble else 0 + self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 + + self.stages = [] # The record of each observation stage + self.controls = [] # The solution at the beginning of each time-chunk + + # first rank sets up functionals for background initial observations + if self.trank == 0: + self.controls.append(control) + + # RF to recalculate error vector (x_0 - x_b) + self.background_error = isolated_rf( + operation=lambda x0: _ad_sub(x0, self.background), + control=control, + functional_name="bkg_err_vec", + control_name="Control_0_bkg_copy") + + # RF to recalculate inner product |x_0 - x_b|_B + self.background_norm = isolated_rf( + operation=background_iprod, + control=self.background_error.functional, + control_name="bkg_err_vec_copy") + + if self.initial_observations: + + # RF to recalculate error vector (H(x_0) - y_0) + self.initial_observation_error = isolated_rf( + operation=observation_err, + control=control, + functional_name="obs_err_vec_0", + control_name="Control_0_obs_copy") + + # RF to recalculate inner product |H(x_0) - y_0|_R + self.initial_observation_norm = isolated_rf( + operation=observation_iprod, + control=self.initial_observation_error.functional, + functional_name="obs_err_vec_0_copy") + + else: + # create halo for previous state with stop_annotating(): - control_copy = control.copy_data() - _rename(control_copy, "Control_0_bkg_copy") - - # vector of x_0 - x_b - bkg_err_vec = _ad_sub(control_copy, self.background) - _rename(bkg_err_vec, "bkg_err_vec") - - # RF to recover x_0 - x_b - self.background_error = ReducedFunctional( - bkg_err_vec, Control(control_copy), tape=tape) - - # new tape for background error reduction - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - bkg_err_vec_copy = bkg_err_vec._ad_copy() - _rename(bkg_err_vec_copy, "bkg_err_vec_copy") - - # inner product |x_0 - x_b|_B - bkg_err = background_iprod(bkg_err_vec_copy) + self.xprev = control.copy_data() + self.control_prev = Control(self.xprev) - # RF to recover |x_0 - x_b|_B - self.background_rf = ReducedFunctional( - bkg_err, Control(bkg_err_vec_copy), tape=tape) + if background_iprod is not None: + raise ValueError("Only the first ensemble rank needs `background_iprod`") + if observation_iprod is not None: + raise ValueError("Only the first ensemble rank needs `observation_iprod`") + if observation_err is not None: + raise ValueError("Only the first ensemble rank needs `observation_err`") - self.controls = [control] # The solution at the beginning of each time-chunk - self.states = [] # The model propogation at the end of each time-chunk - self.forward_model_stages = [] # ReducedFunctional for each model propogation (returns state) - self.forward_model_errors = [] # Inner product for model errors (possibly including error covariance) - self.forward_model_rfs = [] # Inner product for model errors (possibly including error covariance) - self.observation_errors = [] # ReducedFunctional for each observation set (returns observation error) - self.observation_rfs = [] # Inner product for observation errors (possibly including error covariance) + # create all controls on local ensemble member + with stop_annotating(): + for _ in range(nlocal_stages): + self.controls.append(Control(control.copy_data())) - if self.initial_observations: + # halo for the derivative from the next chunk + if self.ensemble and self.trank != self.nchunks - 1: + self.xnext = control.copy_data() - # new tape for observation error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - control_copy = control.copy_data() - _rename(control_copy, "Control_0_obs_copy") - - # vector of H(x_0) - y_0 - obs_err_vec = observation_err(control_copy) - _rename(obs_err_vec, "obs_err_vec_0") - - # RF to recover H(x_0) - y_0 - self.observation_errors.append(ReducedFunctional( - obs_err_vec, Control(control_copy), tape=tape) - ) - - # new tape for observation error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - obs_err_vec_copy = obs_err_vec._ad_copy() - _rename(obs_err_vec_copy, "obs_err_vec_0_copy") - - # inner product |H(x_0) - y_0|_R - obs_err = observation_iprod(obs_err_vec_copy) - - # RF to recover |H(x_0) - y_0|_R - self.observation_rfs.append(ReducedFunctional( - obs_err, Control(obs_err_vec_copy), tape=tape) - ) - - # new tape for the next stage - set_working_tape() - self._stage_tape = get_working_tape() + # new tape for the initial stage + if self.trank == 0: + self.stages.append( + WeakObservationStage(self.controls[0], index=0)) + else: + self._stage_tape = None else: self._annotate_accumulation = True + self._accumulation_started = False # initial conditions guess to be updated self.controls = Enlist(control) + self.tape = get_working_tape() if tape is None else tape + # Strong constraint functional to be converted to ReducedFunctional later # penalty for straying from prior @@ -195,155 +249,6 @@ def __init__(self, control: Control, self._accumulate_functional( observation_iprod(observation_err(control.control))) - def set_observation(self, state: OverloadedType, - observation_err: Callable[[OverloadedType], OverloadedType], - observation_iprod: Callable[[OverloadedType], AdjFloat], - forward_model_iprod: Optional[Callable[[OverloadedType], AdjFloat]]): - """ - Record an observation at the time of `state`. - - Parameters - ---------- - - state - The state at the current observation time. - - observation_err - Given a state :math:`x`, returns the observations error - :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are - the observations at the current observation time and - :math:`\\mathcal{H}_{i}` is the observation operator for the - current observation time. - - observation_iprod - The inner product to calculate the observation error functional - from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. - Can include the error covariance matrix. - - forward_model_iprod - The inner product to calculate the model error functional from - the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can - include the error covariance matrix. Ignored if using the strong - constraint formulation. - """ - if self.weak_constraint: - - stage_index = len(self.controls) - - # Cut the tape into seperate time-chunks. - # State is output from previous control i.e. forward model - # propogation over previous time-chunk. - - # get the tape used for this stage and make sure its the right one - prev_stage_tape = get_working_tape() - if prev_stage_tape is not self._stage_tape: - raise ValueError( - "Working tape at the end of the observation stage" - " differs from the tape at the stage beginning." - ) - - # # record forward propogation - with set_working_tape(prev_stage_tape.copy()) as tape: - prev_control = self.controls[-1] - self.forward_model_stages.append(ReducedFunctional( - state._ad_copy(), controls=prev_control, tape=tape) - ) - - # Beginning of next time-chunk is the control for this observation - # and the state at the end of the next time-chunk. - with stop_annotating(): - # smuggle initial guess at this time into the control without the tape seeing - next_control_state = state._ad_copy() - _rename(next_control_state, f"Control_{len(self.controls)}") - next_control = Control(next_control_state) - self.controls.append(next_control) - - # model error links time-chunks by depending on both the - # previous and current controls - - # new tape for model error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - state_copy = state._ad_copy() - _rename(state_copy, f"state_{stage_index}_copy") - next_control_copy = next_control.copy_data() - _rename(next_control_copy, f"Control_{stage_index}_model_copy") - - # vector of M_i - x_i - model_err_vec = _ad_sub(state_copy, next_control_copy) - _rename(model_err_vec, f"model_err_vec_{stage_index}") - - # RF to recover M_i - x_i - fmcontrols = [Control(state_copy), Control(next_control_copy)] - self.forward_model_errors.append(ReducedFunctional( - model_err_vec, fmcontrols, tape=tape) - ) - - # new tape for model error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - model_err_vec_copy = model_err_vec._ad_copy() - _rename(model_err_vec_copy, f"model_err_vec_{stage_index}_copy") - - # inner product |M_i - x_i|_Q - model_err = forward_model_iprod(model_err_vec_copy) - - # RF to recover |M_i - x_i|_Q - self.forward_model_rfs.append(ReducedFunctional( - model_err, Control(model_err_vec_copy), tape=tape) - ) - - # Observations after tape cut because this is now a control, not a state - - # new tape for observation error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - next_control_copy = next_control.copy_data() - _rename(next_control_copy, f"Control_{stage_index}_obs_copy") - - # vector of H(x_i) - y_i - obs_err_vec = observation_err(next_control_copy) - _rename(obs_err_vec, f"obs_err_vec_{stage_index}") - - # RF to recover H(x_i) - y_i - self.observation_errors.append(ReducedFunctional( - obs_err_vec, Control(next_control_copy), tape=tape) - ) - - # new tape for observation error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - obs_err_vec_copy = obs_err_vec._ad_copy() - _rename(obs_err_vec_copy, f"obs_err_vec_{stage_index}_copy") - - # inner product |H(x_i) - y_i|_R - obs_err = observation_iprod(obs_err_vec_copy) - - # RF to recover |H(x_i) - y_i|_R - self.observation_rfs.append(ReducedFunctional( - obs_err, Control(obs_err_vec_copy), tape=tape) - ) - - # new tape for the next stage - - set_working_tape() - self._stage_tape = get_working_tape() - - # Look we're starting this time-chunk from an "unrelated" value... really! - state.assign(next_control.control) - - else: - - if hasattr(self, "_strong_reduced_functional"): - msg = "Cannot add observations once strong constraint ReducedFunctional instantiated" - raise ValueError(msg) - self._accumulate_functional( - observation_iprod(observation_err(state))) - @cached_property def strong_reduced_functional(self): """A :class:`pyadjoint.ReducedFunctional` for the strong constraint 4DVar system. @@ -362,10 +267,10 @@ def __getattr__(self, attr): """ If using strong constraint then grab attributes from self.strong_reduced_functional. """ - if self.weak_constraint: + # hasattr calls getattr, so check self.__dir__ directly here to avoid recursion + if self.weak_constraint or "_strong_reduced_functional" not in dir(self): raise AttributeError(f"'{type(self)}' object has no attribute '{attr}'") - else: - return getattr(self.strong_reduced_functional, attr) + return getattr(self.strong_reduced_functional, attr) @sc_passthrough @no_annotations @@ -387,48 +292,59 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - # controls are updated by the sub ReducedFunctionals - # so we don't need to do it ourselves - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. - # Model i propogates from i-1 to i. - # Observation i is at i. - for c, v in zip(self.controls, values): c.control.assign(v) - model_stages = [None, *self.forward_model_stages] - model_errors = [None, *self.forward_model_errors] - model_rfs = [None, *self.forward_model_rfs] + # post messages for control of forward model propogation on next chunk + trank = self.trank + if self.ensemble: + src = trank - 1 + dst = trank + 1 + + if trank != self.nchunks - 1: + self.ensemble.isend( + self.controls[-1].control, dest=dst, tag=dst) - observation_errors = (self.observation_errors if self.initial_observations - else [None, *self.observation_errors]) + if trank != 0: + recv_reqs = self.ensemble.irecv( + self.xprev, source=src, tag=trank) - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) + # first "control" is the halo + if self.ensemble and trank != 0: + values = [self.xprev, *values] # Initial condition functionals - bkg_err_vec = self.background_error(values[0]) - J = self.background_rf(bkg_err_vec) + if trank == 0: + Jlocal = ( + self.background_norm( + self.background_error(values[0])) + ) + + # observations at time 0 + if self.initial_observations: + Jlocal += ( + self.initial_observation_norm( + self.initial_observation_error(values[0])) + ) + else: + Jlocal = 0. - # observations at time 0 - if self.initial_observations: - obs_err_vec = observation_errors[0](values[0]) - J += observation_rfs[0](obs_err_vec) + # evaluate all stages on chunk except first + for i in range(1, len(self.stages)): + Jlocal += self.stages[i](values[i:i+2]) - for i in range(1, len(observation_rfs)): - prev_control = values[i-1] - this_control = values[i] + # wait for halo swap to finish + if trank != 0: + MPI.Request.Waitall(recv_reqs) - # observation error - do we match the 'real world'? - obs_err_vec = observation_errors[i](this_control) - J += observation_rfs[i](obs_err_vec) + # evaluate first stage model on chunk now we have data + Jlocal += self.stages[0](values[0:2]) - # Model error - does propogation from last control match this control? - Mi = model_stages[i](prev_control) - model_err_vec = model_errors[i]([Mi, this_control]) - J += model_rfs[i](model_err_vec) + # sum all stages + if self.ensemble: + J = self.ensemble.ensemble_comm.allreduce(Jlocal) + else: + J = Jlocal return J @@ -458,55 +374,65 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): derivatives = [] # chaining ReducedFunctionals means we need to pass Cofunctions not Functions + options = options or {} intermediate_options = { 'riesz_representation': None, **{k: v for k, v in options.items() if (k != 'riesz_representation')} } - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. Model i propogates from i-1 to i. - model_stages = [None, *self.forward_model_stages] - model_errors = [None, *self.forward_model_errors] - model_rfs = [None, *self.forward_model_rfs] - - observation_errors = (self.observation_errors if self.initial_observations - else [None, *self.observation_errors]) - - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) - # initial condition derivatives - bkg_deriv = self.background_rf.derivative(adj_input=adj_input, - options=intermediate_options) - derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, - options=options)) - - # observations at time 0 - if self.initial_observations: - obs_deriv = observation_rfs[0].derivative(adj_input=adj_input, - options=intermediate_options) - derivatives[0] += observation_errors[0].derivative(adj_input=obs_deriv, - options=options) - - for i in range(1, len(observation_rfs)): - obs_deriv = observation_rfs[i].derivative(adj_input=adj_input, - options=intermediate_options) - derivatives.append(observation_errors[i].derivative(adj_input=obs_deriv, + if self.trank == 0: + bkg_deriv = self.background_norm.derivative(adj_input=adj_input, + options=intermediate_options) + derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, options=options)) - # derivative of model error will split: - # wrt x_i through error vector - # wrt x_i-1 through stage propogation - model_deriv = model_rfs[i].derivative(adj_input=adj_input, - options=intermediate_options) - model_err_derivs = model_errors[i].derivative(adj_input=model_deriv, - options=intermediate_options) - model_stage_deriv = model_stages[i].derivative(adj_input=model_err_derivs[0], - options=options) + # observations at time 0 + if self.initial_observations: + obs_deriv = self.initial_observation_norm.derivative(adj_input=adj_input, + options=intermediate_options) + derivatives[0] += self.initial_observation_error.derivative(adj_input=obs_deriv, + options=options) + + # evaluate first forward model, which contributes to previous chunk + derivs = self.stages[0].derivative(adj_input=adj_input, options=options) - derivatives[i-1] += model_stage_deriv - derivatives[i] += model_err_derivs[1].riesz_representation() + if self.trank == 0: + derivatives[0] += derivs[0] + else: + derivatives.append(derivs[0]) + derivatives.append(derivs[1]) + + # post the derivative halo exchange + if self.ensemble: + src = self.trank + 1 + dst = self.trank - 1 + + if self.trank != 0: + self.ensemble.isend( + derivatives[0], dest=dst, tag=dst) + + if self.trank != self.nchunks - 1: + recv_reqs = self.ensemble.irecv( + self.xnext, source=src, tag=self.trank) + + # # evaluate all forward models on chunk except first while halo in flight + for i in range(1, len(self.stages)): + derivs = self.stages[i].derivative(adj_input=adj_input, options=options) + derivatives[i] += derivs[0] + derivatives.append(derivs[1]) + + # finish the derivative halo exchange + if self.ensemble: + if self.trank != self.nchunks - 1: + MPI.Request.Waitall(recv_reqs) + derivatives[-1] += self.xnext + + # we don't own the control for the halo, so remove it from the + # list of local derivatives once the communication has finished + if self.trank != 0: + derivatives.pop(0) return derivatives @@ -529,40 +455,19 @@ def hessian(self, m_dot: OverloadedType, options: dict = {}): A dictionary of options. To find a list of available options have a look at the specific control type. + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + Returns ------- pyadjoint.OverloadedType The action of the Hessian in the direction m_dot. Should be an instance of the same type as the control. """ - # create a list of overloaded types to put hessian into - hessians = [] - - kwargs = {'options': options} - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. Model i propogates from i-1 to i. - model_rfs = [None, *self.forward_model_rfs] - - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) - - # initial condition hessians - hessians.append( - self.background_rf.hessian(m_dot[0], **kwargs)) - - if self.initial_observations: - hessians[0] += observation_rfs[0].hessian(m_dot[0], **kwargs) - - for i in range(1, len(model_rfs)): - hessians.append(observation_rfs[i].hessian(m_dot[i], **kwargs)) - - mhess = model_rfs[i].hessian(m_dot[i-1:i+1], **kwargs) - - hessians[i-1] += mhess[0] - hessians[i] += mhess[1] - - return hessians + raise ValueError("Not implemented yet") @no_annotations def hessian_matrix(self): @@ -571,22 +476,461 @@ def hessian_matrix(self): raise AttributeError("Strong constraint 4DVar does not form a Hessian matrix") raise NotImplementedError - @sc_passthrough - @no_annotations - def optimize_tape(self): - for rf in (self.background_error, - self.background_rf, - *self.observation_errors, - *self.observation_rfs, - *self.forward_model_stages, - *self.forward_model_errors, - *self.forward_model_rfs): - rf.optimize_tape() - def _accumulate_functional(self, val): if not self._annotate_accumulation: return - if hasattr(self, '_total_functional'): + if self._accumulation_started: self._total_functional += val else: self._total_functional = val + self._accumulation_started = True + + @contextmanager + def recording_stages(self, sequential=True, **kwargs): + if not sequential: + raise ValueError("Recording stages concurrently not yet implemented") + + # indices of stage in global and local list + stage_kwargs = {k: v for k, v in kwargs.items()} + stage_kwargs['local_index'] = 0 + stage_kwargs['global_index'] = 0 + + # record over ensemble + if self.weak_constraint: + + trank = self.trank + + # later ranks recv forward state and kwargs + if trank > 0: + tcomm = self.ensemble.ensemble_comm + src = trank-1 + with stop_annotating(): + self.ensemble.recv(self.xprev, source=src, tag=trank+000) + + for i, (k, v) in enumerate(stage_kwargs.items()): + stage_kwargs[k] = _scalarRecv( + tcomm, dtype=type(v), source=src, tag=trank+i*100) + # restart local stage counter + stage_kwargs['local_index'] = 0 + + # subsequent ranks start from halo + controls = self.controls if trank == 0 else [self.control_prev, *self.controls] + + stage_sequence = ObservationStageSequence( + controls, self, stage_kwargs, sequential, weak_constraint=True) + + yield stage_sequence + + # grab the stages now they have been taped + self.stages = stage_sequence.stages + + # send forward state and kwargs + if self.ensemble and trank != self.nchunks - 1: + with stop_annotating(): + tcomm = self.ensemble.ensemble_comm + dst = trank+1 + + state = self.stages[-1].controls[1].control + self.ensemble.send(state, dest=dst, tag=dst+000) + + for i, k in enumerate(stage_kwargs.keys()): + v = getattr(stage_sequence.ctx, k) + _scalarSend( + tcomm, v, dest=dst, tag=dst+i*100) + + else: # strong constraint + + yield ObservationStageSequence( + self.controls, self, stage_kwargs, + sequential=True, weak_constraint=False) + + +class ObservationStageSequence: + def __init__(self, controls: Control, + aaorf: AllAtOnceReducedFunctional, + stage_kwargs: dict = None, + sequential: bool = True, + weak_constraint: bool = True): + self.controls = controls + self.nstages = len(controls) - 1 + self.aaorf = aaorf + self.ctx = StageContext(**(stage_kwargs or {})) + self.index = 0 + self.weak_constraint = weak_constraint + if weak_constraint: + self.stages = [] + + def __iter__(self): + return self + + def __next__(self): + + if self.weak_constraint: + # start of the next stage + next_control = self.controls[self.index] + + # smuggle state forward and increment stage indices + if self.index > 0: + self.ctx.local_index += 1 + self.ctx.global_index += 1 + + state = self.stages[-1].controls[1].control + with stop_annotating(): + next_control.control.assign(state) + + # stop after we've recorded all stages + if self.index >= self.nstages: + raise StopIteration + self.index += 1 + + stage = WeakObservationStage(next_control, index=self.ctx.global_index) + self.stages.append(stage) + + else: # strong constraint + + # increment stage indices + if self.index > 0: + self.ctx.local_index += 1 + self.ctx.global_index += 1 + + # stop after we've recorded all stages + if self.index >= self.nstages: + raise StopIteration + self.index += 1 + + # dummy control to "start" stage from + control = (self.aaorf.controls[0].control if self.index == 0 + else self._prev_stage.state) + + stage = StrongObservationStage(control, self.aaorf) + self._prev_stage = stage + + return stage, self.ctx + + +class StageContext: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class StrongObservationStage: + """ + Record an observation for strong constraint 4DVar at the time of `state`. + + Parameters + ---------- + + aaorf + The strong constraint AllAtOnceReducedFunctional. + + """ + + def __init__(self, control: OverloadedType, + aaorf: AllAtOnceReducedFunctional): + self.aaorf = aaorf + self.control = control + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + """ + if hasattr(self.aaorf, "_strong_reduced_functional"): + raise ValueError("Cannot add observations once strong" + " constraint ReducedFunctional instantiated") + self.aaorf._accumulate_functional( + observation_iprod(observation_err(state))) + self.state = state + + +class WeakObservationStage: + """ + A single stage for weak constraint 4DVar at the time of `state`. + + Records the forward model propogation from the control at the beginning + of the stage, and the model and observation errors at the end of the stage. + + Parameters + ---------- + + control + The control x_{i-1} at the beginning of the stage + + index + Optional integer to name controls and functionals with + + """ + def __init__(self, control: Control, + index: Optional[int] = None): + # "control" to use as initial condition. + # Not actual `Control` for consistency with strong constraint + self.control = control.control + + self.controls = Enlist(control) + self.index = index + set_working_tape() + self._stage_tape = get_working_tape() + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat], + forward_model_iprod: Callable[[OverloadedType], AdjFloat]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + + forward_model_iprod + The inner product to calculate the model error functional from + the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can + include the error covariance matrix. + """ + # get the tape used for this stage and make sure its the right one + stage_tape = get_working_tape() + if stage_tape is not self._stage_tape: + raise ValueError( + "Working tape at the end of the observation stage" + " differs from the tape at the stage beginning." + ) + + # record forward propogation + with set_working_tape(stage_tape.copy()) as tape: + self.forward_model = ReducedFunctional( + state._ad_copy(), controls=self.controls[0], tape=tape) + + # Beginning of next time-chunk is the control for this observation + # and the state at the end of the next time-chunk. + with stop_annotating(): + # smuggle initial guess at this time into the control without the tape seeing + self.controls.append(Control(state._ad_copy())) + if self.index: + _rename(self.controls[-1].control, f"Control_{self.index}") + + # model error links time-chunks by depending on both the + # previous and current controls + + # RF to recalculate error vector (M_i - x_i) + names = { + 'functional_name': f"model_err_vec_{self.index}", + 'control_name': [f"state_{self.index}_copy", + f"Control_{self.index}_model_copy"] + } if self.index else {} + + self.model_error = isolated_rf( + operation=lambda controls: _ad_sub(*controls), + control=[state, self.controls[-1].control], + **names) + + # RF to recalculate inner product |M_i - x_i|_Q + names = { + 'control_name': f"model_err_vec_{self.index}_copy" + } if self.index else {} + + self.model_norm = isolated_rf( + operation=forward_model_iprod, + control=self.model_error.functional, + **names) + + # Observations after tape cut because this is now a control, not a state + + # RF to recalculate error vector (H(x_i) - y_i) + names = { + 'functional_name': f"obs_err_vec_{self.index}", + 'control_name': f"Control_{self.index}_obs_copy" + } if self.index else {} + + self.observation_error = isolated_rf( + operation=observation_err, + control=self.controls[-1], + **names) + + # RF to recalculate inner product |H(x_i) - y_i|_R + names = { + 'functional_name': "obs_err_vec_{self.index}_copy" + } if self.index else {} + self.observation_norm = isolated_rf( + operation=observation_iprod, + control=self.observation_error.functional, + **names) + + # remove the stage initial condition "control" now we've finished recording + delattr(self, "control") + + @no_annotations + def __call__(self, values: OverloadedType, + rftype: Optional[str] = None): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ + J = 0.0 + + # evaluate model error + if (rftype is None) or (rftype == 'model'): + Mi = self.forward_model(values[0]) + J += self.model_norm(self.model_error([Mi, values[1]])) + + # evaluate observation errors + if (rftype is None) or (rftype == 'obs'): + J += self.observation_norm(self.observation_error(values[1])) + + return J + + @no_annotations + def derivative(self, adj_input: float = 1.0, options: dict = {}, + rftype: Optional[str] = None): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + """ + # create a list of overloaded types to put derivative into + derivatives = [] + + # chaining ReducedFunctionals means we need to pass Cofunctions not Functions + options = options or {} + intermediate_options = { + 'riesz_representation': None, + **{k: v for k, v in options.items() + if (k != 'riesz_representation')} + } + + if (rftype is None) or (rftype == 'model'): + # derivative of reduction + dm_norm = self.model_norm.derivative(adj_input=adj_input, + options=intermediate_options) + + # derivative of difference splits into (Mi, xi) + dm_errors = self.model_error.derivative(adj_input=dm_norm, + options=intermediate_options) + + # derivative through the forward model wrt to xprev + dm_forward = self.forward_model.derivative(adj_input=dm_errors[0], + options=options) + + derivatives.append(dm_forward) + derivatives.append(dm_errors[1].riesz_representation()) + + if (rftype is None) or (rftype == 'obs'): + # derivative of reduction + do_norm = self.observation_norm.derivative(adj_input=adj_input, + options=intermediate_options) + # derivative of error + do_error = self.observation_error.derivative(adj_input=do_norm, + options=options) + + if len(derivatives) == 0: + derivatives.append(None) + derivatives.append(do_error) + else: + derivatives[1] += do_error + + return derivatives + + @no_annotations + def hessian(self, m_dot: OverloadedType, options: dict = {}, + rftype: Optional[str] = None): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + """ + pass From 806a6054f3391158019c73d6a767ede68741f176 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 19 Nov 2024 15:59:07 +0000 Subject: [PATCH 02/23] EnsembleFunction - initial (incomplete) impl --- firedrake/__init__.py | 1 + .../adjoint/all_at_once_reduced_functional.py | 163 +++++----- firedrake/adjoint_utils/__init__.py | 1 + firedrake/adjoint_utils/ensemblefunction.py | 69 ++++ firedrake/ensemblefunction.py | 295 ++++++++++++++++++ 5 files changed, 444 insertions(+), 85 deletions(-) create mode 100644 firedrake/adjoint_utils/ensemblefunction.py create mode 100644 firedrake/ensemblefunction.py diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 0fc9aeeed6..b34a4f3d31 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -116,6 +116,7 @@ from firedrake.vector import * from firedrake.version import __version__ as ver, __version_info__, check # noqa: F401 from firedrake.ensemble import * +from firedrake.ensemblefunction import * from firedrake.randomfunctiongen import * from firedrake.external_operators import * from firedrake.progress_bar import ProgressBar # noqa: F401 diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 6a4e84617c..2c5ad915ea 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,7 +1,6 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ stop_annotating, no_annotations, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist -from firedrake import Ensemble from functools import wraps, cached_property from typing import Callable, Optional from contextlib import contextmanager @@ -100,11 +99,11 @@ class AllAtOnceReducedFunctional(ReducedFunctional): ---------- control - The initial condition :math:`x_{0}`. Starting value is used as the - background (prior) data :math:`x_{b}`. + The :class:`EnsembleFunction` for the control x_{i} at the initial + condition and at the end of each observation stage. - nlocal_stages - The number of observation stages on the local ensemble member. + control + The background (prior) data for the initial condition :math:`x_{b}`. background_iprod The inner product to calculate the background error functional @@ -126,56 +125,56 @@ class AllAtOnceReducedFunctional(ReducedFunctional): weak_constraint Whether to use the weak or strong constraint 4DVar formulation. - ensemble - The ensemble communicator to parallelise over. None for no time parallelism. - If `ensemble` is provided, then `background_iprod`, `observation_err` and - `observation_iprod` must only be provided on ensemble rank 0. - See Also -------- :class:`pyadjoint.ReducedFunctional`. """ def __init__(self, control: Control, - nlocal_stages: int, + background: OverloadedType, background_iprod: Optional[Callable[[OverloadedType], AdjFloat]], observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, weak_constraint: bool = True, tape: Optional[Tape] = None, - _annotate_accumulation: bool = False, - ensemble: Optional[Ensemble] = None): + _annotate_accumulation: bool = False): self.tape = get_working_tape() if tape is None else tape self.weak_constraint = weak_constraint self.initial_observations = observation_err is not None - # We need a copy for the prior, but this shouldn't be part of the tape with stop_annotating(): - self.background = control.copy_data() + self.background = background._ad_copy() + _rename(self.background, "Background") if self.weak_constraint: self._annotate_accumulation = _annotate_accumulation self._accumulation_started = False - self.nlocal_stages = nlocal_stages - + ensemble = control.ensemble self.ensemble = ensemble self.trank = ensemble.ensemble_comm.rank if ensemble else 0 self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 + x = control.control.subfunctions + self.x = x + + self.control = control + self._controls = tuple(Control(xi) for xi in x) + + # first control on rank 0 is initial conditions, not end of observation stage + self.nlocal_stages = len(x) - (1 if self.trank == 0 else 0) + self.stages = [] # The record of each observation stage - self.controls = [] # The solution at the beginning of each time-chunk # first rank sets up functionals for background initial observations if self.trank == 0: - self.controls.append(control) # RF to recalculate error vector (x_0 - x_b) self.background_error = isolated_rf( operation=lambda x0: _ad_sub(x0, self.background), - control=control, + control=x[0], functional_name="bkg_err_vec", control_name="Control_0_bkg_copy") @@ -190,7 +189,7 @@ def __init__(self, control: Control, # RF to recalculate error vector (H(x_0) - y_0) self.initial_observation_error = isolated_rf( operation=observation_err, - control=control, + control=x[0], functional_name="obs_err_vec_0", control_name="Control_0_obs_copy") @@ -201,10 +200,6 @@ def __init__(self, control: Control, functional_name="obs_err_vec_0_copy") else: - # create halo for previous state - with stop_annotating(): - self.xprev = control.copy_data() - self.control_prev = Control(self.xprev) if background_iprod is not None: raise ValueError("Only the first ensemble rank needs `background_iprod`") @@ -213,21 +208,16 @@ def __init__(self, control: Control, if observation_err is not None: raise ValueError("Only the first ensemble rank needs `observation_err`") - # create all controls on local ensemble member - with stop_annotating(): - for _ in range(nlocal_stages): - self.controls.append(Control(control.copy_data())) + # create halo for previous state + if self.ensemble and self.trank != 0: + with stop_annotating(): + self.xprev = x[0]._ad_copy() + self._control_prev = Control(self.xprev) # halo for the derivative from the next chunk if self.ensemble and self.trank != self.nchunks - 1: - self.xnext = control.copy_data() - - # new tape for the initial stage - if self.trank == 0: - self.stages.append( - WeakObservationStage(self.controls[0], index=0)) - else: - self._stage_tape = None + with stop_annotating(): + self.xnext = x[0]._ad_copy() else: self._annotate_accumulation = True @@ -292,53 +282,54 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - for c, v in zip(self.controls, values): - c.control.assign(v) + self.control.assign(values) + trank = self.trank + + # first "control" for later ranks is the halo + if self.ensemble and trank != 0: + x = [self.xprev, *self.x] + else: + x = [*self.x] # post messages for control of forward model propogation on next chunk - trank = self.trank if self.ensemble: src = trank - 1 dst = trank + 1 if trank != self.nchunks - 1: self.ensemble.isend( - self.controls[-1].control, dest=dst, tag=dst) + x[-1], dest=dst, tag=dst) if trank != 0: recv_reqs = self.ensemble.irecv( self.xprev, source=src, tag=trank) - # first "control" is the halo - if self.ensemble and trank != 0: - values = [self.xprev, *values] - # Initial condition functionals if trank == 0: Jlocal = ( self.background_norm( - self.background_error(values[0])) + self.background_error(x[0])) ) # observations at time 0 if self.initial_observations: Jlocal += ( self.initial_observation_norm( - self.initial_observation_error(values[0])) + self.initial_observation_error(x[0])) ) else: Jlocal = 0. # evaluate all stages on chunk except first for i in range(1, len(self.stages)): - Jlocal += self.stages[i](values[i:i+2]) + Jlocal += self.stages[i](x[i:i+2]) # wait for halo swap to finish if trank != 0: MPI.Request.Waitall(recv_reqs) # evaluate first stage model on chunk now we have data - Jlocal += self.stages[0](values[0:2]) + Jlocal += self.stages[0](x[0:2]) # sum all stages if self.ensemble: @@ -370,8 +361,16 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): The derivative with respect to the control. Should be an instance of the same type as the control. """ + trank = self.trank # create a list of overloaded types to put derivative into - derivatives = [] + derivatives = self.control._ad_copy() + derivatives.zero() + + if self.ensemble and trank != 0: + self.xprev.zero() + derivs = [self.xprev, *derivatives.subfunctions] + else: + derivs = [*derivatives.subfunctions] # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} @@ -382,57 +381,50 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): } # initial condition derivatives - if self.trank == 0: + if trank == 0: bkg_deriv = self.background_norm.derivative(adj_input=adj_input, options=intermediate_options) - derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, - options=options)) + derivs[0] += self.background_error.derivative(adj_input=bkg_deriv, + options=options) # observations at time 0 if self.initial_observations: obs_deriv = self.initial_observation_norm.derivative(adj_input=adj_input, options=intermediate_options) - derivatives[0] += self.initial_observation_error.derivative(adj_input=obs_deriv, - options=options) + derivs[0] += self.initial_observation_error.derivative(adj_input=obs_deriv, + options=options) # evaluate first forward model, which contributes to previous chunk - derivs = self.stages[0].derivative(adj_input=adj_input, options=options) + sderiv0 = self.stages[0].derivative(adj_input=adj_input, options=options) - if self.trank == 0: - derivatives[0] += derivs[0] - else: - derivatives.append(derivs[0]) - derivatives.append(derivs[1]) + derivs[0] += sderiv0[0] + derivs[1] += sderiv0[1] # post the derivative halo exchange + from firedrake import norm if self.ensemble: - src = self.trank + 1 - dst = self.trank - 1 + src = trank + 1 + dst = trank - 1 - if self.trank != 0: + if trank != 0: self.ensemble.isend( - derivatives[0], dest=dst, tag=dst) + derivs[0], dest=dst, tag=dst) - if self.trank != self.nchunks - 1: + if trank != self.nchunks - 1: recv_reqs = self.ensemble.irecv( - self.xnext, source=src, tag=self.trank) + self.xnext, source=src, tag=trank) # # evaluate all forward models on chunk except first while halo in flight for i in range(1, len(self.stages)): - derivs = self.stages[i].derivative(adj_input=adj_input, options=options) - derivatives[i] += derivs[0] - derivatives.append(derivs[1]) + sderiv = self.stages[i].derivative(adj_input=adj_input, options=options) + derivs[i] += sderiv[0] + derivs[i+1] += sderiv[1] # finish the derivative halo exchange if self.ensemble: - if self.trank != self.nchunks - 1: + if trank != self.nchunks - 1: MPI.Request.Waitall(recv_reqs) - derivatives[-1] += self.xnext - - # we don't own the control for the halo, so remove it from the - # list of local derivatives once the communication has finished - if self.trank != 0: - derivatives.pop(0) + derivs[-1] += self.xnext return derivatives @@ -514,10 +506,10 @@ def recording_stages(self, sequential=True, **kwargs): stage_kwargs['local_index'] = 0 # subsequent ranks start from halo - controls = self.controls if trank == 0 else [self.control_prev, *self.controls] + controls = self._controls if trank == 0 else [self._control_prev, *self._controls] stage_sequence = ObservationStageSequence( - controls, self, stage_kwargs, sequential, weak_constraint=True) + controls, self, stage_kwargs, sequential) yield stage_sequence @@ -541,23 +533,21 @@ def recording_stages(self, sequential=True, **kwargs): else: # strong constraint yield ObservationStageSequence( - self.controls, self, stage_kwargs, - sequential=True, weak_constraint=False) + self.controls, self, stage_kwargs, sequential=True) class ObservationStageSequence: def __init__(self, controls: Control, aaorf: AllAtOnceReducedFunctional, stage_kwargs: dict = None, - sequential: bool = True, - weak_constraint: bool = True): + sequential: bool = True): self.controls = controls self.nstages = len(controls) - 1 self.aaorf = aaorf self.ctx = StageContext(**(stage_kwargs or {})) self.index = 0 - self.weak_constraint = weak_constraint - if weak_constraint: + self.weak_constraint = aaorf.weak_constraint + if self.weak_constraint: self.stages = [] def __iter__(self): @@ -792,6 +782,9 @@ def set_observation(self, state: OverloadedType, # remove the stage initial condition "control" now we've finished recording delattr(self, "control") + # stop the stage tape recording anything else + set_working_tape() + @no_annotations def __call__(self, values: OverloadedType, rftype: Optional[str] = None): diff --git a/firedrake/adjoint_utils/__init__.py b/firedrake/adjoint_utils/__init__.py index 3b3426a850..c71da61ade 100644 --- a/firedrake/adjoint_utils/__init__.py +++ b/firedrake/adjoint_utils/__init__.py @@ -12,3 +12,4 @@ from firedrake.adjoint_utils.solving import * # noqa: F401 from firedrake.adjoint_utils.mesh import * # noqa: F401 from firedrake.adjoint_utils.checkpointing import * # noqa: F401 +from firedrake.adjoint_utils.ensemblefunction import * # noqa: F401 diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py new file mode 100644 index 0000000000..ee1e4a4ce3 --- /dev/null +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -0,0 +1,69 @@ +from pyadjoint.overloaded_type import OverloadedType +from functools import wraps + + +class EnsembleFunctionMixin(OverloadedType): + + @staticmethod + def _ad_annotate_init(init): + @wraps(init) + def wrapper(self, *args, **kwargs): + OverloadedType.__init__(self) + init(self, *args, **kwargs) + return wrapper + + @staticmethod + def _ad_to_list(m): + raise ValueError("NotImplementedYet") + + @staticmethod + def _ad_assign_numpy(dst, src, offset): + raise ValueError("NotImplementedYet") + + def _ad_dot(self, other, options=None): + # local dot product + ldot = sum( + uself._ad_dot(uother, options=options) + for uself, uother in zip(self.subfunctions, + other.subfunctions)) + # global dot product + gdot = self.ensemble.ensemble_comm.allreduce(ldot) + return gdot + + def _ad_add(self, other): + new = self.copy() + new += other + return new + + def _ad_mul(self, other): + new = self.copy() + # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + other = other._fbuf if type(other) is type(self) else other + new._fbuf.assign(other*new._fbuf) + return new + + def _ad_iadd(self, other): + self += other + return self + + def _ad_imul(self, other): + self *= other + return self + + def _ad_copy(self): + return self.copy() + + def _ad_convert_riesz(self, value, options=None): + raise ValueError("NotImplementedYet") + + def _ad_from_petsc(self, vec): + with self.vec_wo as self_v: + vec.copy(result=self_v) + + def _ad_to_petsc(self, vec=None): + with self.vec_ro as self_v: + if vec: + self_v.copy(result=vec) + else: + vec = self_v.copy() + return vec diff --git a/firedrake/ensemblefunction.py b/firedrake/ensemblefunction.py new file mode 100644 index 0000000000..33ed7e7b2f --- /dev/null +++ b/firedrake/ensemblefunction.py @@ -0,0 +1,295 @@ +from firedrake.petsc import PETSc +from firedrake.adjoint_utils import EnsembleFunctionMixin +from firedrake.functionspace import MixedFunctionSpace +from firedrake.function import Function +from ufl.duals import is_primal, is_dual +from pyop2 import MixedDat + +from functools import cached_property +from contextlib import contextmanager + +__all__ = ("EnsembleFunction", "EnsembleCofunction") + + +class EnsembleFunctionBase(EnsembleFunctionMixin): + """ + A mixed finite element (co)function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The sub(co)functions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each (co)function on the + local ensemble member. + """ + + @PETSc.Log.EventDecorator() + @EnsembleFunctionMixin._ad_annotate_init + def __init__(self, ensemble, function_spaces): + self.ensemble = ensemble + self.local_function_spaces = function_spaces + self.local_size = len(function_spaces) + + # the local functions are stored as a big mixed space + self._function_space = MixedFunctionSpace(function_spaces) + self._fbuf = Function(self._function_space) + + # create a Vec containing the data for all functions on all + # ensemble members. Because we use the Vec of each local mixed + # function as the storage, if the data in the Function Vec + # is valid then the data in the EnsembleFunction Vec is valid. + + with self._fbuf.dat.vec as fvec: + local_size = self._function_space.node_set.size + sizes = (local_size, PETSc.DETERMINE) + self._vec = PETSc.Vec().createWithArray(fvec.array, + size=sizes, + comm=ensemble.global_comm) + self._vec.setFromOptions() + + @cached_property + def subfunctions(self): + """ + The (co)functions on the local ensemble member + """ + def local_function(i): + V = self.local_function_spaces[i] + usubs = self._subcomponents(i) + if len(usubs) == 1: + dat = usubs[0].dat + else: + dat = MixedDat((u.dat for u in usubs)) + return Function(V, val=dat) + + self._subfunctions = tuple(local_function(i) + for i in range(self.local_size)) + return self._subfunctions + + def _subcomponents(self, i): + """ + Return the subfunctions of the local mixed function storage + corresponding to the i-th local function. + """ + return tuple(self._fbuf.subfunctions[j] + for j in self._component_indices(i)) + + def _component_indices(self, i): + """ + Return the indices into the local mixed function storage + corresponding to the i-th local function. + """ + V = self.local_function_spaces[i] + offset = sum(len(V) for V in self.local_function_spaces[:i]) + return tuple(offset + i for i in range(len(V))) + + @PETSc.Log.EventDecorator() + def riesz_representation(self, riesz_map="L2", **kwargs): + """ + Return the Riesz representation of this :class:`EnsembleFunction` + with respect to the given Riesz map. + + Parameters + ---------- + + riesz_map + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + + kwargs + other arguments to be passed to the firedrake.riesz_map. + """ + DualType = { + EnsembleFunction: EnsembleCofunction, + EnsembleCofunction: EnsembleFunction, + }[type(self)] + Vdual = [V.dual() for V in self.local_function_spaces] + riesz = DualType(self.ensemble, Vdual) + for u in riesz.subfunctions: + u.assign(u.riesz_representation(riesz_map=riesz_map, **kwargs)) + return riesz + + @PETSc.Log.EventDecorator() + def assign(self, other, subsets=None): + r"""Set the :class:`EnsembleFunction` to the value of another + :class:`EnsembleFunction` other. + + Parameters + ---------- + + other + The :class:`EnsembleFunction` to assign from. + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if type(other) is not type(self): + raise ValueError( + f"Cannot assign {type(self)} from {type(other)}") + if subsets: + for i in range(self.local_size): + self.subfunctions[i].assign( + other.subfunctions[i], subset=subsets[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]) + return self + + @PETSc.Log.EventDecorator() + def copy(self): + """ + Return a deep copy of the :class:`EnsembleFunction`. + """ + new = type(self)(self.ensemble, self.local_function_spaces) + new.assign(self) + return new + + @PETSc.Log.EventDecorator() + def zero(self, subsets=None): + """ + Set values to zero. + + Parameters + ---------- + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if subsets: + for i in range(self.local_size): + self.subfunctions[i].zero(subsets[i]) + else: + for u in self.subfunctions: + u.zero() + return self + + @PETSc.Log.EventDecorator() + def __iadd__(self, other): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us + uo) + return self + + @PETSc.Log.EventDecorator() + def __imul__(self, other): + if type(other) is type(self): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us*uo) + else: + for us in self.subfunctions: + us *= other + return self + + @PETSc.Log.EventDecorator() + def __add__(self, other): + new = self.copy() + for i in range(self.local_size): + new.subfunctions[i] += other.subfunctions[i] + return new + + @PETSc.Log.EventDecorator() + def __mul__(self, other): + new = self.copy() + if type(other) is type(self): + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]*self.subfunctions[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other*self.subfunctions[i]) + return new + + @PETSc.Log.EventDecorator() + def __rmul__(self, other): + return self.__mul__(other) + + @contextmanager + def vec(self): + """ + Context manager for the global PETSc Vec with read/write access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager so that the data gets copied to/from + # the Function.dat storage and _vec. + # However, this copy is done without _vec knowing, so we have + # to manually increment the state. + with self._fbuf.dat.vec: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_ro(self): + """ + Context manager for the global PETSc Vec with read only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # to the Function.dat storage and _vec. + with self._fbuf.dat.vec_ro: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_wo(self): + """ + Context manager for the global PETSc Vec with write only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # from the Function.dat storage and _vec. + with self._fbuf.dat.vec_wo: + yield self._vec + + +class EnsembleFunction(EnsembleFunctionBase): + """ + A mixed finite element Function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subfunctions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each function on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_primal(V) for V in function_spaces): + raise TypeError( + "EnsembleFunction must be created using primal FunctionSpaces") + super().__init__(ensemble, function_spaces) + + +class EnsembleCofunction(EnsembleFunctionBase): + """ + A mixed finite element Cofunction distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subcofunctions are distributed + over the different ensemble members. + + function_spaces + A list of dual function spaces for each cofunction on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_dual(V) for V in function_spaces): + raise TypeError( + "EnsembleCofunction must be created using dual FunctionSpaces") + super().__init__(ensemble, function_spaces) From 3515a52bd1957ae0221442542b17b30384377e34 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 19 Nov 2024 17:35:07 +0000 Subject: [PATCH 03/23] aaorf - attach indices to stages and make obs0_err/norm not error --- .../adjoint/all_at_once_reduced_functional.py | 144 ++++++++++-------- 1 file changed, 78 insertions(+), 66 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 2c5ad915ea..54874d9d2e 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -102,25 +102,28 @@ class AllAtOnceReducedFunctional(ReducedFunctional): The :class:`EnsembleFunction` for the control x_{i} at the initial condition and at the end of each observation stage. - control - The background (prior) data for the initial condition :math:`x_{b}`. - background_iprod The inner product to calculate the background error functional from the background error :math:`x_{0} - x_{b}`. Can include the - error covariance matrix. + error covariance matrix. Only used on ensemble rank 0. + + background + The background (prior) data for the initial condition :math:`x_{b}`. + If not provided, the value of the first Function on the first ensemble + of the EnsembleFunction will be used. observation_err Given a state :math:`x`, returns the observations error :math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the observations at the initial time and :math:`\\mathcal{H}_{0}` is - the observation operator for the initial time. Optional. + the observation operator for the initial time. Only used on + ensemble rank 0. Optional. observation_iprod The inner product to calculate the observation error functional from the observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`. Can include the error covariance matrix. Must be provided if - observation_err is provided. + observation_err is provided. Only used on ensemble rank 0 weak_constraint Whether to use the weak or strong constraint 4DVar formulation. @@ -131,8 +134,8 @@ class AllAtOnceReducedFunctional(ReducedFunctional): """ def __init__(self, control: Control, - background: OverloadedType, background_iprod: Optional[Callable[[OverloadedType], AdjFloat]], + background: Optional[OverloadedType] = None, observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, weak_constraint: bool = True, @@ -145,7 +148,10 @@ def __init__(self, control: Control, self.initial_observations = observation_err is not None with stop_annotating(): - self.background = background._ad_copy() + if background: + self.background = background._ad_copy() + else: + self.background = control.control.subfunctions[0]._ad_copy() _rename(self.background, "Background") if self.weak_constraint: @@ -199,15 +205,6 @@ def __init__(self, control: Control, control=self.initial_observation_error.functional, functional_name="obs_err_vec_0_copy") - else: - - if background_iprod is not None: - raise ValueError("Only the first ensemble rank needs `background_iprod`") - if observation_iprod is not None: - raise ValueError("Only the first ensemble rank needs `observation_iprod`") - if observation_err is not None: - raise ValueError("Only the first ensemble rank needs `observation_err`") - # create halo for previous state if self.ensemble and self.trank != 0: with stop_annotating(): @@ -308,15 +305,13 @@ def __call__(self, values: OverloadedType): if trank == 0: Jlocal = ( self.background_norm( - self.background_error(x[0])) - ) + self.background_error(x[0]))) # observations at time 0 if self.initial_observations: Jlocal += ( self.initial_observation_norm( - self.initial_observation_error(x[0])) - ) + self.initial_observation_error(x[0]))) else: Jlocal = 0. @@ -382,26 +377,28 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # initial condition derivatives if trank == 0: - bkg_deriv = self.background_norm.derivative(adj_input=adj_input, - options=intermediate_options) - derivs[0] += self.background_error.derivative(adj_input=bkg_deriv, - options=options) + bkg_deriv = self.background_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.background_error.derivative( + adj_input=bkg_deriv, options=options) # observations at time 0 if self.initial_observations: - obs_deriv = self.initial_observation_norm.derivative(adj_input=adj_input, - options=intermediate_options) - derivs[0] += self.initial_observation_error.derivative(adj_input=obs_deriv, - options=options) + obs_deriv = self.initial_observation_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.initial_observation_error.derivative( + adj_input=obs_deriv, options=options) # evaluate first forward model, which contributes to previous chunk - sderiv0 = self.stages[0].derivative(adj_input=adj_input, options=options) + sderiv0 = self.stages[0].derivative( + adj_input=adj_input, options=options) derivs[0] += sderiv0[0] derivs[1] += sderiv0[1] # post the derivative halo exchange - from firedrake import norm if self.ensemble: src = trank + 1 dst = trank - 1 @@ -416,7 +413,9 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # # evaluate all forward models on chunk except first while halo in flight for i in range(1, len(self.stages)): - sderiv = self.stages[i].derivative(adj_input=adj_input, options=options) + sderiv = self.stages[i].derivative( + adj_input=adj_input, options=options) + derivs[i] += sderiv[0] derivs[i+1] += sderiv[1] @@ -484,14 +483,15 @@ def recording_stages(self, sequential=True, **kwargs): # indices of stage in global and local list stage_kwargs = {k: v for k, v in kwargs.items()} - stage_kwargs['local_index'] = 0 - stage_kwargs['global_index'] = 0 # record over ensemble if self.weak_constraint: trank = self.trank + if trank == 0: + global_index = 0 + # later ranks recv forward state and kwargs if trank > 0: tcomm = self.ensemble.ensemble_comm @@ -502,20 +502,18 @@ def recording_stages(self, sequential=True, **kwargs): for i, (k, v) in enumerate(stage_kwargs.items()): stage_kwargs[k] = _scalarRecv( tcomm, dtype=type(v), source=src, tag=trank+i*100) - # restart local stage counter - stage_kwargs['local_index'] = 0 + + global_index = _scalarRecv( + tcomm, dtype=int, source=src, tag=trank+i*9000) # subsequent ranks start from halo controls = self._controls if trank == 0 else [self._control_prev, *self._controls] stage_sequence = ObservationStageSequence( - controls, self, stage_kwargs, sequential) + controls, self, global_index, stage_kwargs, sequential) yield stage_sequence - # grab the stages now they have been taped - self.stages = stage_sequence.stages - # send forward state and kwargs if self.ensemble and trank != self.nchunks - 1: with stop_annotating(): @@ -530,6 +528,9 @@ def recording_stages(self, sequential=True, **kwargs): _scalarSend( tcomm, v, dest=dst, tag=dst+i*100) + _scalarSend( + tcomm, global_index, dest=dst, tag=dst+i*9000) + else: # strong constraint yield ObservationStageSequence( @@ -539,6 +540,7 @@ def recording_stages(self, sequential=True, **kwargs): class ObservationStageSequence: def __init__(self, controls: Control, aaorf: AllAtOnceReducedFunctional, + global_index: int, stage_kwargs: dict = None, sequential: bool = True): self.controls = controls @@ -547,8 +549,8 @@ def __init__(self, controls: Control, self.ctx = StageContext(**(stage_kwargs or {})) self.index = 0 self.weak_constraint = aaorf.weak_constraint - if self.weak_constraint: - self.stages = [] + self.global_index = global_index + self.local_index = 0 def __iter__(self): return self @@ -559,12 +561,14 @@ def __next__(self): # start of the next stage next_control = self.controls[self.index] + stages = self.aaorf.stages + # smuggle state forward and increment stage indices if self.index > 0: - self.ctx.local_index += 1 - self.ctx.global_index += 1 + self.local_index += 1 + self.global_index += 1 - state = self.stages[-1].controls[1].control + state = stages[-1].controls[1].control with stop_annotating(): next_control.control.assign(state) @@ -573,15 +577,17 @@ def __next__(self): raise StopIteration self.index += 1 - stage = WeakObservationStage(next_control, index=self.ctx.global_index) - self.stages.append(stage) + stage = WeakObservationStage(next_control, + local_index=self.local_index, + global_index=self.global_index) + stages.append(stage) else: # strong constraint # increment stage indices if self.index > 0: - self.ctx.local_index += 1 - self.ctx.global_index += 1 + self.local_index += 1 + self.global_index += 1 # stop after we've recorded all stages if self.index >= self.nstages: @@ -666,18 +672,24 @@ class WeakObservationStage: control The control x_{i-1} at the beginning of the stage - index - Optional integer to name controls and functionals with + local_index + The index of this stage in the timeseries on the + local ensemble member. + + global_index + The index of this stage in the global timeseries. """ def __init__(self, control: Control, - index: Optional[int] = None): + local_index: Optional[int] = None, + global_index: Optional[int] = None): # "control" to use as initial condition. # Not actual `Control` for consistency with strong constraint self.control = control.control self.controls = Enlist(control) - self.index = index + self.local_index = local_index + self.global_index = global_index set_working_tape() self._stage_tape = get_working_tape() @@ -729,18 +741,18 @@ def set_observation(self, state: OverloadedType, with stop_annotating(): # smuggle initial guess at this time into the control without the tape seeing self.controls.append(Control(state._ad_copy())) - if self.index: - _rename(self.controls[-1].control, f"Control_{self.index}") + if self.global_index: + _rename(self.controls[-1].control, f"Control_{self.global_index}") # model error links time-chunks by depending on both the # previous and current controls # RF to recalculate error vector (M_i - x_i) names = { - 'functional_name': f"model_err_vec_{self.index}", - 'control_name': [f"state_{self.index}_copy", - f"Control_{self.index}_model_copy"] - } if self.index else {} + 'functional_name': f"model_err_vec_{self.global_index}", + 'control_name': [f"state_{self.global_index}_copy", + f"Control_{self.global_index}_model_copy"] + } if self.global_index else {} self.model_error = isolated_rf( operation=lambda controls: _ad_sub(*controls), @@ -749,8 +761,8 @@ def set_observation(self, state: OverloadedType, # RF to recalculate inner product |M_i - x_i|_Q names = { - 'control_name': f"model_err_vec_{self.index}_copy" - } if self.index else {} + 'control_name': f"model_err_vec_{self.global_index}_copy" + } if self.global_index else {} self.model_norm = isolated_rf( operation=forward_model_iprod, @@ -761,9 +773,9 @@ def set_observation(self, state: OverloadedType, # RF to recalculate error vector (H(x_i) - y_i) names = { - 'functional_name': f"obs_err_vec_{self.index}", - 'control_name': f"Control_{self.index}_obs_copy" - } if self.index else {} + 'functional_name': f"obs_err_vec_{self.global_index}", + 'control_name': f"Control_{self.global_index}_obs_copy" + } if self.global_index else {} self.observation_error = isolated_rf( operation=observation_err, @@ -772,8 +784,8 @@ def set_observation(self, state: OverloadedType, # RF to recalculate inner product |H(x_i) - y_i|_R names = { - 'functional_name': "obs_err_vec_{self.index}_copy" - } if self.index else {} + 'functional_name': "obs_err_vec_{self.global_index}_copy" + } if self.global_index else {} self.observation_norm = isolated_rf( operation=observation_iprod, control=self.observation_error.functional, From aba4f6d67a48217cc0030efa43c10f199d98226a Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 20 Nov 2024 12:46:38 +0000 Subject: [PATCH 04/23] ensemble sequential context manager --- .../adjoint/all_at_once_reduced_functional.py | 99 ++++++++----------- firedrake/ensemble.py | 62 +++++++++++- 2 files changed, 101 insertions(+), 60 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 54874d9d2e..977764722b 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,5 +1,5 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ - stop_annotating, no_annotations, get_working_tape, set_working_tape + stop_annotating, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist from functools import wraps, cached_property from typing import Callable, Optional @@ -9,7 +9,7 @@ __all__ = ['AllAtOnceReducedFunctional'] -@set_working_tape(decorator=True) +@set_working_tape() def isolated_rf(operation, control, functional_name=None, control_name=None): @@ -75,18 +75,6 @@ def _ad_sub(left, right): return result -def _scalarSend(comm, x, **kwargs): - from numpy import ones - comm.Send(x*ones(1, dtype=type(x)), **kwargs) - - -def _scalarRecv(comm, dtype=float, **kwargs): - from numpy import zeros - xtmp = zeros(1, dtype=dtype) - comm.Recv(xtmp, **kwargs) - return xtmp[0] - - class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. @@ -260,7 +248,7 @@ def __getattr__(self, attr): return getattr(self.strong_reduced_functional, attr) @sc_passthrough - @no_annotations + @stop_annotating() def __call__(self, values: OverloadedType): """Computes the reduced functional with supplied control value. @@ -335,7 +323,7 @@ def __call__(self, values: OverloadedType): return J @sc_passthrough - @no_annotations + @stop_annotating() def derivative(self, adj_input: float = 1.0, options: dict = {}): """Returns the derivative of the functional w.r.t. the control. Using the adjoint method, the derivative of the functional with @@ -428,7 +416,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): return derivatives @sc_passthrough - @no_annotations + @stop_annotating() def hessian(self, m_dot: OverloadedType, options: dict = {}): """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. @@ -460,7 +448,7 @@ def hessian(self, m_dot: OverloadedType, options: dict = {}): """ raise ValueError("Not implemented yet") - @no_annotations + @stop_annotating() def hessian_matrix(self): # Other reduced functionals don't have this. if not self.weak_constraint: @@ -477,59 +465,52 @@ def _accumulate_functional(self, val): self._accumulation_started = True @contextmanager - def recording_stages(self, sequential=True, **kwargs): + def recording_stages(self, sequential=True, **stage_kwargs): if not sequential: raise ValueError("Recording stages concurrently not yet implemented") - # indices of stage in global and local list - stage_kwargs = {k: v for k, v in kwargs.items()} - # record over ensemble if self.weak_constraint: trank = self.trank - if trank == 0: - global_index = 0 + global_index = 0 + with stop_annotating(): + xhalo = self.x[0]._ad_copy() - # later ranks recv forward state and kwargs - if trank > 0: - tcomm = self.ensemble.ensemble_comm - src = trank-1 - with stop_annotating(): - self.ensemble.recv(self.xprev, source=src, tag=trank+000) + # add our data onto the users context data + ekwargs = {k: v for k, v in stage_kwargs.items()} + ekwargs['global_index'] = global_index + ekwargs['xhalo'] = xhalo - for i, (k, v) in enumerate(stage_kwargs.items()): - stage_kwargs[k] = _scalarRecv( - tcomm, dtype=type(v), source=src, tag=trank+i*100) + # proceed one ensemble rank at a time + with self.ensemble.sequential(**ekwargs) as ectx: - global_index = _scalarRecv( - tcomm, dtype=int, source=src, tag=trank+i*9000) + # later ranks start from halo + if trank == 0: + controls = self._controls + else: + controls = [self._control_prev, *self._controls] + with stop_annotating(): + controls[0].assign(ectx.xhalo) - # subsequent ranks start from halo - controls = self._controls if trank == 0 else [self._control_prev, *self._controls] + # grab the user's data from the ensemble context + local_stage_kwargs = { + k: getattr(ectx, k) for k in stage_kwargs.keys() + } - stage_sequence = ObservationStageSequence( - controls, self, global_index, stage_kwargs, sequential) + # initialise iterator for local stages + stage_sequence = ObservationStageSequence( + controls, self, ectx.global_index, + local_stage_kwargs, sequential) - yield stage_sequence + # let the user record the local stages + yield stage_sequence - # send forward state and kwargs - if self.ensemble and trank != self.nchunks - 1: + # send the state forward with stop_annotating(): - tcomm = self.ensemble.ensemble_comm - dst = trank+1 - state = self.stages[-1].controls[1].control - self.ensemble.send(state, dest=dst, tag=dst+000) - - for i, k in enumerate(stage_kwargs.keys()): - v = getattr(stage_sequence.ctx, k) - _scalarSend( - tcomm, v, dest=dst, tag=dst+i*100) - - _scalarSend( - tcomm, global_index, dest=dst, tag=dst+i*9000) + ectx.xhalo.assign(state) else: # strong constraint @@ -558,11 +539,11 @@ def __iter__(self): def __next__(self): if self.weak_constraint: + stages = self.aaorf.stages + # start of the next stage next_control = self.controls[self.index] - stages = self.aaorf.stages - # smuggle state forward and increment stage indices if self.index > 0: self.local_index += 1 @@ -797,7 +778,7 @@ def set_observation(self, state: OverloadedType, # stop the stage tape recording anything else set_working_tape() - @no_annotations + @stop_annotating() def __call__(self, values: OverloadedType, rftype: Optional[str] = None): """Computes the reduced functional with supplied control value. @@ -836,7 +817,7 @@ def __call__(self, values: OverloadedType, return J - @no_annotations + @stop_annotating() def derivative(self, adj_input: float = 1.0, options: dict = {}, rftype: Optional[str] = None): """Returns the derivative of the functional w.r.t. the control. @@ -907,7 +888,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, return derivatives - @no_annotations + @stop_annotating() def hessian(self, m_dot: OverloadedType, options: dict = {}, rftype: Optional[str] = None): """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index f847be51bf..5bb5519044 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -1,8 +1,11 @@ import weakref +from contextlib import contextmanager +from itertools import zip_longest from firedrake.petsc import PETSc +from firedrake.function import Function +from firedrake.cofunction import Cofunction from pyop2.mpi import MPI, internal_comm -from itertools import zip_longest __all__ = ("Ensemble", ) @@ -283,3 +286,60 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) for dat in frecv.dat]) return requests + + @contextmanager + def sequential(self, **kwargs): + """ + Context manager for executing code on each ensemble + member in turn. + + Any data in `kwargs` will be made available in the context + and will be communicated forward after each ensemble member + exits. Firedrake Functions/Cofunctions will be send with the + corresponding Ensemble methods. + + with ensemble.sequential(index=0) as ctx: + print(ensemble.ensemble_comm.rank, ctx.index) + ctx.index += 2 + + Would print: + 0 0 + 1 2 + 2 4 + 3 6 + ... etc ... + + """ + rank = self.ensemble_comm.rank + first_rank = (rank == 0) + last_rank = (rank == self.ensemble_comm.size - 1) + + if not first_rank: + src = rank - 1 + for i, (k, v) in enumerate(kwargs.items()): + recv_kwargs = {'source': src, 'tag': rank+i*100} + if isinstance(v, (Function, Cofunction)): + self.recv(kwargs[k], **recv_kwargs) + else: + kwargs[k] = self.ensemble_comm.recv( + **recv_kwargs) + + ctx = _EnsembleContext(**kwargs) + + yield ctx + + if not last_rank: + dst = rank + 1 + for i, v in enumerate((getattr(ctx, k) + for k in kwargs.keys())): + send_kwargs = {'dest': dst, 'tag': dst+i*100} + if isinstance(v, (Function, Cofunction)): + self.send(v, **send_kwargs) + else: + self.ensemble_comm.send(v, **send_kwargs) + + +class _EnsembleContext: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) From 60374835e0fceeb2dca666215d1ac5898ca7bc49 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Wed, 20 Nov 2024 14:29:17 +0000 Subject: [PATCH 05/23] REMOVE BEFORE MERGE: pyadjoint branch --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0fff428f43..ae6b7d48fe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,6 +83,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch pyadjoint JHopeCollins/set_working_tape_decorator || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From b42c8fc8b2fbe75621a0598d1911651466503c1e Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Wed, 20 Nov 2024 14:29:44 +0000 Subject: [PATCH 06/23] REMOVE BEFORE MERGE: pyadjoint branch --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ae6b7d48fe..590f52b192 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,7 +83,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ - --package-branch pyadjoint JHopeCollins/set_working_tape_decorator + --package-branch pyadjoint JHopeCollins/set_working_tape_decorator \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From 61f3e75820b832a651ca6cff5fc88eb7eec31bf6 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 26 Nov 2024 06:46:11 +0000 Subject: [PATCH 07/23] Update .github/workflows/build.yml --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 590f52b192..0fff428f43 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,7 +83,6 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ - --package-branch pyadjoint JHopeCollins/set_working_tape_decorator \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From d581787611ab23a1b336bbd2d00b3dab697ee3e9 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 29 Nov 2024 15:27:41 +0000 Subject: [PATCH 08/23] aaorf - ensemblefunction and pyadjoint minimize --- .../adjoint/all_at_once_reduced_functional.py | 3 +- firedrake/adjoint_utils/ensemblefunction.py | 30 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 977764722b..0b78f5e236 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -155,6 +155,7 @@ def __init__(self, control: Control, self.x = x self.control = control + self.controls = [control] self._controls = tuple(Control(xi) for xi in x) # first control on rank 0 is initial conditions, not end of observation stage @@ -267,7 +268,7 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - self.control.assign(values) + self.control.assign(values[0] if isinstance(values, list) else values) trank = self.trank # first "control" for later ranks is the halo diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py index ee1e4a4ce3..1c39e04517 100644 --- a/firedrake/adjoint_utils/ensemblefunction.py +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -1,4 +1,7 @@ from pyadjoint.overloaded_type import OverloadedType +from firedrake.petsc import PETSc +from .checkpointing import disk_checkpointing + from functools import wraps @@ -14,11 +17,27 @@ def wrapper(self, *args, **kwargs): @staticmethod def _ad_to_list(m): - raise ValueError("NotImplementedYet") + with m.vec_ro() as gvec: + lcomm = PETSc.COMM_SELF + gsize = gvec.size + lvec = PETSc.Vec().createSeq(gsize, comm=lcomm) + is_ = PETSc.IS().createStride(gsize, 0, 1, comm=lcomm) + + mode = PETSc.InsertMode.INSERT_VALUES + scatter = PETSc.Scatter().create(gvec, is_, lvec, None) + scatter.scatterBegin(gvec, lvec, addv=mode) + scatter.scatterEnd(gvec, lvec, addv=mode) + + return lvec.array_r.tolist() @staticmethod def _ad_assign_numpy(dst, src, offset): - raise ValueError("NotImplementedYet") + with dst.vec_wo() as vec: + begin, end = vec.owner_range + src_array = src[offset + begin: offset + end] + vec.array[:] = src_array + offset += vec.size + return dst, offset def _ad_dot(self, other, options=None): # local dot product @@ -56,6 +75,13 @@ def _ad_copy(self): def _ad_convert_riesz(self, value, options=None): raise ValueError("NotImplementedYet") + def _ad_create_checkpoint(self): + if disk_checkpointing(): + raise NotImplementedError( + "Disk checkpointing not implemented for EnsembleFunctions") + else: + return self.copy() + def _ad_from_petsc(self, vec): with self.vec_wo as self_v: vec.copy(result=self_v) From 6360250ca3a1ad97fce69439c67645a98fe26dad Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 4 Dec 2024 12:25:16 +0000 Subject: [PATCH 09/23] 4dvar rf test --- firedrake/adjoint/__init__.py | 1 + .../adjoint/all_at_once_reduced_functional.py | 200 +++++++---- firedrake/adjoint_utils/ensemblefunction.py | 6 + .../test_4dvar_reduced_functional.py | 325 ++++++++++++++++++ 4 files changed, 458 insertions(+), 74 deletions(-) create mode 100644 tests/firedrake/regression/test_4dvar_reduced_functional.py diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index c48b990420..08f6a3c2ec 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,6 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.all_at_once_reduced_functional import AllAtOnceReducedFunctional # noqa F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 0b78f5e236..42f3f90965 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -5,11 +5,12 @@ from typing import Callable, Optional from contextlib import contextmanager from mpi4py import MPI +from firedrake.petsc import PETSc __all__ = ['AllAtOnceReducedFunctional'] -@set_working_tape() +# @set_working_tape() # ends up using old_tape = None because evaluates when imported - need separate decorator def isolated_rf(operation, control, functional_name=None, control_name=None): @@ -17,28 +18,29 @@ def isolated_rf(operation, control, Return a ReducedFunctional where the functional is `operation` applied to a copy of `control`, and the tape contains only `operation`. """ - controls = Enlist(control) - control_names = Enlist(control_name) + with set_working_tape(): + controls = Enlist(control) + control_names = Enlist(control_name) - with stop_annotating(): - control_copies = [control._ad_copy() for control in controls] + with stop_annotating(): + control_copies = [control._ad_copy() for control in controls] - if control_names: - for control, name in zip(control_copies, control_names): - _rename(control, name) + if control_names: + for control, name in zip(control_copies, control_names): + _rename(control, name) - if len(control_copies) == 1: - functional = operation(control_copies[0]) - control = Control(control_copies[0]) - else: - functional = operation(control_copies) - control = [Control(control) for control in control_copies] + if len(control_copies) == 1: + functional = operation(control_copies[0]) + control = Control(control_copies[0]) + else: + functional = operation(control_copies) + control = [Control(control) for control in control_copies] - if functional_name: - _rename(functional, functional_name) + if functional_name: + _rename(functional, functional_name) - return ReducedFunctional( - functional, control) + return ReducedFunctional( + functional, control) def sc_passthrough(func): @@ -151,15 +153,16 @@ def __init__(self, control: Control, self.trank = ensemble.ensemble_comm.rank if ensemble else 0 self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 - x = control.control.subfunctions - self.x = x + self._cbuf = control.copy() + _x = self._cbuf.subfunctions + self._x = _x + self._controls = tuple(Control(xi) for xi in _x) self.control = control self.controls = [control] - self._controls = tuple(Control(xi) for xi in x) # first control on rank 0 is initial conditions, not end of observation stage - self.nlocal_stages = len(x) - (1 if self.trank == 0 else 0) + self.nlocal_stages = len(_x) - (1 if self.trank == 0 else 0) self.stages = [] # The record of each observation stage @@ -169,7 +172,7 @@ def __init__(self, control: Control, # RF to recalculate error vector (x_0 - x_b) self.background_error = isolated_rf( operation=lambda x0: _ad_sub(x0, self.background), - control=x[0], + control=_x[0], functional_name="bkg_err_vec", control_name="Control_0_bkg_copy") @@ -184,7 +187,7 @@ def __init__(self, control: Control, # RF to recalculate error vector (H(x_0) - y_0) self.initial_observation_error = isolated_rf( operation=observation_err, - control=x[0], + control=_x[0], functional_name="obs_err_vec_0", control_name="Control_0_obs_copy") @@ -197,13 +200,13 @@ def __init__(self, control: Control, # create halo for previous state if self.ensemble and self.trank != 0: with stop_annotating(): - self.xprev = x[0]._ad_copy() + self.xprev = _x[0]._ad_copy() self._control_prev = Control(self.xprev) # halo for the derivative from the next chunk if self.ensemble and self.trank != self.nchunks - 1: with stop_annotating(): - self.xnext = x[0]._ad_copy() + self.xnext = _x[0]._ad_copy() else: self._annotate_accumulation = True @@ -268,14 +271,21 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - self.control.assign(values[0] if isinstance(values, list) else values) + value = values[0] if isinstance(values, list) else values + + if not isinstance(value, type(self.control.control)): + raise ValueError(f"Value must be of type {type(self.control.control)} not type {type(value)}") + + self.control.update(value) + self._cbuf.assign(value) + trank = self.trank # first "control" for later ranks is the halo if self.ensemble and trank != 0: - x = [self.xprev, *self.x] + x = [self.xprev, *self._x] else: - x = [*self.x] + x = [*self._x] # post messages for control of forward model propogation on next chunk if self.ensemble: @@ -346,49 +356,53 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): Should be an instance of the same type as the control. """ trank = self.trank - # create a list of overloaded types to put derivative into - derivatives = self.control._ad_copy() - derivatives.zero() - - if self.ensemble and trank != 0: - self.xprev.zero() - derivs = [self.xprev, *derivatives.subfunctions] - else: - derivs = [*derivatives.subfunctions] # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} intermediate_options = { - 'riesz_representation': None, + 'riesz_representation': 'l2', **{k: v for k, v in options.items() if (k != 'riesz_representation')} } - # initial condition derivatives - if trank == 0: - bkg_deriv = self.background_norm.derivative( - adj_input=adj_input, options=intermediate_options) - - derivs[0] += self.background_error.derivative( - adj_input=bkg_deriv, options=options) - - # observations at time 0 - if self.initial_observations: - obs_deriv = self.initial_observation_norm.derivative( - adj_input=adj_input, options=intermediate_options) - - derivs[0] += self.initial_observation_error.derivative( - adj_input=obs_deriv, options=options) - # evaluate first forward model, which contributes to previous chunk sderiv0 = self.stages[0].derivative( adj_input=adj_input, options=options) + # create the derivative in the right primal or dual space + from ufl.duals import is_primal, is_dual + if is_primal(sderiv0[0]): + from firedrake.ensemblefunction import EnsembleFunction + derivatives = EnsembleFunction( + self.ensemble, self.control.local_function_spaces) + else: + if not is_dual(sderiv0[0]): + raise ValueError( + "Do not know how to handle stage derivative which is not primal or dual") + from firedrake.ensemblefunction import EnsembleCofunction + derivatives = EnsembleCofunction( + self.ensemble, [V.dual() for V in self.control.local_function_spaces]) + + derivatives.zero() + + if self.ensemble: + with stop_annotating(): + xprev = derivatives.subfunctions[0]._ad_copy() + xnext = derivatives.subfunctions[0]._ad_copy() + xprev.zero() + xnext.zero() + if trank != 0: + derivs = [xprev, *derivatives.subfunctions] + else: + derivs = [*derivatives.subfunctions] + + # start accumulating the complete derivative derivs[0] += sderiv0[0] derivs[1] += sderiv0[1] # post the derivative halo exchange if self.ensemble: + # halos sent backward in time src = trank + 1 dst = trank - 1 @@ -398,7 +412,23 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): if trank != self.nchunks - 1: recv_reqs = self.ensemble.irecv( - self.xnext, source=src, tag=trank) + xnext, source=src, tag=trank) + + # initial condition derivatives + if trank == 0: + bkg_deriv = self.background_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.background_error.derivative( + adj_input=bkg_deriv, options=options) + + # observations at time 0 + if self.initial_observations: + obs_deriv = self.initial_observation_norm.derivative( + adj_input=adj_input, options=intermediate_options) + + derivs[0] += self.initial_observation_error.derivative( + adj_input=obs_deriv, options=options) # # evaluate all forward models on chunk except first while halo in flight for i in range(1, len(self.stages)): @@ -412,7 +442,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): if self.ensemble: if trank != self.nchunks - 1: MPI.Request.Waitall(recv_reqs) - derivs[-1] += self.xnext + derivs[-1] += xnext return derivatives @@ -475,13 +505,17 @@ def recording_stages(self, sequential=True, **stage_kwargs): trank = self.trank - global_index = 0 + # index of "previous" stage and observation in global series + global_index = -1 + observation_index = 0 if self.initial_observations else -1 with stop_annotating(): - xhalo = self.x[0]._ad_copy() + xhalo = self._x[0]._ad_copy() - # add our data onto the users context data + # add our data onto the user's context data ekwargs = {k: v for k, v in stage_kwargs.items()} ekwargs['global_index'] = global_index + ekwargs['observation_index'] = observation_index + ekwargs['xhalo'] = xhalo # proceed one ensemble rank at a time @@ -503,6 +537,7 @@ def recording_stages(self, sequential=True, **stage_kwargs): # initialise iterator for local stages stage_sequence = ObservationStageSequence( controls, self, ectx.global_index, + ectx.observation_index, local_stage_kwargs, sequential) # let the user record the local stages @@ -512,6 +547,12 @@ def recording_stages(self, sequential=True, **stage_kwargs): with stop_annotating(): state = self.stages[-1].controls[1].control ectx.xhalo.assign(state) + # grab the user's information to send forward + for k in local_stage_kwargs.keys(): + setattr(ectx, k, getattr(stage_sequence.ctx, k)) + # increment the global indices for the last local stage + ectx.global_index = self.stages[-1].global_index + ectx.observation_index = self.stages[-1].observation_index else: # strong constraint @@ -523,16 +564,17 @@ class ObservationStageSequence: def __init__(self, controls: Control, aaorf: AllAtOnceReducedFunctional, global_index: int, + observation_index: int, stage_kwargs: dict = None, sequential: bool = True): self.controls = controls self.nstages = len(controls) - 1 self.aaorf = aaorf self.ctx = StageContext(**(stage_kwargs or {})) - self.index = 0 self.weak_constraint = aaorf.weak_constraint self.global_index = global_index - self.local_index = 0 + self.observation_index = observation_index + self.local_index = -1 def __iter__(self): return self @@ -542,34 +584,36 @@ def __next__(self): if self.weak_constraint: stages = self.aaorf.stages + # increment global indices + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + # start of the next stage - next_control = self.controls[self.index] + next_control = self.controls[self.local_index] # smuggle state forward and increment stage indices - if self.index > 0: - self.local_index += 1 - self.global_index += 1 - + if self.local_index > 0: state = stages[-1].controls[1].control with stop_annotating(): next_control.control.assign(state) # stop after we've recorded all stages - if self.index >= self.nstages: + if self.local_index >= self.nstages: raise StopIteration - self.index += 1 stage = WeakObservationStage(next_control, local_index=self.local_index, - global_index=self.global_index) + global_index=self.global_index, + observation_index=self.observation_index) stages.append(stage) else: # strong constraint # increment stage indices - if self.index > 0: - self.local_index += 1 - self.global_index += 1 + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 # stop after we've recorded all stages if self.index >= self.nstages: @@ -583,6 +627,7 @@ def __next__(self): stage = StrongObservationStage(control, self.aaorf) self._prev_stage = stage + return stage, self.ctx @@ -661,10 +706,16 @@ class WeakObservationStage: global_index The index of this stage in the global timeseries. + observation_index + The index of the observation at the end of this stage in + the global timeseries. May be different from global_index if + an observation is taken at the initial time. + """ def __init__(self, control: Control, local_index: Optional[int] = None, - global_index: Optional[int] = None): + global_index: Optional[int] = None, + observation_index: Optional[int] = None): # "control" to use as initial condition. # Not actual `Control` for consistency with strong constraint self.control = control.control @@ -672,6 +723,7 @@ def __init__(self, control: Control, self.controls = Enlist(control) self.local_index = local_index self.global_index = global_index + self.observation_index = observation_index set_working_tape() self._stage_tape = get_working_tape() diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py index 1c39e04517..ba6882f054 100644 --- a/firedrake/adjoint_utils/ensemblefunction.py +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -82,6 +82,12 @@ def _ad_create_checkpoint(self): else: return self.copy() + def _ad_restore_at_checkpoint(self, checkpoint): + if isinstance(checkpoint, type(self)): + return checkpoint + raise NotImplementedError( + "Checkpointing not implemented for EnsembleFunctions") + def _ad_from_petsc(self, vec): with self.vec_wo as self_v: vec.copy(result=self_v) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py new file mode 100644 index 0000000000..bf0530293f --- /dev/null +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -0,0 +1,325 @@ +import pytest +import firedrake as fd +from firedrake.__future__ import interpolate +from firedrake.adjoint import ( + continue_annotation, pause_annotation, stop_annotating, set_working_tape, + Control, taylor_test, ReducedFunctional, AllAtOnceReducedFunctional) + + +def function_space(comm): + """DG0 periodic advection""" + mesh = fd.PeriodicUnitIntervalMesh(nx, comm=comm) + return fd.FunctionSpace(mesh, "DG", 0) + + +def timestepper(V): + """Implicit midpoint timestepper for the advection equation""" + qn = fd.Function(V, name="qn") + qn1 = fd.Function(V, name="qn1") + + def mass(q, phi): + return fd.inner(q, phi)*fd.dx + + def tendency(q, phi): + u = fd.as_vector([vconst]) + n = fd.FacetNormal(V.mesh()) + un = fd.Constant(0.5)*(fd.dot(u, n) + abs(fd.dot(u, n))) + return (- q*fd.div(phi*u)*fd.dx + + fd.jump(phi)*fd.jump(un*q)*fd.dS) + + # midpoint rule + q = fd.TrialFunction(V) + phi = fd.TestFunction(V) + + qh = fd.Constant(0.5)*(q + qn) + eqn = mass(q - qn, phi) + fd.Constant(dt)*tendency(qh, phi) + + stepper = fd.LinearVariationalSolver( + fd.LinearVariationalProblem( + fd.lhs(eqn), fd.rhs(eqn), qn1, + constant_jacobian=True)) + + return qn, qn1, stepper + + +def prod2(w): + """generate weighted inner products to pass to FourDVarReducedFunctional""" + def n2(x): + return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx) + return n2 + + +prodB = prod2(0.1) # background error +prodR = prod2(10.) # observation error +prodQ = prod2(1.0) # model error + + +"""Advecting velocity""" +velocity = 1 +vconst = fd.Constant(velocity) + +"""Number of cells""" +nx = 16 + +"""Timestep size""" +cfl = 2.3523 +dx = 1.0/nx +dt = cfl*dx/velocity + +"""How many times / how often we take observations +(one extra at initial time)""" +observation_frequency = 5 +observation_n = 6 +observation_times = [i*observation_frequency*dt + for i in range(observation_n+1)] + + +def nlocal_observations(ensemble): + """How many observations on the current ensemble member""" + esize = ensemble.ensemble_comm.size + erank = ensemble.ensemble_comm.rank + if esize == 1: + return observation_n + 1 + assert (observation_n % esize == 0), "Must be able to split observations across ensemble" # noqa: E501 + return observation_n//esize + (1 if erank == 0 else 0) + + +def analytic_solution(V, t, mag=1.0, phase=0.0): + """Exact advection of sin wave after time t""" + x, = fd.SpatialCoordinate(V.mesh()) + return fd.Function(V).interpolate( + mag*fd.sin(2*fd.pi*((x + phase) - vconst*t))) + + +def analytic_series(V, tshift=0.0, mag=1.0, phase=0.0, ensemble=None): + """Timeseries of the analytic solution""" + series = [analytic_solution(V, t+tshift, mag=mag, phase=phase) + for t in observation_times] + + if ensemble is None: + return series + else: + nlocal_obs = nlocal_observations(ensemble) + rank = ensemble.ensemble_comm.rank + offset = (0 if rank == 0 else rank*nlocal_obs + 1) + + efunc = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + for e, s in zip(efunc.subfunctions, + series[offset:offset+nlocal_obs]): + e.assign(s) + return efunc + + +def observation_errors(V): + """List of functions to evaluate the observation error + at each observation time""" + + observation_locations = [ + [x] for x in [0.13, 0.18, 0.34, 0.36, 0.49, 0.61, 0.72, 0.99] + ] + + observation_mesh = fd.VertexOnlyMesh(V.mesh(), observation_locations) + Vobs = fd.FunctionSpace(observation_mesh, "DG", 0) + + # observation operator + def H(x): + return fd.assemble(interpolate(x, Vobs)) + + # ground truth + targets = analytic_series(V) + + # take observations + y = [H(x) for x in targets] + + # generate function to evaluate observation error at observation time i + def observation_error(i): + def obs_err(x): + return fd.Function(Vobs).assign(H(x) - y[i]) + return obs_err + + return observation_error + + +def background(V): + """Prior for initial condition""" + return analytic_solution(V, t=0, mag=0.9, phase=0.1) + + +def m(V, ensemble=None): + """The expansion points for the Taylor test""" + return analytic_series(V, tshift=0.1, mag=1.1, phase=-0.2, + ensemble=ensemble) + + +def h(V, ensemble=None): + """The perturbation direction for the Taylor test""" + return analytic_series(V, tshift=0.3, mag=0.1, phase=0.3, + ensemble=ensemble) + + +def fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + controls = [fd.Function(V) + for _ in range(len(observation_times))] + + # Prior + bkg = background(V) + + controls[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # background error + J = prodB(controls[0] - bkg) + + # initial observation error + J += prodR(obs_errors(0)(controls[0])) + + # record observation stages + for i in range(1, len(controls)): + qn.assign(controls[i-1]) + + # forward model propogation + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # we need to smuggle the state over to next + # control without the tape seeing so that we + # can continue the timeseries through the next + # stage but with the tape thinking that the + # forward model in each stage is independent. + with stop_annotating(): + controls[i].assign(qn) + + # model error for this stage + J += prodQ(qn - controls[i]) + + # observation error + J += prodR(obs_errors(i)(controls[i])) + + pause_annotation() + + Jhat = ReducedFunctional( + J, [Control(c) for c in controls]) + + return Jhat + + +def fdvar_firedrake(V, ensemble): + """Build an AllAtOnceReducedFunctional for the 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + + nlocal_obs = nlocal_observations(ensemble) + + control = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + # Prior + bkg = background(V) + + if ensemble.ensemble_comm.rank == 0: + control.subfunctions[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = AllAtOnceReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=True) + + # record observation stages + with Jhat.recording_stages() as stages: + + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # take observation + obs_err = obs_errors(stage.observation_index) + stage.set_observation(qn, obs_err, + observation_iprod=prodR, + forward_model_iprod=prodQ) + + pause_annotation() + + return Jhat + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_advection(): + main_test_advection() + + +def main_test_advection(): + global_comm = fd.COMM_WORLD + if global_comm.size in (1, 2): # time serial + nspace = global_comm.size + elif global_comm.size == 3: # time parallel + nspace = 1 + elif global_comm.size == 4: # space-time parallel + nspace = 2 + + ensemble = fd.Ensemble(global_comm, nspace) + V = function_space(ensemble.comm) + + erank = ensemble.ensemble_comm.rank + + # only setup the reference pyadjoint rf on the first ensemble member + if erank == 0: + Jhat_pyadj = fdvar_pyadjoint(V) + mp = m(V) + hp = h(V) + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 + + Jpm = ensemble.ensemble_comm.bcast(Jhat_pyadj(mp) if erank == 0 else None) + Jph = ensemble.ensemble_comm.bcast(Jhat_pyadj(hp) if erank == 0 else None) + + Jhat_aaorf = fdvar_firedrake(V, ensemble) + + ma = m(V, ensemble) + ha = h(V, ensemble) + + eps = 1e-12 + # Does evaluating the functional match the reference rf? + assert abs(Jpm - Jhat_aaorf(ma)) < eps + assert abs(Jph - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor test + # should mean we match the derivative too. + assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + + +if __name__ == '__main__': + main_test_advection() From d15eba60f70ef769c47de9ee19f9d58f65a79387 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 17 Dec 2024 08:45:32 +0000 Subject: [PATCH 10/23] fix assign visitor for cofunction?? --- firedrake/assign.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/firedrake/assign.py b/firedrake/assign.py index 8d1b30681a..892ed5b1b0 100644 --- a/firedrake/assign.py +++ b/firedrake/assign.py @@ -100,6 +100,9 @@ def component_tensor(self, o, a, _): def coefficient(self, o): return ((o, 1),) + def cofunction(self, o): + return ((o, 1),) + def constant_value(self, o): return ((o, 1),) From 74a09747bf4570f30e874a8b65e772386df2cf0c Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 17 Dec 2024 08:46:01 +0000 Subject: [PATCH 11/23] aaorf - pass riesz_representation through chained rfs properly --- .../adjoint/all_at_once_reduced_functional.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 42f3f90965..437bd4623f 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -5,7 +5,6 @@ from typing import Callable, Optional from contextlib import contextmanager from mpi4py import MPI -from firedrake.petsc import PETSc __all__ = ['AllAtOnceReducedFunctional'] @@ -77,6 +76,20 @@ def _ad_sub(left, right): return result +def _intermediate_options(final_options): + """ + Options set for the intermediate stages of a chain of ReducedFunctionals + + Takes all elements of the final_options except riesz_representation, + which is set to prevent returning derivatives to the primal space. + """ + return { + 'riesz_representation': None, + **{k: v for k, v in final_options.items() + if (k != 'riesz_representation')} + } + + class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. @@ -359,11 +372,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - intermediate_options = { - 'riesz_representation': 'l2', - **{k: v for k, v in options.items() - if (k != 'riesz_representation')} - } + intermediate_options = _intermediate_options(options) # evaluate first forward model, which contributes to previous chunk sderiv0 = self.stages[0].derivative( @@ -627,7 +636,6 @@ def __next__(self): stage = StrongObservationStage(control, self.aaorf) self._prev_stage = stage - return stage, self.ctx @@ -903,11 +911,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - intermediate_options = { - 'riesz_representation': None, - **{k: v for k, v in options.items() - if (k != 'riesz_representation')} - } + intermediate_options = _intermediate_options(options) if (rftype is None) or (rftype == 'model'): # derivative of reduction @@ -922,8 +926,16 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, dm_forward = self.forward_model.derivative(adj_input=dm_errors[0], options=options) + sentinel = -12345 + riesz_map = options.get('riesz_representation', sentinel) derivatives.append(dm_forward) - derivatives.append(dm_errors[1].riesz_representation()) + if riesz_map != sentinel: + if riesz_map is None: + derivatives.append(dm_errors[1]) + else: + derivatives.append(dm_errors[1].riesz_representation(riesz_map)) + else: + derivatives.append(dm_errors[1].riesz_representation()) if (rftype is None) or (rftype == 'obs'): # derivative of reduction From 6c4c5b74de76382694c8baa91f9165eb913cfef5 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 17 Dec 2024 10:44:04 +0000 Subject: [PATCH 12/23] aaorf - delegate converting derivative from intermediate type to pyadjoint --- firedrake/adjoint/all_at_once_reduced_functional.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 437bd4623f..0dd2ccafe8 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -926,16 +926,11 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, dm_forward = self.forward_model.derivative(adj_input=dm_errors[0], options=options) - sentinel = -12345 - riesz_map = options.get('riesz_representation', sentinel) derivatives.append(dm_forward) - if riesz_map != sentinel: - if riesz_map is None: - derivatives.append(dm_errors[1]) - else: - derivatives.append(dm_errors[1].riesz_representation(riesz_map)) - else: - derivatives.append(dm_errors[1].riesz_representation()) + + # dm_errors is still in the dual space, so we need to convert it to the + # type that the user has requested - this will be the type of dm_forward. + derivatives.append(dm_forward._ad_convert_type(dm_errors[1], options)) if (rftype is None) or (rftype == 'obs'): # derivative of reduction From c22531facbb6742602edac097ce6f1c907175a92 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Mon, 23 Dec 2024 09:42:45 +0000 Subject: [PATCH 13/23] aaorf - restore strong constraint 4dvar --- .../adjoint/all_at_once_reduced_functional.py | 103 +++++++----- firedrake/ensemble.py | 5 +- .../test_4dvar_reduced_functional.py | 154 ++++++++++++++++-- 3 files changed, 202 insertions(+), 60 deletions(-) diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 0dd2ccafe8..db7d23642a 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,6 +1,9 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ stop_annotating, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist +from firedrake.function import Function +from firedrake.ensemblefunction import EnsembleFunction, EnsembleCofunction + from functools import wraps, cached_property from typing import Callable, Optional from contextlib import contextmanager @@ -93,17 +96,14 @@ def _intermediate_options(final_options): class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. - Creates either the strong constraint or weak constraint system incrementally + Creates either the strong constraint or weak constraint system by logging observations through the initial forward model run. - Warning: Weak constraint 4DVar not implemented yet. - Parameters ---------- control - The :class:`EnsembleFunction` for the control x_{i} at the initial - condition and at the end of each observation stage. + The :class:`EnsembleFunction` for the control x_{i} at the initial condition and at the end of each observation stage. background_iprod The inner product to calculate the background error functional @@ -150,17 +150,22 @@ def __init__(self, control: Control, self.weak_constraint = weak_constraint self.initial_observations = observation_err is not None - with stop_annotating(): - if background: - self.background = background._ad_copy() - else: - self.background = control.control.subfunctions[0]._ad_copy() - _rename(self.background, "Background") - if self.weak_constraint: self._annotate_accumulation = _annotate_accumulation self._accumulation_started = False + if not isinstance(control.control, EnsembleFunction): + raise TypeError( + "Control for weak constraint 4DVar must be an EnsembleFunction" + ) + + with stop_annotating(): + if background: + self.background = background._ad_copy() + else: + self.background = control.control.subfunctions[0]._ad_copy() + _rename(self.background, "Background") + ensemble = control.ensemble self.ensemble = ensemble self.trank = ensemble.ensemble_comm.rank if ensemble else 0 @@ -225,11 +230,21 @@ def __init__(self, control: Control, self._annotate_accumulation = True self._accumulation_started = False + if not isinstance(control.control, Function): + raise TypeError( + "Control for strong constraint 4DVar must be a Function" + ) + + with stop_annotating(): + if background: + self.background = background._ad_copy() + else: + self.background = control.control._ad_copy() + _rename(self.background, "Background") + # initial conditions guess to be updated self.controls = Enlist(control) - self.tape = get_working_tape() if tape is None else tape - # Strong constraint functional to be converted to ReducedFunctional later # penalty for straying from prior @@ -249,10 +264,10 @@ def strong_reduced_functional(self): before all observations are recorded. """ if self.weak_constraint: - msg = "Strong constraint ReducedFunctional not instantiated for weak constraint 4DVar" + msg = "Strong constraint ReducedFunctional cannot be instantiated for weak constraint 4DVar" raise AttributeError(msg) self._strong_reduced_functional = ReducedFunctional( - self._total_functional, self.controls, tape=self.tape) + self._total_functional, self.controls.delist(), tape=self.tape) return self._strong_reduced_functional def __getattr__(self, attr): @@ -381,14 +396,12 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # create the derivative in the right primal or dual space from ufl.duals import is_primal, is_dual if is_primal(sderiv0[0]): - from firedrake.ensemblefunction import EnsembleFunction derivatives = EnsembleFunction( self.ensemble, self.control.local_function_spaces) else: if not is_dual(sderiv0[0]): raise ValueError( "Do not know how to handle stage derivative which is not primal or dual") - from firedrake.ensemblefunction import EnsembleCofunction derivatives = EnsembleCofunction( self.ensemble, [V.dual() for V in self.control.local_function_spaces]) @@ -505,7 +518,7 @@ def _accumulate_functional(self, val): self._accumulation_started = True @contextmanager - def recording_stages(self, sequential=True, **stage_kwargs): + def recording_stages(self, sequential=True, nstages=None, **stage_kwargs): if not sequential: raise ValueError("Recording stages concurrently not yet implemented") @@ -566,7 +579,9 @@ def recording_stages(self, sequential=True, **stage_kwargs): else: # strong constraint yield ObservationStageSequence( - self.controls, self, stage_kwargs, sequential=True) + self.controls, self, global_index=-1, + observation_index=0 if self.initial_observations else -1, + stage_kwargs=stage_kwargs, nstages=nstages) class ObservationStageSequence: @@ -575,29 +590,34 @@ def __init__(self, controls: Control, global_index: int, observation_index: int, stage_kwargs: dict = None, - sequential: bool = True): + nstages: Optional[int] = None): self.controls = controls - self.nstages = len(controls) - 1 self.aaorf = aaorf self.ctx = StageContext(**(stage_kwargs or {})) self.weak_constraint = aaorf.weak_constraint self.global_index = global_index self.observation_index = observation_index self.local_index = -1 + self.nstages = (len(controls) - 1 if self.weak_constraint + else nstages) def __iter__(self): return self def __next__(self): + # increment global indices + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + + # stop after we've recorded all stages + if self.local_index >= self.nstages: + raise StopIteration + if self.weak_constraint: stages = self.aaorf.stages - # increment global indices - self.local_index += 1 - self.global_index += 1 - self.observation_index += 1 - # start of the next stage next_control = self.controls[self.local_index] @@ -607,10 +627,6 @@ def __next__(self): with stop_annotating(): next_control.control.assign(state) - # stop after we've recorded all stages - if self.local_index >= self.nstages: - raise StopIteration - stage = WeakObservationStage(next_control, local_index=self.local_index, global_index=self.global_index, @@ -619,21 +635,15 @@ def __next__(self): else: # strong constraint - # increment stage indices - self.local_index += 1 - self.global_index += 1 - self.observation_index += 1 - - # stop after we've recorded all stages - if self.index >= self.nstages: - raise StopIteration - self.index += 1 - # dummy control to "start" stage from - control = (self.aaorf.controls[0].control if self.index == 0 + control = (self.aaorf.controls[0].control if self.local_index == 0 else self._prev_stage.state) - stage = StrongObservationStage(control, self.aaorf) + stage = StrongObservationStage( + control, self.aaorf, + index=self.local_index, + observation_index=self.observation_index) + self._prev_stage = stage return stage, self.ctx @@ -658,9 +668,13 @@ class StrongObservationStage: """ def __init__(self, control: OverloadedType, - aaorf: AllAtOnceReducedFunctional): + aaorf: AllAtOnceReducedFunctional, + index: Optional[int] = None, + observation_index: Optional[int] = None): self.aaorf = aaorf self.control = control + self.index = index + self.observation_index = observation_index def set_observation(self, state: OverloadedType, observation_err: Callable[[OverloadedType], OverloadedType], @@ -691,6 +705,7 @@ def set_observation(self, state: OverloadedType, " constraint ReducedFunctional instantiated") self.aaorf._accumulate_functional( observation_iprod(observation_err(state))) + # save the user's state to hand back for beginning of next stage self.state = state diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index 5bb5519044..265ad8ba2b 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -291,12 +291,13 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r def sequential(self, **kwargs): """ Context manager for executing code on each ensemble - member in turn. + member consecutively by `ensemble_comm.rank`. Any data in `kwargs` will be made available in the context and will be communicated forward after each ensemble member - exits. Firedrake Functions/Cofunctions will be send with the + exits. Firedrake Functions/Cofunctions will be sent with the corresponding Ensemble methods. + For example: with ensemble.sequential(index=0) as ctx: print(ensemble.ensemble_comm.rank, ctx.index) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py index bf0530293f..894a25849c 100644 --- a/tests/firedrake/regression/test_4dvar_reduced_functional.py +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -2,8 +2,15 @@ import firedrake as fd from firedrake.__future__ import interpolate from firedrake.adjoint import ( - continue_annotation, pause_annotation, stop_annotating, set_working_tape, - Control, taylor_test, ReducedFunctional, AllAtOnceReducedFunctional) + continue_annotation, pause_annotation, stop_annotating, + set_working_tape, get_working_tape, Control, taylor_test, + ReducedFunctional, AllAtOnceReducedFunctional) + + +@pytest.fixture(autouse=True) +def clear_tape_teardown(): + yield + get_working_tape().clear_tape() def function_space(comm): @@ -159,8 +166,95 @@ def h(V, ensemble=None): ensemble=ensemble) -def fdvar_pyadjoint(V): - """Build a pyadjoint ReducedFunctional for the 4DVar system""" +def strong_fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the strong constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # prior data + bkg = background(V) + control = bkg.copy(deepcopy=True) + + # generate ground truths + obs_errors = observation_errors(V) + + continue_annotation() + set_working_tape() + + # background functional + J = prodB(control - bkg) + + # initial observation functional + J += prodR(obs_errors(0)(control)) + + qn.assign(control) + + # record observation stages + for i in range(1, len(observation_times)): + + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # observation functional + J += prodR(obs_errors(i)(qn)) + + pause_annotation() + + Jhat = ReducedFunctional(J, Control(control)) + + return Jhat + + +def strong_fdvar_firedrake(V): + """Build an AllAtOnceReducedFunctional for the strong constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # prior data + bkg = background(V) + control = bkg.copy(deepcopy=True) + + # generate ground truths + obs_errors = observation_errors(V) + + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = AllAtOnceReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=False) + + # record observation stages + with Jhat.recording_stages(nstages=len(observation_times)-1) as stages: + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + obs_index = stage.index + 1 + + # take observation + stage.set_observation(qn, obs_errors(obs_index), + observation_iprod=prodR) + + pause_annotation() + return Jhat + + +def weak_fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the weak constraint 4DVar system""" qn, qn1, stepper = timestepper(V) # One control for each observation time @@ -217,8 +311,8 @@ def fdvar_pyadjoint(V): return Jhat -def fdvar_firedrake(V, ensemble): - """Build an AllAtOnceReducedFunctional for the 4DVar system""" +def weak_fdvar_firedrake(V, ensemble): + """Build an AllAtOnceReducedFunctional for the weak constraint 4DVar system""" qn, qn1, stepper = timestepper(V) # One control for each observation time @@ -276,12 +370,34 @@ def fdvar_firedrake(V, ensemble): return Jhat -@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) -def test_advection(): - main_test_advection() +def main_test_strong_4dvar_advection(): + V = function_space(fd.COMM_WORLD) + + # setup the reference pyadjoint rf + Jhat_pyadj = strong_fdvar_pyadjoint(V) + mp = m(V)[0] + hp = h(V)[0] + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 -def main_test_advection(): + Jhat_aaorf = strong_fdvar_firedrake(V) + + ma = m(V)[0] + ha = h(V)[0] + + eps = 1e-12 + + # Does evaluating the functional match the reference rf? + assert abs(Jhat_pyadj(mp) - Jhat_aaorf(ma)) < eps + assert abs(Jhat_pyadj(hp) - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor test + # should mean that we match the derivative too. + assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + + +def main_test_weak_4dvar_advection(): global_comm = fd.COMM_WORLD if global_comm.size in (1, 2): # time serial nspace = global_comm.size @@ -297,7 +413,7 @@ def main_test_advection(): # only setup the reference pyadjoint rf on the first ensemble member if erank == 0: - Jhat_pyadj = fdvar_pyadjoint(V) + Jhat_pyadj = weak_fdvar_pyadjoint(V) mp = m(V) hp = h(V) # make sure we've set up the reference rf correctly @@ -306,7 +422,7 @@ def main_test_advection(): Jpm = ensemble.ensemble_comm.bcast(Jhat_pyadj(mp) if erank == 0 else None) Jph = ensemble.ensemble_comm.bcast(Jhat_pyadj(hp) if erank == 0 else None) - Jhat_aaorf = fdvar_firedrake(V, ensemble) + Jhat_aaorf = weak_fdvar_firedrake(V, ensemble) ma = m(V, ensemble) ha = h(V, ensemble) @@ -317,9 +433,19 @@ def main_test_advection(): assert abs(Jph - Jhat_aaorf(ha)) < eps # If we match the functional, then passing the taylor test - # should mean we match the derivative too. + # should mean that we match the derivative too. assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 +@pytest.mark.parallel(nprocs=[1, 2]) +def test_strong_4dvar_advection(): + main_test_strong_4dvar_advection() + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_weak_4dvar_advection(): + main_test_weak_4dvar_advection() + + if __name__ == '__main__': - main_test_advection() + main_test_strong_4dvar_advection() From 93fdc83e1a5ef7dff0cdbeb29dc1ac986cb33dd5 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Mon, 23 Dec 2024 11:11:10 +0000 Subject: [PATCH 14/23] rename aao_rf -> 4dvar_rf --- firedrake/adjoint/__init__.py | 2 +- ...functional.py => fourdvar_reduced_functional.py} | 12 ++++++------ .../regression/test_4dvar_reduced_functional.py | 13 ++++++------- 3 files changed, 13 insertions(+), 14 deletions(-) rename firedrake/adjoint/{all_at_once_reduced_functional.py => fourdvar_reduced_functional.py} (99%) diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index 08f6a3c2ec..9155a93c37 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,7 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 -from firedrake.adjoint.all_at_once_reduced_functional import AllAtOnceReducedFunctional # noqa F401 +from firedrake.adjoint.fourdvar_reduced_functional import FourDVarReducedFunctional # noqa F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py similarity index 99% rename from firedrake/adjoint/all_at_once_reduced_functional.py rename to firedrake/adjoint/fourdvar_reduced_functional.py index db7d23642a..046f5d3f2f 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from mpi4py import MPI -__all__ = ['AllAtOnceReducedFunctional'] +__all__ = ['FourDVarReducedFunctional'] # @set_working_tape() # ends up using old_tape = None because evaluates when imported - need separate decorator @@ -54,7 +54,7 @@ def sc_passthrough(func): is instantiated then passes args/kwargs through to the corresponding strong_reduced_functional method. - If using weak constraint, returns the AllAtOnceReducedFunctional + If using weak constraint, returns the FourDVarReducedFunctional method definition. """ @wraps(func) @@ -93,7 +93,7 @@ def _intermediate_options(final_options): } -class AllAtOnceReducedFunctional(ReducedFunctional): +class FourDVarReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. Creates either the strong constraint or weak constraint system @@ -586,7 +586,7 @@ def recording_stages(self, sequential=True, nstages=None, **stage_kwargs): class ObservationStageSequence: def __init__(self, controls: Control, - aaorf: AllAtOnceReducedFunctional, + aaorf: FourDVarReducedFunctional, global_index: int, observation_index: int, stage_kwargs: dict = None, @@ -663,12 +663,12 @@ class StrongObservationStage: ---------- aaorf - The strong constraint AllAtOnceReducedFunctional. + The strong constraint FourDVarReducedFunctional. """ def __init__(self, control: OverloadedType, - aaorf: AllAtOnceReducedFunctional, + aaorf: FourDVarReducedFunctional, index: Optional[int] = None, observation_index: Optional[int] = None): self.aaorf = aaorf diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py index 894a25849c..399f56dc7d 100644 --- a/tests/firedrake/regression/test_4dvar_reduced_functional.py +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -4,7 +4,7 @@ from firedrake.adjoint import ( continue_annotation, pause_annotation, stop_annotating, set_working_tape, get_working_tape, Control, taylor_test, - ReducedFunctional, AllAtOnceReducedFunctional) + ReducedFunctional, FourDVarReducedFunctional) @pytest.fixture(autouse=True) @@ -207,7 +207,7 @@ def strong_fdvar_pyadjoint(V): def strong_fdvar_firedrake(V): - """Build an AllAtOnceReducedFunctional for the strong constraint 4DVar system""" + """Build an FourDVarReducedFunctional for the strong constraint 4DVar system""" qn, qn1, stepper = timestepper(V) # prior data @@ -223,7 +223,7 @@ def strong_fdvar_firedrake(V): # create 4DVar reduced functional and record # background and initial observation functionals - Jhat = AllAtOnceReducedFunctional( + Jhat = FourDVarReducedFunctional( Control(control), background_iprod=prodB, observation_iprod=prodR, @@ -243,9 +243,8 @@ def strong_fdvar_firedrake(V): stepper.solve() qn.assign(qn1) - obs_index = stage.index + 1 - # take observation + obs_index = stage.observation_index stage.set_observation(qn, obs_errors(obs_index), observation_iprod=prodR) @@ -312,7 +311,7 @@ def weak_fdvar_pyadjoint(V): def weak_fdvar_firedrake(V, ensemble): - """Build an AllAtOnceReducedFunctional for the weak constraint 4DVar system""" + """Build an FourDVarReducedFunctional for the weak constraint 4DVar system""" qn, qn1, stepper = timestepper(V) # One control for each observation time @@ -338,7 +337,7 @@ def weak_fdvar_firedrake(V, ensemble): # create 4DVar reduced functional and record # background and initial observation functionals - Jhat = AllAtOnceReducedFunctional( + Jhat = FourDVarReducedFunctional( Control(control), background_iprod=prodB, observation_iprod=prodR, From a33651a652c80224f5099ee168c4d397f9c884ee Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 2 Jan 2025 14:04:05 +0000 Subject: [PATCH 15/23] aaorf - make sure that the user facing control tracks the internal control during observation recording stage --- .../adjoint/fourdvar_reduced_functional.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/firedrake/adjoint/fourdvar_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py index 046f5d3f2f..d54ff2410d 100644 --- a/firedrake/adjoint/fourdvar_reduced_functional.py +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -171,6 +171,9 @@ def __init__(self, control: Control, self.trank = ensemble.ensemble_comm.rank if ensemble else 0 self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 + # because we need to manually evaluate the different bits + # of the functional, we need an internal set of controls + # to use for the stage ReducedFunctionals self._cbuf = control.copy() _x = self._cbuf.subfunctions self._x = _x @@ -305,6 +308,7 @@ def __call__(self, values: OverloadedType): raise ValueError(f"Value must be of type {type(self.control.control)} not type {type(value)}") self.control.update(value) + # put the new value into our internal set of controls to pass to each stage self._cbuf.assign(value) trank = self.trank @@ -576,6 +580,10 @@ def recording_stages(self, sequential=True, nstages=None, **stage_kwargs): ectx.global_index = self.stages[-1].global_index ectx.observation_index = self.stages[-1].observation_index + # make sure that self.control now holds the + # values of the initial timeseris + self.control.assign(self._cbuf) + else: # strong constraint yield ObservationStageSequence( @@ -606,27 +614,29 @@ def __iter__(self): def __next__(self): - # increment global indices + # increment global indices. self.local_index += 1 self.global_index += 1 self.observation_index += 1 - # stop after we've recorded all stages - if self.local_index >= self.nstages: - raise StopIteration - if self.weak_constraint: stages = self.aaorf.stages - # start of the next stage + # control for the start of the next stage. next_control = self.controls[self.local_index] - # smuggle state forward and increment stage indices + # smuggle state forward into aaorf's next control. if self.local_index > 0: state = stages[-1].controls[1].control with stop_annotating(): next_control.control.assign(state) + # now we know that the aaorf's controls have + # been updated from the previous stage's controls, + # we can check if we need to exit. + if self.local_index >= self.nstages: + raise StopIteration + stage = WeakObservationStage(next_control, local_index=self.local_index, global_index=self.global_index, @@ -635,6 +645,10 @@ def __next__(self): else: # strong constraint + # stop after we've recorded all stages + if self.local_index >= self.nstages: + raise StopIteration + # dummy control to "start" stage from control = (self.aaorf.controls[0].control if self.local_index == 0 else self._prev_stage.state) From 37a46ac6b53ee05876f0f6ede992e620ab2a077f Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 2 Jan 2025 16:09:52 +0000 Subject: [PATCH 16/23] revert "fix" for previous unknown (jetlag induced) "bug" --- firedrake/assign.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/firedrake/assign.py b/firedrake/assign.py index 892ed5b1b0..8d1b30681a 100644 --- a/firedrake/assign.py +++ b/firedrake/assign.py @@ -100,9 +100,6 @@ def component_tensor(self, o, a, _): def coefficient(self, o): return ((o, 1),) - def cofunction(self, o): - return ((o, 1),) - def constant_value(self, o): return ((o, 1),) From 9a6f65d988c419d9cbfa8bb8f54a0d46af1fd4ea Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 2 Jan 2025 16:26:17 +0000 Subject: [PATCH 17/23] use SimpleNameSpace instead of handrolled context class --- firedrake/adjoint/fourdvar_reduced_functional.py | 9 ++------- firedrake/ensemble.py | 9 ++------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/firedrake/adjoint/fourdvar_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py index d54ff2410d..858b8aeaf3 100644 --- a/firedrake/adjoint/fourdvar_reduced_functional.py +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -6,6 +6,7 @@ from functools import wraps, cached_property from typing import Callable, Optional +from types import SimpleNamespace from contextlib import contextmanager from mpi4py import MPI @@ -601,7 +602,7 @@ def __init__(self, controls: Control, nstages: Optional[int] = None): self.controls = controls self.aaorf = aaorf - self.ctx = StageContext(**(stage_kwargs or {})) + self.ctx = SimpleNamespace(**(stage_kwargs or {})) self.weak_constraint = aaorf.weak_constraint self.global_index = global_index self.observation_index = observation_index @@ -663,12 +664,6 @@ def __next__(self): return stage, self.ctx -class StageContext: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - class StrongObservationStage: """ Record an observation for strong constraint 4DVar at the time of `state`. diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index 265ad8ba2b..a644e9e65c 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -1,6 +1,7 @@ import weakref from contextlib import contextmanager from itertools import zip_longest +from types import SimpleNamespace from firedrake.petsc import PETSc from firedrake.function import Function @@ -325,7 +326,7 @@ def sequential(self, **kwargs): kwargs[k] = self.ensemble_comm.recv( **recv_kwargs) - ctx = _EnsembleContext(**kwargs) + ctx = SimpleNamespace(**kwargs) yield ctx @@ -338,9 +339,3 @@ def sequential(self, **kwargs): self.send(v, **send_kwargs) else: self.ensemble_comm.send(v, **send_kwargs) - - -class _EnsembleContext: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) From 0a40dc854ae1fd366644a58db17cfbf919f05739 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 2 Jan 2025 17:13:11 +0000 Subject: [PATCH 18/23] aaorf - fix docstring --- firedrake/adjoint/fourdvar_reduced_functional.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/fourdvar_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py index 858b8aeaf3..6b902fb9b5 100644 --- a/firedrake/adjoint/fourdvar_reduced_functional.py +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -104,7 +104,8 @@ class FourDVarReducedFunctional(ReducedFunctional): ---------- control - The :class:`EnsembleFunction` for the control x_{i} at the initial condition and at the end of each observation stage. + The :class:`.EnsembleFunction` for the control x_{i} at the initial condition + and at the end of each observation stage. background_iprod The inner product to calculate the background error functional @@ -113,8 +114,8 @@ class FourDVarReducedFunctional(ReducedFunctional): background The background (prior) data for the initial condition :math:`x_{b}`. - If not provided, the value of the first Function on the first ensemble - of the EnsembleFunction will be used. + If not provided, the value of the first subfunction on the first ensemble + member of the control :class:`.EnsembleFunction` will be used. observation_err Given a state :math:`x`, returns the observations error From 15f6444402f3a7ecbbaf4f97e55340c70ee71851 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 3 Jan 2025 13:53:38 +0000 Subject: [PATCH 19/23] skip 4dvar tests using adjoint in complex mode --- tests/firedrake/regression/test_4dvar_reduced_functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py index 399f56dc7d..afaa9560da 100644 --- a/tests/firedrake/regression/test_4dvar_reduced_functional.py +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -436,11 +436,13 @@ def main_test_weak_4dvar_advection(): assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done @pytest.mark.parallel(nprocs=[1, 2]) def test_strong_4dvar_advection(): main_test_strong_4dvar_advection() +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done @pytest.mark.parallel(nprocs=[1, 2, 3, 4]) def test_weak_4dvar_advection(): main_test_weak_4dvar_advection() From 8e1c69ced63ddfe1c578194cf53eb3528489d48f Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Mon, 6 Jan 2025 11:30:07 +0000 Subject: [PATCH 20/23] 4DVarRF Hessian action --- .../adjoint/composite_reduced_functional.py | 173 ++++++++++++ .../adjoint/fourdvar_reduced_functional.py | 255 ++++++++++++------ firedrake/adjoint_utils/function.py | 4 +- firedrake/ensemble.py | 9 +- .../test_4dvar_reduced_functional.py | 21 +- 5 files changed, 374 insertions(+), 88 deletions(-) create mode 100644 firedrake/adjoint/composite_reduced_functional.py diff --git a/firedrake/adjoint/composite_reduced_functional.py b/firedrake/adjoint/composite_reduced_functional.py new file mode 100644 index 0000000000..a4eeb74445 --- /dev/null +++ b/firedrake/adjoint/composite_reduced_functional.py @@ -0,0 +1,173 @@ +from firedrake.adjoint import stop_annotating, get_working_tape +from pyadjoint.enlisting import Enlist + + +def intermediate_options(options): + """ + Options set for the intermediate stages of a chain of ReducedFunctionals + + Takes all elements of the options except riesz_representation, which + is set to None to prevent returning derivatives to the primal space. + """ + return { + **{k: v for k, v in (options or {}).items() + if k != 'riesz_representation'}, + 'riesz_representation': None + } + + +def compute_tlm(J, m, m_dot, options=None, tape=None): + """ + Compute the tangent linear model of J in a direction m_dot at the current value of m + + Args: + J (OverloadedType): The objective functional. + m (list or instance of Control): The (list of) controls. + m_dot (list or instance of the control type): The direction in which to compute the Hessian. + options (dict): A dictionary of options. To find a list of available options + have a look at the specific control type. + tape: The tape to use. Default is the current tape. + + Returns: + OverloadedType: The tangent linear with respect to the control in direction m_dot. + Should be an instance of the same type as the control. + """ + tape = tape or get_working_tape() + + # reset tlm values + tape.reset_tlm_values() + + m = Enlist(m) + m_dot = Enlist(m_dot) + + # set initial tlm values + for mi, mdi in zip(m, m_dot): + mi.tlm_value = mdi + + # evaluate tlm + with stop_annotating(): + with tape.marked_nodes(m): + tape.evaluate_tlm(markings=True) + + # return functional's tlm + return J._ad_convert_type(J.block_variable.tlm_value, + options=options or {}) + + +def compute_hessian(J, m, options=None, tape=None, hessian_value=0.): + """ + Compute the Hessian of J at the current value of m with the current tlm values on the tape. + + Args: + J (OverloadedType): The objective functional. + m (list or instance of Control): The (list of) controls. + options (dict): A dictionary of options. To find a list of available options + have a look at the specific control type. + tape: The tape to use. Default is the current tape. + + Returns: + OverloadedType: The second derivative with respect to the control in direction m_dot. + Should be an instance of the same type as the control. + """ + tape = tape or get_working_tape() + + # reset hessian values + tape.reset_hessian_values() + + m = Enlist(m) + + # set initial hessian_value + J.block_variable.hessian_value = J._ad_convert_type( + hessian_value, options=intermediate_options(options)) + + # evaluate hessian + with stop_annotating(): + with tape.marked_nodes(m): + tape.evaluate_hessian(markings=True) + + # return controls' hessian values + return m.delist([v.get_hessian(options=options or {}) for v in m]) + + +def tlm(rf, m_dot, options=None): + """Returns the action of the tangent linear model of the functional w.r.t. the control on a vector m_dot. + + Args: + m_dot ([OverloadedType]): The direction in which to compute the + action of the tangent linear model. + options (dict): A dictionary of options. To find a list of + available options have a look at the specific control type. + + Returns: + OverloadedType: The action of the tangent linear model in the direction m_dot. + Should be an instance of the same type as the control. + """ + return compute_tlm(rf.functional, rf.controls, m_dot, + tape=rf.tape, options=options) + + +def hessian(rf, options=None, hessian_value=0.): + """Returns the action of the Hessian of the functional w.r.t. the control. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control and the last tlm values, is computed and returned. + + Args: + options (dict): A dictionary of options. To find a list of + available options have a look at the specific control type. + hessian_value: The Hessian value to initialise the accumulation + from the functional block variable. + + Returns: + OverloadedType: The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + """ + return rf.controls.delist( + compute_hessian(rf.functional, rf.controls, + tape=rf.tape, options=options, + hessian_value=hessian_value)) + + +class CompositeReducedFunctional: + def __init__(self, rf1, rf2): + self.rf1 = rf1 + self.rf2 = rf2 + + def __call__(self, values): + return self.rf2(self.rf1(values)) + + def derivative(self, adj_input=1.0, options=None): + deriv2 = self.rf2.derivative( + adj_input=adj_input, options=intermediate_options(options)) + deriv1 = self.rf1.derivative( + adj_input=deriv2, options=options or {}) + return deriv1 + + def tlm(self, m_dot, options=None): + tlm1 = self._eval_tlm( + self.rf1, m_dot, intermediate_options(options)), + tlm2 = self._eval_tlm( + self.rf2, tlm1, options) + return tlm2 + + def hessian(self, m_dot, options=None, evaluate_tlm=True): + if evaluate_tlm: + self.tlm(m_dot, options=intermediate_options(options)) + hess2 = self._eval_hessian( + self.rf2, 0., intermediate_options(options)) + hess1 = self._eval_hessian( + self.rf1, hess2, options or {}) + return hess1 + + def _eval_tlm(self, rf, m_dot, options): + if isinstance(rf, CompositeReducedFunctional): + return rf.tlm(m_dot, options=options) + else: + return tlm(rf, m_dot=m_dot, options=options) + + def _eval_hessian(self, rf, hessian_value, options): + if isinstance(rf, CompositeReducedFunctional): + return rf.hessian(None, options, evaluate_tlm=False) + else: + return hessian(rf, hessian_value=hessian_value, options=options) diff --git a/firedrake/adjoint/fourdvar_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py index 6b902fb9b5..dbd7553778 100644 --- a/firedrake/adjoint/fourdvar_reduced_functional.py +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -3,6 +3,8 @@ from pyadjoint.enlisting import Enlist from firedrake.function import Function from firedrake.ensemblefunction import EnsembleFunction, EnsembleCofunction +from firedrake.adjoint.composite_reduced_functional import ( + CompositeReducedFunctional, tlm, hessian, intermediate_options) from functools import wraps, cached_property from typing import Callable, Optional @@ -21,27 +23,23 @@ def isolated_rf(operation, control, Return a ReducedFunctional where the functional is `operation` applied to a copy of `control`, and the tape contains only `operation`. """ - with set_working_tape(): + with stop_annotating(): controls = Enlist(control) - control_names = Enlist(control_name) - - with stop_annotating(): - control_copies = [control._ad_copy() for control in controls] + control_copies = [control._ad_copy() for control in controls] - if control_names: - for control, name in zip(control_copies, control_names): - _rename(control, name) + if control_name: + for control, name in zip(control_copies, Enlist(control_name)): + _rename(control, name) - if len(control_copies) == 1: - functional = operation(control_copies[0]) - control = Control(control_copies[0]) - else: - functional = operation(control_copies) - control = [Control(control) for control in control_copies] + with set_working_tape(): + functional = operation(controls.delist(control_copies)) if functional_name: _rename(functional, functional_name) + control = controls.delist([Control(control_copy) + for control_copy in control_copies]) + return ReducedFunctional( functional, control) @@ -80,20 +78,6 @@ def _ad_sub(left, right): return result -def _intermediate_options(final_options): - """ - Options set for the intermediate stages of a chain of ReducedFunctionals - - Takes all elements of the final_options except riesz_representation, - which is set to prevent returning derivatives to the primal space. - """ - return { - 'riesz_representation': None, - **{k: v for k, v in final_options.items() - if (k != 'riesz_representation')} - } - - class FourDVarReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. @@ -176,7 +160,7 @@ def __init__(self, control: Control, # because we need to manually evaluate the different bits # of the functional, we need an internal set of controls # to use for the stage ReducedFunctionals - self._cbuf = control.copy() + self._cbuf = control.copy_data() _x = self._cbuf.subfunctions self._x = _x self._controls = tuple(Control(xi) for xi in _x) @@ -205,6 +189,10 @@ def __init__(self, control: Control, control=self.background_error.functional, control_name="bkg_err_vec_copy") + # compose background reduced functionals to evaluate both together + self.background_rf = CompositeReducedFunctional( + self.background_error, self.background_norm) + if self.initial_observations: # RF to recalculate error vector (H(x_0) - y_0) @@ -220,6 +208,10 @@ def __init__(self, control: Control, control=self.initial_observation_error.functional, functional_name="obs_err_vec_0_copy") + # compose initial observation reduced functionals to evaluate both together + self.initial_observation_rf = CompositeReducedFunctional( + self.initial_observation_error, self.initial_observation_norm) + # create halo for previous state if self.ensemble and self.trank != 0: with stop_annotating(): @@ -336,15 +328,11 @@ def __call__(self, values: OverloadedType): # Initial condition functionals if trank == 0: - Jlocal = ( - self.background_norm( - self.background_error(x[0]))) + Jlocal = self.background_rf(x[0]) # observations at time 0 if self.initial_observations: - Jlocal += ( - self.initial_observation_norm( - self.initial_observation_error(x[0]))) + Jlocal += self.initial_observation_rf(x[0]) else: Jlocal = 0. @@ -393,7 +381,6 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - intermediate_options = _intermediate_options(options) # evaluate first forward model, which contributes to previous chunk sderiv0 = self.stages[0].derivative( @@ -444,19 +431,13 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # initial condition derivatives if trank == 0: - bkg_deriv = self.background_norm.derivative( - adj_input=adj_input, options=intermediate_options) - - derivs[0] += self.background_error.derivative( - adj_input=bkg_deriv, options=options) + derivs[0] += self.background_rf.derivative( + adj_input=adj_input, options=options) # observations at time 0 if self.initial_observations: - obs_deriv = self.initial_observation_norm.derivative( - adj_input=adj_input, options=intermediate_options) - - derivs[0] += self.initial_observation_error.derivative( - adj_input=obs_deriv, options=options) + derivs[0] += self.initial_observation_rf.derivative( + adj_input=adj_input, options=options) # # evaluate all forward models on chunk except first while halo in flight for i in range(1, len(self.stages)): @@ -505,7 +486,84 @@ def hessian(self, m_dot: OverloadedType, options: dict = {}): The action of the Hessian in the direction m_dot. Should be an instance of the same type as the control. """ - raise ValueError("Not implemented yet") + trank = self.trank + + hess = self.control.copy_data() + hess.zero() + + # set up arrays including halos + if trank == 0: + hs = [*hess.subfunctions] + mdot = [*m_dot[0].subfunctions] + else: + hprev = hess.subfunctions[0].copy(deepcopy=True) + mprev = m_dot[0].subfunctions[0].copy(deepcopy=True) + hs = [hprev, *hess.subfunctions] + mdot = [mprev, *m_dot[0].subfunctions] + + if trank != self.nchunks - 1: + hnext = hess.subfunctions[0].copy(deepcopy=True) + + # send m_dot halo forward + if self.ensemble: + src = trank - 1 + dst = trank + 1 + + if trank != self.nchunks - 1: + self.ensemble.isend( + mdot[-1], dest=dst, tag=dst) + + if trank != 0: + recv_reqs = self.ensemble.irecv( + mdot[0], source=src, tag=trank) + + # hessian actions at the initial condition + if trank == 0: + hs[0] += self.background_rf.hessian( + mdot[0], options=options) + + if self.initial_observations: + hs[0] += self.initial_observation_rf.hessian( + mdot[0], options=options) + + # evaluate all stages on chunk except first + for i in range(1, len(self.stages)): + hms = self.stages[i].hessian( + mdot[i:i+2], options=options) + + hs[i] += hms[0] + hs[i+1] += hms[1] + + # wait for halo swap to finish + if trank != 0: + MPI.Request.Waitall(recv_reqs) + + # evaluate first stage on chunk now we have the halo + hms = self.stages[0].hessian( + mdot[:2], options=options) + + hs[0] += hms[0] + hs[1] += hms[1] + + # send result halo backward + if self.ensemble: + src = trank + 1 + dst = trank - 1 + + if trank != 0: + self.ensemble.isend( + hs[0], dest=dst, tag=dst) + + if trank != self.nchunks - 1: + recv_reqs = self.ensemble.irecv( + hnext, source=src, tag=trank) + + # finish the result halo + if trank != self.nchunks - 1: + MPI.Request.Waitall(recv_reqs) + hs[-1] += hnext + + return hess @stop_annotating() def hessian_matrix(self): @@ -836,6 +894,10 @@ def set_observation(self, state: OverloadedType, control=self.model_error.functional, **names) + # compose model error reduced functionals to evaluate both together + self.model_error_rf = CompositeReducedFunctional( + self.model_error, self.model_norm) + # Observations after tape cut because this is now a control, not a state # RF to recalculate error vector (H(x_i) - y_i) @@ -858,6 +920,10 @@ def set_observation(self, state: OverloadedType, control=self.observation_error.functional, **names) + # compose observation reduced functionals to evaluate both together + self.observation_rf = CompositeReducedFunctional( + self.observation_error, self.observation_norm) + # remove the stage initial condition "control" now we've finished recording delattr(self, "control") @@ -893,13 +959,13 @@ def __call__(self, values: OverloadedType, J = 0.0 # evaluate model error - if (rftype is None) or (rftype == 'model'): - Mi = self.forward_model(values[0]) - J += self.model_norm(self.model_error([Mi, values[1]])) + if rftype in (None, 'model'): + J += self.model_error_rf( + [self.forward_model(values[0]), values[1]]) # evaluate observation errors - if (rftype is None) or (rftype == 'obs'): - J += self.observation_norm(self.observation_error(values[1])) + if rftype in (None, 'obs'): + J += self.observation_rf(values[1]) return J @@ -936,40 +1002,34 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - intermediate_options = _intermediate_options(options) + ioptions = intermediate_options(options) - if (rftype is None) or (rftype == 'model'): - # derivative of reduction - dm_norm = self.model_norm.derivative(adj_input=adj_input, - options=intermediate_options) - - # derivative of difference splits into (Mi, xi) - dm_errors = self.model_error.derivative(adj_input=dm_norm, - options=intermediate_options) + if rftype in (None, 'model'): + # derivative of reduction and difference + model_err_derivs = self.model_error_rf.derivative( + adj_input=adj_input, options=ioptions) # derivative through the forward model wrt to xprev - dm_forward = self.forward_model.derivative(adj_input=dm_errors[0], - options=options) + model_forward_deriv = self.forward_model.derivative( + adj_input=model_err_derivs[0], options=options) - derivatives.append(dm_forward) + derivatives.append(model_forward_deriv) - # dm_errors is still in the dual space, so we need to convert it to the - # type that the user has requested - this will be the type of dm_forward. - derivatives.append(dm_forward._ad_convert_type(dm_errors[1], options)) + # model_err_derivs is still in the dual space, so we need to convert it to the + # type that the user has requested - this will be the type of model_forward_deriv. + derivatives.append( + model_forward_deriv._ad_convert_type( + model_err_derivs[1], options)) - if (rftype is None) or (rftype == 'obs'): - # derivative of reduction - do_norm = self.observation_norm.derivative(adj_input=adj_input, - options=intermediate_options) - # derivative of error - do_error = self.observation_error.derivative(adj_input=do_norm, - options=options) + if rftype in (None, 'obs'): + obs_deriv = self.observation_rf.derivative( + adj_input=adj_input, options=options) if len(derivatives) == 0: derivatives.append(None) - derivatives.append(do_error) + derivatives.append(obs_deriv) else: - derivatives[1] += do_error + derivatives[1] += obs_deriv return derivatives @@ -1004,4 +1064,45 @@ def hessian(self, m_dot: OverloadedType, options: dict = {}, The action of the Hessian in the direction m_dot. Should be an instance of the same type as the control. """ - pass + hessian_value = [] + + if rftype in (None, 'model'): + hessian_value.extend(self._model_hessian( + m_dot, options=options)) + + if rftype in (None, 'obs'): + obs_hessian = self.observation_rf.hessian( + m_dot[1], options=options) + if len(hessian_value) == 0: + hessian_value.append(None) + hessian_value.append(obs_hessian) + else: + hessian_value[1] += obs_hessian + + return hessian_value + + def _model_hessian(self, m_dot, options): + iopts = intermediate_options(options) + + # TLM for model from mdot[0] + forward_tlm = tlm(self.forward_model, m_dot[0], + options=iopts) + + # combine model TLM and mdot[1] + mdot_error = [forward_tlm, m_dot[1]] + + # Hessian (dual) for error + error_hessian = self.model_error_rf.hessian( + mdot_error, options=iopts, evaluate_tlm=True) + + # Hessian for model + model_hessian = hessian( + self.forward_model, options=options, + hessian_value=error_hessian[0]) + + # combine model Hessian and converted error Hessian + return [ + model_hessian, + model_hessian._ad_convert_type(error_hessian[1], + options=options) + ] diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 5e87751d36..e69e5045e9 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -233,14 +233,14 @@ def _ad_convert_riesz(self, value, options=None): return Function(V) if not isinstance(value, (Cofunction, Function)): - raise TypeError("Expected a Cofunction or a Function") + raise TypeError(f"Expected a Cofunction or a Function not a {type(value)}") if riesz_representation == "l2": return Function(V, val=value.dat) elif riesz_representation in ("L2", "H1"): if not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") + raise TypeError(f"Expected a Cofunction not a {type(value)}") ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index a644e9e65c..53049fe7d7 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -289,7 +289,7 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r return requests @contextmanager - def sequential(self, **kwargs): + def sequential(self, synchronise=False, **kwargs): """ Context manager for executing code on each ensemble member consecutively by `ensemble_comm.rank`. @@ -328,7 +328,12 @@ def sequential(self, **kwargs): ctx = SimpleNamespace(**kwargs) - yield ctx + if synchronise: + self.global_comm.Barrier() + yield ctx + self.global_comm.Barrier() + else: + yield ctx if not last_rank: dst = rank + 1 diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py index 399f56dc7d..a8bc0486a3 100644 --- a/tests/firedrake/regression/test_4dvar_reduced_functional.py +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -3,8 +3,9 @@ from firedrake.__future__ import interpolate from firedrake.adjoint import ( continue_annotation, pause_annotation, stop_annotating, - set_working_tape, get_working_tape, Control, taylor_test, + set_working_tape, get_working_tape, Control, taylor_test, taylor_to_dict, ReducedFunctional, FourDVarReducedFunctional) +from numpy import mean @pytest.fixture(autouse=True) @@ -52,7 +53,7 @@ def tendency(q, phi): def prod2(w): """generate weighted inner products to pass to FourDVarReducedFunctional""" def n2(x): - return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx) + return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx)**2 return n2 @@ -391,9 +392,12 @@ def main_test_strong_4dvar_advection(): assert abs(Jhat_pyadj(mp) - Jhat_aaorf(ma)) < eps assert abs(Jhat_pyadj(hp) - Jhat_aaorf(ha)) < eps - # If we match the functional, then passing the taylor test + # If we match the functional, then passing the taylor tests # should mean that we match the derivative too. - assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + taylor = taylor_to_dict(Jhat_aaorf, ma, ha) + assert mean(taylor['R0']['Rate']) > 0.9 + assert mean(taylor['R1']['Rate']) > 1.9 + assert mean(taylor['R2']['Rate']) > 2.9 def main_test_weak_4dvar_advection(): @@ -426,14 +430,17 @@ def main_test_weak_4dvar_advection(): ma = m(V, ensemble) ha = h(V, ensemble) - eps = 1e-12 + eps = 1e-10 # Does evaluating the functional match the reference rf? assert abs(Jpm - Jhat_aaorf(ma)) < eps assert abs(Jph - Jhat_aaorf(ha)) < eps - # If we match the functional, then passing the taylor test + # If we match the functional, then passing the taylor tests # should mean that we match the derivative too. - assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + taylor = taylor_to_dict(Jhat_aaorf, ma, ha) + assert mean(taylor['R0']['Rate']) > 0.9 + assert mean(taylor['R1']['Rate']) > 1.9 + assert mean(taylor['R2']['Rate']) > 2.9 @pytest.mark.parallel(nprocs=[1, 2]) From d8db9da3225f42e616a05b24d40469f9bf460367 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 7 Jan 2025 14:52:59 +0000 Subject: [PATCH 21/23] TO REVERT: pyadjoint branch --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0eb616c24d..f6beb41c4d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,6 +84,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch pyadjoint JHopeCollins/mark_evaluate_tlm \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From 01e52c97e12701a8cffd5eb5b30e94576f9bf37b Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 7 Jan 2025 15:31:32 +0000 Subject: [PATCH 22/23] docstrings for composite and 4dvar rfs --- .../adjoint/composite_reduced_functional.py | 249 ++++++++++++++---- .../adjoint/fourdvar_reduced_functional.py | 12 +- 2 files changed, 210 insertions(+), 51 deletions(-) diff --git a/firedrake/adjoint/composite_reduced_functional.py b/firedrake/adjoint/composite_reduced_functional.py index a4eeb74445..d09dfebe9f 100644 --- a/firedrake/adjoint/composite_reduced_functional.py +++ b/firedrake/adjoint/composite_reduced_functional.py @@ -1,13 +1,25 @@ -from firedrake.adjoint import stop_annotating, get_working_tape +from pyadjoint import stop_annotating, get_working_tape, OverloadedType, Control, Tape, ReducedFunctional from pyadjoint.enlisting import Enlist +from typing import Optional -def intermediate_options(options): +def intermediate_options(options: dict): """ Options set for the intermediate stages of a chain of ReducedFunctionals Takes all elements of the options except riesz_representation, which is set to None to prevent returning derivatives to the primal space. + + Parameters + ---------- + options + The dictionary of options provided by the user + + Returns + ------- + dict + The options for ReducedFunctionals at intermediate stages + """ return { **{k: v for k, v in (options or {}).items() @@ -16,21 +28,36 @@ def intermediate_options(options): } -def compute_tlm(J, m, m_dot, options=None, tape=None): +def compute_tlm(J: OverloadedType, + m: Control, + m_dot: OverloadedType, + options: Optional[dict] = None, + tape: Optional[Tape] = None): """ Compute the tangent linear model of J in a direction m_dot at the current value of m - Args: - J (OverloadedType): The objective functional. - m (list or instance of Control): The (list of) controls. - m_dot (list or instance of the control type): The direction in which to compute the Hessian. - options (dict): A dictionary of options. To find a list of available options - have a look at the specific control type. - tape: The tape to use. Default is the current tape. + Parameters + ---------- + + J + The objective functional. + m + The (list of) :class:`pyadjoint.Control` for the functional. + m_dot + The direction in which to compute the Hessian. + Must be a (list of) :class:`pyadjoint.OverloadedType`. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + tape + The tape to use. Default is the current tape. + + Returns + ------- + pyadjoint.OverloadedType + The tangent linear with respect to the control in direction m_dot. + Should be an instance of the same type as the control. - Returns: - OverloadedType: The tangent linear with respect to the control in direction m_dot. - Should be an instance of the same type as the control. """ tape = tape or get_working_tape() @@ -54,20 +81,34 @@ def compute_tlm(J, m, m_dot, options=None, tape=None): options=options or {}) -def compute_hessian(J, m, options=None, tape=None, hessian_value=0.): +def compute_hessian(J: OverloadedType, + m: Control, + options: Optional[dict] = None, + tape: Optional[Tape] = None, + hessian_value: Optional[OverloadedType] = 0.): """ Compute the Hessian of J at the current value of m with the current tlm values on the tape. - Args: - J (OverloadedType): The objective functional. - m (list or instance of Control): The (list of) controls. - options (dict): A dictionary of options. To find a list of available options - have a look at the specific control type. - tape: The tape to use. Default is the current tape. + Parameters + ---------- + J + The objective functional. + m + The (list of) :class:`pyadjoint.Control` for the functional. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + tape + The tape to use. Default is the current tape. + hessian_value + The initial hessian_value to start accumulating from. + + Returns + ------- + pyadjoint.OverloadedType + The second derivative with respect to the control in direction m_dot. + Should be an instance of the same type as the control. - Returns: - OverloadedType: The second derivative with respect to the control in direction m_dot. - Should be an instance of the same type as the control. """ tape = tape or get_working_tape() @@ -89,39 +130,56 @@ def compute_hessian(J, m, options=None, tape=None, hessian_value=0.): return m.delist([v.get_hessian(options=options or {}) for v in m]) -def tlm(rf, m_dot, options=None): +def tlm(rf: ReducedFunctional, + m_dot: OverloadedType, + options: Optional[dict] = None): """Returns the action of the tangent linear model of the functional w.r.t. the control on a vector m_dot. - Args: - m_dot ([OverloadedType]): The direction in which to compute the - action of the tangent linear model. - options (dict): A dictionary of options. To find a list of - available options have a look at the specific control type. + Parameters + ---------- + rf + The :class:`pyadjoint.ReducedFunctional` to evaluate the tlm of. + m_dot + The direction in which to compute the action of the tangent linear model. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + Returns + ------- + pyadjoint.OverloadedType + The action of the tangent linear model in the direction m_dot. + Should be an instance of the same type as the control. - Returns: - OverloadedType: The action of the tangent linear model in the direction m_dot. - Should be an instance of the same type as the control. """ return compute_tlm(rf.functional, rf.controls, m_dot, tape=rf.tape, options=options) -def hessian(rf, options=None, hessian_value=0.): +def hessian(rf: ReducedFunctional, + options: Optional[dict] = None, + hessian_value: Optional[OverloadedType] = 0.): """Returns the action of the Hessian of the functional w.r.t. the control. Using the second-order adjoint method, the action of the Hessian of the functional with respect to the control, around the last supplied value of the control and the last tlm values, is computed and returned. - Args: - options (dict): A dictionary of options. To find a list of - available options have a look at the specific control type. - hessian_value: The Hessian value to initialise the accumulation - from the functional block variable. + Parameters + ---------- + rf + The :class:`pyadjoint.ReducedFunctional` to evaluate the tlm of. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + hessian_value + The initial hessian_value to start accumulating from. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian. Should be an instance of the same type as the control. - Returns: - OverloadedType: The action of the Hessian in the direction m_dot. - Should be an instance of the same type as the control. """ return rf.controls.delist( compute_hessian(rf.functional, rf.controls, @@ -130,28 +188,129 @@ def hessian(rf, options=None, hessian_value=0.): class CompositeReducedFunctional: + """Class representing the composition of two reduced functionals. + + For two reduced functionals J1: X->Y and J2: Y->Z, this is a convenience + class representing the composition J12: X->Z = J2(J1(x)) and providing + methods for the evaluation, derivative, tlm, and hessian action of J12. + + Parameters + ---------- + rf1 + The first :class:`pyadjoint.ReducedFunctional` in the composition. + rf2 + The second :class:`pyadjoint.ReducedFunctional` in the composition. + The control for rf2 must have the same type as the functional of rf1. + + """ def __init__(self, rf1, rf2): self.rf1 = rf1 self.rf2 = rf2 - def __call__(self, values): + def __call__(self, values: OverloadedType): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ return self.rf2(self.rf1(values)) - def derivative(self, adj_input=1.0, options=None): + def derivative(self, adj_input: Optional[float] = 1.0, options: Optional[dict] = None): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + + """ deriv2 = self.rf2.derivative( adj_input=adj_input, options=intermediate_options(options)) deriv1 = self.rf1.derivative( adj_input=deriv2, options=options or {}) return deriv1 - def tlm(self, m_dot, options=None): + def tlm(self, m_dot: OverloadedType, options: Optional[dict] = None): + """Returns the action of the tangent linear model of the functional w.r.t. the control on a vector m_dot. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + + """ tlm1 = self._eval_tlm( self.rf1, m_dot, intermediate_options(options)), tlm2 = self._eval_tlm( self.rf2, tlm1, options) return tlm2 - def hessian(self, m_dot, options=None, evaluate_tlm=True): + def hessian(self, m_dot: OverloadedType, + options: Optional[dict] = None, + evaluate_tlm: Optional[bool] = True): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + evaluate_tlm + If True, the tlm values on the tape will be reset and evaluated before + the Hessian action is evaluated. If False, the existing tlm values on + the tape will be used. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + + """ if evaluate_tlm: self.tlm(m_dot, options=intermediate_options(options)) hess2 = self._eval_hessian( diff --git a/firedrake/adjoint/fourdvar_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py index dbd7553778..82effbf514 100644 --- a/firedrake/adjoint/fourdvar_reduced_functional.py +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -82,7 +82,7 @@ class FourDVarReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. Creates either the strong constraint or weak constraint system - by logging observations through the initial forward model run. + by logging observations through the initial time propagator run. Parameters ---------- @@ -313,7 +313,7 @@ def __call__(self, values: OverloadedType): else: x = [*self._x] - # post messages for control of forward model propogation on next chunk + # post messages for control of time propagator on next chunk if self.ensemble: src = trank - 1 dst = trank + 1 @@ -382,7 +382,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - # evaluate first forward model, which contributes to previous chunk + # evaluate first time propagator, which contributes to previous chunk sderiv0 = self.stages[0].derivative( adj_input=adj_input, options=options) @@ -439,7 +439,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): derivs[0] += self.initial_observation_rf.derivative( adj_input=adj_input, options=options) - # # evaluate all forward models on chunk except first while halo in flight + # # evaluate all time stages on chunk except first while halo in flight for i in range(1, len(self.stages)): sderiv = self.stages[i].derivative( adj_input=adj_input, options=options) @@ -781,7 +781,7 @@ class WeakObservationStage: """ A single stage for weak constraint 4DVar at the time of `state`. - Records the forward model propogation from the control at the beginning + Records the time propagator from the control at the beginning of the stage, and the model and observation errors at the end of the stage. Parameters @@ -1009,7 +1009,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, model_err_derivs = self.model_error_rf.derivative( adj_input=adj_input, options=ioptions) - # derivative through the forward model wrt to xprev + # derivative through the time propagator wrt to xprev model_forward_deriv = self.forward_model.derivative( adj_input=model_err_derivs[0], options=options) From 021506d8042e443859cab326d438035ce5b98fe5 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 7 Jan 2025 16:41:40 +0000 Subject: [PATCH 23/23] add markings argument to FunctionMergeBlock.evaluate_tlm --- firedrake/adjoint_utils/blocks/function.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index e31a0c4567..dcb02da108 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -242,11 +242,13 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, else: return adj_inputs[0] - def evaluate_tlm(self): + def evaluate_tlm(self, markings=False): tlm_input = self.get_dependencies()[0].tlm_value if tlm_input is None: return output = self.get_outputs()[0] + if markings and not output.marked_in_path: + return fs = output.output.function_space() f = type(output.output)(fs) output.add_tlm_output(