Skip to content

Commit

Permalink
start to move parareal over
Browse files Browse the repository at this point in the history
  • Loading branch information
jshipton committed Dec 17, 2024
1 parent 595551d commit 7e22895
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 30 deletions.
33 changes: 31 additions & 2 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
Constant, NonlinearVariationalProblem,
NonlinearVariationalSolver)
from firedrake.fml import (replace_subject, replace_test_function, Term,
all_terms, drop)
all_terms, drop, LabelledForm)
from firedrake.formmanipulation import split_form
from firedrake.utils import cached_property

from gusto.core.configuration import EmbeddedDGOptions, RecoveryOptions
from gusto.core.labels import (time_derivative, prognostic, physics_label,
mass_weighted, nonlinear_time_derivative)
mass_weighted, nonlinear_time_derivative,
transporting_velocity)
from gusto.core.logging import logger, DEBUG, logging_ksp_monitor_true_residual
from gusto.time_discretisation.wrappers import *

Expand Down Expand Up @@ -324,6 +325,34 @@ def setup(self, equation, apply_bcs=True, *active_labels):
self.x_out = Function(self.fs)
self.x1 = Function(self.fs)

def setup_transporting_velocity(self, uadv):
self.residual = self.residual.label_map(
lambda t: t.has_label(transporting_velocity),
map_if_true=lambda t:
Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}),
t.labels)
)

self.residual = transporting_velocity.update_value(self.residual, uadv)
# Now also replace transporting velocity in the terms that are
# contained in labels
for idx, t in enumerate(self.residual.terms):
if t.has_label(transporting_velocity):
for label in t.labels.keys():
if type(t.labels[label]) is LabelledForm:
t.labels[label] = t.labels[label].label_map(
lambda s: s.has_label(transporting_velocity),
map_if_true=lambda s:
Term(ufl.replace(
s.form,
{s.get(transporting_velocity): uadv}),
s.labels
)
)

self.residual.terms[idx].labels[label] = \
transporting_velocity.update_value(t.labels[label], uadv)

@property
def nlevels(self):
return 1
Expand Down
3 changes: 2 additions & 1 deletion gusto/timestepping/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from gusto.timestepping.timestepper import * # noqa
from gusto.timestepping.parareal import * # noqa
from gusto.timestepping.split_timestepper import * # noqa
from gusto.timestepping.semi_implicit_quasi_newton import * # noqa
from gusto.timestepping.semi_implicit_quasi_newton import * # noqa
101 changes: 101 additions & 0 deletions gusto/timestepping/parareal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from firedrake import Function
from gusto.core.fields import Fields


class PararealFields(object):

def __init__(self, equation, nlevels):
levels = [str(n) for n in range(nlevels+1)]
self.add_fields(equation, levels)

def add_fields(self, equation, levels):
if levels is None:
levels = self.levels
for level in levels:
try:
x = getattr(self, level)
x.add_field(equation.field_name, equation.function_space)
except AttributeError:
setattr(self, level, Fields(equation))

def __call__(self, n):
return getattr(self, str(n))


class Parareal(object):

def __init__(self, domain, coarse_scheme, fine_scheme, nG, nF,
n_intervals, max_its):

assert coarse_scheme.nlevels == 1
assert fine_scheme.nlevels == 1
self.nlevels = 1

self.coarse_scheme = coarse_scheme
self.coarse_scheme.dt.assign(domain.dt/n_intervals/nG)
self.fine_scheme = fine_scheme
self.fine_scheme.dt.assign(domain.dt/n_intervals/nG)
self.nG = nG
self.nF = nF
self.n_intervals = n_intervals
self.max_its = max_its

def setup(self, equation, apply_bcs=True, *active_labels):
self.coarse_scheme.fixed_subcycles = self.nG
self.coarse_scheme.setup(equation, apply_bcs, *active_labels)
self.fine_scheme.fixed_subcycles = self.nF
self.fine_scheme.setup(equation, apply_bcs, *active_labels)
self.x = PararealFields(equation, self.n_intervals)
self.xF = PararealFields(equation, self.n_intervals)
self.xn = Function(equation.function_space)
self.xGk = PararealFields(equation, self.n_intervals)
self.xGkm1 = PararealFields(equation, self.n_intervals)
self.xFn = Function(equation.function_space)
self.xFnp1 = Function(equation.function_space)
self.name = equation.field_name

def setup_transporting_velocity(self, uadv):
self.coarse_scheme.setup_transporting_velocity(uadv)
self.fine_scheme.setup_transporting_velocity(uadv)

def apply(self, x_out, x_in):

self.xn.assign(x_in)
x0 = self.x(0)(self.name)
x0.assign(x_in)
xF0 = self.xF(0)(self.name)
xF0.assign(x_in)

# compute first guess from coarse scheme
for n in range(self.n_intervals):
print("computing first coarse guess for interval: ", n)
# apply coarse scheme and save data as initial conditions for fine
xGnp1 = self.xGkm1(n+1)(self.name)
self.coarse_scheme.apply(xGnp1, self.xn)
xnp1 = self.x(n+1)(self.name)
xnp1.assign(xGnp1)
self.xn.assign(xnp1)

for k in range(self.max_its):

# apply fine scheme in each interval using previously
# calculated coarse data
for n in range(k, self.n_intervals):
print("computing fine guess for iteration and interval: ", k, n)
self.xFn.assign(self.x(n)(self.name))
xFnp1 = self.xF(n+1)(self.name)
self.fine_scheme.apply(xFnp1, self.xFn)

# compute correction
for n in range(k, self.n_intervals):
xn = self.x(n)(self.name)
xGk = self.xGk(n+1)(self.name)
# compute new coarse guess
self.coarse_scheme.apply(xGk, xn)
xnp1 = self.x(n+1)(self.name)
xGkm1 = self.xGkm1(n+1)(self.name)
xFnp1 = self.xF(n+1)(self.name)
xnp1.assign(xGk - xGkm1 + xFnp1)
xGkm1.assign(xGk)

x_out.assign(xnp1)
29 changes: 2 additions & 27 deletions gusto/timestepping/timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gusto.equations import PrognosticEquationSet
from gusto.core import TimeLevelFields, StateFields
from gusto.core.io import TimeData
from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity
from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity, transporting_velocity

Check failure on line 10 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F811

gusto/timestepping/timestepper.py:10:1: F811 redefinition of unused 'transporting_velocity' from line 10

Check failure on line 10 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

gusto/timestepping/timestepper.py:10:1: F401 'gusto.core.labels.transporting_velocity' imported but unused
from gusto.core.logging import logger
from gusto.time_discretisation.time_discretisation import ExplicitTimeDiscretisation
from gusto.spatial_methods.transport_methods import TransportMethod
Expand Down Expand Up @@ -127,32 +127,7 @@ def setup_transporting_velocity(self, scheme):
else:
uadv = self.transporting_velocity

scheme.residual = scheme.residual.label_map(
lambda t: t.has_label(transporting_velocity),
map_if_true=lambda t:
Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}), t.labels)
)

scheme.residual = transporting_velocity.update_value(scheme.residual, uadv)

# Now also replace transporting velocity in the terms that are
# contained in labels
for idx, t in enumerate(scheme.residual.terms):
if t.has_label(transporting_velocity):
for label in t.labels.keys():
if type(t.labels[label]) is LabelledForm:
t.labels[label] = t.labels[label].label_map(
lambda s: s.has_label(transporting_velocity),
map_if_true=lambda s:
Term(ufl.replace(
s.form,
{s.get(transporting_velocity): uadv}),
s.labels
)
)

scheme.residual.terms[idx].labels[label] = \
transporting_velocity.update_value(t.labels[label], uadv)
scheme.setup_transporting_velocity(uadv)

def log_timestep(self):
"""
Expand Down

0 comments on commit 7e22895

Please sign in to comment.