From 7e228955861a33e190bc88034ca14be775e3a73b Mon Sep 17 00:00:00 2001 From: jshipton Date: Tue, 17 Dec 2024 15:57:03 +0000 Subject: [PATCH] start to move parareal over --- .../time_discretisation.py | 33 +++++- gusto/timestepping/__init__.py | 3 +- gusto/timestepping/parareal.py | 101 ++++++++++++++++++ gusto/timestepping/timestepper.py | 29 +---- 4 files changed, 136 insertions(+), 30 deletions(-) create mode 100644 gusto/timestepping/parareal.py diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index d7b4d4aab..900b42b6d 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -12,13 +12,14 @@ Constant, NonlinearVariationalProblem, NonlinearVariationalSolver) from firedrake.fml import (replace_subject, replace_test_function, Term, - all_terms, drop) + all_terms, drop, LabelledForm) from firedrake.formmanipulation import split_form from firedrake.utils import cached_property from gusto.core.configuration import EmbeddedDGOptions, RecoveryOptions from gusto.core.labels import (time_derivative, prognostic, physics_label, - mass_weighted, nonlinear_time_derivative) + mass_weighted, nonlinear_time_derivative, + transporting_velocity) from gusto.core.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.time_discretisation.wrappers import * @@ -324,6 +325,34 @@ def setup(self, equation, apply_bcs=True, *active_labels): self.x_out = Function(self.fs) self.x1 = Function(self.fs) + def setup_transporting_velocity(self, uadv): + self.residual = self.residual.label_map( + lambda t: t.has_label(transporting_velocity), + map_if_true=lambda t: + Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}), + t.labels) + ) + + self.residual = transporting_velocity.update_value(self.residual, uadv) + # Now also replace transporting velocity in the terms that are + # contained in labels + for idx, t in enumerate(self.residual.terms): + if t.has_label(transporting_velocity): + for label in t.labels.keys(): + if type(t.labels[label]) is LabelledForm: + t.labels[label] = t.labels[label].label_map( + lambda s: s.has_label(transporting_velocity), + map_if_true=lambda s: + Term(ufl.replace( + s.form, + {s.get(transporting_velocity): uadv}), + s.labels + ) + ) + + self.residual.terms[idx].labels[label] = \ + transporting_velocity.update_value(t.labels[label], uadv) + @property def nlevels(self): return 1 diff --git a/gusto/timestepping/__init__.py b/gusto/timestepping/__init__.py index cb216fa33..31a78b513 100644 --- a/gusto/timestepping/__init__.py +++ b/gusto/timestepping/__init__.py @@ -1,3 +1,4 @@ from gusto.timestepping.timestepper import * # noqa +from gusto.timestepping.parareal import * # noqa from gusto.timestepping.split_timestepper import * # noqa -from gusto.timestepping.semi_implicit_quasi_newton import * # noqa \ No newline at end of file +from gusto.timestepping.semi_implicit_quasi_newton import * # noqa diff --git a/gusto/timestepping/parareal.py b/gusto/timestepping/parareal.py new file mode 100644 index 000000000..6e16141da --- /dev/null +++ b/gusto/timestepping/parareal.py @@ -0,0 +1,101 @@ +from firedrake import Function +from gusto.core.fields import Fields + + +class PararealFields(object): + + def __init__(self, equation, nlevels): + levels = [str(n) for n in range(nlevels+1)] + self.add_fields(equation, levels) + + def add_fields(self, equation, levels): + if levels is None: + levels = self.levels + for level in levels: + try: + x = getattr(self, level) + x.add_field(equation.field_name, equation.function_space) + except AttributeError: + setattr(self, level, Fields(equation)) + + def __call__(self, n): + return getattr(self, str(n)) + + +class Parareal(object): + + def __init__(self, domain, coarse_scheme, fine_scheme, nG, nF, + n_intervals, max_its): + + assert coarse_scheme.nlevels == 1 + assert fine_scheme.nlevels == 1 + self.nlevels = 1 + + self.coarse_scheme = coarse_scheme + self.coarse_scheme.dt.assign(domain.dt/n_intervals/nG) + self.fine_scheme = fine_scheme + self.fine_scheme.dt.assign(domain.dt/n_intervals/nG) + self.nG = nG + self.nF = nF + self.n_intervals = n_intervals + self.max_its = max_its + + def setup(self, equation, apply_bcs=True, *active_labels): + self.coarse_scheme.fixed_subcycles = self.nG + self.coarse_scheme.setup(equation, apply_bcs, *active_labels) + self.fine_scheme.fixed_subcycles = self.nF + self.fine_scheme.setup(equation, apply_bcs, *active_labels) + self.x = PararealFields(equation, self.n_intervals) + self.xF = PararealFields(equation, self.n_intervals) + self.xn = Function(equation.function_space) + self.xGk = PararealFields(equation, self.n_intervals) + self.xGkm1 = PararealFields(equation, self.n_intervals) + self.xFn = Function(equation.function_space) + self.xFnp1 = Function(equation.function_space) + self.name = equation.field_name + + def setup_transporting_velocity(self, uadv): + self.coarse_scheme.setup_transporting_velocity(uadv) + self.fine_scheme.setup_transporting_velocity(uadv) + + def apply(self, x_out, x_in): + + self.xn.assign(x_in) + x0 = self.x(0)(self.name) + x0.assign(x_in) + xF0 = self.xF(0)(self.name) + xF0.assign(x_in) + + # compute first guess from coarse scheme + for n in range(self.n_intervals): + print("computing first coarse guess for interval: ", n) + # apply coarse scheme and save data as initial conditions for fine + xGnp1 = self.xGkm1(n+1)(self.name) + self.coarse_scheme.apply(xGnp1, self.xn) + xnp1 = self.x(n+1)(self.name) + xnp1.assign(xGnp1) + self.xn.assign(xnp1) + + for k in range(self.max_its): + + # apply fine scheme in each interval using previously + # calculated coarse data + for n in range(k, self.n_intervals): + print("computing fine guess for iteration and interval: ", k, n) + self.xFn.assign(self.x(n)(self.name)) + xFnp1 = self.xF(n+1)(self.name) + self.fine_scheme.apply(xFnp1, self.xFn) + + # compute correction + for n in range(k, self.n_intervals): + xn = self.x(n)(self.name) + xGk = self.xGk(n+1)(self.name) + # compute new coarse guess + self.coarse_scheme.apply(xGk, xn) + xnp1 = self.x(n+1)(self.name) + xGkm1 = self.xGkm1(n+1)(self.name) + xFnp1 = self.xF(n+1)(self.name) + xnp1.assign(xGk - xGkm1 + xFnp1) + xGkm1.assign(xGk) + + x_out.assign(xnp1) diff --git a/gusto/timestepping/timestepper.py b/gusto/timestepping/timestepper.py index 3c619753b..e2df54e06 100644 --- a/gusto/timestepping/timestepper.py +++ b/gusto/timestepping/timestepper.py @@ -7,7 +7,7 @@ from gusto.equations import PrognosticEquationSet from gusto.core import TimeLevelFields, StateFields from gusto.core.io import TimeData -from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity +from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity, transporting_velocity from gusto.core.logging import logger from gusto.time_discretisation.time_discretisation import ExplicitTimeDiscretisation from gusto.spatial_methods.transport_methods import TransportMethod @@ -127,32 +127,7 @@ def setup_transporting_velocity(self, scheme): else: uadv = self.transporting_velocity - scheme.residual = scheme.residual.label_map( - lambda t: t.has_label(transporting_velocity), - map_if_true=lambda t: - Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}), t.labels) - ) - - scheme.residual = transporting_velocity.update_value(scheme.residual, uadv) - - # Now also replace transporting velocity in the terms that are - # contained in labels - for idx, t in enumerate(scheme.residual.terms): - if t.has_label(transporting_velocity): - for label in t.labels.keys(): - if type(t.labels[label]) is LabelledForm: - t.labels[label] = t.labels[label].label_map( - lambda s: s.has_label(transporting_velocity), - map_if_true=lambda s: - Term(ufl.replace( - s.form, - {s.get(transporting_velocity): uadv}), - s.labels - ) - ) - - scheme.residual.terms[idx].labels[label] = \ - transporting_velocity.update_value(t.labels[label], uadv) + scheme.setup_transporting_velocity(uadv) def log_timestep(self): """