Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/reservoir experiments #2363

Draft
wants to merge 46 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9429f30
add tracker
AnnaKwa Oct 18, 2023
4e95c6c
add precip update
AnnaKwa Oct 18, 2023
c1cc245
delete unused methods
AnnaKwa Oct 18, 2023
fb9fcee
update config regtest output
AnnaKwa Oct 18, 2023
a787595
shorten class name to preciptracker
AnnaKwa Oct 24, 2023
fd3598d
add scalar summary of summed 2d variance to wandb log
AnnaKwa Oct 25, 2023
8cd34e4
allow reservoir model keys to be string repr of integers
AnnaKwa Oct 30, 2023
dba0410
fix key conversion
AnnaKwa Oct 30, 2023
a1b608a
Merge branch 'convert-str-keys-prog-config' into clean-res-precip
AnnaKwa Oct 30, 2023
9e77fc4
test config raises error on invalid key
AnnaKwa Oct 30, 2023
b0006ca
Merge branch 'clean-res-precip' into add-spatial-variance-metrics
AnnaKwa Oct 31, 2023
c126ab9
add config option for interval stepper to only apply to first n_calls
AnnaKwa Oct 31, 2023
8eba6e5
Merge branch 'interval-stepper-nsteps-option' into add-spatial-varian…
AnnaKwa Nov 1, 2023
bc65e0c
add config field
AnnaKwa Nov 1, 2023
c58d4c2
Merge branch 'interval-stepper-nsteps-option' into add-spatial-varian…
AnnaKwa Nov 1, 2023
08a7dae
rm extra post init
AnnaKwa Nov 1, 2023
c297bc4
clip z dim
AnnaKwa Nov 3, 2023
2552723
merge input taper option
AnnaKwa Nov 3, 2023
2252592
Don't clip outputs
AnnaKwa Nov 3, 2023
7c8e1e6
Merge branch 'res-input-taper' into dev/reservoir-experiments
AnnaKwa Nov 3, 2023
302661e
option to blend prediction with input state using taper config
AnnaKwa Nov 7, 2023
1ba700e
Fix keyerrors
AnnaKwa Nov 8, 2023
44c05ec
Fix test
AnnaKwa Nov 8, 2023
542aa04
fix typo
AnnaKwa Nov 8, 2023
d07a967
update regtest
AnnaKwa Nov 8, 2023
e7c243c
get validation working for tapered models
AnnaKwa Nov 8, 2023
c8c99f8
zero fill output level option
AnnaKwa Nov 8, 2023
c4400f1
Fix slow encoding
frodre Oct 27, 2023
6ed47a9
get clipped inputs working in offlien validation
AnnaKwa Nov 8, 2023
9761844
fix output clipping in transformer
AnnaKwa Nov 8, 2023
8f7473d
Merge branch 'res-input-taper' into dev/reservoir-experiments
AnnaKwa Nov 8, 2023
068bb7a
use encoder/decoder call directly
AnnaKwa Nov 8, 2023
c7f4eba
Fix unit test
AnnaKwa Nov 9, 2023
7b8d10d
Merge branch 'encoder-speed-fix' into dev/reservoir-experiments
AnnaKwa Nov 9, 2023
35bc7bd
update precip with moistening from interval prescriber change
AnnaKwa Nov 13, 2023
2fd75a7
Merge branch 'prescriber-precip' into dev/reservoir-experiments
AnnaKwa Nov 13, 2023
6b32c30
save diags prior to update on all steps
AnnaKwa Nov 15, 2023
e9d96fb
fix precipitation_sum so it does not return zero when dq2 is empty
AnnaKwa Nov 15, 2023
2a066a6
Merge branch 'fix/postphysics-precip-zero-update' into dev/reservoir-…
AnnaKwa Nov 15, 2023
dd850c0
Only touch accumulated precip
AnnaKwa Nov 17, 2023
f14bf5e
Merge branch 'master' into dev/reservoir-experiments
AnnaKwa Nov 17, 2023
82137f3
remove old net_moistening line
AnnaKwa Nov 17, 2023
5d89d0e
fix save ranks variable renaming
AnnaKwa Nov 19, 2023
a919ff8
enable saving state at last timestep
AnnaKwa Dec 6, 2023
6f328ea
fix text missing arg
AnnaKwa Dec 6, 2023
87ce7e6
fix path to state dump
AnnaKwa Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions external/fv3fit/fv3fit/reservoir/adapters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations
from dataclasses import asdict
import fsspec
import numpy as np
import os
import typing
from typing import Iterable, Hashable, Sequence, Union, Mapping
from typing import Iterable, Hashable, Sequence, Union, Mapping, Optional
import xarray as xr
import yaml

import fv3fit
from fv3fit import Predictor
from fv3fit._shared import io
from fv3fit.reservoir.config import ClipZConfig
from .model import (
HybridReservoirComputingModel,
ReservoirComputingModel,
Expand Down Expand Up @@ -66,7 +70,10 @@ def output_array_to_ds(
).transpose(*output_dims)

def input_dataset_to_arrays(
self, inputs: xr.Dataset, variables: Iterable[Hashable]
self,
inputs: xr.Dataset,
variables: Iterable[Hashable],
clip_config: Optional[Mapping[Hashable, ClipZConfig]] = None,
) -> Sequence[np.ndarray]:
# Converts from xr dataset to sequence of variable ndarrays expected by encoder
# Make sure the xy dimensions match the rank divider
Expand All @@ -80,6 +87,10 @@ def input_dataset_to_arrays(
da = transposed_inputs[variable]
if "z" not in da.dims:
da = da.expand_dims("z", axis=-1)
if clip_config is not None and variable in clip_config:
da = da.isel(
z=slice(clip_config[variable].start, clip_config[variable].stop)
)
input_arrs.append(da.values)
return input_arrs

Expand Down Expand Up @@ -118,6 +129,9 @@ def input_overlap(self):
def is_hybrid(self):
return False

def dump_state(self, path: str):
self.model.reservoir.dump_state(path)

def predict(self, inputs: xr.Dataset) -> xr.Dataset:
# inputs arg is not used, but is required by Predictor signature and prog run
prediction_arr = self.model.predict()
Expand All @@ -127,7 +141,7 @@ def predict(self, inputs: xr.Dataset) -> xr.Dataset:

def increment_state(self, inputs: xr.Dataset):
xy_input_arrs = self.model_adapter.input_dataset_to_arrays(
inputs, self.input_variables
inputs, self.input_variables,
) # x, y, feature dims
self.model.increment_state(xy_input_arrs)

Expand Down Expand Up @@ -159,12 +173,14 @@ def load(cls, path: str) -> "ReservoirDatasetAdapter":
@io.register("hybrid-reservoir-adapter")
class HybridReservoirDatasetAdapter(Predictor):
MODEL_DIR = "hybrid_reservoir_model"
CLIP_CONFIG_FILE = "clip_config.yaml"

def __init__(
self,
model: HybridReservoirComputingModel,
input_variables: Iterable[Hashable],
output_variables: Iterable[Hashable],
clip_config: Optional[Mapping[Hashable, ClipZConfig]] = None,
) -> None:
"""Wraps a hybrid reservoir model to take in and return xarray datasets.
The initialization args for input and output variables are not used and
Expand All @@ -183,6 +199,10 @@ def __init__(
input_variables=self.input_variables,
output_variables=model.output_variables,
)
self.clip_config = clip_config

def dump_state(self, path: str):
self.model.reservoir.dump_state(path)

@property
def input_overlap(self):
Expand All @@ -195,7 +215,7 @@ def is_hybrid(self):

def predict(self, inputs: xr.Dataset) -> xr.Dataset:
xy_input_arrs = self.model_adapter.input_dataset_to_arrays(
inputs, self.model.hybrid_variables
inputs, self.model.hybrid_variables, clip_config=self.clip_config
) # x, y, feature dims

prediction_arr = self.model.predict(xy_input_arrs)
Expand All @@ -205,7 +225,7 @@ def predict(self, inputs: xr.Dataset) -> xr.Dataset:

def increment_state(self, inputs: xr.Dataset):
xy_input_arrs = self.model_adapter.input_dataset_to_arrays(
inputs, self.model.input_variables
inputs, self.model.input_variables, clip_config=self.clip_config
) # x, y, feature dims
self.model.increment_state(xy_input_arrs)

Expand All @@ -224,14 +244,30 @@ def get_model_from_subdomain(

def dump(self, path):
self.model.dump(os.path.join(path, self.MODEL_DIR))
if self.clip_config is not None:
clip_config_dict = {
var: asdict(var_config) for var, var_config in self.clip_config.items()
}
with fsspec.open(os.path.join(path, self.CLIP_CONFIG_FILE), "w") as f:
yaml.dump(clip_config_dict, f)

@classmethod
def load(cls, path: str) -> "HybridReservoirDatasetAdapter":
model = HybridReservoirComputingModel.load(os.path.join(path, cls.MODEL_DIR))
try:
with fsspec.open(os.path.join(path, cls.CLIP_CONFIG_FILE), "r") as f:
clip_config_dict = yaml.safe_load(f)
clip_config: Optional[Mapping[Hashable, ClipZConfig]] = {
var: ClipZConfig(**var_config)
for var, var_config in clip_config_dict.items()
}
except FileNotFoundError:
clip_config = None
adapter = cls(
input_variables=model.input_variables,
output_variables=model.output_variables,
model=model,
clip_config=clip_config,
)
return adapter

Expand Down
14 changes: 13 additions & 1 deletion external/fv3fit/fv3fit/reservoir/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dacite
from dataclasses import dataclass, asdict
from typing import Sequence, Optional, Set, Tuple
from typing import Sequence, Optional, Set, Tuple, Mapping, Hashable
import fsspec
import yaml
from .._shared.training_config import Hyperparameters
Expand Down Expand Up @@ -65,6 +65,16 @@ class TransformerConfig:
hybrid: Optional[str] = None


@dataclass
class ClipZConfig:
""" Vertical levels **between** start and stop are kept,
levels outside start/stop are clipped off.
"""

start: Optional[int] = None
stop: Optional[int] = None


@dataclass
class ReservoirTrainingConfig(Hyperparameters):
"""
Expand Down Expand Up @@ -101,11 +111,13 @@ class ReservoirTrainingConfig(Hyperparameters):
n_timesteps_synchronize: int
input_noise: float
seed: int = 0
zero_fill_clipped_output_levels: bool = False
transformers: Optional[TransformerConfig] = None
n_jobs: Optional[int] = 1
square_half_hidden_state: bool = False
hybrid_variables: Optional[Sequence[str]] = None
mask_variable: Optional[str] = None
clip_config: Optional[Mapping[Hashable, ClipZConfig]] = None
_METADATA_NAME = "reservoir_training_config.yaml"

def __post_init__(self):
Expand Down
52 changes: 37 additions & 15 deletions external/fv3fit/fv3fit/reservoir/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
assure_txyz_dims,
SynchronziationTracker,
get_standard_normalizing_transformer,
clip_batch_data,
zero_fill_clipped_output_levels,
)
from .transformers import TransformerGroup, Transformer
from .._shared import register_training_function
Expand Down Expand Up @@ -48,6 +50,8 @@ def _add_input_noise(arr: np.ndarray, stddev: float) -> np.ndarray:
def _get_transformers(
sample_batch: Mapping[str, tf.Tensor], hyperparameters: ReservoirTrainingConfig
) -> TransformerGroup:
clipped_sample_batch = clip_batch_data(sample_batch, hyperparameters.clip_config)

# Load transformers with specified paths
transformers = {}
for variable_group in ["input", "output", "hybrid"]:
Expand All @@ -58,25 +62,25 @@ def _get_transformers(
# If input transformer not specified, always create a standard norm transform
if "input" not in transformers:
transformers["input"] = get_standard_normalizing_transformer(
hyperparameters.input_variables, sample_batch
hyperparameters.input_variables, clipped_sample_batch
)

# If output transformer not specified and output_variables != input_variables,
# create a separate standard norm transform
# Output is not clipped, so use the original sample batch
if hyperparameters.zero_fill_clipped_output_levels:
sample_batch = zero_fill_clipped_output_levels(
sample_batch, hyperparameters.clip_config
)
if "output" not in transformers:
if hyperparameters.output_variables != hyperparameters.input_variables:
transformers["output"] = get_standard_normalizing_transformer(
hyperparameters.output_variables, sample_batch
)
else:
transformers["output"] = transformers["input"]
transformers["output"] = get_standard_normalizing_transformer(
hyperparameters.output_variables, sample_batch
)

# If hybrid variables transformer not specified, and hybrid variables are defined,
# create a separate standard norm transform
if "hybrid" not in transformers:
if hyperparameters.hybrid_variables is not None:
transformers["hybrid"] = get_standard_normalizing_transformer(
hyperparameters.hybrid_variables, sample_batch
hyperparameters.hybrid_variables, clipped_sample_batch
)
else:
transformers["hybrid"] = transformers["input"]
Expand Down Expand Up @@ -115,9 +119,17 @@ def train_reservoir_model(
train_batches if isinstance(train_batches, Sequence) else [train_batches]
)
sample_batch = next(iter(train_batches_sequence[0]))
sample_X = get_ordered_X(sample_batch, hyperparameters.input_variables)
if hyperparameters.zero_fill_clipped_output_levels:
sample_batch = zero_fill_clipped_output_levels(
sample_batch, hyperparameters.clip_config
)

# Clipping is done inside this function to preserve full length outputs
transformers = _get_transformers(sample_batch, hyperparameters,)

clipped_sample_batch = clip_batch_data(sample_batch, hyperparameters.clip_config)
sample_X = get_ordered_X(clipped_sample_batch, hyperparameters.input_variables)

transformers = _get_transformers(sample_batch, hyperparameters)
subdomain_config = hyperparameters.subdomain

# sample_X[0] is the first data variable, shape elements 1:-1 are the x,y shape
Expand All @@ -131,7 +143,7 @@ def train_reservoir_model(

if hyperparameters.mask_variable is not None:
input_mask_array: Optional[np.ndarray] = _get_input_mask_array(
hyperparameters.mask_variable, sample_batch, rank_divider
hyperparameters.mask_variable, clipped_sample_batch, rank_divider
)
else:
input_mask_array = None
Expand All @@ -157,9 +169,12 @@ def train_reservoir_model(
)

for b, batch_data in enumerate(train_batches):
batch_data_clipped = clip_batch_data(
batch_data, hyperparameters.clip_config
)
input_time_series = process_batch_data(
variables=hyperparameters.input_variables,
batch_data=batch_data,
batch_data=batch_data_clipped,
rank_divider=rank_divider,
autoencoder=transformers.input,
trim_halo=False,
Expand All @@ -171,6 +186,11 @@ def train_reservoir_model(
_output_rank_divider_with_overlap = rank_divider.get_new_zdim_rank_divider(
z_feature_size=transformers.output.n_latent_dims
)
# don't pass in clipped data here, as clipping is not enabled for outputs
if hyperparameters.zero_fill_clipped_output_levels:
batch_data = zero_fill_clipped_output_levels(
batch_data, hyperparameters.clip_config
)
output_time_series = process_batch_data(
variables=hyperparameters.output_variables,
batch_data=batch_data,
Expand All @@ -196,7 +216,7 @@ def train_reservoir_model(

hybrid_time_series = process_batch_data(
variables=hyperparameters.hybrid_variables,
batch_data=batch_data,
batch_data=batch_data_clipped,
rank_divider=_hybrid_rank_divider_w_overlap,
autoencoder=transformers.hybrid,
trim_halo=True,
Expand Down Expand Up @@ -279,6 +299,7 @@ def train_reservoir_model(
model=model,
input_variables=model.input_variables,
output_variables=model.output_variables,
clip_config=hyperparameters.clip_config,
)

if validation_batches is not None and wandb.run is not None:
Expand All @@ -287,6 +308,7 @@ def train_reservoir_model(
model,
val_batches=validation_batches,
n_synchronize=hyperparameters.n_timesteps_synchronize,
clip_config=hyperparameters.clip_config,
)
log_rmse_z_plots(ds_val, model.output_variables)
log_rmse_scalar_metrics(ds_val, model.output_variables)
Expand Down
Loading