Skip to content

Commit

Permalink
add 4 and 5 stage SSPRK3 schemes
Browse files Browse the repository at this point in the history
  • Loading branch information
JHopeCollins committed Dec 18, 2024
1 parent 6e9e2c0 commit 333f17b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
57 changes: 49 additions & 8 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,19 +532,37 @@ def __init__(

class SSPRK3(ExplicitRungeKutta):
u"""
Implements the 3-stage Strong-Stability-Preserving Runge-Kutta method
for solving ∂y/∂t = F(y). It can be written as: \n
Implements 3rd order Strong-Stability-Preserving Runge-Kutta methods
for solving ∂y/∂t = F(y). \n
The 3-stage method can be written as: \n
k0 = F[y^n] \n
k1 = F[y^n + dt*k1] \n
k2 = F[y^n + (1/4)*dt*(k0+k1)] \n
y^(n+1) = y^n + (1/6)*dt*(k0 + k1 + 4*k2) \n
The 4-stage method can be written as: \n
k0 = F[y^n] \n
k1 = F[y^n + (1/2)*dt*k1] \n
k2 = F[y^n + (1/2)*dt*(k0+k1)] \n
k3 = F[y^n + (1/6)*dt*(k0+k1+k2)] \n
y^(n+1) = y^n + (1/6)*dt*(k0 + k1 + k2 + 3*k3) \n
The 5-stage method can be written as: \n
k0 = F[y^n] \n
k1 = F[y^n + (1/2)*dt*k1] \n
k2 = F[y^n + (1/2)*dt*(k0+k1)] \n
k3 = F[y^n + (1/6)*dt*(k0+k1+k2)] \n
y^(n+1) = y^n + (1/6)*dt*(k0 + k1 + k2 + 3*k3) \n
"""
def __init__(
self, domain, field_name=None, subcycling_options=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, limiter=None, options=None,
augmentation=None
augmentation=None, stages=3
):
"""
Args:
Expand All @@ -569,13 +587,36 @@ def __init__(
augmentation (:class:`Augmentation`): allows the equation solved in
this time discretisation to be augmented, for instances with
extra terms of another auxiliary variable. Defaults to None.
stages (int, optional): number of stages: (3, 4, 5). Defaults to 3.
"""

butcher_matrix = np.array([
[1., 0., 0.],
[1./4., 1./4., 0.],
[1./6., 1./6., 2./3.]
])
if stages == 3:
butcher_matrix = np.array([
[1., 0., 0.],
[1./4., 1./4., 0.],
[1./6., 1./6., 2./3.]
])
self.cfl_limit = 1
elif stages == 4:
butcher_matrix = np.array([
[1./2., 0., 0., 0.],
[1./2., 1./2., 0., 0.],
[1./6., 1./6., 1./6., 0.],
[1./6., 1./6., 1./6., 1./2.]
])
self.cfl_limit = 2
elif stages == 5:
self.cfl_limit = 2.65062919294483
butcher_matrix = np.array([
[0.37726891511710, 0., 0., 0., 0.],
[0.37726891511710, 0.37726891511710, 0., 0., 0.],
[0.16352294089771, 0.16352294089771, 0.16352294089771, 0., 0.],
[0.14904059394856, 0.14831273384724, 0.14831273384724, 0.34217696850008, 0.],
[0.19707596384481, 0.11780316509765, 0.11709725193772, 0.27015874934251, 0.29786487010104]
])
else:
raise ValueError(f"{stages} stage 3rd order SSPRK not implemented")

super().__init__(domain, butcher_matrix, field_name=field_name,
subcycling_options=subcycling_options,
rk_formulation=rk_formulation,
Expand Down
25 changes: 20 additions & 5 deletions integration-tests/model/test_time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ def run(timestepper, tmax, f_end):

@pytest.mark.parametrize(
"scheme", [
"ssprk3_increment", "TrapeziumRule", "ImplicitMidpoint", "QinZhang",
"ssprk3_increment_3", "ssprk3_predictor_3", "ssprk3_linear_3",
"ssprk3_increment_4", "ssprk3_predictor_4", "ssprk3_linear_4",
"ssprk3_increment_5", "ssprk3_predictor_5", "ssprk3_linear_5",
"TrapeziumRule", "ImplicitMidpoint", "QinZhang",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth", "Leapfrog",
"AdamsMoulton", "AdamsMoulton", "ssprk3_predictor", "ssprk3_linear"
"AdamsMoulton", "AdamsMoulton"
]
)
def test_time_discretisation(tmpdir, scheme, tracer_setup):
Expand All @@ -30,12 +33,24 @@ def test_time_discretisation(tmpdir, scheme, tracer_setup):
V = domain.spaces("DG")
eqn = AdvectionEquation(domain, V, "f")

if scheme == "ssprk3_increment":
if scheme == "ssprk3_increment_3":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.increment)
elif scheme == "ssprk3_predictor":
elif scheme == "ssprk3_predictor_3":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.predictor)
elif scheme == "ssprk3_linear":
elif scheme == "ssprk3_linear_3":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.linear)
if scheme == "ssprk3_increment_4":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.increment, stages=4)
elif scheme == "ssprk3_predictor_4":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.predictor, stages=4)
elif scheme == "ssprk3_linear_4":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.linear, stages=4)
if scheme == "ssprk3_increment_5":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.increment, stages=5)
elif scheme == "ssprk3_predictor_5":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.predictor, stages=5)
elif scheme == "ssprk3_linear_5":
transport_scheme = SSPRK3(domain, rk_formulation=RungeKuttaFormulation.linear, stages=5)
elif scheme == "TrapeziumRule":
transport_scheme = TrapeziumRule(domain)
elif scheme == "ImplicitMidpoint":
Expand Down

0 comments on commit 333f17b

Please sign in to comment.