Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add validation logging #2205

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions external/fv3fit/fv3fit/reservoir/_reshaping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np


def flatten_2d_keeping_columns_contiguous(arr: np.ndarray):
# ex. [[1,2],[3,4], [5,6]] -> [1,3,5,2,4,6]
return np.reshape(arr, -1, "F")
24 changes: 24 additions & 0 deletions external/fv3fit/fv3fit/reservoir/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,27 @@ def assure_same_dims(variable_tensors: Iterable[tf.Tensor]) -> Iterable[tf.Tenso
f"have either {max_dims} or {max_dims-1}."
)
return reshaped_tensors


def merge_subdomains(flat_prediction, rank_divider, latent_dims):
subdomain_columns = flat_prediction.reshape(-1, rank_divider.n_subdomains)
d_ = []
for s in range(rank_divider.n_subdomains):
subdomain_prediction = rank_divider.unstack_subdomain(
np.array([subdomain_columns[:, s]]), with_overlap=False
)
d_.append(subdomain_prediction[0])

domain = []
subdomain_without_overlap_shape = (
rank_divider.subdomain_xy_size_without_overlap,
rank_divider.subdomain_xy_size_without_overlap,
)

for z in range(latent_dims):
domain_z_blocks = np.array(d_)[:, :, :, z].reshape(
*rank_divider.subdomain_layout, *subdomain_without_overlap_shape
)
domain_z = np.concatenate(np.concatenate(domain_z_blocks, axis=1), axis=-1)
domain.append(domain_z)
return np.stack(np.array(domain), axis=0).transpose(1, 2, 0)
7 changes: 3 additions & 4 deletions external/fv3fit/fv3fit/reservoir/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import fsspec
from fv3fit.reservoir.readout import ReservoirComputingReadout
import os
from typing import Optional, Iterable, Hashable
import yaml

from fv3fit import Predictor
from .readout import ReservoirComputingReadout
from .reservoir import Reservoir
from .domain import RankDivider
from fv3fit._shared import io
from .utils import square_even_terms
from .autoencoder import Autoencoder
from ._reshaping import flatten_2d_keeping_columns_contiguous


@io.register("pure-reservoir")
Expand Down Expand Up @@ -60,10 +61,8 @@ def predict(self):
readout_input = self.reservoir.state
# For prediction over multiple subdomains (>1 column in reservoir state
# array), flatten state into 1D vector before predicting
readout_input = readout_input.reshape(-1)

readout_input = flatten_2d_keeping_columns_contiguous(readout_input)
prediction = self.readout.predict(readout_input).reshape(-1)

return prediction

def reset_state(self):
Expand Down
5 changes: 4 additions & 1 deletion external/fv3fit/fv3fit/reservoir/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def get_weights(self):
coefficients, intercepts = W[:-1, :], W[-1, :]
return coefficients, intercepts

def predict(self, X):
coefficients, intercepts = self.get_weights()
return np.dot(X, coefficients) + intercepts


class ReservoirComputingReadout:
"""Readout layer of the reservoir computing model
Expand Down Expand Up @@ -113,7 +117,6 @@ def combine_readouts(readouts: Sequence[ReservoirComputingReadout]):

# Concatenate the intercepts of individual readouts into single array
combined_intercepts = np.concatenate(intercepts)

return ReservoirComputingReadout(
coefficients=combined_coefficients, intercepts=combined_intercepts,
)
75 changes: 74 additions & 1 deletion external/fv3fit/fv3fit/reservoir/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
ReservoirComputingReadout,
)
from .readout import combine_readouts
from .domain import RankDivider, stack_time_series_samples, assure_same_dims
from .domain import (
RankDivider,
stack_time_series_samples,
assure_same_dims,
merge_subdomains,
)
import wandb


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -47,6 +54,64 @@ def _get_ordered_X(X_mapping, variables):
return assure_same_dims(ordered_tensors)


def _decode_columns(data, decoder):
# differs from encode_columns as the decoder can predict multiple outputs
# rather than a single latent vector
# expands a sequnence of N x M x L dim data into i variables
# to one or more N x M x Vi dim array, where Vi is number of features
# (usually vertical levels) of each variable and L << V is a smaller number
# of latent dimensions
reshaped = [_stack_array_preserving_last_dim(var) for var in data]
decoded_reshaped = decoder.predict(reshaped)
original_2d_shape = data[0].shape[:-1]
decoded_data = []
for i, var_data in enumerate(decoded_reshaped):
decoded_data.append(decoded_reshaped[i].reshape(*original_2d_shape, -1))
return decoded_data


def validation_single_timestep(validation_batches, model, n_batches_burn):
for b, batch_data in enumerate(validation_batches):
if b < n_batches_burn:
logger.info(f"Synchronizing on batch {b+1}")
time_series_with_overlap, time_series_without_overlap = _process_batch_data(
variables=model.input_variables,
batch_data=batch_data,
rank_divider=model.rank_divider,
autoencoder=model.autoencoder,
)
else:
X = _get_ordered_X(batch_data, model.input_variables)
truth = []
overlap = model.rank_divider.overlap
for var_data in X:
last_timestep_in_batch = var_data[0]
truth.append(
last_timestep_in_batch[overlap:-overlap, overlap:-overlap, :]
)

flat_prediction = model.predict()
subdomain_predictions_latent_space = merge_subdomains(
flat_prediction, model.rank_divider, model.autoencoder.n_latent_dims
)
prediction = _decode_columns(
[subdomain_predictions_latent_space], model.autoencoder.decoder
)
truth = np.array(truth)
prediction = np.array(prediction)
val_log = {
"truth": truth,
"prediction": prediction,
}
wandb.log(
{
"validation_single_timestep": val_log,
"val_loss": ((truth - prediction) ** 2).mean(),
}
)
return


@register_training_function("pure-reservoir", ReservoirTrainingConfig)
def train_reservoir_model(
hyperparameters: ReservoirTrainingConfig,
Expand Down Expand Up @@ -136,7 +201,15 @@ def train_reservoir_model(
readout=readout,
square_half_hidden_state=hyperparameters.square_half_hidden_state,
rank_divider=rank_divider,
autoencoder=autoencoder,
)

if validation_batches is not None and wandb.run is not None:
logger.info("Single timestep validation")
validation_single_timestep(
validation_batches, model, hyperparameters.n_batches_burn
)

return model


Expand Down
9 changes: 9 additions & 0 deletions external/fv3fit/tests/reservoir/test__reshaping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np
from fv3fit.reservoir._reshaping import flatten_2d_keeping_columns_contiguous


def test_flatten_2d_keeping_columns_contiguous():
x = np.array([[1, 2], [3, 4], [5, 6]])
np.testing.assert_array_equal(
flatten_2d_keeping_columns_contiguous(x), np.array([1, 3, 5, 2, 4, 6])
)
1 change: 1 addition & 0 deletions projects/reservoir/.envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export WANDB_PROJECT='reservoir-training'
9 changes: 6 additions & 3 deletions projects/reservoir/fv3/save_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tempfile import NamedTemporaryFile
import toolz

from .cubed_sphere import CubedSphereDivider
from cubed_sphere import CubedSphereDivider
import vcm

logging.basicConfig()
Expand Down Expand Up @@ -64,6 +64,7 @@ def _get_parser() -> argparse.ArgumentParser:
help=("Number of timesteps to save per rank netcdf."),
)
parser.add_argument("--variables", type=str, nargs="+", default=[])
parser.add_argument("--ranks", type=int, nargs="+", default=None)

return parser

Expand Down Expand Up @@ -118,11 +119,13 @@ def get_ordered_dims_extent(dims: dict):
else:
time_chunks = [list(data_times)]

ranks = args.ranks or range(cubedsphere_divider.total_ranks)
for t, time_chunk in enumerate(time_chunks):
data_time_slice = data.sel(time=time_chunk).load()
for r in range(cubedsphere_divider.total_ranks):
for r in ranks:
output_dir = os.path.join(args.output_path, f"rank_{r}")
rank_output_path = os.path.join(output_dir, f"{t}.nc")
file_str = f"0{t}" if t < 10 else f"{t}"
rank_output_path = os.path.join(output_dir, f"{file_str}.nc")
rank_data = cubedsphere_divider.get_rank_data(
data_time_slice, rank=r, overlap=args.overlap
)
Expand Down
35 changes: 35 additions & 0 deletions projects/reservoir/fv3/test_save_ranks.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash


python -m save_ranks \
gs://vcm-ml-experiments/spencerc/2022-01-19/n2f-25km-unperturbed-snoalb/fv3gfs_run/state_after_timestep.zarr \
gs://vcm-ml-scratch/annak/2023-02-27/rank_data/ \
2 \
2 \
--stop-time 20180815.000000 \
--variables air_temperature specific_humidity \
--time-chunks 40


python -m save_ranks \
gs://vcm-ml-experiments/spencerc/2022-01-19/n2f-25km-unperturbed-snoalb/fv3gfs_run/state_after_timestep.zarr \
gs://vcm-ml-experiments/reservoir-computing-offline/data/n2f-25km/train/start_20190215_end_20190615 \
2 \
2 \
--start-time 20190215.000000 \
--stop-time 20190615.000000 \
--variables air_temperature specific_humidity \
--time-chunks 40 \
--ranks 0 1


python -m save_ranks \
gs://vcm-ml-experiments/spencerc/2022-01-19/n2f-25km-unperturbed-snoalb/fv3gfs_run/state_after_timestep.zarr \
gs://vcm-ml-experiments/reservoir-computing-offline/data/n2f-25km/val/start_20190615_end_2019_0715 \
2 \
2 \
--start-time 20190615.000000 \
--stop-time 20190715.000000 \
--variables air_temperature specific_humidity \
--time-chunks 40 \
--ranks 0 1
8 changes: 8 additions & 0 deletions projects/reservoir/fv3/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

python -m fv3fit.train \
/home/AnnaK/fv3net/projects/reservoir/fv3/train_config.yaml \
/home/AnnaK/fv3net/projects/reservoir/fv3/train_data.yaml \
gs://vcm-ml-scratch/annak/2023-04-19/persistence_rc_no_encoder_T \
--validation-data-config /home/AnnaK/fv3net/projects/reservoir/fv3/train_data.yaml \
--no-wandb
35 changes: 35 additions & 0 deletions projects/reservoir/fv3/train_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

model_type: pure-reservoir
hyperparameters:
n_jobs: 4
#autoencoder_path: gs://vcm-ml-experiments/reservoir-computing-offline/2023-03-29/dense-autoencoder-train-full-year/trained_models/dense_autoencoder/autoencoder
# autoencoder_path: gs://vcm-ml-scratch/annak/2023-03-03/trained_autoencoder
seed: 0
input_variables:
- air_temperature
#- specific_humidity
output_variables:
- air_temperature
#- specific_humidity
subdomain:
layout:
- 6
- 6
overlap: 2
rank_dims:
- time
- x
- y
- z
reservoir_hyperparameters:
state_size: 6000
adjacency_matrix_sparsity: 0.9
spectral_radius: 0.7
seed: 0
input_coupling_sparsity: 0
input_coupling_scaling: 0.0001
readout_hyperparameters:
l2: 0.05
n_batches_burn: 1
input_noise: 0.01
square_half_hidden_state: True
11 changes: 11 additions & 0 deletions projects/reservoir/fv3/train_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

#url: gs://vcm-ml-scratch/annak/2023-02-27/rank_data/rank_0
url: gs://vcm-ml-scratch/annak/2023-04-19/persistence_netcdfs #gs://vcm-ml-scratch/annak/2023-04-19/train #gs://vcm-ml-scratch/annak/2023-02-22/rank_data/rank_1
dim_order:
- time
- x
- y
- z
varying_first_dim: True
sort_files: True
shuffle: False
10 changes: 10 additions & 0 deletions projects/reservoir/fv3/val_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

#url: gs://vcm-ml-scratch/annak/2023-02-27/rank_data/rank_0
url: gs://vcm-ml-scratch/annak/2023-04-19/persistence_netcdfs
dim_order:
- time
- x
- y
- z
varying_first_dim: True
sort_files: True