Skip to content

Commit

Permalink
Add hierarchy to DataTree outputs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696255795
  • Loading branch information
Nush395 authored and Torax team committed Nov 14, 2024
1 parent fd7538b commit 7f26a3f
Show file tree
Hide file tree
Showing 57 changed files with 364 additions and 295 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"chex>=0.1.85",
"equinox>=0.11.3",
"PyYAML>=6.0.1",
"xarray>=2023.12.0",
# DataTree is only working with development version of xarray for now.
"xarray @ git+ssh://[email protected]/pydata/xarray",
"netcdf4>=1.6.5,<1.7.1",
"h5netcdf>=1.3.0",
"scipy>=1.12.0",
Expand Down
4 changes: 2 additions & 2 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def main(_):
)
build_time = time.time() - start_time
start_time = time.time()
_, output_file = _call_sim_app_main(
output_file = _call_sim_app_main(
sim=sim,
output_dir=output_dir,
log_sim_progress=log_sim_progress,
Expand Down Expand Up @@ -534,7 +534,7 @@ def main(_):
)
else:
start_time = time.time()
_, output_file = _call_sim_app_main(
output_file = _call_sim_app_main(
sim=sim,
output_dir=new_runtime_params.output_dir,
log_sim_progress=log_sim_progress,
Expand Down
150 changes: 103 additions & 47 deletions torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ToraxSimOutputs:


# Core profiles.
CORE_PROFILES = "core_profiles"
TEMP_EL = "temp_el"
TEMP_EL_RIGHT_BC = "temp_el_right_bc"
TEMP_ION = "temp_ion"
Expand Down Expand Up @@ -79,6 +80,7 @@ class ToraxSimOutputs:
NREF = "nref"

# Core transport.
CORE_TRANSPORT = "core_transport"
CHI_FACE_ION = "chi_face_ion"
CHI_FACE_EL = "chi_face_el"
D_FACE_EL = "d_face_el"
Expand All @@ -99,51 +101,81 @@ class ToraxSimOutputs:
TIME = "time"

# Post processed outputs
POST_PROCESSED_OUTPUTS = "post_processed_outputs"
Q_FUSION = "Q_fusion"

# Simulation error state.
SIM_ERROR = "sim_error"

# Sources.
CORE_SOURCES = "core_sources"

def safe_load_dataset(filepath: str) -> xr.Dataset:

def safe_load_dataset(filepath: str) -> xr.DataTree:
with open(filepath, "rb") as f:
with xr.open_dataset(f) as ds_open:
ds = ds_open.compute()
return ds
with xr.open_datatree(f) as dt_open:
dt = dt_open.compute()
return dt


def load_state_file(
filepath: str,
) -> xr.Dataset:
) -> xr.DataTree:
"""Loads a state file from a filepath."""
if os.path.exists(filepath):
ds = safe_load_dataset(filepath)
data_tree = safe_load_dataset(filepath)
logging.info("Loading state file %s", filepath)
return ds
return data_tree
else:
raise ValueError(f"File {filepath} does not exist.")


def concat_datatrees(
tree1: xr.DataTree,
tree2: xr.DataTree,
) -> xr.DataTree:
"""Concats two xr.DataTrees along the time dimension.
For any duplicate time steps, the values from the first dataset are kept.
Args:
tree1: The first xr.DataTree to concatenate.
tree2: The second xr.DataTree to concatenate.
Returns:
A xr.DataTree containing the concatenation of the two datasets.
"""
def _concat_datasets(previous_ds: xr.Dataset, ds: xr.Dataset,) -> xr.Dataset:
"""Concats two xr.Datasets."""
# Do a minimal concat to avoid concatting any non time indexed vars.
ds = xr.concat([previous_ds, ds], dim=TIME, data_vars="minimal")
# Drop the duplicate restart time step. Using "first" imposes
# keeping the restart state from the previous simulation, which contains
# more complete information e.g. transport and post processed outputs.
ds = ds.drop_duplicates(dim=TIME, keep="first")
return ds
return xr.map_over_datasets(_concat_datasets, tree1, tree2)


def stitch_state_files(
file_restart: runtime_params.FileRestart,
ds: xr.Dataset,
) -> xr.Dataset:
"""Stitches a restarted state file to the beginning of its source sim."""
previous_ds = load_state_file(
file_restart.filename,
)
file_restart: runtime_params.FileRestart, datatree: xr.DataTree
) -> xr.DataTree:
"""Stitch a datatree to the end of a previous state file.
Args:
file_restart: Contains information on a file this sim was restarted from.
datatree: The xr.DataTree to stitch to the end of the previous state file.
Returns:
A xr.DataTree containing the stitched dataset.
"""
previous_datatree = load_state_file(file_restart.filename)
# Reduce previous_ds to all times before the first time step in this
# sim output. We use ds.time[0] instead of file_restart.time because
# we are uncertain if file_restart.time is the exact time of the
# first time step in this sim output (it takes the nearest time).
previous_ds = previous_ds.sel(time=slice(None, ds.time[0]))
# Do a minimal concat to avoid concatting any non time indexed vars.
ds = xr.concat([previous_ds, ds], dim=TIME, data_vars="minimal")
# Drop the duplicate restart time step. Using "first" imposes
# keeping the restart state from the previous simulation, which contains
# more complete information e.g. transport and post processed outputs.
ds = ds.drop_duplicates(dim=TIME, keep="first")
return ds
previous_datatree = previous_datatree.sel(time=slice(None, datatree.time[0]))
return concat_datatrees(previous_datatree, datatree)


class StateHistory:
Expand Down Expand Up @@ -343,8 +375,8 @@ def simulation_output_to_xr(
self,
geo: geometry.Geometry,
file_restart: runtime_params.FileRestart | None = None,
) -> xr.Dataset:
"""Build an xr.Dataset of the simulation output.
) -> xr.DataTree:
"""Build an xr.DataTree of the simulation output.
Args:
geo: The geometry of the simulation. This is used to retrieve the TORAX
Expand All @@ -354,15 +386,27 @@ def simulation_output_to_xr(
beggining of this sim output.
Returns:
An xr.Dataset of the simulation output. The dataset contains the following
coordinates:
A xr.DataTree containing a single top level xr.Dataset and four child
datasets. The top level dataset contains the following variables:
- time: The time of the simulation.
- rho_face_norm: The normalized toroidal coordinate on the face grid.
- rho_cell_norm: The normalized toroidal coordinate on the cell grid.
- rho_face: The toroidal coordinate on the face grid.
- rho_cell: The toroidal coordinate on the cell grid.
The dataset contains data variables for quantities in the CoreProfiles,
CoreTransport, and CoreSources, as well as time and the sim_error state.
- vpr: The volume derivative w.r.t rho_norm.
- spr: The surface derivative w.r.t rho_norm.
- vpr_face: The volume derivative w.r.t rho_face_norm.
- spr_face: The surface derivative w.r.t rho_face_norm.
- sim_error: The simulation error state.
The child datasets contain the following variables:
- core_profiles: Contains data variables for quantities in the
CoreProfiles.
- core_transport: Contains data variables for quantities in the
CoreTransport.
- core_sources: Contains data variables for quantities in the
CoreSources.
- post_processed_outputs: Contains data variables for quantities in the
PostProcessedOutputs.
"""
# TODO(b/338033916). Extend outputs with:
# Post-processed integrals, more geo outputs.
Expand All @@ -387,29 +431,41 @@ def simulation_output_to_xr(
VPR_FACE: xr.DataArray(geo.vpr_face, dims=[RHO_FACE], name=VPR_FACE),
SPR_FACE: xr.DataArray(geo.spr_face, dims=[RHO_FACE], name=SPR_FACE),
}
coords = {
TIME: time,
RHO_FACE_NORM: rho_face_norm,
RHO_CELL_NORM: rho_cell_norm,
RHO_FACE: rho_face,
RHO_CELL: rho_cell,
}

# Update dict with flattened StateHistory dataclass containers
xr_dict.update(self._get_core_profiles(geo))
xr_dict.update(self._save_core_transport(geo))
existing_keys = set(xr_dict.keys())
xr_dict.update(self._save_core_sources(geo, existing_keys))
xr_dict.update(self._save_post_processed_outputs(geo))

ds = xr.Dataset(
xr_dict,
coords={
TIME: time,
RHO_FACE_NORM: rho_face_norm,
RHO_CELL_NORM: rho_cell_norm,
RHO_FACE: rho_face,
RHO_CELL: rho_cell,
core_profiles_ds = xr.Dataset(
self._get_core_profiles(geo), coords=coords
)
core_transport_ds = xr.Dataset(
self._save_core_transport(geo), coords=coords
)
core_sources_ds = xr.Dataset(
self._save_core_sources(geo, set(xr_dict.keys())), coords=coords,
)
post_processed_outputs_ds = xr.Dataset(
self._save_post_processed_outputs(geo), coords=coords
)
xr_dict.update({SIM_ERROR: self.sim_error.value})
data_tree = xr.DataTree(
children={
CORE_PROFILES: xr.DataTree(dataset=core_profiles_ds),
CORE_TRANSPORT: xr.DataTree(dataset=core_transport_ds),
CORE_SOURCES: xr.DataTree(dataset=core_sources_ds),
POST_PROCESSED_OUTPUTS: xr.DataTree(
dataset=post_processed_outputs_ds
),
},
dataset=xr.Dataset(xr_dict, coords=coords),
)

if file_restart is not None and file_restart.stitch:
ds = stitch_state_files(file_restart, ds)
data_tree = stitch_state_files(file_restart, data_tree)

# Add sim_error as a new variable
ds[SIM_ERROR] = self.sim_error.value

return ds
return data_tree
Loading

0 comments on commit 7f26a3f

Please sign in to comment.