Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645544950
  • Loading branch information
The swirl_lm Authors authored and john-qingwang committed Jun 22, 2024
1 parent 36bf777 commit 11decd8
Show file tree
Hide file tree
Showing 16 changed files with 277 additions and 436 deletions.
8 changes: 8 additions & 0 deletions swirl_lm/base/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ def _update_additional_states(
updated_additional_states = additional_states
params = common_kwargs['params']

# Clear source terms computed in the previous step.
for varname in updated_additional_states:
if not varname.startswith('src_'):
continue
zeros = tf.nest.map_structure(tf.zeros_like,
updated_additional_states[varname])
updated_additional_states[varname] = zeros

# Update BC additional states. Note currently this is only done
# for the nonreflecting BC and will be a no-op if there is no nonreflecting
# BC present.
Expand Down
5 changes: 0 additions & 5 deletions swirl_lm/boundary_condition/rayleigh_damping_layer.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,6 @@ message RayleighDampingLayer {
// The name of the state to be used as a sponge target.
string target_state_name = 5;
}
// An indicator of whether this is the only forcing term of this variable:
// 'true' if it's the only force and need to replace previous values;
// 'false' if it's not the only force and need to be added after other
// forces are updated
optional bool override = 3;
// An indicator of whether the sponge force is applied to a primitive
// variable or a conservative one. If 'true', the force term is generated
// for the primitive variable; otherwise the force term is for a
Expand Down
63 changes: 6 additions & 57 deletions swirl_lm/boundary_condition/rayleigh_damping_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,6 @@ def get_sponge_target_name(varname: Text) -> Text:
return 'sponge_target_{}'.format(varname)


def target_value_lib_from_proto(
sponges: _RayleighDampingLayerSeq) -> TargetValueLib:
"""Generates a target value library from the proto.
Args:
sponges: A sequence of materialized sponge layer protos.
Returns:
A dictionary with keys being the variable names, and values being the target
value in the sponge layer.
"""
lib = {}
for sponge in sponges:
for info in sponge.variable_info:
if info.HasField('target_value'):
lib.update({info.name: info.target_value})
elif info.HasField('target_state_name'):
lib.update({info.name: info.target_state_name})
else:
lib.update({info.name: None})
return lib


def variable_type_lib_from_proto(sponges: _RayleighDampingLayerSeq) -> BoolMap:
"""Generates a library for the type of the variable from the proto.
Expand All @@ -125,25 +102,6 @@ def variable_type_lib_from_proto(sponges: _RayleighDampingLayerSeq) -> BoolMap:
return out


def target_status_lib_from_proto(sponges: _RayleighDampingLayerSeq) -> BoolMap:
"""Generates a sponge forcing status library from the proto.
Args:
sponges: A sequence of materialized sponge layer protos.
Returns:
A dictionary with keys being the name of the forcing term, and values being
the target behavior of the forcing term. `False` if there are other forcing
terms need to be combined with the sponge force, and `True` if the sponge
force is the only forcing term for that variable.
"""
out = {}
for sponge in sponges:
for info in sponge.variable_info:
out[get_sponge_force_name(info.name)] = info.override
return out


def _get_beta_name_from_sponge_info(
sponge: rayleigh_damping_layer_pb2.RayleighDampingLayer) -> str:
# Use default name 'sponge_beta' if beta_name is not explicitly set.
Expand Down Expand Up @@ -304,8 +262,6 @@ def __init__(
periodic_dims: An optional list of booleans indicating the periodic
dimensions.
"""
self._target_values = target_value_lib_from_proto(sponge_infos)
self._target_status = target_status_lib_from_proto(sponge_infos)
self._is_primitive = variable_type_lib_from_proto(sponge_infos)
self._beta_name_by_var = beta_name_by_var(sponge_infos)
self._sponge_info_map = sponge_info_map(sponge_infos)
Expand All @@ -318,7 +274,7 @@ def __init__(

logging.info(
'Sponge layer will be applied for the following variables with '
'following values: %r', self._target_values)
'following values: %s', self._sponge_info_map)

def _get_sponge_force(
self,
Expand Down Expand Up @@ -414,14 +370,10 @@ def additional_states_update_fn(
) -> FlowFieldMap:
"""Updates the forcing term due to the sponge layer.
The forcing term will replace or add to the existing forcing term in
The forcing term will be added to the existing forcing term in
`additional_states` for variables that are in the scope of
`self._target_values`, following the indicator stated in
`self._target_status`: if `False`, the sponge force will be added to the
input force with the same name; if `True`, the sponge force will override
the existing one. If other forcing terms needs to be applied to a same
variable, the values of all these forcing terms needs to be updated ahead of
the sponge forces.
`self._sponge_info_map` following the specification of target values stored
in the values of this dictionary.
Args:
kernel_op: An object holding a library of kernel operations.
Expand Down Expand Up @@ -482,10 +434,7 @@ def add_to_additional_states(
sponge_force = tf.nest.map_structure(
tf.math.multiply, states['rho'], sponge_force
)
if self._target_status[sponge_name]:
additional_states_updated.update({sponge_name: sponge_force})
else:
additional_states_updated.update(
{sponge_name: add_to_additional_states(sponge_name, sponge_force)})
additional_states_updated.update(
{sponge_name: add_to_additional_states(sponge_name, sponge_force)})

return additional_states_updated
3 changes: 1 addition & 2 deletions swirl_lm/boundary_condition/simulated_turbulent_inflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def __init__(self, params: parameters_lib.SwirlLMParameters):

if self._model_params.WhichOneof('operation') == 'generation':
# Get the index of the plane to extract the inflow data.
mesh_size = (self._params.dx, self._params.dy,
self._params.dz)[self._inflow_dim]
mesh_size = self._params.grid_spacings[self._inflow_dim]
mesh_count = (self._params.nx, self._params.ny,
self._params.nz)[self._inflow_dim]
idx = int(self._model_params.generation.location // mesh_size)
Expand Down
49 changes: 41 additions & 8 deletions swirl_lm/example/fire/fire.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,13 @@ def cubic_obstacles(
types.FlowFieldMap]


def get_init_rho_f(ground_elevation: tf.Tensor, fuel_bed_height: float,
fuel_density: float, fuel_start_x: float,
dz: float) -> wildfire_utils.InitFn:
def get_init_rho_f(
ground_elevation: tf.Tensor,
fuel_bed_height: float,
fuel_density: float,
fuel_start_x: float,
dz: float | tf.Tensor,
) -> wildfire_utils.InitFn:
"""Returns the initializer function for rho_f."""

# We assume grid coordinates in zz are integer multiples of dz and use +/-0.1
Expand All @@ -371,7 +375,9 @@ def get_init_rho_f(ground_elevation: tf.Tensor, fuel_bed_height: float,
# fuel cell will be in the halo.

quantized_ground_elevation = tf.math.ceil(ground_elevation / dz) * dz - 2 * dz
num_full_cells = int(np.floor(fuel_bed_height / dz))
num_full_cells = tf.cast(
tf.floor(fuel_bed_height / dz), quantized_ground_elevation.dtype
)
# Note that the number of full cells is one more than given by
# fuel_bed_height because we also put fuel into the boundary cell (i.e.,
# the cell that intersects with the terrain).
Expand Down Expand Up @@ -1172,10 +1178,12 @@ def init_y_o(xx, yy, zz, lx, ly, lz, coord):
def init_rho_m(xx, yy, zz, lx, ly, lz, coord):
"""Generates initial moisture `rho_m` field."""
del xx, yy, lx, ly, lz, coord
# In case of stretched grid, we use the first grid spacing as reference
# here.
return tf.compat.v1.where(
tf.math.logical_and(
zz <= ground_elevation + self.fire_utils.fuel_bed_height,
zz >= ground_elevation - self.config.dz,
zz >= ground_elevation - self.config.z[1],
),
self.fire_utils.moisture_density * tf.ones_like(zz),
tf.zeros_like(zz),
Expand Down Expand Up @@ -1347,9 +1355,34 @@ def init_ignition_kernel(xx, yy, zz, lx, ly, lz, coord):
assert (
'rho_f' in self.config.additional_state_keys
), 'Fuel height is none zero but rho_f is not included in the config.'
init_rho_f = get_init_rho_f(
ground_elevation, self.fire_utils.fuel_bed_height,
self.fire_utils.fuel_density, self.fuel_start_x, self.config.dz)

# In case of stretched grid, we assign the same fuel density at all node
# points below the fuel height.
if self.config.use_stretched_grid[2]:

def init_rho_f(xx, yy, zz, lx, ly, lz, coord):
"""Generates initial fuel density `rho_f` field."""
del yy, lx, ly, lz, coord
rho_f = tf.where(
tf.math.logical_and(
zz <= ground_elevation + self.fire_utils.fuel_bed_height,
zz >= ground_elevation - self.config.z[1],
),
self.fire_utils.fuel_density * tf.ones_like(zz),
tf.zeros_like(zz),
)
return tf.where(
tf.greater_equal(xx, self.fuel_start_x), rho_f, tf.zeros_like(xx)
)

else:
init_rho_f = get_init_rho_f(
ground_elevation,
self.fire_utils.fuel_bed_height,
self.fire_utils.fuel_density,
self.fuel_start_x,
self.config.dz,
)
output.update({
'rho_f':
self.fire_utils.states_init(coordinates, init_rho_f, 'CONSTANT'),
Expand Down
Loading

0 comments on commit 11decd8

Please sign in to comment.