Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start to move parareal over #597

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
Constant, NonlinearVariationalProblem,
NonlinearVariationalSolver)
from firedrake.fml import (replace_subject, replace_test_function, Term,
all_terms, drop)
all_terms, drop, LabelledForm)
from firedrake.formmanipulation import split_form
from firedrake.utils import cached_property

from gusto.core.configuration import EmbeddedDGOptions, RecoveryOptions
from gusto.core.labels import (time_derivative, prognostic, physics_label,
mass_weighted, nonlinear_time_derivative)
mass_weighted, nonlinear_time_derivative,
transporting_velocity)
from gusto.core.logging import logger, DEBUG, logging_ksp_monitor_true_residual
from gusto.time_discretisation.wrappers import *

Expand Down Expand Up @@ -324,6 +325,34 @@ def setup(self, equation, apply_bcs=True, *active_labels):
self.x_out = Function(self.fs)
self.x1 = Function(self.fs)

def setup_transporting_velocity(self, uadv):
self.residual = self.residual.label_map(
lambda t: t.has_label(transporting_velocity),
map_if_true=lambda t:
Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}),
t.labels)
)

self.residual = transporting_velocity.update_value(self.residual, uadv)
# Now also replace transporting velocity in the terms that are
# contained in labels
for idx, t in enumerate(self.residual.terms):
if t.has_label(transporting_velocity):
for label in t.labels.keys():
if type(t.labels[label]) is LabelledForm:
t.labels[label] = t.labels[label].label_map(
lambda s: s.has_label(transporting_velocity),
map_if_true=lambda s:
Term(ufl.replace(
s.form,
{s.get(transporting_velocity): uadv}),
s.labels
)
)

self.residual.terms[idx].labels[label] = \
transporting_velocity.update_value(t.labels[label], uadv)

@property
def nlevels(self):
return 1
Expand Down
3 changes: 2 additions & 1 deletion gusto/timestepping/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from gusto.timestepping.timestepper import * # noqa
from gusto.timestepping.parareal import * # noqa
from gusto.timestepping.split_timestepper import * # noqa
from gusto.timestepping.semi_implicit_quasi_newton import * # noqa
from gusto.timestepping.semi_implicit_quasi_newton import * # noqa
101 changes: 101 additions & 0 deletions gusto/timestepping/parareal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from firedrake import Function
from gusto.core.fields import Fields


class PararealFields(object):

def __init__(self, equation, nlevels):
levels = [str(n) for n in range(nlevels+1)]
self.add_fields(equation, levels)

def add_fields(self, equation, levels):
if levels is None:
levels = self.levels
for level in levels:
try:
x = getattr(self, level)
x.add_field(equation.field_name, equation.function_space)
except AttributeError:
setattr(self, level, Fields(equation))

def __call__(self, n):
return getattr(self, str(n))


class Parareal(object):

def __init__(self, domain, coarse_scheme, fine_scheme, nG, nF,
n_intervals, max_its):

assert coarse_scheme.nlevels == 1
assert fine_scheme.nlevels == 1
self.nlevels = 1

self.coarse_scheme = coarse_scheme
self.coarse_scheme.dt.assign(domain.dt/n_intervals/nG)
self.fine_scheme = fine_scheme
self.fine_scheme.dt.assign(domain.dt/n_intervals/nG)
self.nG = nG
self.nF = nF
self.n_intervals = n_intervals
self.max_its = max_its

def setup(self, equation, apply_bcs=True, *active_labels):
self.coarse_scheme.fixed_subcycles = self.nG
self.coarse_scheme.setup(equation, apply_bcs, *active_labels)
self.fine_scheme.fixed_subcycles = self.nF
self.fine_scheme.setup(equation, apply_bcs, *active_labels)
self.x = PararealFields(equation, self.n_intervals)
self.xF = PararealFields(equation, self.n_intervals)
self.xn = Function(equation.function_space)
self.xGk = PararealFields(equation, self.n_intervals)
self.xGkm1 = PararealFields(equation, self.n_intervals)
self.xFn = Function(equation.function_space)
self.xFnp1 = Function(equation.function_space)
self.name = equation.field_name

def setup_transporting_velocity(self, uadv):
self.coarse_scheme.setup_transporting_velocity(uadv)
self.fine_scheme.setup_transporting_velocity(uadv)

def apply(self, x_out, x_in):

self.xn.assign(x_in)
x0 = self.x(0)(self.name)
x0.assign(x_in)
xF0 = self.xF(0)(self.name)
xF0.assign(x_in)

# compute first guess from coarse scheme
for n in range(self.n_intervals):
print("computing first coarse guess for interval: ", n)
# apply coarse scheme and save data as initial conditions for fine
xGnp1 = self.xGkm1(n+1)(self.name)
self.coarse_scheme.apply(xGnp1, self.xn)
xnp1 = self.x(n+1)(self.name)
xnp1.assign(xGnp1)
self.xn.assign(xnp1)

for k in range(self.max_its):

# apply fine scheme in each interval using previously
# calculated coarse data
for n in range(k, self.n_intervals):
print("computing fine guess for iteration and interval: ", k, n)
self.xFn.assign(self.x(n)(self.name))
xFnp1 = self.xF(n+1)(self.name)
self.fine_scheme.apply(xFnp1, self.xFn)

# compute correction
for n in range(k, self.n_intervals):
xn = self.x(n)(self.name)
xGk = self.xGk(n+1)(self.name)
# compute new coarse guess
self.coarse_scheme.apply(xGk, xn)
xnp1 = self.x(n+1)(self.name)
xGkm1 = self.xGkm1(n+1)(self.name)
xFnp1 = self.xF(n+1)(self.name)
xnp1.assign(xGk - xGkm1 + xFnp1)
xGkm1.assign(xGk)

x_out.assign(xnp1)
29 changes: 2 additions & 27 deletions gusto/timestepping/timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from abc import ABCMeta, abstractmethod, abstractproperty
from firedrake import Function, Projector, split
from firedrake.fml import drop, Term, LabelledForm

Check failure on line 5 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

gusto/timestepping/timestepper.py:5:1: F401 'firedrake.fml.Term' imported but unused

Check failure on line 5 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

gusto/timestepping/timestepper.py:5:1: F401 'firedrake.fml.LabelledForm' imported but unused
from pyop2.profiling import timed_stage
from gusto.equations import PrognosticEquationSet
from gusto.core import TimeLevelFields, StateFields
from gusto.core.io import TimeData
from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity
from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity, transporting_velocity

Check failure on line 10 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F811

gusto/timestepping/timestepper.py:10:1: F811 redefinition of unused 'transporting_velocity' from line 10

Check failure on line 10 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

gusto/timestepping/timestepper.py:10:1: F401 'gusto.core.labels.transporting_velocity' imported but unused
from gusto.core.logging import logger
from gusto.time_discretisation.time_discretisation import ExplicitTimeDiscretisation
from gusto.spatial_methods.transport_methods import TransportMethod
import ufl

Check failure on line 14 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

gusto/timestepping/timestepper.py:14:1: F401 'ufl' imported but unused

__all__ = ["BaseTimestepper", "Timestepper", "PrescribedTransport"]

Expand Down Expand Up @@ -39,7 +39,7 @@
self.io.log_parameters(equation)

@abstractproperty
def transporting_velocity(self):

Check failure on line 42 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F811

gusto/timestepping/timestepper.py:42:5: F811 redefinition of unused 'transporting_velocity' from line 10
return NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -127,32 +127,7 @@
else:
uadv = self.transporting_velocity

scheme.residual = scheme.residual.label_map(
lambda t: t.has_label(transporting_velocity),
map_if_true=lambda t:
Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}), t.labels)
)

scheme.residual = transporting_velocity.update_value(scheme.residual, uadv)

# Now also replace transporting velocity in the terms that are
# contained in labels
for idx, t in enumerate(scheme.residual.terms):
if t.has_label(transporting_velocity):
for label in t.labels.keys():
if type(t.labels[label]) is LabelledForm:
t.labels[label] = t.labels[label].label_map(
lambda s: s.has_label(transporting_velocity),
map_if_true=lambda s:
Term(ufl.replace(
s.form,
{s.get(transporting_velocity): uadv}),
s.labels
)
)

scheme.residual.terms[idx].labels[label] = \
transporting_velocity.update_value(t.labels[label], uadv)
scheme.setup_transporting_velocity(uadv)

def log_timestep(self):
"""
Expand Down Expand Up @@ -317,7 +292,7 @@
super().__init__(equation=equation, io=io)

@property
def transporting_velocity(self):

Check failure on line 295 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F811

gusto/timestepping/timestepper.py:295:5: F811 redefinition of unused 'transporting_velocity' from line 10
return "prognostic"

def setup_fields(self):
Expand Down Expand Up @@ -383,7 +358,7 @@
self.velocity_apply = None

@property
def transporting_velocity(self):

Check failure on line 361 in gusto/timestepping/timestepper.py

View workflow job for this annotation

GitHub Actions / Run linter

F811

gusto/timestepping/timestepper.py:361:5: F811 redefinition of unused 'transporting_velocity' from line 10
return self.fields('u')

def setup_fields(self):
Expand Down
Loading