Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

25 normalize data on gpu #39

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ data
│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
│ ├── flux_mean.pt - Mean of solar flux forcing (create_parameter_weights.py)
│ ├── flux_std.pt - Std.-dev. of solar flux forcing (create_parameter_weights.py)
│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
├── dataset2
├── ...
Expand Down
67 changes: 24 additions & 43 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __len__(self):
def get_original_indices(self):
return self.original_indices

def get_original_window_indices(self, step_length):
def get_window_indices(self, step_length):
return [
i // step_length
for i in range(len(self.original_indices) * step_length)
Expand Down Expand Up @@ -120,10 +120,9 @@ def save_stats(
flux_mean = torch.mean(flux_means) # (,)
flux_second_moment = torch.mean(flux_squares) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
torch.save(
torch.stack((flux_mean, flux_std)).cpu(),
os.path.join(static_dir_path, "flux_stats.pt"),
)
print("Saving mean, std.-dev, flux_mean, flux_std...")
torch.save(flux_mean, os.path.join(static_dir_path, "flux_mean.pt"))
torch.save(flux_std, os.path.join(static_dir_path, "flux_std.pt"))


def main():
Expand Down Expand Up @@ -208,7 +207,6 @@ def main():
split="train",
subsample_step=1,
pred_length=63,
standardize=False,
)
if distributed:
ds = PaddedWeatherDataset(
Expand Down Expand Up @@ -298,44 +296,29 @@ def main():
if distributed:
dist.barrier()

if rank == 0:
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
config_loader.dataset.name,
split="train",
subsample_step=1,
pred_length=63,
standardize=True,
) # Re-load with standardization
if distributed:
ds_standard = PaddedWeatherDataset(
ds_standard,
world_size,
args.batch_size,
)
sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank, shuffle=False
)
else:
sampler_standard = None
loader_standard = torch.utils.data.DataLoader(
ds_standard,
args.batch_size,
shuffle=False,
num_workers=args.n_workers,
sampler=sampler_standard,
)
used_subsample_len = (65 // args.step_length) * args.step_length
print("Computing mean and std.-dev. for one-step differences...")

diff_means, diff_squares = [], []
used_subsample_len = (65 // args.step_length) * args.step_length

for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
mean = torch.load(os.path.join(static_dir_path, "parameter_mean.pt"))
std = torch.load(os.path.join(static_dir_path, "parameter_std.pt"))
if distributed:
mean, std = mean.to(device), std.to(device)
diff_means = []
diff_squares = []
for init_batch, target_batch, _ in tqdm(loader):
if distributed:
init_batch, target_batch = init_batch.to(device), target_batch.to(
device
init_batch, target_batch = (
init_batch.to(device),
target_batch.to(device),
)
# (N_batch, N_t', N_grid, d_features)
batch = torch.cat((init_batch, target_batch), dim=1)
# normalize the batch
init_batch = (init_batch - mean) / std
target_batch = (target_batch - mean) / std

batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t', N_grid, d_features)
# Note: batch contains only 1h-steps
stepped_batch = torch.cat(
[
Expand Down Expand Up @@ -373,9 +356,7 @@ def main():
).view(
-1, *diff_squares[0].shape
)
original_indices = ds_standard.get_original_window_indices(
args.step_length
)
original_indices = ds.get_window_indices(args.step_length)
diff_means, diff_squares = [
diff_means_gathered[i] for i in original_indices
], [diff_squares_gathered[i] for i in original_indices]
Expand Down
14 changes: 14 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ def common_step(self, batch):

return prediction, target_states, pred_std

def on_after_batch_transfer(self, batch, dataloader_idx):
"""Normalize batch data to mean 0, std 1 after transferring to the
device."""
init_states, target_states, forcing_features = batch
init_states = (init_states - self.data_mean) / self.data_std
target_states = (target_states - self.data_mean) / self.data_std
forcing_features = (forcing_features - self.flux_mean) / self.flux_std
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now all forcing seem to be normalized with the flux statistics, but there is more forcing than flux. Note how this was only applied to the flux in WeatherDataset before.
Have you tested that this gives exactly the same tensors as before? (e.g. save the first batch to disk on main, check this out, save first batch and compare).

Copy link
Collaborator Author

@sadamov sadamov Jun 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, the forcings are now handled differently. I suggest to implement a new logic to handle forcings in #54. The user can define combined_vars that share statistics and also define vars that should not be normalized.

batch = (
init_states,
target_states,
forcing_features,
)
return batch

def training_step(self, batch):
"""
Train on single batch
Expand Down
30 changes: 5 additions & 25 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,6 @@
from tueplots import bundles, figsizes


def load_dataset_stats(dataset_name, device="cpu"):
"""
Load arrays with stored dataset statistics from pre-processing
"""
static_dir_path = os.path.join("data", dataset_name, "static")

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
)

data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_stats = loads_file("flux_stats.pt") # (2,)
flux_mean, flux_std = flux_stats

return {
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
}


def load_static_data(dataset_name, device="cpu"):
"""
Load static files related to dataset
Expand Down Expand Up @@ -65,6 +40,9 @@ def loads_file(fn):
data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_mean = loads_file("flux_mean.pt") # (,)
flux_std = loads_file("flux_std.pt") # (,)

# Load loss weighting vectors
param_weights = torch.tensor(
np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
Expand All @@ -79,6 +57,8 @@ def loads_file(fn):
"step_diff_std": step_diff_std,
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
"param_weights": param_weights,
}

Expand Down
22 changes: 0 additions & 22 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import numpy as np
import torch

# First-party
from neural_lam import utils


class WeatherDataset(torch.utils.data.Dataset):
"""
Expand All @@ -29,7 +26,6 @@ def __init__(
pred_length=19,
split="train",
subsample_step=3,
standardize=True,
subset=False,
control_only=False,
):
Expand Down Expand Up @@ -61,17 +57,6 @@ def __init__(
self.sample_length <= self.original_sample_length
), "Requesting too long time series samples"

# Set up for standardization
self.standardize = standardize
if standardize:
ds_stats = utils.load_dataset_stats(dataset_name, "cpu")
self.data_mean, self.data_std, self.flux_mean, self.flux_std = (
ds_stats["data_mean"],
ds_stats["data_std"],
ds_stats["flux_mean"],
ds_stats["flux_std"],
)

# If subsample index should be sampled (only duing training)
self.random_subsample = split == "train"

Expand Down Expand Up @@ -148,10 +133,6 @@ def __getitem__(self, idx):
sample = sample[init_id : (init_id + self.sample_length)]
# (sample_length, N_grid, d_features)

if self.standardize:
# Standardize sample
sample = (sample - self.data_mean) / self.data_std

# Split up sample in init. states and target states
init_states = sample[:2] # (2, N_grid, d_features)
target_states = sample[2:] # (sample_length-2, N_grid, d_features)
Expand Down Expand Up @@ -185,9 +166,6 @@ def __getitem__(self, idx):
-1
) # (N_t', dim_y, dim_x, 1)

if self.standardize:
flux = (flux - self.flux_mean) / self.flux_std

# Flatten and subsample flux forcing
flux = flux.flatten(1, 2) # (N_t, N_grid, 1)
flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1)
Expand Down
Loading