Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 563233314
  • Loading branch information
The swirl_lm Authors authored and bcgazen committed Sep 6, 2023
1 parent e41fac2 commit 7f6765b
Show file tree
Hide file tree
Showing 20 changed files with 1,119 additions and 248 deletions.
202 changes: 110 additions & 92 deletions swirl_lm/base/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import os
import time
from typing import Any, Dict, Optional, Tuple, TypeVar
from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar

from absl import flags
from absl import logging
Expand All @@ -40,9 +40,11 @@
'files prefix. This will be suffixed with the field '
'components and step count.', allow_override=True)
flags.DEFINE_string(
'data_load_prefix', '/tmp/data', 'The input `ser` or `h5` '
'files prefix. This will be suffixed with the field '
'components and step count.', allow_override=True)
'data_load_prefix', '', 'If non-empty, the input `ser` or `h5` '
'files prefix from where the initial state will be loaded. This will be '
'suffixed with the field components and step count. If set, the directory '
'portion of the prefix has to be different from the directory portion of '
'--data_dump_prefix.', allow_override=True)
flags.DEFINE_bool(
'apply_data_load_filter', False,
'If True, only variables with names provided in field `states_from_file` '
Expand Down Expand Up @@ -387,7 +389,48 @@ def solver(
with strategy.scope():
step_id = tf.Variable(params.start_step, dtype=tf.int32)
output_dir, filename_prefix = os.path.split(FLAGS.data_dump_prefix)
input_dir, input_filename_prefix = os.path.split(FLAGS.data_load_prefix)

if FLAGS.data_load_prefix:
input_dir, input_filename_prefix = os.path.split(FLAGS.data_load_prefix)
# This is a potential user error where both read and output directories
# are pointing to the same place. We do not allow this to happen as some
# files could get overwritten and it can be very confusing.
if input_dir == output_dir:
raise ValueError(
'Please check your configuration. The loading directory is '
f'set to be the same as the output directory {output_dir}, this '
'will cause confusion and potentially over-write important data. '
'If you are trying to continue the simulation run using a previous '
'simulation step from a different run, please use a separate '
'output directory. To have a separate output directory, the '
'directory portions of --data_load_prefix and --data_dump_prefix '
'need to be different.')
# If a loading directory is specified, we check if the step directory to
# read from exists. The step id for data to read from the input directory is
# provided by the `loading_step` in `params`. If it exists, we *assume* the
# needed files are there and will proceed to read (if it turns out files are
# missing, the job will just fail).
loading_subdir = os.path.join(input_dir, str(params.loading_step))
states_from_file = (
list(params.states_from_file) if params.states_from_file else None
)
if not tf.io.gfile.exists(loading_subdir):
raise ValueError(
f'--data_load_prefix was set to {FLAGS.data_load_prefix} and '
f'loading step is {params.loading_step} but no restart files are '
f'found in {loading_subdir}.')
read_state_from_input_dir = functools.partial(
driver_tpu.distributed_read_state,
strategy,
logical_coordinates=logical_coordinates,
output_dir=input_dir,
filename_prefix=input_filename_prefix,
states_from_file=states_from_file,
)
logging.info('read_state_from_input_dir function created.')
else:
input_dir = None
read_state_from_input_dir = None

logging.info('Getting checkpoint_manager.')
ckpt_manager = get_checkpoint_manager(
Expand All @@ -404,13 +447,25 @@ def solver(
def step_id_value():
return tf.constant(step_id.numpy(), tf.int32)

write_state = functools.partial(
driver_tpu.distributed_write_state,
strategy,
logical_coordinates=logical_coordinates,
output_dir=output_dir,
filename_prefix=filename_prefix)
logging.info('write_state function created.')
def write_state_and_sync(
state: Tuple[Structure],
step_id: Array,
data_dump_filter: Optional[Sequence[str]] = None,
):
write_status = driver_tpu.distributed_write_state(
strategy,
state,
logical_coordinates=logical_coordinates,
output_dir=output_dir,
filename_prefix=filename_prefix,
step_id=step_id,
data_dump_filter=data_dump_filter)

# This will block until all replicas are done writing.
replica_id_write_status = []
for i in range(num_replicas):
replica_id_write_status.append(write_status[i]['replica_id'].numpy())
return replica_id_write_status

read_state = functools.partial(
driver_tpu.distributed_read_state,
Expand All @@ -420,20 +475,6 @@ def step_id_value():
filename_prefix=filename_prefix)
logging.info('read_state function created.')

states_from_file = (
list(params.states_from_file) if params.states_from_file else None
)

read_state_from_input_dir = functools.partial(
driver_tpu.distributed_read_state,
strategy,
logical_coordinates=logical_coordinates,
output_dir=input_dir,
filename_prefix=input_filename_prefix,
states_from_file=states_from_file,
)

logging.info('read_state_from_input_dir function created.')
t_start = time.time()

# Wrapping `init_fn` with tf.function so it is not retraced unnecessarily for
Expand All @@ -443,7 +484,7 @@ def step_id_value():
strategy, value_fn=tf.function(init_fn),
logical_coordinates=logical_coordinates)

# Accessing the values in state to synchornize the client so the main thread
# Accessing the values in state to synchronize the client so the main thread
# will wait here until the `state` is initialized and all remote operations
# are done.
replica_values = state['replica_id'].values
Expand All @@ -452,18 +493,18 @@ def step_id_value():
logging.info('Initialization stage done. Took %f secs.',
t_post_init - t_start)

def block_on_write(write_status):
"""This will block until the all replicas are done writing."""
replica_id_write_status = []
for i in range(num_replicas):
replica_id_write_status.append(write_status[i]['replica_id'].numpy())
return replica_id_write_status
write_initial_state = False

# Restore from an existing checkpoint if present.
if ckpt_manager.latest_checkpoint:
# The checkpoint restore updates the `step_id` variable; which is then used
# to read in the state.
logging.info('Detected checkpoint. Starting `restore_or_initialize`.')
if input_dir is not None:
logging.info('--data_load_prefix was set to %s but not using it '
'because checkpoint was detected in the data dump '
'directory %s.', FLAGS.data_load_prefix,
FLAGS.data_dump_prefix)
ckpt_manager.restore_or_initialize()
state = read_state(state=_local_state(strategy, state),
step_id=step_id_value())
Expand All @@ -473,62 +514,39 @@ def block_on_write(write_status):
'`restoring-checkpoint-if-necessary` stage '
'done with reading checkpoint. Replicas are: %s',
str(replica_id_values))
# Override initial state with state from a previous run if requested.
elif input_dir is not None:
logging.info('--data_load_prefix is set to %s, loading from %s at step %s, '
'and overriding the default initialized state.',
FLAGS.data_load_prefix, input_dir, params.loading_step)
state = read_state_from_input_dir(
state=_local_state(strategy, state),
step_id=tf.constant(params.loading_step),
)
write_initial_state = True
# This is to sync the client code to the worker execution.
replica_id_values = state['replica_id'].values
logging.info(
'`restoring-checkpoint-if-necessary` stage '
'done with reading from load directory. Replicas are: %s',
str(replica_id_values))
logging.info('Read states from %s at %i', input_dir, params.loading_step)
# Use default initial state.
else:
# In case we're not restoring from a checkpoint, we do the following two
# steps:
# 1. Check if the step directory to read from exists. The step id
# for data to read from the input directory is provided by the
# `loading_step` in `params`. If it exists, we *assume* the needed files
# are there and will proceed to read (if it turns out files are missing,
# the job will just fail).
loading_subdir = os.path.join(input_dir, str(params.loading_step))
if tf.io.gfile.exists(loading_subdir):
# This is a potential user error where both read and output directories
# are pointing to the same place. We do not allow this to happen as some
# files could get overwritten and it can be very confusing.
if input_dir == output_dir:
raise ValueError(
'Please check your configuration. The loading directory is '
f'set to be the same as the output directory {output_dir}, this '
'will cause confusion and potentially over-write important data. '
'If you are trying to continue the simulation run using a previous '
'simulation step from a different run, please use a separate '
'output directory. If this is a rerun of a previously failed / '
'crashed job, you probably should remove the incomplete '
f'sub-directory {loading_subdir} generated from the failed/crashed '
'run first before rerunning the job.')
# Discard the default initialized state and load from the input_dir.
state = read_state_from_input_dir(
state=_local_state(strategy, state),
step_id=tf.constant(params.loading_step),
)
# This is to sync the client code to the worker execution.
replica_id_values = state['replica_id'].values
logging.info(
'`restoring-checkpoint-if-necessary` stage '
'done with reading from load directory. Replicas are: %s',
str(replica_id_values))
logging.info('Read states from %s at %i', input_dir, params.loading_step)
else:
logging.info(
'No restart files found in %s at %i. Proceeding with default'
' initializations for all variables.',
input_dir,
params.loading_step,
)

# 2. Write the initial state in the output directory.
logging.info('No checkpoint detected. Starting `write_state`.')
write_status = write_state(state=_local_state(strategy, state),
step_id=step_id_value())
# This is used to sync the client code to the remote workers asynchronous
# executions.
write_initial_state = True
logging.info('No checkpoint was found and --data_load_prefix was not set. '
'Proceeding with default initializations for all variables.')

if write_initial_state:
logging.info('Starting `write_state` for the initial state.')
write_status = write_state_and_sync(state=_local_state(strategy, state),
step_id=step_id_value())
logging.info(
'`restoring-checkpoint-if-necessary` stage '
'done with writing initial steps. Write status are: %s',
str(block_on_write(write_status)))
'done with writing initial steps. Write status are: %s', write_status)

t_post_restore = time.time()
logging.info('restore-if-necessary or write. Took %f secs.',
logging.info('restore-if-necessary-or-write took %f secs.',
t_post_restore - t_post_init)

if params.num_steps < 0:
Expand Down Expand Up @@ -574,19 +592,19 @@ def block_on_write(write_status):
# is a multiple of the checkpoint interval, else just record, a possibly
# shortened version of the current state.
if (step_id_value() - params.start_step) % checkpoint_interval == 0:
write_status = write_state(_local_state(strategy, state),
step_id=step_id_value())
write_status = write_state_and_sync(_local_state(strategy, state),
step_id=step_id_value())
logging.info('`Post cycle writing full state done. '
'Write status are: %s', str(block_on_write(write_status)))
'Write status are: %s', write_status)
ckpt_manager.save()
else:
# Note, the first time this is called retracing will occur for the
# subgraphs in `distribted_write_state` if data_dump_filter is not `None`.
write_status = write_state(_local_state(strategy, state),
step_id=step_id_value(),
data_dump_filter=data_dump_filter)
write_status = write_state_and_sync(_local_state(strategy, state),
step_id=step_id_value(),
data_dump_filter=data_dump_filter)
logging.info('`Post cycle writing filtered state done. '
'Write status are: %s', str(block_on_write(write_status)))
'Write status are: %s', write_status)
t2 = time.time()
logging.info('Writing output & checkpoint took %f secs.', t2 - t1)

Expand Down
23 changes: 17 additions & 6 deletions swirl_lm/base/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,20 @@ def gen_forcing_fn(xx, yy, zz, lx, ly, lz, coord):
return partial_mesh_for_core(params, coordinate, gen_forcing_fn, perm)


def subgrid_slice(subgrid_size: int,
coordinate: int,
halo_width: Optional[int] = 1) -> slice:
def subgrid_slice_indices(
subgrid_size: int,
coordinate: int,
halo_width: int = 1,
) -> Tuple[int, int]:
"""Determines the start and end indices for slicing."""
core_subgrid_size = subgrid_size - 2 * halo_width
start = coordinate * core_subgrid_size
return start, start + subgrid_size


def subgrid_slice(
subgrid_size: int, coordinate: int, halo_width: Optional[int] = 1
) -> slice:
"""Returns the slice of a field corresponding to `coordinate`.
Args:
Expand All @@ -284,9 +295,9 @@ def subgrid_slice(subgrid_size: int,
The subgrid slice corresponding to the given subgrid coordinate (including
halo).
"""
core_subgrid_size = subgrid_size - 2 * halo_width
start = coordinate * core_subgrid_size
return slice(start, start + subgrid_size)
start, end = subgrid_slice_indices(subgrid_size, coordinate, halo_width)

return slice(start, end)


def three_d_subgrid_slices(
Expand Down
30 changes: 16 additions & 14 deletions swirl_lm/equations/pressure.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,11 @@ def _pressure_corrector_update(
self,
replica_id: tf.Tensor,
replicas: np.ndarray,
states,
states: FlowFieldMap,
additional_states: FlowFieldMap,
rho_info: DensityInfo,
subiter: tf.Tensor = None,
) -> Tuple[FlowFieldVal, monitor.MonitorDataType]: # pytype: disable=annotation-type-mismatch
# pylint: disable=line-too-long
"""Updates the pressure correction.
This method follows the approach introduced in:
Expand All @@ -493,9 +493,11 @@ def _pressure_corrector_update(
Args:
replica_id: The ID number of the replica.
replicas: A numpy array that maps a replica's grid coordinate to its
replica_id, e.g. replicas[0, 0, 0] = 0, replicas[0, 0, 1] = 2.
replica_id, e.g. replicas[0, 0, 0] = 0, replicas[0, 0, 1] = 1.
states: A dictionary that holds flow field variables from the latest
prediction.
additional_states: A dictionary that holds helper variables required by
the Poisson solver.
rho_info: The density information required the pressure solver. For
constant density, `rho_info` is an instance of `ConstantDensityInfo`
which contains the value of the density as a float. For variable
Expand Down Expand Up @@ -613,8 +615,10 @@ def add_factor(
})

else:
raise ValueError('`rho_info` has to be either `ConstantDensityInfo` or '
'`VariableDensityInfo`.')
raise ValueError(
'`rho_info` has to be either `ConstantDensityInfo` or '
f'`VariableDensityInfo`, but {rho_info} is provided.'
)

# pylint: disable=g-complex-comprehension
return [(div_i + drho_dt_i - src_rho_i)
Expand Down Expand Up @@ -664,11 +668,10 @@ def dp_exchange_halos(dpr,):

dp0 = [tf.zeros_like(b_i) for b_i in b]

helper_vars = dict(additional_states)
if (self._thermodynamics.solver_mode ==
thermodynamics_pb2.Thermodynamics.ANELASTIC):
helper_vars = {poisson_solver.VARIABLE_COEFF: states['rho']}
else:
helper_vars = None
helper_vars[poisson_solver.VARIABLE_COEFF] = states['rho']

# Note that the solution that is denoted as `dp` from the Poisson solver has
# different meanings under different modes of thermodynamics. In the low
Expand Down Expand Up @@ -917,8 +920,7 @@ def step(
Returns:
A dictionary with the updated pressure and pressure corrector.
"""
del additional_states

drho_dt = tf.nest.map_structure(tf.zeros_like, states['rho'])
if (self._thermodynamics.solver_mode ==
thermodynamics_pb2.Thermodynamics.LOW_MACH):
exchange_halos = functools.partial(
Expand Down Expand Up @@ -957,13 +959,13 @@ def drho_filter_fn(
drho_dt = [drho_i / self._params.dt for drho_i in drho]
elif (self._thermodynamics.solver_mode ==
thermodynamics_pb2.Thermodynamics.ANELASTIC):
drho_dt = [tf.zeros_like(rho_i) for rho_i in states['rho']]
drho_dt = tf.nest.map_structure(tf.zeros_like, states['rho'])

rho_info = VariableDensityInfo(drho_dt)

dp, monitor_vars = self._pressure_corrector_update(replica_id, replicas,
states, rho_info,
subiter)
dp, monitor_vars = self._pressure_corrector_update(
replica_id, replicas, states, additional_states, rho_info, subiter
)

states_updated = {
'p': tf.nest.map_structure(lambda p_, dp_: p_ + dp_, states['p'], dp),
Expand Down
Loading

0 comments on commit 7f6765b

Please sign in to comment.