Skip to content

Commit

Permalink
PR #453: from firedrakeproject/IMEX_multistage
Browse files Browse the repository at this point in the history
Implement multistage IMEX schemes
  • Loading branch information
tommbendall authored Nov 8, 2023
2 parents 2f9b04a + 7a14129 commit c209deb
Show file tree
Hide file tree
Showing 5 changed files with 598 additions and 84 deletions.
74 changes: 71 additions & 3 deletions gusto/common_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
Provides some basic forms for discretising various common terms in equations for
geophysical fluid dynamics."""

from firedrake import dx, dot, grad, div, inner, outer, cross, curl
from firedrake import (dx, dot, grad, div, inner, outer, cross, curl, split,
TestFunction, TestFunctions, TrialFunction)
from firedrake.fml import subject, drop
from gusto.configuration import TransportEquationType
from gusto.labels import transport, transporting_velocity, diffusion
from gusto.labels import (transport, transporting_velocity, diffusion,
prognostic, linearisation)

__all__ = ["advection_form", "continuity_form", "vector_invariant_form",
"kinetic_energy_form", "advection_equation_circulation_form",
"diffusion_form", "linear_advection_form", "linear_continuity_form"]
"diffusion_form", "linear_advection_form", "linear_continuity_form",
"split_continuity_form"]


def advection_form(test, q, ubar):
Expand Down Expand Up @@ -194,3 +198,67 @@ def diffusion_form(test, q, kappa):
form = -inner(test, div(kappa*grad(q)))*dx

return diffusion(form)


def split_continuity_form(equation):
u"""
Loops through terms in a given equation, and splits all continuity terms
into advective and divergence terms.
This describes splitting ∇.(u*q) terms into u.∇q and q(∇.u),
for transporting velocity u and transported q.
Args:
equation (:class:`PrognosticEquation`): the model's equation.
Returns:
:class:`PrognosticEquation`: the model's equation.
"""

for t in equation.residual:
if (t.get(transport) == TransportEquationType.conservative):
# Get fields and test functions
subj = t.get(subject)
prognostic_field_name = t.get(prognostic)
if hasattr(equation, "field_names"):
idx = equation.field_names.index(prognostic_field_name)
W = equation.function_space
test = TestFunctions(W)[idx]
q = split(subj)[idx]
else:
W = equation.function_space
test = TestFunction(W)
q = subj
# u is either a prognostic or prescribed field
if (hasattr(equation, "field_names")
and 'u' in equation.field_names):
u_idx = equation.field_names.index('u')
uadv = split(equation.X)[u_idx]
elif 'u' in equation.prescribed_fields._field_names:
uadv = equation.prescribed_fields('u')
else:
raise ValueError('Cannot get velocity field')

# Create new advective and divergence terms
adv_term = prognostic(advection_form(test, q, uadv), prognostic_field_name)
div_term = prognostic(test*q*div(uadv)*dx, prognostic_field_name)

# Add linearisations of new terms if required
if (t.has_label(linearisation)):
u_trial = TrialFunction(W)[u_idx]
qbar = split(equation.X_ref)[idx]
# Add linearisation to adv_term
linear_adv_term = linear_advection_form(test, qbar, u_trial)
adv_term = linearisation(adv_term, linear_adv_term)
# Add linearisation to div_term
linear_div_term = transporting_velocity(qbar*test*div(u_trial)*dx, u_trial)
div_term = linearisation(div_term, linear_div_term)

# Add new terms onto residual
equation.residual += subject(adv_term + div_term, subj)
# Drop old term
equation.residual = equation.residual.label_map(
lambda t: t.get(transport) == TransportEquationType.conservative,
map_if_true=drop)

return equation
1 change: 1 addition & 0 deletions gusto/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,

# Depth transport term
D_adv = prognostic(continuity_form(phi, D, u), 'D')

# Transport term needs special linearisation
if self.linearisation_map(D_adv.terms[0]):
linear_D_adv = linear_continuity_form(phi, H, u_trial)
Expand Down
2 changes: 2 additions & 0 deletions gusto/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __call__(self, target, value=None):
# ---------------------------------------------------------------------------- #

time_derivative = Label("time_derivative")
implicit = Label("implicit")
explicit = Label("explicit")
transport = Label("transport", validator=lambda value: type(value) == TransportEquationType)
diffusion = Label("diffusion")
transporting_velocity = Label("transporting_velocity", validator=lambda value: type(value) in [Function, ufl.tensors.ListTensor])
Expand Down
Loading

0 comments on commit c209deb

Please sign in to comment.