diff --git a/README.md b/README.md index 1bdc6602..e91bd56b 100644 --- a/README.md +++ b/README.md @@ -221,7 +221,8 @@ data │ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py) │ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py) │ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py) -│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py) +│ ├── flux_mean.pt - Mean of solar flux forcing (create_parameter_weights.py) +│ ├── flux_std.pt - Std.-dev. of solar flux forcing (create_parameter_weights.py) │ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py) ├── dataset2 ├── ... diff --git a/create_parameter_weights.py b/create_parameter_weights.py index c85cd5a3..9725ffe1 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -43,7 +43,7 @@ def __len__(self): def get_original_indices(self): return self.original_indices - def get_original_window_indices(self, step_length): + def get_window_indices(self, step_length): return [ i // step_length for i in range(len(self.original_indices) * step_length) @@ -120,10 +120,9 @@ def save_stats( flux_mean = torch.mean(flux_means) # (,) flux_second_moment = torch.mean(flux_squares) # (,) flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - torch.save( - torch.stack((flux_mean, flux_std)).cpu(), - os.path.join(static_dir_path, "flux_stats.pt"), - ) + print("Saving mean, std.-dev, flux_mean, flux_std...") + torch.save(flux_mean, os.path.join(static_dir_path, "flux_mean.pt")) + torch.save(flux_std, os.path.join(static_dir_path, "flux_std.pt")) def main(): @@ -208,7 +207,6 @@ def main(): split="train", subsample_step=1, pred_length=63, - standardize=False, ) if distributed: ds = PaddedWeatherDataset( @@ -298,44 +296,29 @@ def main(): if distributed: dist.barrier() - if rank == 0: - print("Computing mean and std.-dev. for one-step differences...") - ds_standard = WeatherDataset( - config_loader.dataset.name, - split="train", - subsample_step=1, - pred_length=63, - standardize=True, - ) # Re-load with standardization - if distributed: - ds_standard = PaddedWeatherDataset( - ds_standard, - world_size, - args.batch_size, - ) - sampler_standard = DistributedSampler( - ds_standard, num_replicas=world_size, rank=rank, shuffle=False - ) - else: - sampler_standard = None - loader_standard = torch.utils.data.DataLoader( - ds_standard, - args.batch_size, - shuffle=False, - num_workers=args.n_workers, - sampler=sampler_standard, - ) - used_subsample_len = (65 // args.step_length) * args.step_length + print("Computing mean and std.-dev. for one-step differences...") - diff_means, diff_squares = [], [] + used_subsample_len = (65 // args.step_length) * args.step_length - for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): + mean = torch.load(os.path.join(static_dir_path, "parameter_mean.pt")) + std = torch.load(os.path.join(static_dir_path, "parameter_std.pt")) + if distributed: + mean, std = mean.to(device), std.to(device) + diff_means = [] + diff_squares = [] + for init_batch, target_batch, _ in tqdm(loader): if distributed: - init_batch, target_batch = init_batch.to(device), target_batch.to( - device + init_batch, target_batch = ( + init_batch.to(device), + target_batch.to(device), ) - # (N_batch, N_t', N_grid, d_features) - batch = torch.cat((init_batch, target_batch), dim=1) + # normalize the batch + init_batch = (init_batch - mean) / std + target_batch = (target_batch - mean) / std + + batch = torch.cat( + (init_batch, target_batch), dim=1 + ) # (N_batch, N_t', N_grid, d_features) # Note: batch contains only 1h-steps stepped_batch = torch.cat( [ @@ -373,9 +356,7 @@ def main(): ).view( -1, *diff_squares[0].shape ) - original_indices = ds_standard.get_original_window_indices( - args.step_length - ) + original_indices = ds.get_window_indices(args.step_length) diff_means, diff_squares = [ diff_means_gathered[i] for i in original_indices ], [diff_squares_gathered[i] for i in original_indices] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 6ced211f..5b63746d 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -194,6 +194,20 @@ def common_step(self, batch): return prediction, target_states, pred_std + def on_after_batch_transfer(self, batch, dataloader_idx): + """Normalize batch data to mean 0, std 1 after transferring to the + device.""" + init_states, target_states, forcing_features = batch + init_states = (init_states - self.data_mean) / self.data_std + target_states = (target_states - self.data_mean) / self.data_std + forcing_features = (forcing_features - self.flux_mean) / self.flux_std + batch = ( + init_states, + target_states, + forcing_features, + ) + return batch + def training_step(self, batch): """ Train on single batch diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 59a529eb..5743cdfc 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -9,31 +9,6 @@ from tueplots import bundles, figsizes -def load_dataset_stats(dataset_name, device="cpu"): - """ - Load arrays with stored dataset statistics from pre-processing - """ - static_dir_path = os.path.join("data", dataset_name, "static") - - def loads_file(fn): - return torch.load( - os.path.join(static_dir_path, fn), map_location=device - ) - - data_mean = loads_file("parameter_mean.pt") # (d_features,) - data_std = loads_file("parameter_std.pt") # (d_features,) - - flux_stats = loads_file("flux_stats.pt") # (2,) - flux_mean, flux_std = flux_stats - - return { - "data_mean": data_mean, - "data_std": data_std, - "flux_mean": flux_mean, - "flux_std": flux_std, - } - - def load_static_data(dataset_name, device="cpu"): """ Load static files related to dataset @@ -65,6 +40,9 @@ def loads_file(fn): data_mean = loads_file("parameter_mean.pt") # (d_features,) data_std = loads_file("parameter_std.pt") # (d_features,) + flux_mean = loads_file("flux_mean.pt") # (,) + flux_std = loads_file("flux_std.pt") # (,) + # Load loss weighting vectors param_weights = torch.tensor( np.load(os.path.join(static_dir_path, "parameter_weights.npy")), @@ -79,6 +57,8 @@ def loads_file(fn): "step_diff_std": step_diff_std, "data_mean": data_mean, "data_std": data_std, + "flux_mean": flux_mean, + "flux_std": flux_std, "param_weights": param_weights, } diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 3288ed67..beb01686 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -7,9 +7,6 @@ import numpy as np import torch -# First-party -from neural_lam import utils - class WeatherDataset(torch.utils.data.Dataset): """ @@ -29,7 +26,6 @@ def __init__( pred_length=19, split="train", subsample_step=3, - standardize=True, subset=False, control_only=False, ): @@ -61,17 +57,6 @@ def __init__( self.sample_length <= self.original_sample_length ), "Requesting too long time series samples" - # Set up for standardization - self.standardize = standardize - if standardize: - ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ds_stats["flux_mean"], - ds_stats["flux_std"], - ) - # If subsample index should be sampled (only duing training) self.random_subsample = split == "train" @@ -148,10 +133,6 @@ def __getitem__(self, idx): sample = sample[init_id : (init_id + self.sample_length)] # (sample_length, N_grid, d_features) - if self.standardize: - # Standardize sample - sample = (sample - self.data_mean) / self.data_std - # Split up sample in init. states and target states init_states = sample[:2] # (2, N_grid, d_features) target_states = sample[2:] # (sample_length-2, N_grid, d_features) @@ -185,9 +166,6 @@ def __getitem__(self, idx): -1 ) # (N_t', dim_y, dim_x, 1) - if self.standardize: - flux = (flux - self.flux_mean) / self.flux_std - # Flatten and subsample flux forcing flux = flux.flatten(1, 2) # (N_t, N_grid, 1) flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1)