diff --git a/swirl_lm/base/driver.py b/swirl_lm/base/driver.py index 27a8624..0c97551 100644 --- a/swirl_lm/base/driver.py +++ b/swirl_lm/base/driver.py @@ -313,6 +313,14 @@ def _update_additional_states( updated_additional_states = additional_states params = common_kwargs['params'] + # Clear source terms computed in the previous step. + for varname in updated_additional_states: + if not varname.startswith('src_'): + continue + zeros = tf.nest.map_structure(tf.zeros_like, + updated_additional_states[varname]) + updated_additional_states[varname] = zeros + # Update BC additional states. Note currently this is only done # for the nonreflecting BC and will be a no-op if there is no nonreflecting # BC present. diff --git a/swirl_lm/boundary_condition/rayleigh_damping_layer.proto b/swirl_lm/boundary_condition/rayleigh_damping_layer.proto index 8f3498b..0f7f7d8 100644 --- a/swirl_lm/boundary_condition/rayleigh_damping_layer.proto +++ b/swirl_lm/boundary_condition/rayleigh_damping_layer.proto @@ -63,11 +63,6 @@ message RayleighDampingLayer { // The name of the state to be used as a sponge target. string target_state_name = 5; } - // An indicator of whether this is the only forcing term of this variable: - // 'true' if it's the only force and need to replace previous values; - // 'false' if it's not the only force and need to be added after other - // forces are updated - optional bool override = 3; // An indicator of whether the sponge force is applied to a primitive // variable or a conservative one. If 'true', the force term is generated // for the primitive variable; otherwise the force term is for a diff --git a/swirl_lm/boundary_condition/rayleigh_damping_layer.py b/swirl_lm/boundary_condition/rayleigh_damping_layer.py index e2fc1d8..76e1f50 100644 --- a/swirl_lm/boundary_condition/rayleigh_damping_layer.py +++ b/swirl_lm/boundary_condition/rayleigh_damping_layer.py @@ -84,29 +84,6 @@ def get_sponge_target_name(varname: Text) -> Text: return 'sponge_target_{}'.format(varname) -def target_value_lib_from_proto( - sponges: _RayleighDampingLayerSeq) -> TargetValueLib: - """Generates a target value library from the proto. - - Args: - sponges: A sequence of materialized sponge layer protos. - - Returns: - A dictionary with keys being the variable names, and values being the target - value in the sponge layer. - """ - lib = {} - for sponge in sponges: - for info in sponge.variable_info: - if info.HasField('target_value'): - lib.update({info.name: info.target_value}) - elif info.HasField('target_state_name'): - lib.update({info.name: info.target_state_name}) - else: - lib.update({info.name: None}) - return lib - - def variable_type_lib_from_proto(sponges: _RayleighDampingLayerSeq) -> BoolMap: """Generates a library for the type of the variable from the proto. @@ -125,25 +102,6 @@ def variable_type_lib_from_proto(sponges: _RayleighDampingLayerSeq) -> BoolMap: return out -def target_status_lib_from_proto(sponges: _RayleighDampingLayerSeq) -> BoolMap: - """Generates a sponge forcing status library from the proto. - - Args: - sponges: A sequence of materialized sponge layer protos. - - Returns: - A dictionary with keys being the name of the forcing term, and values being - the target behavior of the forcing term. `False` if there are other forcing - terms need to be combined with the sponge force, and `True` if the sponge - force is the only forcing term for that variable. - """ - out = {} - for sponge in sponges: - for info in sponge.variable_info: - out[get_sponge_force_name(info.name)] = info.override - return out - - def _get_beta_name_from_sponge_info( sponge: rayleigh_damping_layer_pb2.RayleighDampingLayer) -> str: # Use default name 'sponge_beta' if beta_name is not explicitly set. @@ -304,8 +262,6 @@ def __init__( periodic_dims: An optional list of booleans indicating the periodic dimensions. """ - self._target_values = target_value_lib_from_proto(sponge_infos) - self._target_status = target_status_lib_from_proto(sponge_infos) self._is_primitive = variable_type_lib_from_proto(sponge_infos) self._beta_name_by_var = beta_name_by_var(sponge_infos) self._sponge_info_map = sponge_info_map(sponge_infos) @@ -318,7 +274,7 @@ def __init__( logging.info( 'Sponge layer will be applied for the following variables with ' - 'following values: %r', self._target_values) + 'following values: %s', self._sponge_info_map) def _get_sponge_force( self, @@ -414,14 +370,10 @@ def additional_states_update_fn( ) -> FlowFieldMap: """Updates the forcing term due to the sponge layer. - The forcing term will replace or add to the existing forcing term in + The forcing term will be added to the existing forcing term in `additional_states` for variables that are in the scope of - `self._target_values`, following the indicator stated in - `self._target_status`: if `False`, the sponge force will be added to the - input force with the same name; if `True`, the sponge force will override - the existing one. If other forcing terms needs to be applied to a same - variable, the values of all these forcing terms needs to be updated ahead of - the sponge forces. + `self._sponge_info_map` following the specification of target values stored + in the values of this dictionary. Args: kernel_op: An object holding a library of kernel operations. @@ -482,10 +434,7 @@ def add_to_additional_states( sponge_force = tf.nest.map_structure( tf.math.multiply, states['rho'], sponge_force ) - if self._target_status[sponge_name]: - additional_states_updated.update({sponge_name: sponge_force}) - else: - additional_states_updated.update( - {sponge_name: add_to_additional_states(sponge_name, sponge_force)}) + additional_states_updated.update( + {sponge_name: add_to_additional_states(sponge_name, sponge_force)}) return additional_states_updated diff --git a/swirl_lm/boundary_condition/simulated_turbulent_inflow.py b/swirl_lm/boundary_condition/simulated_turbulent_inflow.py index 9837693..c7a3d33 100644 --- a/swirl_lm/boundary_condition/simulated_turbulent_inflow.py +++ b/swirl_lm/boundary_condition/simulated_turbulent_inflow.py @@ -104,8 +104,7 @@ def __init__(self, params: parameters_lib.SwirlLMParameters): if self._model_params.WhichOneof('operation') == 'generation': # Get the index of the plane to extract the inflow data. - mesh_size = (self._params.dx, self._params.dy, - self._params.dz)[self._inflow_dim] + mesh_size = self._params.grid_spacings[self._inflow_dim] mesh_count = (self._params.nx, self._params.ny, self._params.nz)[self._inflow_dim] idx = int(self._model_params.generation.location // mesh_size) diff --git a/swirl_lm/example/fire/fire.py b/swirl_lm/example/fire/fire.py index da08058..67037e5 100644 --- a/swirl_lm/example/fire/fire.py +++ b/swirl_lm/example/fire/fire.py @@ -347,9 +347,13 @@ def cubic_obstacles( types.FlowFieldMap] -def get_init_rho_f(ground_elevation: tf.Tensor, fuel_bed_height: float, - fuel_density: float, fuel_start_x: float, - dz: float) -> wildfire_utils.InitFn: +def get_init_rho_f( + ground_elevation: tf.Tensor, + fuel_bed_height: float, + fuel_density: float, + fuel_start_x: float, + dz: float | tf.Tensor, +) -> wildfire_utils.InitFn: """Returns the initializer function for rho_f.""" # We assume grid coordinates in zz are integer multiples of dz and use +/-0.1 @@ -371,7 +375,9 @@ def get_init_rho_f(ground_elevation: tf.Tensor, fuel_bed_height: float, # fuel cell will be in the halo. quantized_ground_elevation = tf.math.ceil(ground_elevation / dz) * dz - 2 * dz - num_full_cells = int(np.floor(fuel_bed_height / dz)) + num_full_cells = tf.cast( + tf.floor(fuel_bed_height / dz), quantized_ground_elevation.dtype + ) # Note that the number of full cells is one more than given by # fuel_bed_height because we also put fuel into the boundary cell (i.e., # the cell that intersects with the terrain). @@ -1172,10 +1178,12 @@ def init_y_o(xx, yy, zz, lx, ly, lz, coord): def init_rho_m(xx, yy, zz, lx, ly, lz, coord): """Generates initial moisture `rho_m` field.""" del xx, yy, lx, ly, lz, coord + # In case of stretched grid, we use the first grid spacing as reference + # here. return tf.compat.v1.where( tf.math.logical_and( zz <= ground_elevation + self.fire_utils.fuel_bed_height, - zz >= ground_elevation - self.config.dz, + zz >= ground_elevation - self.config.z[1], ), self.fire_utils.moisture_density * tf.ones_like(zz), tf.zeros_like(zz), @@ -1347,9 +1355,34 @@ def init_ignition_kernel(xx, yy, zz, lx, ly, lz, coord): assert ( 'rho_f' in self.config.additional_state_keys ), 'Fuel height is none zero but rho_f is not included in the config.' - init_rho_f = get_init_rho_f( - ground_elevation, self.fire_utils.fuel_bed_height, - self.fire_utils.fuel_density, self.fuel_start_x, self.config.dz) + + # In case of stretched grid, we assign the same fuel density at all node + # points below the fuel height. + if self.config.use_stretched_grid[2]: + + def init_rho_f(xx, yy, zz, lx, ly, lz, coord): + """Generates initial fuel density `rho_f` field.""" + del yy, lx, ly, lz, coord + rho_f = tf.where( + tf.math.logical_and( + zz <= ground_elevation + self.fire_utils.fuel_bed_height, + zz >= ground_elevation - self.config.z[1], + ), + self.fire_utils.fuel_density * tf.ones_like(zz), + tf.zeros_like(zz), + ) + return tf.where( + tf.greater_equal(xx, self.fuel_start_x), rho_f, tf.zeros_like(xx) + ) + + else: + init_rho_f = get_init_rho_f( + ground_elevation, + self.fire_utils.fuel_bed_height, + self.fire_utils.fuel_density, + self.fuel_start_x, + self.config.dz, + ) output.update({ 'rho_f': self.fire_utils.states_init(coordinates, init_rho_f, 'CONSTANT'), diff --git a/swirl_lm/example/firebench/compute_stats.py b/swirl_lm/example/firebench/compute_stats.py deleted file mode 100644 index 1e1c497..0000000 --- a/swirl_lm/example/firebench/compute_stats.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2024 The swirl_lm Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Calculates min/max/mean values for a Swirl-LM dataset. - -For an overview of running Apache Beam pipelines on Google Cloud, see: - -https://cloud.google.com/dataflow/docs/guides/use-beam -https://cloud.google.com/dataflow/docs/quickstarts/create-pipeline-python - -Steps to run: - -1. Build a custom container: - - * Change to the directory swirl_lm/example/firebench/docker. - - * Run: - - gcloud builds submit --region= \ - --tag -docker.pkg.dev///: - - : For example, us-central1. - : Google Cloud project id. If the id contains ':'s, replace - them with slashes. - : A new or existing repository name. - : A new unique name for the image. - : The tag for the version being created. - - For more info see: - https://cloud.google.com/dataflow/docs/guides/build-container-image - - * Verify that the image was built successfully by viewing the "Artifact - Registry" pages in the Cloud console. - - * The image does not need to be rebuilt as long as new python dependencies - are not added. - -2. Launch the dataflow job: - - * Verify that the machine used for the launch has the same version of python3 - as the Beam image (see docker/Dockerfile) and all the requirements - (docker/requirements.txt) are installed. Using a virtual env is recommended - for setting up python and dependencies on the launch machine. - - An alternative to setting up a virtual env is to start up a shell using the - docker image and launch from the docker image, though this is currently not - well tested. - - * Change to the directory swirl_lm/example/firebench. - - python3 compute_stats.py \ - --input_path=gs://firebench/ \ - --output_path=gs:// \ - --pipeline_options="--runner=apache_beam.runners.dataflow.dataflow_runner.DataflowRunner,--project=,--temp_location=gs://,--staging_location=gs://,--region=,--sdk_container_image=@,--sdk_location=container,--save_main_session" - - : Path to input zarr dataset. - : Path to output zarr dataset. - : Google Cloud project id. - : Writable path in a GCS bucket. - : Writable path in a GCS bucket. - : For example, us-central1. - : -docker.pkg.dev/// - from step 1 *without* the tag. - : Digest as output by gcloud builds or as shown on the "Artifact - Repository", e.g., sha256:6e1cf2a963132a240fd06f535c9f9e8cfb1353ca510b2df31cf2f32ff658a8c9 - -""" - - -from typing import Tuple - -from absl import app -from absl import flags -import apache_beam as beam -import xarray -import xarray_beam as xbeam - - -# NOTE: To make top-level imports available to workers, we need to have -# --save_main_sesion=True, but then Beam refuses to save flag values (via -# pickling) so we can't assign flags to global variables as we normally do. -flags.DEFINE_string('input_path', None, help='Input Zarr path') -flags.DEFINE_string('output_path', None, help='Output Zarr path') -flags.DEFINE_list( - 'pipeline_options', ['--runner=DirectRunner'], - 'A comma-separated list of command line arguments to be used as options' - ' for the Beam Pipeline.' -) - - -def compute_stats( - key: xbeam.Key, dataset: xarray.Dataset -) -> Tuple[xbeam.Key, xarray.Dataset]: - """Computes spatial mean/min/max for all variables.""" - spatial_dims = set(dataset.dims) - {'t'} - return ( - key.with_offsets(x=None, y=None, z=None, stat=0), - xarray.concat( - [ - dataset.mean(spatial_dims).assign_coords(stat='mean'), - dataset.min(spatial_dims).assign_coords(stat='min'), - dataset.max(spatial_dims).assign_coords(stat='max'), - ], - dim='stat', - ), - ) - - -class CombineStatsFn(beam.CombineFn): - """Combiner for mean/min/max. - - Keeps track of count of datasets and the accumulated dataset. The accumulated - dataset keeps the sum of means (at stat='mean') from the input datasets, and - the min and the max. At the end of the combine stage, the mean is calculated - from the sum of the means and the count. - """ - - def create_accumulator(self, *args, **kwargs): - return 0, None # Count of datasets, accumulated dataset - - def _merge_stats(self, left, right): - if left[0] == 0: - return right - if right[0] == 0: - return left - accumulator = xarray.concat( - [left[1].sel(stat='mean') + right[1].sel(stat='mean'), - xarray.where((left[1].sel(stat='min') < - right[1].sel(stat='min')), - left[1].sel(stat='min'), - right[1].sel(stat='min')), - xarray.where((left[1].sel(stat='max') > - right[1].sel(stat='max')), - left[1].sel(stat='max'), - right[1].sel(stat='max'))], dim='stat') - return left[0] + right[0], accumulator - - def add_input(self, mutable_accumulator, element, *args, **kwargs): - return self._merge_stats(mutable_accumulator, (1, element)) - - def merge_accumulators(self, accumulators, *args, **kwargs): - out = 0, None - for accumulator in accumulators: - out = self._merge_stats(out, accumulator) - return out - - def extract_output(self, accumulator, *args, **kwargs): - return xarray.concat( - [ - accumulator[1].sel(stat='mean') / accumulator[0], - accumulator[1].sel(stat='min'), - accumulator[1].sel(stat='max'), - ], - dim='stat', - ) - - -class ComputeStats(beam.PTransform): - """Main pipeline as a PTransform to make testing easier.""" - - def __init__(self, input_path: str, output_path: str): - self.input_path = input_path - self.output_path = output_path - - def expand(self, pcoll): - source_dataset, source_chunks = xbeam.open_zarr(self.input_path) - - template = ( - xbeam.make_template(source_dataset) - .isel(x=0, y=0, z=0, drop=True) - .expand_dims(stat=['mean', 'min', 'max']) - ) - - compute_stats_sizes = dict(source_dataset.sizes) - del compute_stats_sizes['x'] - del compute_stats_sizes['y'] - del compute_stats_sizes['z'] - compute_stats_sizes['stat'] = 3 - - compute_stats_chunks = dict(source_chunks) - del compute_stats_chunks['x'] - del compute_stats_chunks['y'] - del compute_stats_chunks['z'] - compute_stats_chunks['stat'] = -1 - - output_chunks = {'t': compute_stats_sizes['t'], 'stat': 3} - - return ( - pcoll - | xbeam.DatasetToChunks(source_dataset, source_chunks) - | beam.MapTuple(compute_stats) - | beam.CombinePerKey(CombineStatsFn()) - | xbeam.Rechunk( - compute_stats_sizes, - compute_stats_chunks, - output_chunks, - itemsize=8, - min_mem=0, # Small chunks are OK. - ) - | xbeam.ChunksToZarr(self.output_path, template, output_chunks) - ) - - -def main(args): - del args - pipeline_options = beam.options.pipeline_options.PipelineOptions( - flags.FLAGS.pipeline_options) - with beam.Pipeline(options=pipeline_options) as root: - _ = ( - root - | ComputeStats(flags.FLAGS.input_path, flags.FLAGS.output_path) - ) - - -if __name__ == '__main__': - app.run(main) diff --git a/swirl_lm/example/firebench/docker/Dockerfile b/swirl_lm/example/firebench/docker/Dockerfile deleted file mode 100644 index d023aa5..0000000 --- a/swirl_lm/example/firebench/docker/Dockerfile +++ /dev/null @@ -1,5 +0,0 @@ -FROM apache/beam_python3.11_sdk - -COPY requirements.txt . - -RUN pip install -r requirements.txt diff --git a/swirl_lm/example/firebench/docker/requirements.txt b/swirl_lm/example/firebench/docker/requirements.txt deleted file mode 100644 index 141f0a9..0000000 --- a/swirl_lm/example/firebench/docker/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -absl-py -apache-beam[gcp] -gcsfs -xarray -xarray_beam diff --git a/swirl_lm/numerics/convection.py b/swirl_lm/numerics/convection.py index fb5825a..f0e9c42 100644 --- a/swirl_lm/numerics/convection.py +++ b/swirl_lm/numerics/convection.py @@ -267,9 +267,7 @@ def face_interpolation( # wall, and the flux to be computed is a wall normal velocity component. if varname is not None and varname in (common.KEYS_VELOCITY[dim], common.KEYS_MOMENTUM[dim]): - nz = len(state) - nx, ny = state[0].get_shape().as_list() - n_grid = (nx, ny, nz)[dim] + n_grid = common_ops.get_shape(state)[dim] n_core = replicas.shape[dim] for face in range(2): diff --git a/swirl_lm/physics/combustion/igniter.py b/swirl_lm/physics/combustion/igniter.py index bf0e02d..6a2a477 100644 --- a/swirl_lm/physics/combustion/igniter.py +++ b/swirl_lm/physics/combustion/igniter.py @@ -151,19 +151,25 @@ def local_ignition_kernel_fn(schedule: tf.Tensor) -> tf.Tensor: return tf.compat.v1.where( tf.math.logical_and( tf.greater_equal(schedule, t - self._igniter_radius_in_time), - tf.less_equal(schedule, t + self._igniter_radius_in_time)), - tf.ones_like(schedule), tf.zeros_like(schedule)) + tf.less_equal(schedule, t + self._igniter_radius_in_time), + ), + tf.ones_like(schedule), + tf.zeros_like(schedule), + ) - ignition_kernel = [ - local_ignition_kernel_fn(schedule) for schedule in ignition_schedule - ] + ignition_kernel = tf.nest.map_structure( + local_ignition_kernel_fn, ignition_schedule + ) def trim_time_interval(kernel: tf.Tensor) -> tf.Tensor: """Limits ignition only in the time interval specified.""" return tf.cond( tf.math.logical_and( tf.greater_equal(t, self._start_time), - tf.less_equal(t, self._end_time)), lambda: kernel, - lambda: tf.zeros_like(kernel)) + tf.less_equal(t, self._end_time), + ), + lambda: kernel, + lambda: tf.zeros_like(kernel), + ) - return [trim_time_interval(kernel) for kernel in ignition_kernel] + return tf.nest.map_structure(trim_time_interval, ignition_kernel) diff --git a/swirl_lm/physics/radiation/optics/optics_base.py b/swirl_lm/physics/radiation/optics/optics_base.py index 6491a61..ab614d6 100644 --- a/swirl_lm/physics/radiation/optics/optics_base.py +++ b/swirl_lm/physics/radiation/optics/optics_base.py @@ -175,31 +175,6 @@ def n_gpt_sw(self) -> int: def solar_fraction_by_gpt(self) -> Sequence[float]: """Mapping from g-point to the fraction of total solar radiation.""" - def _slice_field( - self, - f: FlowFieldVal, - dim: int, - face: int, - idx: int, - ) -> FlowFieldVal: - """Slices a plane from `f` normal to `dim`.""" - face_slice = common_ops.get_face(f, dim, face, idx) - if isinstance(f, tf.Tensor) or dim != 2: - # Remove the outer list. - return face_slice[0] - return face_slice - - def _field_shape( - self, - f: FlowFieldVal, - ) -> FlowFieldVal: - """Returns the x, y, and z dimensions of the field `f`.""" - if isinstance(f, tf.Tensor): - input_shape = f.get_shape().as_list() - return input_shape[1:] + input_shape[:1] - else: - return f[0].get_shape().as_list() + [len(f)] - def _exchange_halos( self, replica_id: tf.Tensor, @@ -207,10 +182,17 @@ def _exchange_halos( f: FlowFieldVal, ) -> FlowFieldVal: """Exchanges halos, preserving the boundary values along the vertical.""" - boundary_vals = [ - [self._slice_field(f, self._g_dim, face, i) for i in range(self._halos)] - for face in range(2) - ] + boundary_vals = [] + # Lower boundary values. + boundary_vals.append([ + common_ops.slice_field(f, self._g_dim, i, size=1) + for i in range(self._halos) + ]) + # Top boundary values. + boundary_vals.append([ + common_ops.slice_field(f, self._g_dim, -i - 1, size=1) + for i in range(self._halos) + ]) # Reverse the order of the top boundary values to restore ascending order. boundary_vals[1].reverse() @@ -268,10 +250,10 @@ def _reconstruct_face_values( # Shift down to obtain the top cell face values and pad the top outermost # halo layer with a copy of the adjacent inner layer. f_top = self._shift_down_fn(f_bottom) - outermost_valid_top_layer = self._slice_field( - f_top, self._g_dim, face=1, idx=1 + outermost_valid_top_layer = common_ops.slice_field( + f_top, self._g_dim, -2, size=1 ) - shape = self._field_shape(f_top) + shape = common_ops.get_shape(f_top) # Update the last halo layer along the vertical. f_top = common_ops.tensor_scatter_1d_update( f_top, diff --git a/swirl_lm/physics/radiation/rte/monochromatic_two_stream.py b/swirl_lm/physics/radiation/rte/monochromatic_two_stream.py index a7a966b..8299fc4 100644 --- a/swirl_lm/physics/radiation/rte/monochromatic_two_stream.py +++ b/swirl_lm/physics/radiation/rte/monochromatic_two_stream.py @@ -547,8 +547,8 @@ def t_noscat_fn(tau: tf.Tensor) -> tf.Tensor: ) # Direct-beam flux incident on the surface. - flux_down_sfc = self.rte_utils.slice( - flux_down_direct, self.g_dim, self.halos, 0 + flux_down_sfc = common_ops.slice_field( + flux_down_direct, self.g_dim, self.halos, size=1 ) core_idx = common_ops.get_core_coordinate(replicas, replica_id)[self.g_dim] diff --git a/swirl_lm/physics/radiation/rte/rte_utils.py b/swirl_lm/physics/radiation/rte/rte_utils.py index 9bdcf18..6922fa1 100644 --- a/swirl_lm/physics/radiation/rte/rte_utils.py +++ b/swirl_lm/physics/radiation/rte/rte_utils.py @@ -63,20 +63,6 @@ def __init__( self.grid_size = (params.nx, params.ny, params.nz) self.halos = params.halo_width - def slice( - self, - f: types.FlowFieldVal, - dim: int, - idx: int, - face: int, - ) -> FlowFieldVal: - """Slices a plane from `f` normal to `dim`.""" - face_slice = common_ops.get_face(f, dim, face, idx) - if isinstance(f, tf.Tensor) or dim != 2: - # Remove the outer list. - return face_slice[0] - return face_slice - def _append( self, a: FlowFieldVal, @@ -204,17 +190,19 @@ def _local_recurrent_op( """ x = variables['x0'] - face = 0 if forward else 1 - for i in range(n): prev_idx = i - 1 + slice_idx = i if forward else -i - 1 plane_args = { - k: self.slice(v, dim, i, face) + k: common_ops.slice_field(v, dim, slice_idx, size=1) for k, v in variables.items() if k != 'x0' } + prev_slice_idx = prev_idx if forward else -prev_idx - 1 plane_args['x0'] = ( - x if i == 0 else self.slice(x, dim, prev_idx, face) + x + if i == 0 + else common_ops.slice_field(x, dim, prev_slice_idx, size=1) ) arg_lst = [ plane_args[k] for k in inspect.getfullargspec(recurrent_fn).args @@ -222,7 +210,8 @@ def _local_recurrent_op( next_layer = tf.nest.map_structure(recurrent_fn, *arg_lst) x = next_layer if i == 0 else self._append(x, next_layer, dim, forward) - last_local_layer = self.slice(x, dim, n - 1, face) + last_layer = -1 if forward else 0 + last_local_layer = common_ops.slice_field(x, dim, last_layer, size=1) return x, last_local_layer diff --git a/swirl_lm/physics/radiation/rte/two_stream.py b/swirl_lm/physics/radiation/rte/two_stream.py index ea85f43..287957f 100644 --- a/swirl_lm/physics/radiation/rte/two_stream.py +++ b/swirl_lm/physics/radiation/rte/two_stream.py @@ -36,6 +36,7 @@ from swirl_lm.physics.radiation.optics import optics from swirl_lm.physics.radiation.rte import monochromatic_two_stream import swirl_lm.physics.radiation.rte.rte_utils as utils +from swirl_lm.utility import common_ops from swirl_lm.utility import get_kernel_fn from swirl_lm.utility import grid_parametrization from swirl_lm.utility import types @@ -171,7 +172,7 @@ def solve_lw( # Create a plane for the surface temperature representation. sfc_temperature = tf.nest.map_structure( lambda x: sfc_temperature * tf.ones_like(x), - self._rte_utils.slice(pressure, self._g_dim, 0, 0) + common_ops.slice_field(pressure, self._g_dim, 0, size=1) ) def step_fn(igpt, cumulative_flux): @@ -208,11 +209,11 @@ def step_fn(igpt, cumulative_flux): ) ) # Boundary conditions. - sfc_src = planck_srcs.get('planck_src_sfc', self._rte_utils.slice( + sfc_src = planck_srcs.get('planck_src_sfc', common_ops.slice_field( planck_srcs['planck_src_bottom'], self._g_dim, - face=0, - idx=self._halos, + self._halos, + size=1 )) top_flux_down = tf.nest.map_structure( lambda x: self._top_flux_down_lw * tf.ones_like(x), sfc_src @@ -329,8 +330,8 @@ def step_fn(igpt, partial_fluxes): ) sfc_albedo = tf.nest.map_structure( lambda x: self._sfc_albedo * tf.ones_like(x), - self._rte_utils.slice( - sw_optical_props['optical_depth'], self._g_dim, 0, 0 + common_ops.slice_field( + sw_optical_props['optical_depth'], self._g_dim, 0, size=1 ), ) # Monochromatic top of atmosphere flux. diff --git a/swirl_lm/utility/common_ops.py b/swirl_lm/utility/common_ops.py index 5196aa9..24091cb 100644 --- a/swirl_lm/utility/common_ops.py +++ b/swirl_lm/utility/common_ops.py @@ -15,7 +15,7 @@ """Library for common operations.""" import enum import functools -from typing import Any, Callable, Dict, Iterable, Literal, Mapping, Sequence, Text, Tuple +from typing import Any, Callable, Dict, Iterable, Literal, Mapping, Optional, Sequence, Text, Tuple import numpy as np from swirl_lm.utility import types @@ -1654,6 +1654,106 @@ def cross_replica_gather(x: tf.Tensor, num_replicas: int) -> list[tf.Tensor]: return [gathered[i, ...] for i in range(num_replicas)] +def gather(x: tf.Tensor, indices: tf.Tensor) -> tf.Tensor: + """Retrieves values in `x` located at `indices`. + + Args: + x: A 3D `tf.Tensor` from which values are gathered. + indices: A 2D `tf.Tensor` with shape (n_pts, 3), with the columns being the + indices associated with the first to last dimension of `x`. Note that + repeated indices entries are allowed. + + Returns: + Values in `x` at locations specified by indices as a vector of length + matching the first dimension of `indices`. If `indices` is empty, return + a vector with length 0. + """ + if tf.shape(indices)[0] == 0: + return tf.constant([]) + + i, j, k = [ + tf.cast(tf.one_hot(indices[:, l], depth=tf.shape(x)[l]), dtype=x.dtype) + for l in range(3) + ] + + return tf.einsum('qi,qj,qk,ijk->q', i, j, k, x) + + +def gather_from_mask(x: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + """Retrieves values in `x` corresponding to locations of 1s in `mask`. + + Args: + x: A 3D `tf.Tensor` from which values are gathered. + mask: A 3D `tf.Tensor` of ones and zeros, with ones indicating the location + where the values are gathered. Note that the shape of `x` and `mask` must + be the same. + + Returns: + A 1D vector storing values corresponding to locations of ones in `mask`. Its + length equals the number of ones in `mask`. Note that the indices of ones + in `mask` follows the order of the dimensions from first to last. If no ones + are found in `mask`, return a vector with length 0. + + Raises: + ValueError If the shapes of `x` and `mask` are different. + """ + return gather(x, tf.where(tf.less(tf.abs(mask - 1.0), 1e-6))) + + +def scatter( + x: tf.Tensor, indices: tf.Tensor, shape: tf.Tensor, dtype: tf.DType +) -> tf.Tensor: + """Generates a 3D tensor with values in `x` specified at `indices`. + + Args: + x: A 1D `tf.Tensor` to be scattered into a 3D tensor. + indices: A 2D `tf.Tensor` with shape (n_pts, 3), with the columns being the + indices associated with the dimensions of the output 3D tensor. The first + dimension of `x` and `indices` must be the same. Note that repeated + entries of indices is allowed, in which case the sum of these values will + be scattered. + shape: A 1D array of length 3 specifying the shape of the 3D tensor in the + output. + dtype: The data type of the returned 3D tensor. + + Returns: + A 3D tensor with values in `x` scattered to locations specifed by `indices`, + with everywhere else being 0. If `indices` is empty, a 3D tensor with all + zeros will be returned. + """ + if tf.shape(indices)[0] == 0: + return tf.zeros(shape, dtype=dtype) + + i, j, k = [ + tf.cast(tf.one_hot(indices[:, l], depth=shape[l]), dtype=x.dtype) + for l in range(3) + ] + + return tf.cast(tf.einsum('q,qi,qj,qk->ijk', x, i, j, k), dtype) + + +def scatter_to_mask( + x: tf.Tensor, mask: tf.Tensor, dtype: tf.DType +) -> tf.Tensor: + """Generates a 3D tensor with `x` scattered to locations of 1 in `mask`. + + Args: + x: A 1D `tf.Tensor` to be scattered into a 3D tensor. The order of these + points must follow the row-major order of the indices of ones in `mask`. + mask: A 3D `tf.Tensor` of ones and zeros, with ones indicating the location + where the values are scattered. + dtype: The data type of the returned 3D tensor. + + Returns: + A 3D tensor with values in `x` scattered to locations specifed by ones in + `mask`, with everywhere else being 0. If `mask` has all zeros, a 3D tensor + with all zeros will be returned. + """ + return scatter( + x, tf.where(tf.less(tf.abs(mask - 1.0), 1e-6)), tf.shape(mask), dtype + ) + + def pad( f: FlowFieldVal, paddings: Sequence[Sequence[int]], @@ -1687,11 +1787,68 @@ def pad( return lower_pad + list(padded) + upper_pad -def get_face(value: FlowFieldVal, - dim: int, - face: int, - index: int, - scaling_factor: float = 1.) -> FlowFieldVal: +def slice_field( + f: FlowFieldVal, + dim: int, + start_idx: int, + size: Optional[int] = None, +) -> FlowFieldVal: + """Slices the input field along the given dimension. + + Args: + f: A list of 2D `tf.Tensor` or a single 3D `tf.Tensor` representing a 3D + field. If a list of 2D `tf.Tensor`, the length of the list is `nz` and + each 2D `tf.Tensor` has the shape [nx, ny]. If a single 3D `tf.Tensor`, + its shape is [nz, nx, ny]. + dim: The dimension of the plane to slice, should be one of 0, 1, and 2. + start_idx: The index of the first point in the slice. If negative, it will + be counted from the end of the field along `dim`. + size: The optional length of the slice along `dim`. If not provided, the + slice will run from `start_idx` to the end of the array. + + Returns: + A slice having the same format as the input field but a dimension of `size` + along `dim`. + """ + shape = list(get_shape(f)) + start = [0, 0, 0] + + n = int(shape[dim]) + if start_idx < 0: + start_idx = n + start_idx + + if size is None: + size = n - start_idx + + start[dim] = start_idx + shape[dim] = size + + if isinstance(f, tf.Tensor): + # Handles the case of single 3D tensor. + shape = tf.roll(shape, shift=1, axis=0) + start = tf.roll(start, shift=1, axis=0) + return tf.slice(f, start, shape) + + # Handles the case of list of 2D tensors. + if dim in (0, 1): + return tf.nest.map_structure( + lambda x: tf.slice(x, start[:-1], shape[:-1]), f + ) + elif dim == 2: # Z + return f[start_idx : start_idx + size] + else: + raise ValueError( + f'`dim` has to be one of 0, 1, or 2. But {dim} is provided.' + ) + + +def get_face( + value: FlowFieldVal, + dim: int, + face: int, + index: int, + scaling_factor: float = 1.0, +) -> FlowFieldVal: """Gets the face in `value` that is `index` number of points from boundary. This function extracts the requested plane from a 3D tensor. @@ -1725,58 +1882,18 @@ def get_face(value: FlowFieldVal, the length - index'th plane is returned. The returned slice will be multiplied by `scaling_factor`. """ - if isinstance(value, tf.Tensor): - # Handles the case of single 3D tensor. - shifted_dim = (dim + 1) % 3 - shape = value.get_shape().as_list() - n = shape[shifted_dim] - start_idx = [0, 0, 0] - - if face == 0: # low - start_idx[shifted_dim] = index - elif face == 1: # high - start_idx[shifted_dim] = n - index - 1 - - shape[shifted_dim] = 1 - return [scaling_factor * tf.slice(value, start_idx, shape)] - - # Handles the case of list of 2D tensors. - nz = len(value) - if dim in (0, 1): - shape = value[0].get_shape().as_list() - n = shape[dim] - start_idx = [0, 0] - if face == 0: - start_idx[dim] = index - elif face == 1: - start_idx[dim] = n - index - 1 - else: - raise ValueError( - f'`face` has to be either 0 or 1. But {face} is provided.' - ) - shape[dim] = 1 - bc_value = [ - tf.nest.map_structure( - lambda x: scaling_factor * tf.slice(x, start_idx, shape), value - ) - ] - elif dim == 2: # Z - if face == 0: # low - bc_value = [ - scaling_factor * value[index], - ] - elif face == 1: # high - bc_value = [ - scaling_factor * value[nz - index - 1], - ] - else: - raise ValueError( - f'`face` has to be either 0 or 1. But {face} is provided.' - ) - else: + if face not in (0, 1): raise ValueError( - f'`dim` has to be one of 0, 1, or 2. But {dim} is provided.' + f'`face` has to be one of 0 or 1. But {face} is provided.' ) + index = index if face == 0 else -index - 1 + bc_value = [ + tf.nest.map_structure( + lambda x: scaling_factor * x, slice_field(value, dim, index, size=1) + ) + ] + if isinstance(value, Sequence) and dim == 2: + bc_value = bc_value[0] return bc_value # pytype: disable=bad-return-type diff --git a/swirl_lm/utility/text_util.py b/swirl_lm/utility/text_util.py index e98b421..290cacc 100644 --- a/swirl_lm/utility/text_util.py +++ b/swirl_lm/utility/text_util.py @@ -50,7 +50,8 @@ def seconds_to_string(total_seconds: float, separator: str = ' ', """ # Handle negative durations. if total_seconds < 0: - return '-' + seconds_to_string(-total_seconds, separator) + positive_str = seconds_to_string(-total_seconds, separator, precision, zero) + return '-' + positive_str if positive_str != zero else zero # Convert durations to integer nanoseconds to avoid floating errors. # (2^63 nanoseconds is approximately 300 years).