Skip to content

Commit

Permalink
Implement simultaneous transport with SIQN (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 authored Dec 4, 2024
1 parent e138b5e commit 564810d
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 54 deletions.
7 changes: 4 additions & 3 deletions gusto/equations/prognostic_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,13 @@ def add_tracers_to_prognostics(self, domain, active_tracers):
name of the active tracer.
"""

# Check if there are any conservatively transported tracers.
# If so, ensure that the reference density is indexed before this tracer.
# If there are any conservatively transported tracers, ensure
# that the reference density, if it is also an active tracer,
# is indexed earlier.
for i in range(len(active_tracers) - 1):
tracer = active_tracers[i]
if tracer.transport_eqn == TransportEquationType.tracer_conservative:
ref_density = next(x for x in active_tracers if x.name == tracer.density_name)
ref_density = next((x for x in active_tracers if x.name == tracer.density_name), tracer)
j = active_tracers.index(ref_density)
if j > i:
# Swap the indices of the tracer and the reference density
Expand Down
84 changes: 65 additions & 19 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,58 @@ def setup(self, equation, apply_bcs=True, *active_labels):
self.residual = equation.residual

if self.field_name is not None and hasattr(equation, "field_names"):
self.idx = equation.field_names.index(self.field_name)
self.fs = equation.spaces[self.idx]
self.residual = self.residual.label_map(
lambda t: t.get(prognostic) == self.field_name,
lambda t: Term(
split_form(t.form)[self.idx].form,
t.labels),
drop)
if isinstance(self.field_name, list):
# Multiple fields are being solved for simultaneously.
# This enables conservative transport to be implemented with SIQN.
# Use the full mixed space for self.fs, with the
# field_name, residual, and BCs being set up later.
self.fs = equation.function_space
self.idx = None
else:
self.idx = equation.field_names.index(self.field_name)
self.fs = equation.spaces[self.idx]
self.residual = self.residual.label_map(
lambda t: t.get(prognostic) == self.field_name,
lambda t: Term(
split_form(t.form)[self.idx].form,
t.labels),
drop)

else:
self.field_name = equation.field_name
self.fs = equation.function_space
self.idx = None

bcs = equation.bcs[self.field_name]

if len(active_labels) > 0:
self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, *active_labels)),
map_if_false=drop)
if isinstance(self.field_name, list):
# Multiple fields are being solved for simultaneously.
# Keep all time derivative terms:
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)

# Only keep active labels for prognostics in the list
# of simultaneously transported variables:
for subname in self.field_name:
field_residual = self.residual.label_map(
lambda t: t.get(prognostic) == subname,
map_if_false=drop)

residual += field_residual.label_map(
lambda t: t.has_label(*active_labels),
map_if_false=drop)

self.residual = residual
else:
self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, *active_labels)),
map_if_false=drop)

# Set the field name if using simultaneous transport.
if isinstance(self.field_name, list):
self.field_name = equation.field_name

bcs = equation.bcs[self.field_name]

self.evaluate_source = []
self.physics_names = []
Expand All @@ -175,7 +207,10 @@ def setup(self, equation, apply_bcs=True, *active_labels):
# timestepper should be used instead.
if len(field_terms.label_map(lambda t: t.has_label(mass_weighted), map_if_false=drop)) > 0:
if len(field_terms.label_map(lambda t: not t.has_label(mass_weighted), map_if_false=drop)) > 0:
raise ValueError(f"Mass-weighted and non-mass-weighted terms are present in a timestepping equation for {field}. As these terms cannot be solved for simultaneously, a split timestepping method should be used instead.")
raise ValueError('Mass-weighted and non-mass-weighted terms are present in a '
+ f'timestepping equation for {field}. As these terms cannot '
+ 'be solved for simultaneously, a split timestepping method '
+ 'should be used instead.')
else:
# Replace the terms with a mass_weighted label with the
# mass_weighted form. It is important that the labels from
Expand All @@ -199,10 +234,11 @@ def setup(self, equation, apply_bcs=True, *active_labels):
for field, subwrapper in self.wrapper.subwrappers.items():

if field not in equation.field_names:
raise ValueError(f"The option defined for {field} is for a field that does not exist in the equation set")
raise ValueError(f'The option defined for {field} is for a field '
+ 'that does not exist in the equation set.')

field_idx = equation.field_names.index(field)
subwrapper.setup(equation.spaces[field_idx], wrapper_bcs)
subwrapper.setup(equation.spaces[field_idx], equation.bcs[field])

# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space
Expand Down Expand Up @@ -244,9 +280,19 @@ def setup(self, equation, apply_bcs=True, *active_labels):
if not apply_bcs:
self.bcs = None
elif self.wrapper is not None:
# Transfer boundary conditions onto test function space
self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain)
for bc in bcs]
if self.wrapper_name == 'mixed_options':
# Define new Dirichlet BCs on the wrapper-modified
# mixed function space.
self.bcs = []
for idx, field_name in enumerate(self.equation.field_names):
for bc in equation.bcs[field_name]:
self.bcs.append(DirichletBC(self.fs.sub(idx),
bc.function_arg,
bc.sub_domain))
else:
# Transfer boundary conditions onto test function space
self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain)
for bc in bcs]
else:
self.bcs = bcs

Expand Down
95 changes: 63 additions & 32 deletions gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.reference_update_freq = reference_update_freq
self.to_update_ref_profile = False

# Flag for if we have simultaneous transport
self.simult = False

# default is to not offcentre transporting velocity but if it
# is offcentred then use the same value as alpha
self.alpha_u = Constant(alpha) if off_centred_u else Constant(0.5)
Expand Down Expand Up @@ -148,15 +151,30 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.transported_fields = []
for scheme in transport_schemes:
assert scheme.nlevels == 1, "multilevel schemes not supported as part of this timestepping loop"
assert scheme.field_name in equation_set.field_names
self.active_transport.append((scheme.field_name, scheme))
self.transported_fields.append(scheme.field_name)
# Check that there is a corresponding transport method
method_found = False
for method in spatial_methods:
if scheme.field_name == method.variable and method.term_label == transport:
method_found = True
assert method_found, f'No transport method found for variable {scheme.field_name}'
if isinstance(scheme.field_name, list):
# This means that multiple fields are being transported simultaneously
self.simult = True
for subfield in scheme.field_name:
assert subfield in equation_set.field_names

# Check that there is a corresponding transport method for
# each field in the list
method_found = False
for method in spatial_methods:
if subfield == method.variable and method.term_label == transport:
method_found = True
assert method_found, f'No transport method found for variable {scheme.field_name}'
self.active_transport.append((scheme.field_name, scheme))
else:
assert scheme.field_name in equation_set.field_names

# Check that there is a corresponding transport method
method_found = False
for method in spatial_methods:
if scheme.field_name == method.variable and method.term_label == transport:
method_found = True
self.active_transport.append((scheme.field_name, scheme))
assert method_found, f'No transport method found for variable {scheme.field_name}'

self.diffusion_schemes = []
if diffusion_schemes is not None:
Expand Down Expand Up @@ -240,7 +258,11 @@ def transporting_velocity(self):
def setup_fields(self):
"""Sets up time levels n, star, p and np1"""
self.x = TimeLevelFields(self.equation, 1)
self.x.add_fields(self.equation, levels=("star", "p", "after_slow", "after_fast"))
if self.simult is True:
# If there is any simultaneous transport, add an extra 'simult' field:
self.x.add_fields(self.equation, levels=("star", "p", "simult", "after_slow", "after_fast"))
else:
self.x.add_fields(self.equation, levels=("star", "p", "after_slow", "after_fast"))
for aux_eqn, _ in self.auxiliary_equations_and_schemes:
self.x.add_fields(aux_eqn)
# Prescribed fields for auxiliary eqns should come from prognostics of
Expand Down Expand Up @@ -282,32 +304,44 @@ def copy_active_tracers(self, x_in, x_out):
for name in self.tracers_to_copy:
x_out(name).assign(x_in(name))

def transport_field(self, name, scheme, xstar, xp):
def transport_fields(self, outer, xstar, xp):
"""
Performs the transport of a field in xstar, placing the result in xp.
Transports all fields in xstar with a transport scheme
and places the result in xp.
Args:
name (str): the name of the field to be transported.
scheme (:class:`TimeDiscretisation`): the time discretisation used
for the transport.
outer (int): the outer loop iteration number
xstar (:class:`Fields`): the collection of state fields to be
transported.
xp (:class:`Fields`): the collection of state fields resulting from
the transport.
"""

if name == self.predictor:
# Pre-multiply this variable by (1 - dt*beta*div(u))
V = xstar(name).function_space()
field_out = Function(V)
self.predictor_interpolator.interpolate()
scheme.apply(field_out, self.predictor_field_in)

# xp is xstar plus the increment from the transported predictor
xp(name).assign(xstar(name) + field_out - self.predictor_field_in)
else:
# Standard transport
scheme.apply(xp(name), xstar(name))
for name, scheme in self.active_transport:
if isinstance(name, list):
# Transport multiple fields from xstar simultaneously.
# We transport the mixed function space from xstar to xsimult, then
# extract the updated fields and pass them to xp; this avoids overwriting
# any previously transported fields.
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: '
+ f'Simultaneous transport of {name}')
scheme.apply(self.x.simult(self.field_name), xstar(self.field_name))
for field_name in name:
xp(field_name).assign(self.x.simult(field_name))
else:
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}')
# transports a single field from xstar and puts the result in xp
if name == self.predictor:
# Pre-multiply this variable by (1 - dt*beta*div(u))
V = xstar(name).function_space()
field_out = Function(V)
self.predictor_interpolator.interpolate()
scheme.apply(field_out, self.predictor_field_in)

# xp is xstar plus the increment from the transported predictor
xp(name).assign(xstar(name) + field_out - self.predictor_field_in)
else:
# Standard transport
scheme.apply(xp(name), xstar(name))

def update_reference_profiles(self):
"""
Expand Down Expand Up @@ -367,10 +401,7 @@ def timestep(self):
with timed_stage("Transport"):
self.io.log_courant(self.fields, 'transporting_velocity',
message=f'transporting velocity, outer iteration {outer}')
for name, scheme in self.active_transport:
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}')
# transports a field from xstar and puts result in xp
self.transport_field(name, scheme, xstar, xp)
self.transport_fields(outer, xstar, xp)

# Fast physics -----------------------------------------------------
x_after_fast(self.field_name).assign(xp(self.field_name))
Expand Down
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 564810d

Please sign in to comment.