From c014222adc1b613bf10ca1b2c8759fb04d29b6fc Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 18:57:05 +0200 Subject: [PATCH 1/5] Implement on_batch_transfer logic to normalize data --- README.md | 3 ++- create_parameter_weights.py | 24 +++++++++--------------- neural_lam/models/ar_model.py | 13 +++++++++++++ neural_lam/utils.py | 30 +++++------------------------- neural_lam/weather_dataset.py | 22 ---------------------- 5 files changed, 29 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index ba0bb3fe..81602dd7 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,8 @@ data │ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py) │ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py) │ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py) -│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py) +│ ├── flux_mean.pt - Mean of solar flux forcing (create_parameter_weights.py) +│ ├── flux_std.pt - Std.-dev. of solar flux forcing (create_parameter_weights.py) │ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py) ├── dataset2 ├── ... diff --git a/create_parameter_weights.py b/create_parameter_weights.py index cae1ae3e..f3ffe806 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -74,7 +74,6 @@ def main(): split="train", subsample_step=1, pred_length=63, - standardize=False, ) # Without standardization loader = torch.utils.data.DataLoader( ds, args.batch_size, shuffle=False, num_workers=args.n_workers @@ -107,30 +106,25 @@ def main(): flux_mean = torch.mean(torch.stack(flux_means)) # (,) flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - flux_stats = torch.stack((flux_mean, flux_std)) - print("Saving mean, std.-dev, flux_stats...") + print("Saving mean, std.-dev, flux_mean, flux_std...") torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) - torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) + torch.save(flux_mean, os.path.join(static_dir_path, "flux_mean.pt")) + torch.save(flux_std, os.path.join(static_dir_path, "flux_std.pt")) # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") - ds_standard = WeatherDataset( - config_loader.dataset.name, - split="train", - subsample_step=1, - pred_length=63, - standardize=True, - ) # Re-load with standardization - loader_standard = torch.utils.data.DataLoader( - ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers - ) used_subsample_len = (65 // args.step_length) * args.step_length diff_means = [] diff_squares = [] - for init_batch, target_batch, _ in tqdm(loader_standard): + for init_batch, target_batch, _ in tqdm(loader): + # normalize the batch + init_batch = (init_batch - mean) / std + target_batch = (target_batch - mean) / std + + batch = torch.cat((init_batch, target_batch), dim=1) batch = torch.cat( (init_batch, target_batch), dim=1 ) # (N_batch, N_t', N_grid, d_features) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9cda9fc2..c1dce738 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -197,6 +197,19 @@ def common_step(self, batch): return prediction, target_states, pred_std + def on_after_batch_transfer(self, batch, dataloader_idx): + """Normalize Batch data after transferring to the device.""" + init_states, target_states, forcing_features = batch + init_states = (init_states - self.data_mean) / self.data_std + target_states = (target_states - self.data_mean) / self.data_std + forcing_features = (forcing_features - self.flux_mean) / self.flux_std + batch = ( + init_states, + target_states, + forcing_features, + ) + return batch + def training_step(self, batch): """ Train on single batch diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 836b04ed..3a5ffa8e 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -8,31 +8,6 @@ from tueplots import bundles, figsizes -def load_dataset_stats(dataset_name, device="cpu"): - """ - Load arrays with stored dataset statistics from pre-processing - """ - static_dir_path = os.path.join("data", dataset_name, "static") - - def loads_file(fn): - return torch.load( - os.path.join(static_dir_path, fn), map_location=device - ) - - data_mean = loads_file("parameter_mean.pt") # (d_features,) - data_std = loads_file("parameter_std.pt") # (d_features,) - - flux_stats = loads_file("flux_stats.pt") # (2,) - flux_mean, flux_std = flux_stats - - return { - "data_mean": data_mean, - "data_std": data_std, - "flux_mean": flux_mean, - "flux_std": flux_std, - } - - def load_static_data(dataset_name, device="cpu"): """ Load static files related to dataset @@ -64,6 +39,9 @@ def loads_file(fn): data_mean = loads_file("parameter_mean.pt") # (d_features,) data_std = loads_file("parameter_std.pt") # (d_features,) + flux_mean = loads_file("flux_mean.pt") # (,) + flux_std = loads_file("flux_std.pt") # (,) + # Load loss weighting vectors param_weights = torch.tensor( np.load(os.path.join(static_dir_path, "parameter_weights.npy")), @@ -78,6 +56,8 @@ def loads_file(fn): "step_diff_std": step_diff_std, "data_mean": data_mean, "data_std": data_std, + "flux_mean": flux_mean, + "flux_std": flux_std, "param_weights": param_weights, } diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index a782806b..4b1fe7e9 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -7,9 +7,6 @@ import numpy as np import torch -# First-party -from neural_lam import utils - class WeatherDataset(torch.utils.data.Dataset): """ @@ -29,7 +26,6 @@ def __init__( pred_length=19, split="train", subsample_step=3, - standardize=True, subset=False, control_only=False, ): @@ -61,17 +57,6 @@ def __init__( self.sample_length <= self.original_sample_length ), "Requesting too long time series samples" - # Set up for standardization - self.standardize = standardize - if standardize: - ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ds_stats["flux_mean"], - ds_stats["flux_std"], - ) - # If subsample index should be sampled (only duing training) self.random_subsample = split == "train" @@ -148,10 +133,6 @@ def __getitem__(self, idx): sample = sample[init_id : (init_id + self.sample_length)] # (sample_length, N_grid, d_features) - if self.standardize: - # Standardize sample - sample = (sample - self.data_mean) / self.data_std - # Split up sample in init. states and target states init_states = sample[:2] # (2, N_grid, d_features) target_states = sample[2:] # (sample_length-2, N_grid, d_features) @@ -185,9 +166,6 @@ def __getitem__(self, idx): -1 ) # (N_t', dim_x, dim_y, 1) - if self.standardize: - flux = (flux - self.flux_mean) / self.flux_std - # Flatten and subsample flux forcing flux = flux.flatten(1, 2) # (N_t, N_grid, 1) flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1) From a7deff960e5eb12a8995cf5102d81c1d47b2ebf8 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 18:58:12 +0200 Subject: [PATCH 2/5] bugfix in main --- neural_lam/vis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines From 7802553d6330b745f6e256835effaa0eed0e47e8 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 7 Jun 2024 21:08:50 +0200 Subject: [PATCH 3/5] fixed some issues after merge with main --- create_parameter_weights.py | 81 ++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 93461eb7..9725ffe1 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -43,7 +43,7 @@ def __len__(self): def get_original_indices(self): return self.original_indices - def get_original_window_indices(self, step_length): + def get_window_indices(self, step_length): return [ i // step_length for i in range(len(self.original_indices) * step_length) @@ -120,10 +120,9 @@ def save_stats( flux_mean = torch.mean(flux_means) # (,) flux_second_moment = torch.mean(flux_squares) # (,) flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - torch.save( - torch.stack((flux_mean, flux_std)).cpu(), - os.path.join(static_dir_path, "flux_stats.pt"), - ) + print("Saving mean, std.-dev, flux_mean, flux_std...") + torch.save(flux_mean, os.path.join(static_dir_path, "flux_mean.pt")) + torch.save(flux_std, os.path.join(static_dir_path, "flux_std.pt")) def main(): @@ -208,7 +207,6 @@ def main(): split="train", subsample_step=1, pred_length=63, - standardize=False, ) if distributed: ds = PaddedWeatherDataset( @@ -264,33 +262,60 @@ def main(): dist.all_gather_object(flux_means_gathered, flux_means) dist.all_gather_object(flux_squares_gathered, flux_squares) - flux_mean = torch.mean(torch.stack(flux_means)) # (,) - flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) - flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - flux_stats = torch.stack((flux_mean, flux_std)) + if rank == 0: + means_gathered, squares_gathered = torch.cat( + means_gathered, dim=0 + ), torch.cat(squares_gathered, dim=0) + flux_means_gathered, flux_squares_gathered = torch.tensor( + flux_means_gathered + ), torch.tensor(flux_squares_gathered) + + original_indices = ds.get_original_indices() + means, squares = [means_gathered[i] for i in original_indices], [ + squares_gathered[i] for i in original_indices + ] + flux_means, flux_squares = [ + flux_means_gathered[i] for i in original_indices + ], [flux_squares_gathered[i] for i in original_indices] + else: + means = [torch.cat(means, dim=0)] # (N_batch, d_features,) + squares = [torch.cat(squares, dim=0)] # (N_batch, d_features,) + flux_means = [torch.tensor(flux_means)] # (N_batch,) + flux_squares = [torch.tensor(flux_squares)] # (N_batch,) - print("Saving mean, std.-dev, flux_stats...") - torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) - torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) - torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) + if rank == 0: + save_stats( + static_dir_path, + means, + squares, + flux_means, + flux_squares, + "parameter", + ) + + if distributed: + dist.barrier() - # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") - ds_standard = WeatherDataset( - config_loader.dataset.name, - split="train", - subsample_step=1, - pred_length=63, - standardize=True, - ) # Re-load with standardization - loader_standard = torch.utils.data.DataLoader( - ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers - ) + used_subsample_len = (65 // args.step_length) * args.step_length + mean = torch.load(os.path.join(static_dir_path, "parameter_mean.pt")) + std = torch.load(os.path.join(static_dir_path, "parameter_std.pt")) + if distributed: + mean, std = mean.to(device), std.to(device) diff_means = [] diff_squares = [] - for init_batch, target_batch, _ in tqdm(loader_standard): + for init_batch, target_batch, _ in tqdm(loader): + if distributed: + init_batch, target_batch = ( + init_batch.to(device), + target_batch.to(device), + ) + # normalize the batch + init_batch = (init_batch - mean) / std + target_batch = (target_batch - mean) / std + batch = torch.cat( (init_batch, target_batch), dim=1 ) # (N_batch, N_t', N_grid, d_features) @@ -331,9 +356,7 @@ def main(): ).view( -1, *diff_squares[0].shape ) - original_indices = ds_standard.get_original_window_indices( - args.step_length - ) + original_indices = ds.get_window_indices(args.step_length) diff_means, diff_squares = [ diff_means_gathered[i] for i in original_indices ], [diff_squares_gathered[i] for i in original_indices] From ef3acc27fe6e0fdffb130a3b8b4921e7bfd306d4 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sun, 9 Jun 2024 11:45:34 +0200 Subject: [PATCH 4/5] improve docstring --- neural_lam/models/ar_model.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index ecb7dd13..e79fb3e8 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,6 +6,7 @@ import numpy as np import pytorch_lightning as pl import torch + import wandb # First-party @@ -195,7 +196,8 @@ def common_step(self, batch): return prediction, target_states, pred_std def on_after_batch_transfer(self, batch, dataloader_idx): - """Normalize Batch data after transferring to the device.""" + """Normalize batch data to mean 0, std 1 after transferring to the + device.""" init_states, target_states, forcing_features = batch init_states = (init_states - self.data_mean) / self.data_std target_states = (target_states - self.data_mean) / self.data_std @@ -222,8 +224,11 @@ def training_step(self, batch): log_dict = {"train_loss": batch_loss} self.log_dict( - log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True - ) + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True) return batch_loss def all_gather_cat(self, tensor_to_gather): @@ -351,8 +356,9 @@ def test_step(self, batch, batch_idx): ): # Need to plot more example predictions n_additional_examples = min( - prediction.shape[0], self.n_example_pred - self.plotted_examples - ) + prediction.shape[0], + self.n_example_pred + - self.plotted_examples) self.plot_examples( batch, n_additional_examples, prediction=prediction @@ -574,10 +580,14 @@ def on_test_epoch_end(self): ) for loss_map in mean_spatial_loss ] - pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") + pdf_loss_maps_dir = os.path.join( + wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): - fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) + fig.savefig( + os.path.join( + pdf_loss_maps_dir, + f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( mean_spatial_loss.cpu(), From 86e6cbacf16517419e69cc43ec0e41dced553c52 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sun, 9 Jun 2024 11:49:02 +0200 Subject: [PATCH 5/5] linter --- neural_lam/models/ar_model.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index e79fb3e8..5b63746d 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,7 +6,6 @@ import numpy as np import pytorch_lightning as pl import torch - import wandb # First-party @@ -224,11 +223,8 @@ def training_step(self, batch): log_dict = {"train_loss": batch_loss} self.log_dict( - log_dict, - prog_bar=True, - on_step=True, - on_epoch=True, - sync_dist=True) + log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + ) return batch_loss def all_gather_cat(self, tensor_to_gather): @@ -356,9 +352,8 @@ def test_step(self, batch, batch_idx): ): # Need to plot more example predictions n_additional_examples = min( - prediction.shape[0], - self.n_example_pred - - self.plotted_examples) + prediction.shape[0], self.n_example_pred - self.plotted_examples + ) self.plot_examples( batch, n_additional_examples, prediction=prediction @@ -580,14 +575,10 @@ def on_test_epoch_end(self): ) for loss_map in mean_spatial_loss ] - pdf_loss_maps_dir = os.path.join( - wandb.run.dir, "spatial_loss_maps") + pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): - fig.savefig( - os.path.join( - pdf_loss_maps_dir, - f"loss_t{t_i}.pdf")) + fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( mean_spatial_loss.cpu(),