diff --git a/src/anemoi/models/data/data_indices/collection.py b/src/anemoi/models/data/data_indices/collection.py new file mode 100644 index 0000000..a71ad01 --- /dev/null +++ b/src/anemoi/models/data/data_indices/collection.py @@ -0,0 +1,74 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import operator + +import yaml +from omegaconf import OmegaConf + +from anemoi.models.data.data_indices.index import BaseIndex +from anemoi.models.data.data_indices.index import DataIndex +from anemoi.models.data.data_indices.index import ModelIndex +from anemoi.models.data.data_indices.tensor import BaseTensorIndex +from anemoi.models.data.data_indices.tensor import InputTensorIndex +from anemoi.models.data.data_indices.tensor import OutputTensorIndex + + +class IndexCollection: + """Collection of data and model indices.""" + + def __init__(self, config, name_to_index) -> None: + self.config = OmegaConf.to_container(config, resolve=True) + + self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True) + self.diagnostic = ( + [] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True) + ) + + assert set(self.diagnostic).isdisjoint(self.forcing), ( + f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ", + "Please drop them at a dataset-level to exclude them from the training data.", + ) + self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1))) + name_to_index_model_input = { + name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic) + } + name_to_index_model_output = { + name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing) + } + + self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index) + self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output) + + def __repr__(self) -> str: + return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})" + + def __eq__(self, other): + if not isinstance(other, IndexCollection): + # don't attempt to compare against unrelated types + return NotImplemented + + return self.model == other.model and self.data == other.data + + def __getitem__(self, key): + return getattr(self, key) + + def todict(self): + return { + "data": self.data.todict(), + "model": self.model.todict(), + } + + @staticmethod + def representer(dumper, data): + return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data)) + + +for cls in [BaseTensorIndex, InputTensorIndex, OutputTensorIndex, BaseIndex, DataIndex, ModelIndex, IndexCollection]: + yaml.add_representer(cls, cls.representer) diff --git a/src/anemoi/models/data/data_indices/index.py b/src/anemoi/models/data/data_indices/index.py new file mode 100644 index 0000000..1a2f8ad --- /dev/null +++ b/src/anemoi/models/data/data_indices/index.py @@ -0,0 +1,93 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from anemoi.models.data.data_indices.tensor import InputTensorIndex +from anemoi.models.data.data_indices.tensor import OutputTensorIndex + + +class BaseIndex: + """Base class for data and model indices.""" + + def __init__(self) -> None: + self.input = NotImplementedError + self.output = NotImplementedError + + def __eq__(self, other): + if not isinstance(other, BaseIndex): + # don't attempt to compare against unrelated types + return NotImplemented + + return self.input == other.input and self.output == other.output + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(input={self.input}, output={self.output})" + + def __getitem__(self, key): + return getattr(self, key) + + def todict(self): + return { + "input": self.input.todict(), + "output": self.output.todict(), + } + + @staticmethod + def representer(dumper, data): + return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data)) + + +class DataIndex(BaseIndex): + """Indexing for data variables.""" + + def __init__(self, diagnostic, forcing, name_to_index) -> None: + self._diagnostic = diagnostic + self._forcing = forcing + self._name_to_index = name_to_index + self.input = InputTensorIndex( + includes=forcing, + excludes=diagnostic, + name_to_index=name_to_index, + ) + + self.output = OutputTensorIndex( + includes=diagnostic, + excludes=forcing, + name_to_index=name_to_index, + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, name_to_index={self._name_to_index})" + + +class ModelIndex(BaseIndex): + """Indexing for model variables.""" + + def __init__(self, diagnostic, forcing, name_to_index_model_input, name_to_index_model_output) -> None: + self._diagnostic = diagnostic + self._forcing = forcing + self._name_to_index_model_input = name_to_index_model_input + self._name_to_index_model_output = name_to_index_model_output + self.input = InputTensorIndex( + includes=forcing, + excludes=[], + name_to_index=name_to_index_model_input, + ) + + self.output = OutputTensorIndex( + includes=diagnostic, + excludes=[], + name_to_index=name_to_index_model_output, + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, " + f"name_to_index_model_input={self._name_to_index_model_input}, " + f"name_to_index_model_output={self._name_to_index_model_output})" + ) diff --git a/src/anemoi/models/data/data_indices/tensor.py b/src/anemoi/models/data/data_indices/tensor.py new file mode 100644 index 0000000..c7306cf --- /dev/null +++ b/src/anemoi/models/data/data_indices/tensor.py @@ -0,0 +1,114 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import torch + + +class BaseTensorIndex: + """Indexing for variables in index as Tensor.""" + + def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None: + """Initialize indexing tensors from includes and excludes using name_to_index. + + Parameters + ---------- + includes : list[str] + Variables to include in the indexing that are exclusive to this indexing. + e.g. Forcing variables for the input indexing, diagnostic variables for the output indexing + excludes : list[str] + Variables to exclude from the indexing. + e.g. Diagnostic variables for the input indexing, forcing variables for the output indexing + name_to_index : dict[str, int] + Dictionary mapping variable names to their index in the Tensor. + """ + self.includes = includes + self.excludes = excludes + self.name_to_index = name_to_index + + assert set(self.excludes).issubset( + self.name_to_index.keys(), + ), f"Data indexing has invalid entries {[var for var in self.excludes if var not in self.name_to_index]}, not in dataset." + assert set(self.includes).issubset( + self.name_to_index.keys(), + ), f"Data indexing has invalid entries {[var for var in self.includes if var not in self.name_to_index]}, not in dataset." + + self.full = self._build_idx_from_excludes() + self._only = self._build_idx_from_includes() + self._removed = self._build_idx_from_includes(self.excludes) + self.prognostic = self._build_idx_prognostic() + self.diagnostic = NotImplementedError + self.forcing = NotImplementedError + + def __len__(self) -> int: + return len(self.full) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(includes={self.includes}, excludes={self.excludes}, name_to_index={self.name_to_index})" + + def __eq__(self, other): + if not isinstance(other, BaseTensorIndex): + # don't attempt to compare against unrelated types + return NotImplemented + + return ( + torch.allclose(self.full, other.full) + and torch.allclose(self._only, other._only) + and torch.allclose(self._removed, other._removed) + and torch.allclose(self.prognostic, other.prognostic) + and torch.allclose(self.diagnostic, other.diagnostic) + and torch.allclose(self.forcing, other.forcing) + and self.includes == other.includes + and self.excludes == other.excludes + ) + + def __getitem__(self, key): + return getattr(self, key) + + def todict(self): + return { + "full": self.full, + "prognostic": self.prognostic, + "diagnostic": self.diagnostic, + "forcing": self.forcing, + } + + @staticmethod + def representer(dumper, data): + return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data)) + + def _build_idx_from_excludes(self, excludes=None) -> "torch.Tensor[int]": + if excludes is None: + excludes = self.excludes + return torch.Tensor(sorted(i for name, i in self.name_to_index.items() if name not in excludes)).to(torch.int) + + def _build_idx_from_includes(self, includes=None) -> "torch.Tensor[int]": + if includes is None: + includes = self.includes + return torch.Tensor(sorted(self.name_to_index[name] for name in includes)).to(torch.int) + + def _build_idx_prognostic(self) -> "torch.Tensor[int]": + return self._build_idx_from_excludes(self.includes + self.excludes) + + +class InputTensorIndex(BaseTensorIndex): + """Indexing for input variables.""" + + def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None: + super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index) + self.forcing = self._only + self.diagnostic = self._removed + + +class OutputTensorIndex(BaseTensorIndex): + """Indexing for output variables.""" + + def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None: + super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index) + self.forcing = self._removed + self.diagnostic = self._only diff --git a/src/anemoi/models/data/data_module.py b/src/anemoi/models/data/data_module.py new file mode 100644 index 0000000..aca1681 --- /dev/null +++ b/src/anemoi/models/data/data_module.py @@ -0,0 +1,195 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os +from functools import cached_property + +import pytorch_lightning as pl +from anemoi.datasets.data import open_dataset +from omegaconf import DictConfig +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from anemoi.models.data.data_indices.collection import IndexCollection +from anemoi.models.data.dataset import NativeGridDataset +from anemoi.models.data.dataset import worker_init_func + +LOGGER = logging.getLogger(__name__) + + +class ECMLDataModule(pl.LightningDataModule): + """ECML data module for PyTorch Lightning.""" + + def __init__(self, config: DictConfig) -> None: + """Initialize ECML data module. + + Parameters + ---------- + config : DictConfig + Job configuration + """ + super().__init__() + LOGGER.setLevel(config.diagnostics.log.code.level) + + self.config = config + + # Determine the step size relative to the data frequency + frequency = self.config.data.frequency + timestep = self.config.data.timestep + assert ( + isinstance(frequency, str) and isinstance(timestep, str) and frequency[-1] == "h" and timestep[-1] == "h" + ), f"Error in format of timestep, {timestep}, or data frequency, {frequency}" + assert ( + int(timestep[:-1]) % int(frequency[:-1]) == 0 + ), f"Timestep isn't a multiple of data frequency, {timestep}, or data frequency, {frequency}" + self.timeincrement = int(timestep[:-1]) // int(frequency[:-1]) + LOGGER.info( + f"Timeincrement set to {self.timeincrement} for data with frequency, {frequency}, and timestep, {timestep}" + ) + + self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank + self.model_comm_group_id = ( + self.global_rank // self.config.hardware.num_gpus_per_model + ) # id of the model communication group the rank is participating in + self.model_comm_group_rank = ( + self.global_rank % self.config.hardware.num_gpus_per_model + ) # rank within one model communication group + total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes + assert ( + total_gpus + ) % self.config.hardware.num_gpus_per_model == 0, ( + f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" + ) + self.model_comm_num_groups = ( + self.config.hardware.num_gpus_per_node + * self.config.hardware.num_nodes + // self.config.hardware.num_gpus_per_model + ) # number of model communication groups + LOGGER.debug( + "Rank %d model communication group number %d, with local model communication group rank %d", + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + ) + + # Set the maximum rollout to be expected + self.rollout = ( + self.config.training.rollout.max + if self.config.training.rollout.epoch_increment > 0 + else self.config.training.rollout.start + ) + + # Set the training end date if not specified + if self.config.dataloader.training.end is None: + LOGGER.info( + "No end date specified for training data, setting default before validation start date %s.", + self.config.dataloader.validation.start - 1, + ) + self.config.dataloader.training.end = self.config.dataloader.validation.start - 1 + + def _check_resolution(self, resolution) -> None: + assert ( + self.config.data.resolution.lower() == resolution.lower() + ), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}" + + @cached_property + def statistics(self) -> dict: + return self.dataset_train.statistics + + @cached_property + def metadata(self) -> dict: + return self.dataset_train.metadata + + @cached_property + def data_indices(self) -> IndexCollection: + return IndexCollection(self.config, self.dataset_train.name_to_index) + + @cached_property + def dataset_train(self) -> NativeGridDataset: + return self._get_dataset( + open_dataset(OmegaConf.to_container(self.config.dataloader.training, resolve=True)), label="train" + ) + + @cached_property + def dataset_validation(self) -> NativeGridDataset: + r = self.rollout + if self.config.diagnostics.eval.enabled: + r = max(r, self.config.diagnostics.eval.rollout) + assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( + f"Training end date {self.config.dataloader.training.end} is not before" + f"validation start date {self.config.dataloader.validation.start}" + ) + return self._get_dataset( + open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)), + shuffle=False, + rollout=r, + label="validation", + ) + + @cached_property + def dataset_test(self) -> NativeGridDataset: + assert self.config.dataloader.training.end < self.config.dataloader.test.start, ( + f"Training end date {self.config.dataloader.training.end} is not before" + f"test start date {self.config.dataloader.test.start}" + ) + assert self.config.dataloader.validation.end < self.config.dataloader.test.start, ( + f"Validation end date {self.config.dataloader.validation.end} is not before" + f"test start date {self.config.dataloader.test.start}" + ) + return self._get_dataset( + open_dataset(OmegaConf.to_container(self.config.dataloader.test, resolve=True)), + shuffle=False, + label="test", + ) + + def _get_dataset( + self, data_reader, shuffle: bool = True, rollout: int = 1, label: str = "generic" + ) -> NativeGridDataset: + r = max(rollout, self.rollout) + data = NativeGridDataset( + data_reader=data_reader, + rollout=r, + multistep=self.config.training.multistep_input, + timeincrement=self.timeincrement, + model_comm_group_rank=self.model_comm_group_rank, + model_comm_group_id=self.model_comm_group_id, + model_comm_num_groups=self.model_comm_num_groups, + shuffle=shuffle, + label=label, + logging=self.config.diagnostics.log.code.level, + ) + self._check_resolution(data.resolution) + return data + + def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: + assert stage in ["training", "validation", "test"] + return DataLoader( + ds, + batch_size=self.config.dataloader.batch_size[stage], + # number of worker processes + num_workers=self.config.dataloader.num_workers[stage], + # use of pinned memory can speed up CPU-to-GPU data transfers + # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning + pin_memory=True, + # worker initializer + worker_init_fn=worker_init_func, + # prefetch batches + prefetch_factor=self.config.dataloader.prefetch_factor, + persistent_workers=True, + ) + + def train_dataloader(self) -> DataLoader: + return self._get_dataloader(self.dataset_train, "training") + + def val_dataloader(self) -> DataLoader: + return self._get_dataloader(self.dataset_validation, "validation") + + def test_dataloader(self) -> DataLoader: + return self._get_dataloader(self.dataset_test, "test") diff --git a/src/anemoi/models/data/dataset.py b/src/anemoi/models/data/dataset.py new file mode 100644 index 0000000..b9cd693 --- /dev/null +++ b/src/anemoi/models/data/dataset.py @@ -0,0 +1,255 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os +import random +from functools import cached_property +from typing import Callable +from typing import Optional + +import numpy as np +import torch +from einops import rearrange +from torch.utils.data import IterableDataset +from torch.utils.data import get_worker_info + +from anemoi.models.utils.seeding import get_base_seed + +LOGGER = logging.getLogger(__name__) + + +class NativeGridDataset(IterableDataset): + """Iterable dataset for AnemoI data on the arbitrary grids.""" + + def __init__( + self, + data_reader: Callable, + rollout: int = 1, + multistep: int = 1, + timeincrement: int = 1, + model_comm_group_rank: int = 0, + model_comm_group_id: int = 0, + model_comm_num_groups: int = 1, + shuffle: bool = True, + label: str = "generic", + logging: str = "INFO", + ) -> None: + """Initialize (part of) the dataset state. + + Parameters + ---------- + data_reader : Callable + user function that opens and returns the zarr array data + rollout : int, optional + length of rollout window, by default 12 + multistep : int, optional + collate (t-1, ... t - multistep) into the input state vector, by default 1 + model_comm_group_rank : int, optional + process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 + model_comm_group_id: int, optional + device group ID, default 0 + model_comm_num_groups : int, optional + total number of device groups, by default 1 + shuffle : bool, optional + Shuffle batches, by default True + + Raises + ------ + RuntimeError + Multistep value cannot be negative. + """ + LOGGER.setLevel(logging) + self.label = label + + self.data = data_reader + + self.rollout = rollout + self.timeincrement = timeincrement + + # lazy init + self.n_samples_per_epoch_total: int = 0 + self.n_samples_per_epoch_per_worker: int = 0 + + # DDP-relevant info + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_id = model_comm_group_id + self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + + # additional state vars (lazy init) + self.n_samples_per_worker = 0 + self.chunk_index_range: Optional[np.ndarray] = None + self.shuffle = shuffle + + # Data dimensions + self.multi_step = multistep + assert self.multi_step > 0, "Multistep value must be greater than zero." + self.ensemble_dim: int = 2 + self.ensemble_size = self.data.shape[self.ensemble_dim] + + @cached_property + def statistics(self) -> dict: + """Return dataset statistics.""" + return self.data.statistics + + @cached_property + def metadata(self) -> dict: + """Return dataset metadata.""" + return self.data.metadata() + + @cached_property + def name_to_index(self) -> dict: + """Return dataset statistics.""" + return self.data.name_to_index + + @cached_property + def resolution(self) -> dict: + """Return dataset resolution.""" + return self.data.resolution + + def per_worker_init(self, n_workers: int, worker_id: int) -> None: + """Called by worker_init_func on each copy of dataset. + + This initialises after the worker process has been spawned. + + Parameters + ---------- + n_workers : int + Number of workers + worker_id : int + Worker ID + """ + self.worker_id = worker_id + + # Total number of valid ICs is dataset length minus rollout minus additional multistep inputs + len_corrected = len(self.data) - (self.rollout + (self.multi_step - 1)) * self.timeincrement + + # Divide this equally across shards (one shard per group!) + shard_size = len_corrected // self.model_comm_num_groups + shard_start = self.model_comm_group_id * shard_size + (self.multi_step - 1) * self.timeincrement + shard_end = min((self.model_comm_group_id + 1) * shard_size, len(self.data) - self.rollout * self.timeincrement) + + shard_len = shard_end - shard_start + self.n_samples_per_worker = shard_len // n_workers + + low = shard_start + worker_id * self.n_samples_per_worker + high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + + LOGGER.debug( + "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", + worker_id, + os.getpid(), + self.global_rank, + self.model_comm_group_id, + low, + high, + ) + + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) + + # each worker must have a different seed for its random number generator, + # otherwise all the workers will output exactly the same data + # should we check lightning env variable "PL_SEED_WORKERS" here? + # but we alwyas want to seed these anyways ... + + base_seed = get_base_seed() + + seed = ( + base_seed * (self.model_comm_group_id + 1) - worker_id + ) # note that test, validation etc. datasets get same seed + torch.manual_seed(seed) + random.seed(seed) + self.rng = np.random.default_rng(seed=seed) + sanity_rnd = self.rng.random(1) + + LOGGER.debug( + "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, group_rank %d, base_seed %d) using seed %d, sanity rnd %f", + worker_id, + self.label, + os.getpid(), + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + base_seed, + seed, + sanity_rnd, + ) + + def __iter__(self): + """Return an iterator over the dataset. + + The datasets are retrieved by ECML Tools from zarr files. This iterator yields + chunked batches for DDP and sharded training. + + Currently it receives data with an ensemble dimension, which is discarded for + now. (Until the code is "ensemble native".) + """ + if self.shuffle: + shuffled_chunk_indices = self.rng.choice( + self.chunk_index_range, size=self.n_samples_per_worker, replace=False + ) + else: + shuffled_chunk_indices = self.chunk_index_range + + LOGGER.debug( + "Worker pid %d, label %s, worker id %d, global_rank %d, model comm group %d, group_rank %d using indices[0:10]: %s", + os.getpid(), + self.label, + self.worker_id, + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + shuffled_chunk_indices[:10], + ) + + for i in shuffled_chunk_indices: + start = i - (self.multi_step - 1) * self.timeincrement + end = i + (self.rollout + 1) * self.timeincrement + + x = self.data[start : end : self.timeincrement] + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") + self.ensemble_dim = 1 + + yield torch.from_numpy(x) + + def __repr__(self) -> str: + return f""" + {super().__repr__()} + Dataset: {self.data} + Rollout: {self.rollout} + Multistep: {self.multi_step} + Timeincrement: {self.timeincrement} + """ + + +def worker_init_func(worker_id: int) -> None: + """Configures each dataset worker process. + + Calls WeatherBenchDataset.per_worker_init() on each dataset object. + + Parameters + ---------- + worker_id : int + Worker ID + + Raises + ------ + RuntimeError + If worker_info is None + """ + worker_info = get_worker_info() # information specific to each worker process + if worker_info is None: + LOGGER.error("worker_info is None! Set num_workers > 0 in your dataloader!") + raise RuntimeError + dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process. + dataset_obj.per_worker_init( + n_workers=worker_info.num_workers, + worker_id=worker_id, + ) diff --git a/src/anemoi/models/data/normalizer.py b/src/anemoi/models/data/normalizer.py new file mode 100644 index 0000000..da14ef8 --- /dev/null +++ b/src/anemoi/models/data/normalizer.py @@ -0,0 +1,187 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +import warnings +from typing import Optional + +import numpy as np +import torch +from torch import nn + +LOGGER = logging.getLogger(__name__) + + +class InputNormalizer(nn.Module): + """Normalizes input data to zero mean and unit variance.""" + + def __init__(self, *, config, statistics: dict, data_indices: dict) -> None: + """Initialize the normalizer. + + Parameters + ---------- + zarr_metadata : Dict + Zarr metadata dictionary + """ + super().__init__() + LOGGER.setLevel(config.diagnostics.log.code.level) + + default = config.data.normalizer.default + method_config = {k: v for k, v in config.data.normalizer.items() if k != "default" and v is not None} + + if not method_config: + LOGGER.warning( + f"Normalizing: Using default method {default} for all variables not specified in the config." + ) + + name_to_index = data_indices.data.input.name_to_index + + methods = { + variable: method + for method, variables in method_config.items() + if not isinstance(variables, str) + for variable in variables + } + + assert len(methods) == sum(len(v) for v in method_config.values()), ( + f"Error parsing methods in InputNormalizer methods ({len(methods)}) " + f"and entries in config ({sum(len(v) for v in method_config)}) do not match." + ) + + minimum = statistics["minimum"] + maximum = statistics["maximum"] + mean = statistics["mean"] + stdev = statistics["stdev"] + + n = minimum.size + assert maximum.size == n, (maximum.size, n) + assert mean.size == n, (mean.size, n) + assert stdev.size == n, (stdev.size, n) + + assert isinstance(methods, dict) + for name, method in methods.items(): + assert name in name_to_index, f"{name} is not a valid variable name" + assert method in [ + "mean-std", + # "robust", + "min-max", + "max", + "none", + ], f"{method} is not a valid normalisation method" + + _norm_add = np.zeros((n,), dtype=np.float32) + _norm_mul = np.ones((n,), dtype=np.float32) + + for name, i in name_to_index.items(): + m = methods.get(name, default) + if m == "mean-std": + LOGGER.debug(f"Normalizing: {name} is mean-std-normalised.") + if stdev[i] < (mean[i] * 1e-6): + warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}") + _norm_mul[i] = 1 / stdev[i] + _norm_add[i] = -mean[i] / stdev[i] + + elif m == "min-max": + LOGGER.debug(f"Normalizing: {name} is min-max-normalised to [0, 1].") + x = maximum[i] - minimum[i] + if x < 1e-9: + warnings.warn(f"Normalizing: the field {name} seems to have only one value {maximum[i]}.") + _norm_mul[i] = 1 / x + _norm_add[i] = -minimum[i] / x + + elif m == "max": + LOGGER.debug(f"Normalizing: {name} is max-normalised to [0, 1].") + _norm_mul[i] = 1 / maximum[i] + + elif m == "none": + LOGGER.info(f"Normalizing: {name} is not normalized.") + + else: + raise ValueError[f"Unknown normalisation method for {name}: {m}"] + + # register buffer - this will ensure they get copied to the correct device(s) + self.register_buffer("_norm_mul", torch.from_numpy(_norm_mul), persistent=True) + self.register_buffer("_norm_add", torch.from_numpy(_norm_add), persistent=True) + self.register_buffer("_input_idx", data_indices.data.input.full, persistent=True) + self.register_buffer("_output_idx", data_indices.data.output.full, persistent=True) + + def normalize( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Normalizes an input tensor x of shape [..., nvars]. + + Normalization done in-place unless specified otherwise. + + The default usecase either assume the full batch tensor or the full input tensor. + A dataindex is based on the full data can be supplied to choose which variables to normalise. + + Parameters + ---------- + x : torch.Tensor + Data to normalize + in_place : bool, optional + Normalize in-place, by default True + data_index : Optional[torch.Tensor], optional + Normalize only the specified indices, by default None + + Returns + ------- + torch.Tensor + _description_ + """ + if not in_place: + x = x.clone() + + if data_index is not None: + x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index] + elif x.shape[-1] == len(self._input_idx): + x[..., :] = x[..., :] * self._norm_mul[self._input_idx] + self._norm_add[self._input_idx] + else: + x[..., :] = x[..., :] * self._norm_mul + self._norm_add + return x + + def forward(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + return self.normalize(x, in_place=in_place) + + def denormalize( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Denormalizes an input tensor x of shape [..., nvars | nvars_pred]. + + Denormalization done in-place unless specified otherwise. + + The default usecase either assume the full batch tensor or the full output tensor. + A dataindex is based on the full data can be supplied to choose which variables to denormalise. + + Parameters + ---------- + x : torch.Tensor + Data to denormalize + in_place : bool, optional + Denormalize in-place, by default True + data_index : Optional[torch.Tensor], optional + Denormalize only the specified indices, by default None + + Returns + ------- + torch.Tensor + Denormalized data + """ + if not in_place: + x = x.clone() + + # Denormalize dynamic or full tensors + # input and predicted tensors have different shapes + # hence, we mask out the forcing indices + if data_index is not None: + x[..., :] = (x[..., :] - self._norm_add[data_index]) / self._norm_mul[data_index] + elif x.shape[-1] == len(self._output_idx): + x[..., :] = (x[..., :] - self._norm_add[self._output_idx]) / self._norm_mul[self._output_idx] + else: + x[..., :] = (x[..., :] - self._norm_add) / self._norm_mul + return x diff --git a/src/anemoi/models/data/scaling.py b/src/anemoi/models/data/scaling.py new file mode 100644 index 0000000..68d08ff --- /dev/null +++ b/src/anemoi/models/data/scaling.py @@ -0,0 +1,16 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import numpy as np + + +def pressure_level(plev) -> np.ndarray: + """Convert pressure levels to PyTorch Lightning scaling.""" + return np.array(plev) / 1000