From d188a1263267bb7f786453d52272037e93f6558c Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 17 May 2024 15:37:38 +0000 Subject: [PATCH 1/5] refactor: Initial data implementation Encompasses data_module and dataset. Co-authored-by: Jesper Dramsch Co-authored-by: Florian Pinault Co-authored-by: Baudouin Raoult Co-authored-by: Matthew Chantry Co-authored-by: mihai.alexe --- src/anemoi/training/data/__init__.py | 0 src/anemoi/training/data/data_module.py | 195 ++++++++++++++++++ src/anemoi/training/data/dataset.py | 253 ++++++++++++++++++++++++ 3 files changed, 448 insertions(+) create mode 100644 src/anemoi/training/data/__init__.py create mode 100644 src/anemoi/training/data/data_module.py create mode 100644 src/anemoi/training/data/dataset.py diff --git a/src/anemoi/training/data/__init__.py b/src/anemoi/training/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/anemoi/training/data/data_module.py b/src/anemoi/training/data/data_module.py new file mode 100644 index 00000000..0b596f60 --- /dev/null +++ b/src/anemoi/training/data/data_module.py @@ -0,0 +1,195 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os +from functools import cached_property + +import pytorch_lightning as pl +from anemoi.datasets.data import open_dataset +from anemoi.models.data.data_indices.collection import IndexCollection +from omegaconf import DictConfig +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from anemoi.training.data.dataset import NativeGridDataset +from anemoi.training.data.dataset import worker_init_func + +LOGGER = logging.getLogger(__name__) + + +class ECMLDataModule(pl.LightningDataModule): + """ECML data module for PyTorch Lightning.""" + + def __init__(self, config: DictConfig) -> None: + """Initialize ECML data module. + + Parameters + ---------- + config : DictConfig + Job configuration + """ + super().__init__() + LOGGER.setLevel(config.diagnostics.log.code.level) + + self.config = config + + # Determine the step size relative to the data frequency + frequency = self.config.data.frequency + timestep = self.config.data.timestep + assert ( + isinstance(frequency, str) and isinstance(timestep, str) and frequency[-1] == "h" and timestep[-1] == "h" + ), f"Error in format of timestep, {timestep}, or data frequency, {frequency}" + assert ( + int(timestep[:-1]) % int(frequency[:-1]) == 0 + ), f"Timestep isn't a multiple of data frequency, {timestep}, or data frequency, {frequency}" + self.timeincrement = int(timestep[:-1]) // int(frequency[:-1]) + LOGGER.info( + f"Timeincrement set to {self.timeincrement} for data with frequency, {frequency}, and timestep, {timestep}" + ) + + self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank + self.model_comm_group_id = ( + self.global_rank // self.config.hardware.num_gpus_per_model + ) # id of the model communication group the rank is participating in + self.model_comm_group_rank = ( + self.global_rank % self.config.hardware.num_gpus_per_model + ) # rank within one model communication group + total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes + assert ( + total_gpus + ) % self.config.hardware.num_gpus_per_model == 0, ( + f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" + ) + self.model_comm_num_groups = ( + self.config.hardware.num_gpus_per_node + * self.config.hardware.num_nodes + // self.config.hardware.num_gpus_per_model + ) # number of model communication groups + LOGGER.debug( + "Rank %d model communication group number %d, with local model communication group rank %d", + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + ) + + # Set the maximum rollout to be expected + self.rollout = ( + self.config.training.rollout.max + if self.config.training.rollout.epoch_increment > 0 + else self.config.training.rollout.start + ) + + # Set the training end date if not specified + if self.config.dataloader.training.end is None: + LOGGER.info( + "No end date specified for training data, setting default before validation start date %s.", + self.config.dataloader.validation.start - 1, + ) + self.config.dataloader.training.end = self.config.dataloader.validation.start - 1 + + def _check_resolution(self, resolution) -> None: + assert ( + self.config.data.resolution.lower() == resolution.lower() + ), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}" + + @cached_property + def statistics(self) -> dict: + return self.dataset_train.statistics + + @cached_property + def metadata(self) -> dict: + return self.dataset_train.metadata + + @cached_property + def data_indices(self) -> IndexCollection: + return IndexCollection(self.config, self.dataset_train.name_to_index) + + @cached_property + def dataset_train(self) -> NativeGridDataset: + return self._get_dataset( + open_dataset(OmegaConf.to_container(self.config.dataloader.training, resolve=True)), label="train" + ) + + @cached_property + def dataset_validation(self) -> NativeGridDataset: + r = self.rollout + if self.config.diagnostics.eval.enabled: + r = max(r, self.config.diagnostics.eval.rollout) + assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( + f"Training end date {self.config.dataloader.training.end} is not before" + f"validation start date {self.config.dataloader.validation.start}" + ) + return self._get_dataset( + open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)), + shuffle=False, + rollout=r, + label="validation", + ) + + @cached_property + def dataset_test(self) -> NativeGridDataset: + assert self.config.dataloader.training.end < self.config.dataloader.test.start, ( + f"Training end date {self.config.dataloader.training.end} is not before" + f"test start date {self.config.dataloader.test.start}" + ) + assert self.config.dataloader.validation.end < self.config.dataloader.test.start, ( + f"Validation end date {self.config.dataloader.validation.end} is not before" + f"test start date {self.config.dataloader.test.start}" + ) + return self._get_dataset( + open_dataset(OmegaConf.to_container(self.config.dataloader.test, resolve=True)), + shuffle=False, + label="test", + ) + + def _get_dataset( + self, data_reader, shuffle: bool = True, rollout: int = 1, label: str = "generic" + ) -> NativeGridDataset: + r = max(rollout, self.rollout) + data = NativeGridDataset( + data_reader=data_reader, + rollout=r, + multistep=self.config.training.multistep_input, + timeincrement=self.timeincrement, + model_comm_group_rank=self.model_comm_group_rank, + model_comm_group_id=self.model_comm_group_id, + model_comm_num_groups=self.model_comm_num_groups, + shuffle=shuffle, + label=label, + logging=self.config.diagnostics.log.code.level, + ) + self._check_resolution(data.resolution) + return data + + def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: + assert stage in ["training", "validation", "test"] + return DataLoader( + ds, + batch_size=self.config.dataloader.batch_size[stage], + # number of worker processes + num_workers=self.config.dataloader.num_workers[stage], + # use of pinned memory can speed up CPU-to-GPU data transfers + # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning + pin_memory=True, + # worker initializer + worker_init_fn=worker_init_func, + # prefetch batches + prefetch_factor=self.config.dataloader.prefetch_factor, + persistent_workers=True, + ) + + def train_dataloader(self) -> DataLoader: + return self._get_dataloader(self.dataset_train, "training") + + def val_dataloader(self) -> DataLoader: + return self._get_dataloader(self.dataset_validation, "validation") + + def test_dataloader(self) -> DataLoader: + return self._get_dataloader(self.dataset_test, "test") diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py new file mode 100644 index 00000000..94550638 --- /dev/null +++ b/src/anemoi/training/data/dataset.py @@ -0,0 +1,253 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import os +import random +from functools import cached_property +from typing import Callable +from typing import Optional + +import numpy as np +import torch +from anemoi.utils import get_base_seed +from einops import rearrange +from torch.utils.data import IterableDataset +from torch.utils.data import get_worker_info + +LOGGER = logging.getLogger(__name__) + + +class NativeGridDataset(IterableDataset): + """Iterable dataset for AnemoI data on the arbitrary grids.""" + + def __init__( + self, + data_reader: Callable, + rollout: int = 1, + multistep: int = 1, + timeincrement: int = 1, + model_comm_group_rank: int = 0, + model_comm_group_id: int = 0, + model_comm_num_groups: int = 1, + shuffle: bool = True, + label: str = "generic", + logging: str = "INFO", + ) -> None: + """Initialize (part of) the dataset state. + + Parameters + ---------- + data_reader : Callable + user function that opens and returns the zarr array data + rollout : int, optional + length of rollout window, by default 12 + multistep : int, optional + collate (t-1, ... t - multistep) into the input state vector, by default 1 + model_comm_group_rank : int, optional + process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 + model_comm_group_id: int, optional + device group ID, default 0 + model_comm_num_groups : int, optional + total number of device groups, by default 1 + shuffle : bool, optional + Shuffle batches, by default True + + Raises + ------ + RuntimeError + Multistep value cannot be negative. + """ + LOGGER.setLevel(logging) + self.label = label + + self.data = data_reader + + self.rollout = rollout + self.timeincrement = timeincrement + + # lazy init + self.n_samples_per_epoch_total: int = 0 + self.n_samples_per_epoch_per_worker: int = 0 + + # DDP-relevant info + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_id = model_comm_group_id + self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + + # additional state vars (lazy init) + self.n_samples_per_worker = 0 + self.chunk_index_range: Optional[np.ndarray] = None + self.shuffle = shuffle + + # Data dimensions + self.multi_step = multistep + assert self.multi_step > 0, "Multistep value must be greater than zero." + self.ensemble_dim: int = 2 + self.ensemble_size = self.data.shape[self.ensemble_dim] + + @cached_property + def statistics(self) -> dict: + """Return dataset statistics.""" + return self.data.statistics + + @cached_property + def metadata(self) -> dict: + """Return dataset metadata.""" + return self.data.metadata() + + @cached_property + def name_to_index(self) -> dict: + """Return dataset statistics.""" + return self.data.name_to_index + + @cached_property + def resolution(self) -> dict: + """Return dataset resolution.""" + return self.data.resolution + + def per_worker_init(self, n_workers: int, worker_id: int) -> None: + """Called by worker_init_func on each copy of dataset. + + This initialises after the worker process has been spawned. + + Parameters + ---------- + n_workers : int + Number of workers + worker_id : int + Worker ID + """ + self.worker_id = worker_id + + # Total number of valid ICs is dataset length minus rollout minus additional multistep inputs + len_corrected = len(self.data) - (self.rollout + (self.multi_step - 1)) * self.timeincrement + + # Divide this equally across shards (one shard per group!) + shard_size = len_corrected // self.model_comm_num_groups + shard_start = self.model_comm_group_id * shard_size + (self.multi_step - 1) * self.timeincrement + shard_end = min((self.model_comm_group_id + 1) * shard_size, len(self.data) - self.rollout * self.timeincrement) + + shard_len = shard_end - shard_start + self.n_samples_per_worker = shard_len // n_workers + + low = shard_start + worker_id * self.n_samples_per_worker + high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + + LOGGER.debug( + "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", + worker_id, + os.getpid(), + self.global_rank, + self.model_comm_group_id, + low, + high, + ) + + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) + + # each worker must have a different seed for its random number generator, + # otherwise all the workers will output exactly the same data + # should we check lightning env variable "PL_SEED_WORKERS" here? + # but we alwyas want to seed these anyways ... + + base_seed = get_base_seed() + + seed = ( + base_seed * (self.model_comm_group_id + 1) - worker_id + ) # note that test, validation etc. datasets get same seed + torch.manual_seed(seed) + random.seed(seed) + self.rng = np.random.default_rng(seed=seed) + sanity_rnd = self.rng.random(1) + + LOGGER.debug( + "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, group_rank %d, base_seed %d) using seed %d, sanity rnd %f", + worker_id, + self.label, + os.getpid(), + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + base_seed, + seed, + sanity_rnd, + ) + + def __iter__(self): + """Return an iterator over the dataset. + + The datasets are retrieved by ECML Tools from zarr files. This iterator yields + chunked batches for DDP and sharded training. + + Currently it receives data with an ensemble dimension, which is discarded for + now. (Until the code is "ensemble native".) + """ + if self.shuffle: + shuffled_chunk_indices = self.rng.choice( + self.chunk_index_range, size=self.n_samples_per_worker, replace=False + ) + else: + shuffled_chunk_indices = self.chunk_index_range + + LOGGER.debug( + "Worker pid %d, label %s, worker id %d, global_rank %d, model comm group %d, group_rank %d using indices[0:10]: %s", + os.getpid(), + self.label, + self.worker_id, + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + shuffled_chunk_indices[:10], + ) + + for i in shuffled_chunk_indices: + start = i - (self.multi_step - 1) * self.timeincrement + end = i + (self.rollout + 1) * self.timeincrement + + x = self.data[start : end : self.timeincrement] + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") + self.ensemble_dim = 1 + + yield torch.from_numpy(x) + + def __repr__(self) -> str: + return f""" + {super().__repr__()} + Dataset: {self.data} + Rollout: {self.rollout} + Multistep: {self.multi_step} + Timeincrement: {self.timeincrement} + """ + + +def worker_init_func(worker_id: int) -> None: + """Configures each dataset worker process. + + Calls WeatherBenchDataset.per_worker_init() on each dataset object. + + Parameters + ---------- + worker_id : int + Worker ID + + Raises + ------ + RuntimeError + If worker_info is None + """ + worker_info = get_worker_info() # information specific to each worker process + if worker_info is None: + LOGGER.error("worker_info is None! Set num_workers > 0 in your dataloader!") + raise RuntimeError + dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process. + dataset_obj.per_worker_init( + n_workers=worker_info.num_workers, + worker_id=worker_id, + ) From d79d326513281178eb995c5e42dda55b1fd62186 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 17 May 2024 15:42:23 +0000 Subject: [PATCH 2/5] refactor: Initial loss implementation Encompasses MSE loss and grad_scaler. Co-authored-by: Jesper Dramsch Co-authored-by: Simon Lang Co-authored-by: Matthew Chantry Co-authored-by: mihai.alexe --- src/anemoi/training/losses/__init__.py | 0 src/anemoi/training/losses/mse.py | 71 ++++++++++++++++++++++++++ src/anemoi/training/losses/utils.py | 52 +++++++++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 src/anemoi/training/losses/__init__.py create mode 100644 src/anemoi/training/losses/mse.py create mode 100644 src/anemoi/training/losses/utils.py diff --git a/src/anemoi/training/losses/__init__.py b/src/anemoi/training/losses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py new file mode 100644 index 00000000..6e38b78d --- /dev/null +++ b/src/anemoi/training/losses/mse.py @@ -0,0 +1,71 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import torch +from torch import nn + +LOGGER = logging.getLogger(__name__) + + +class WeightedMSELoss(nn.Module): + """Latitude-weighted MSE loss.""" + + def __init__(self, area_weights: torch.Tensor, data_variances: Optional[torch.Tensor] = None) -> None: + """Latitude- and (inverse-)variance-weighted MSE Loss. + + Parameters + ---------- + area_weights : torch.Tensor + Weights by area + data_variances : Optional[torch.Tensor], optional + precomputed, per-variable stepwise variance estimate, by default None + """ + super().__init__() + + self.register_buffer("weights", area_weights, persistent=True) + if data_variances is not None: + self.register_buffer("ivar", data_variances, persistent=True) + + def forward(self, pred: torch.Tensor, target: torch.Tensor, squash=True) -> torch.Tensor: + """Calculates the lat-weighted MSE loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + + Returns + ------- + torch.Tensor + Weighted MSE loss + """ + out = torch.square(pred - target) + + # Use variances if available + if hasattr(self, "ivar"): + out *= self.ivar + + # Squash by last dimension + if squash: + out = out.mean(dim=-1) + out = out * self.weights.expand_as(out) + out /= torch.sum(self.weights.expand_as(out)) + return out.sum() + + # Weight by area + out = out * self.weights[..., None].expand_as(out) + out /= torch.sum(self.weights[..., None].expand_as(out)) + return out.sum(axis=(0, 1, 2)) diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py new file mode 100644 index 00000000..dcd2a11b --- /dev/null +++ b/src/anemoi/training/losses/utils.py @@ -0,0 +1,52 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import torch +from torch import nn + +LOGGER = logging.getLogger(__name__) + + +def grad_scaler( + module: nn.Module, + grad_in: tuple[torch.Tensor, ...], + grad_out: tuple[torch.Tensor, ...], +) -> Optional[tuple[torch.Tensor, ...]]: + """Scales the loss gradients. + + Uses the formula in https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2 + + Use .register_full_backward_hook(grad_scaler, prepend=False) to register this hook. + + Parameters + ---------- + module : nn.Module + Loss object (not used) + grad_in : tuple[torch.Tensor, ...] + Loss gradients + grad_out : tuple[torch.Tensor, ...] + Output gradients (not used) + + Returns + ------- + tuple[torch.Tensor, ...] + Re-scaled input gradients + """ + del module, grad_out # not needed + # loss = module(x_pred, x_true) + # so - the first grad_input is that of the predicted state and the second is that of the "ground truth" (== zero) + channels = grad_in[0].shape[-1] # number of channels + channel_weights = torch.reciprocal(torch.sum(torch.abs(grad_in[0]), dim=1, keepdim=True)) # channel-wise weights + new_grad_in = ( + (channels * channel_weights) / torch.sum(channel_weights, dim=-1, keepdim=True) * grad_in[0] + ) # rescaled gradient + return new_grad_in, grad_in[1] From 3a46b4e5cbde2fded41b1d3d876c30e3839c8260 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 20 May 2024 15:11:13 +0000 Subject: [PATCH 3/5] chore: add dependencies --- pyproject.toml | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index deb952d3..bd12f0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,24 @@ classifiers = [ ] dependencies = [ - + "torch>=2.2", + "torch-geometric>=2.3.1,<2.5", + "einops>=0.6.1", + "anemoi-datasets[data]>=0.1.0", + "anemoi-utils[provenance]>=0.1.3", + "anemoi-models@git+https://github.com/ecmwf/anemoi-models.git", + "pytorch-lightning>=2.1.0", + "timm>=0.9.2", + "hydra-core>=1.3", + "matplotlib>=3.7.1", + "tqdm>=4.65.0", + "torchinfo>=1.8.0", + "zarr>=2.14.2", + "pre-commit>=3.3.3", + "mlflow>=2.11.1", + "pynvml>=11.5.0", + "mlflow-export-import>=1.2.0", + "pyshtools>=4.10.4", ] [project.optional-dependencies] From 534012ea960b6777939509c9bea43eeab72b240c Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 20 May 2024 15:12:56 +0000 Subject: [PATCH 4/5] refactor: Initial implementation of diagnostics module Provides a set of tools for monitoring and evaluating the performance of machine learning models. The module includes a set of classes and functions for logging, profiling, and visualizing model performance. Co-authored-by: Jesper Dramsch Co-authored-by: Matthew Chantry Co-authored-by: Simon Lang Co-authored-by: Mihai Alexe Co-authored-by: Sara Hahner Co-authored-by: Ana Prieto Nemesio --- src/anemoi/training/diagnostics/__init__.py | 0 src/anemoi/training/diagnostics/callbacks.py | 737 ++++++++++++++++++ src/anemoi/training/diagnostics/maps.py | 80 ++ .../training/diagnostics/mlflow_logger.py | 352 +++++++++ src/anemoi/training/diagnostics/plots.py | 500 ++++++++++++ src/anemoi/training/diagnostics/profilers.py | 511 ++++++++++++ 6 files changed, 2180 insertions(+) create mode 100644 src/anemoi/training/diagnostics/__init__.py create mode 100644 src/anemoi/training/diagnostics/callbacks.py create mode 100644 src/anemoi/training/diagnostics/maps.py create mode 100644 src/anemoi/training/diagnostics/mlflow_logger.py create mode 100644 src/anemoi/training/diagnostics/plots.py create mode 100644 src/anemoi/training/diagnostics/profilers.py diff --git a/src/anemoi/training/diagnostics/__init__.py b/src/anemoi/training/diagnostics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/anemoi/training/diagnostics/callbacks.py b/src/anemoi/training/diagnostics/callbacks.py new file mode 100644 index 00000000..e0b86481 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks.py @@ -0,0 +1,737 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import copy +import json +import logging +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from pathlib import Path +from typing import Any +from typing import Optional +from zipfile import ZipFile + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +import torchinfo +from omegaconf import DictConfig +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +from anemoi.training.diagnostics.plots import init_plot_settings +from anemoi.training.diagnostics.plots import plot_graph_features +from anemoi.training.diagnostics.plots import plot_histogram +from anemoi.training.diagnostics.plots import plot_loss +from anemoi.training.diagnostics.plots import plot_power_spectrum +from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample + +LOGGER = logging.getLogger(__name__) + + +class PlotCallback(Callback): + """Factory for creating a callback that plots data to Weights and Biases.""" + + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.save_basedir = config.hardware.paths.plots + self.plot_frequency = config.diagnostics.plot.frequency + self.normalizer = None + self.latlons = None + init_plot_settings() + + def _output_figure(self, logger, fig, epoch: int, tag: str = "gnn") -> None: + """Figure output: save to file and/or display in notebook.""" + if self.save_basedir is not None: + save_path = Path( + self.save_basedir, + "plots", + f"{tag}_epoch{epoch:03d}.png", + ) + + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=100, bbox_inches="tight") + + if self.config.diagnostics.log.mlflow.enabled: + run_id = logger.run_id + logger.experiment.log_artifact(run_id, str(save_path)) + + plt.close(fig) # cleanup + + +class AsyncPlotCallback(PlotCallback): + """Factory for creating a callback that plots data to Weights and Biases.""" + + def __init__(self, config) -> None: + super().__init__(config) + + self._executor = ThreadPoolExecutor(max_workers=1) + self._error: Optional[BaseException] = None + + def teardown(self, trainer, pl_module, stage) -> None: + """Close the threads.""" + self._executor.shutdown(wait=True) + self.check_error() + + def check_error(self) -> None: + # if an error was raised anytime in any of the `executor.submit` calls + if self._error: + raise self._error + + def _plot( + *args, + **kwargs, + ) -> None: + NotImplementedError + + def _async_plot( + self, + trainer, + *args, + **kwargs, + ) -> None: + """Execute the plot function but ensuring we catch any errors.""" + try: + if trainer.is_global_zero: + self._plot(trainer, *args, **kwargs) + except BaseException as ex: + self._error = ex + + +class RolloutEval(Callback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__(self, config) -> None: + """Initialize RolloutEval callback. + + Parameters + ---------- + config : dict + Dictionary with configuration settings + """ + super().__init__() + + LOGGER.setLevel(config.diagnostics.log.code.level) + + LOGGER.debug( + "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", + config.diagnostics.eval.rollout, + config.diagnostics.eval.frequency, + ) + self.rollout = config.diagnostics.eval.rollout + self.frequency = config.diagnostics.eval.frequency + + def _eval( + self, + pl_module: pl.LightningModule, + batch: torch.Tensor, + ) -> None: + loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) + # NB! the batch is already normalized in-place - see pl_model.validation_step() + metrics = {} + + # start rollout + x = batch[ + :, 0 : pl_module.multi_step, ..., pl_module.data_indices.data.input.full + ] # (bs, multi_step, latlon, nvar) + assert ( + batch.shape[1] >= self.rollout + pl_module.multi_step + ), "Batch length not sufficient for requested rollout length!" + + with torch.no_grad(): + for rollout_step in range(self.rollout): + y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) + y = batch[ + :, + pl_module.multi_step + rollout_step, + ..., + pl_module.data_indices.data.output.full, + ] # target, shape = (bs, latlon, nvar) + # y includes the auxiliary variables, so we must leave those out when computing the loss + loss += pl_module.loss(y_pred, y) + + x = pl_module.advance_input(x, y_pred, batch, rollout_step) + + metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) + metrics.update(metrics_next) + + # scale loss + loss *= 1.0 / self.rollout + self._log(pl_module, loss, metrics, batch.shape[0]) + + def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: + pl_module.log( + f"val_r{self.rollout}_wmse", + loss, + on_epoch=True, + on_step=True, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + for mname, mvalue in metrics.items(): + pl_module.log( + f"val_r{self.rollout}_" + mname, + mvalue, + on_epoch=True, + on_step=False, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + del trainer, outputs # not used + if batch_idx % self.frequency == 3 and pl_module.global_rank == 0: + self._eval(pl_module, batch) + + +class GraphTrainableFeaturesPlot(AsyncPlotCallback): + """Visualize the trainable features defined at the data and hidden graph nodes. + + TODO: How best to visualize the learned edge embeddings? Offline, perhaps - using code from @Simon's notebook? + """ + + def __init__(self, config) -> None: + super().__init__(config) + self._graph_name_data = config.graph.data + self._graph_name_hidden = config.graph.hidden + + def _plot( + # self, trainer, latlons:np.ndarray, features:np.ndarray, tag:str, exp_log_tag:str + self, + trainer, + latlons, + features, + epoch, + tag, + exp_log_tag, + ) -> None: + fig = plot_graph_features(latlons, features) + self._output_figure(trainer.logger, fig, epoch=epoch, tag=tag, exp_log_tag=exp_log_tag) + + def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if pl_module.global_rank == 0: + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + graph = pl_module.graph_data.cpu() + epoch = trainer.current_epoch + + if model.trainable_data is not None: + data_coords = np.rad2deg( + graph[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.numpy() + ) + + self._executor.submit( + self._async_plot, + trainer, + data_coords, + model.trainable_data.trainable.cpu(), + epoch=epoch, + tag="trainable_data", + exp_log_tag="trainable_data", + ) + + if model.trainable_hidden is not None: + hidden_coords = np.rad2deg( + graph[(self._graph_name_hidden, "to", self._graph_name_hidden)].hcoords_rad.numpy() + ) + + self._executor.submit( + self._async_plot, + trainer, + hidden_coords, + model.trainable_hidden.trainable.cpu(), + epoch=epoch, + tag="trainable_hidden", + exp_log_tag="trainable_hidden", + ) + + self.check_error() + + +class PlotLoss(AsyncPlotCallback): + """Plots the unsqueezed loss over rollouts.""" + + def __init__(self, config) -> None: + super().__init__(config) + + def _plot( + self, + trainer, + pl_module, + outputs, + batch, + epoch, + ) -> None: + logger = trainer.logger + del trainer + for rollout_step in range(pl_module.rollout): + y_hat = outputs[1][rollout_step] + y_true = batch[:, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.data.output.full] + loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() + fig = plot_loss(loss) + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + ) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: + if batch_idx % self.plot_frequency == 3 and trainer.global_rank == 0: + self._async_plot(trainer, pl_module, outputs, batch, epoch=trainer.current_epoch) + + self.check_error() + + +class PlotSample(AsyncPlotCallback): + """Plots a denormalized sample: input, target and prediction.""" + + def __init__(self, config) -> None: + super().__init__(config) + self.sample_idx = self.config.diagnostics.plot.sample_idx + + def _plot( + # batch_idx: int, rollout_step: int, x: torch.Tensor, y_true: torch.Tensor, y_pred: torch.Tensor, + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + epoch, + ) -> None: + logger = trainer.logger + + # Build dictionary of indices and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) + for name in self.config.diagnostics.plot.parameters + } + + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.normalizer is None: + # Copy to be used across all the training cycle + self.normalizer = copy.deepcopy(pl_module.model.normalizer).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.data.output.full, + ].cpu() + data = self.normalizer.denormalize(input_tensor).numpy() + + output_tensor = self.normalizer.denormalize( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ).numpy() + + for rollout_step in range(pl_module.rollout): + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.config.diagnostics.plot.per_sample, + self.latlons, + self.config.diagnostics.plot.accumulation_levels_plot, + self.config.diagnostics.plot.cmap_accumulation, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", + ) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: + if batch_idx % self.plot_frequency == 3 and trainer.global_rank == 0: + self._executor.submit( + self._async_plot, trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch + ) + + self.check_error() + + +class PlotAdditionalMetrics(AsyncPlotCallback): + """Plots TP related metric comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + + - Power Spectrum + - Histograms + """ + + def __init__(self, config) -> None: + super().__init__(config) + self.sample_idx = self.config.diagnostics.plot.sample_idx + + def _plot( + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + epoch, + ) -> None: + logger = trainer.logger + + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.normalizer is None: + # Copy to be used across all the training cycle + self.normalizer = copy.deepcopy(pl_module.model.normalizer).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.data_latlons.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.data.output.full, + ].cpu() + data = self.normalizer.denormalize(input_tensor).numpy() + output_tensor = self.normalizer.denormalize( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ).numpy() + + for rollout_step in range(pl_module.rollout): + if self.config.diagnostics.plot.parameters_histogram is not None: + # Build dictionary of inidicies and parameters to be plotted + + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict_histogram = { + pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) + for name in self.config.diagnostics.plot.parameters_histogram + } + + fig = plot_histogram( + plot_parameters_dict_histogram, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) + + if self.config.diagnostics.plot.parameters_spectrum is not None: + # Build dictionary of inidicies and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + + plot_parameters_dict_spectrum = { + pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) + for name in self.config.diagnostics.plot.parameters_spectrum + } + + fig = plot_power_spectrum( + plot_parameters_dict_spectrum, + self.latlons, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: + if batch_idx % self.plot_frequency == 3 and trainer.global_rank == 0: + self._executor.submit( + self._async_plot, trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch + ) + + self.check_error() + + +class ParentUUIDCallback(Callback): + """A callback that retrieves the parent UUID for a model, if it is a child model.""" + + def __init__(self, config, **kwargs): + super().__init__() + self.config = config + + def on_load_checkpoint(self, trainer, pl_module, checkpoint): + pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] + + +class AnemoiCheckpoint(ModelCheckpoint): + """A checkpoint callback that saves the model after every validation epoch.""" + + def __init__(self, config, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + self.start = time.time() + self._model_metadata = None + self._tracker_metadata = None + self._tracker_name = None + + def _torch_drop_down(self, trainer: pl.Trainer) -> torch.nn.Module: + # Get the model from the DataParallel wrapper, for single and multi-gpu cases + assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" + return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model + + def model_metadata(self, model): + if self._model_metadata is not None: + return self._model_metadata + + self._model_metadata = { + "model": model.__class__.__name__, + "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), + "total_parameters": sum(p.numel() for p in model.parameters()), + "summary": repr( + torchinfo.summary( + model, + depth=50, + verbose=0, + row_settings=["var_names"], + ), + ), + } + + return self._model_metadata + + def tracker_metadata(self, trainer): + if self._tracker_metadata is not None: + return {self._tracker_name: self._tracker_metadata} + + elif self.config.diagnostics.log.mlflow.enabled: + self._tracker_name = "mlflow" + + from anemoi.training.diagnostics.mlflow_logger import AIFSMLflowLogger + + mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AIFSMLflowLogger)) + run_id = mlflow_logger.run_id + run = mlflow_logger._mlflow_client.get_run(run_id) + + if run is not None: + self._tracker_metadata = { + "id": run.info.run_id, + "name": run.info.run_name, + "url": run.info.artifact_uri, + "project": run.info.experiment_id, + } + else: + self._tracker_metadata = {} + + return {self._tracker_name: self._tracker_metadata} + + def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: + if trainer.is_global_zero: + model = self._torch_drop_down(trainer) + + # We want a different uuid each time we save the model + # so we can tell them apart in the catalogue (i.e. different epochs) + checkpoint_uuid = str(uuid.uuid4()) + trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid + + trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) + trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) + + trainer.lightning_module._hparams["metadata"]["training"] = { + "current_epoch": trainer.current_epoch, + "global_step": trainer.global_step, + "elapsed_time": time.time() - self.start, + } + + Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) + + save_config = model.config + model.config = None + + save_metadata = model.metadata + model.metadata = None + + metadata = dict(**save_metadata) + + inference_checkpoint_filepath = Path(lightning_checkpoint_filepath).parent / Path( + "inference-" + str(Path(lightning_checkpoint_filepath).name), + ) + + torch.save(model, inference_checkpoint_filepath) + + with ZipFile(inference_checkpoint_filepath, "a") as zipf: + base = Path(inference_checkpoint_filepath).stem + zipf.writestr( + f"{base}/ai-models.json", + json.dumps(metadata), + ) + + model.config = save_config + model.metadata = save_metadata + + self._last_global_step_saved = trainer.global_step + + trainer.strategy.barrier() + + # saving checkpoint used for pytorch-lightning based training + trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) + self._last_global_step_saved = trainer.global_step + self._last_checkpoint_saved = lightning_checkpoint_filepath + + # notify loggers + if trainer.is_global_zero: + from weakref import proxy + + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) + + +def get_callbacks(config: DictConfig) -> list: + """Setup callbacks for PyTorch Lightning trainer. + + Parameters + ---------- + config : DictConfig + Job configuration + + Returns + ------- + List + A list of PyTorch Lightning callbacks + """ + LOGGER.setLevel(config.diagnostics.log.code.level) + + checkpoint_settings = { + "dirpath": config.hardware.paths.checkpoints, + "verbose": False, + # save weights, optimizer states, LR-schedule states, hyperparameters etc. + # https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html#contents-of-a-checkpoint + "save_weights_only": False, + "auto_insert_metric_name": False, + # save after every validation epoch, if we've improved + "save_on_train_epoch_end": False, + "enable_version_counter": False, + } + + ckpt_frequency_save_dict = {} + for key, frequency_dict in config.diagnostics.checkpoint.items(): + frequency = frequency_dict["save_frequency"] + n_saved = frequency_dict["num_models_saved"] + if key == "every_n_minutes" and frequency_dict["save_frequency"] is not None: + target = "train_time_interval" + frequency = timedelta(minutes=frequency_dict["save_frequency"]) + else: + target = key + ckpt_frequency_save_dict[target] = (config.hardware.files.checkpoint[key], frequency, n_saved) + + trainer_callbacks = [] + if not config.diagnostics.profiler: + for save_key, (name, save_frequency, save_n_models) in ckpt_frequency_save_dict.items(): + if save_frequency is not None: + LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency) + trainer_callbacks.extend( + # save_top_k: the save_top_k flag can either save the best or the last k checkpoints + # depending on the monitor flag on ModelCheckpoint. + # See https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html for reference + [ + AnemoiCheckpoint( + config=config, + filename=name, + save_last=True, + **{save_key: save_frequency}, + # if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved + save_top_k=save_n_models, + monitor="step", + mode="max", + **checkpoint_settings, + ), + ], + ) + else: + LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) + else: + # the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints + LOGGER.warning("Profiling is enabled - AIFS will not write any training or inference model checkpoints!") + + if any([config.diagnostics.log.wandb.enabled, config.diagnostics.log.mlflow.enabled]): + from pytorch_lightning.callbacks import LearningRateMonitor + + trainer_callbacks.append( + LearningRateMonitor( + logging_interval="step", + log_momentum=False, + ), + ) + + if config.diagnostics.eval.enabled: + trainer_callbacks.append(RolloutEval(config)) + + if config.diagnostics.plot.enabled: + trainer_callbacks.extend( + [ + PlotLoss(config), + PlotSample(config), + ], + ) + if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: + trainer_callbacks.extend([PlotAdditionalMetrics(config)]) + + if config.training.swa.enabled: + from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging + + trainer_callbacks.append( + StochasticWeightAveraging( + swa_lrs=config.training.swa.lr, + swa_epoch_start=min( + int(0.75 * config.training.max_epochs), + config.training.max_epochs - 1, + ), + annealing_epochs=max(int(0.25 * config.training.max_epochs), 1), + annealing_strategy="cos", + # TODO: do we want the averaging to happen on the CPU, to save memory? + device=None, + ), + ) + + trainer_callbacks.append(ParentUUIDCallback(config)) + + if config.diagnostics.plot.learned_features: + LOGGER.debug("Setting up a callback to plot the trainable graph node features ...") + trainer_callbacks.append(GraphTrainableFeaturesPlot(config)) + + return trainer_callbacks diff --git a/src/anemoi/training/diagnostics/maps.py b/src/anemoi/training/diagnostics/maps.py new file mode 100644 index 00000000..02311826 --- /dev/null +++ b/src/anemoi/training/diagnostics/maps.py @@ -0,0 +1,80 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import copy +import json + +import numpy as np +from matplotlib.collections import LineCollection + +from anemoi.training import diagnostics + + +class EquirectangularProjection: + """Class to convert lat/lon coordinates to Equirectangular coordinates.""" + + def __init__(self) -> None: + self.x_offset = 0.0 + self.y_offset = 0.0 + + def __call__(self, lon, lat): + lon_rad = np.radians(lon) + lat_rad = np.radians(lat) + x = [v - 2 * np.pi if v > np.pi else v for v in lon_rad] + y = lat_rad + return x, y + + def inverse(self, x, y): + lon = np.degrees(x) + lat = np.degrees(y) + return lon, lat + + +class Coastlines: + """Class to plot coastlines from a GeoJSON file.""" + + def __init__(self, projection=None) -> None: + try: + # this requires python 3.9 or newer + from importlib.resources import files + except ImportError: + try: + from importlib_resources import files + except ModuleNotFoundError as e: + raise ModuleNotFoundError("Please install importlib_resources on Python <=3.8.") from e + + # Get the path to "continents.json" within your library + self.continents_file = files(diagnostics) / "continents.json" + + # Load GeoJSON data from the file + with self.continents_file.open("rt") as file: + self.data = json.load(file) + + if projection is None: + self.projection = EquirectangularProjection() + + self.process_data() + + # Function to extract LineString coordinates + @staticmethod + def extract_coordinates(feature): + return feature["geometry"]["coordinates"] + + def process_data(self) -> None: + lines = [] + for feature in self.data["features"]: + coordinates = self.extract_coordinates(feature) + x, y = zip(*coordinates) # Unzip the coordinates into separate x and y lists + + lines.append(list(zip(*self.projection(x, y)))) # Convert lat/lon to Cartesian coordinates + self.lines = LineCollection(lines, linewidth=0.5, color="black") + + def plot_continents(self, ax) -> None: + # Add the lines to the axis as a collection + # Note that we have to provide a copy of the lines, because of Matplotlib + ax.add_collection(copy.copy(self.lines)) diff --git a/src/anemoi/training/diagnostics/mlflow_logger.py b/src/anemoi/training/diagnostics/mlflow_logger.py new file mode 100644 index 00000000..ea723018 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow_logger.py @@ -0,0 +1,352 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import io +import logging +import os +import re +import sys +import time +from argparse import Namespace +from pathlib import Path +from threading import Thread +from typing import Any +from typing import Literal +from typing import Optional +from typing import Union +from weakref import WeakValueDictionary + +from pytorch_lightning.loggers.mlflow import MLFlowLogger +from pytorch_lightning.loggers.mlflow import _convert_params +from pytorch_lightning.loggers.mlflow import _flatten_dict +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +LOGGER = logging.getLogger(__name__) + + +def get_mlflow_run_params(config, tracking_uri): + run_id = None + tags = {"projectName": config.diagnostics.log.mlflow.project_name} + # create a tag with the command used to run the script + tags["command"] = sys.argv[0].split("/")[-1] # get the python script name + if len(sys.argv) > 1: + # add the arguments to the command tag + tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) + if config.training.run_id or config.training.fork_run_id: + "Either run_id or fork_run_id must be provided to resume a run." + import mlflow + + mlflow_client = mlflow.MlflowClient(tracking_uri) + + if config.training.run_id: + parent_run_id = config.training.run_id # parent_run_id + run_name = mlflow_client.get_run(parent_run_id).info.run_name + tags["mlflow.parentRunId"] = parent_run_id + tags["resumedRun"] = "True" # tags can't take boolean values + else: + parent_run_id = config.training.fork_run_id + tags["forkedRun"] = "True" + tags["forkedRunId"] = parent_run_id + + if config.diagnostics.log.mlflow.run_name: + run_name = config.diagnostics.log.mlflow.run_name + else: + import uuid + + run_name = f"{uuid.uuid4()!s}" + return run_id, run_name, tags + + +class LogsMonitor: + """Class for logging terminal output. + + Inspired by the class for logging terminal output in aim. + Aim-Code: https://github.com/aimhubio/aim/blob/94646d2d317ec7a43303a16530f7963e4e652921/aim/ext/resource/tracker.py#L20 + + Note: If there is an error, the terminal output logging ends before the error message is printed into the log file. + In order for the user to see the error message, the user must look at the slurm output file. + We provide the SLRM job id in the very beginning of the log file and print the final status of the run in the end. + + Parameters + ---------- + artifact_save_dir : str + Directory to save the terminal logs. + experiment : MLflow experiment object. + MLflow experiment object. + run_id: str + MLflow run ID. + log_time_interval : int + Interval (in seconds) at which to write buffered terminal outputs, default 30 + """ + + _buffer_registry = WeakValueDictionary() + _old_out_write = None + _old_err_write = None + + def __init__(self, artifact_save_dir, experiment, run_id, log_time_interval=30.0) -> None: + # active run + self.experiment = experiment + self.run_id = run_id + + # terminal log capturing + self._log_capture_interval = 1 + self._log_time_interval = log_time_interval + self._old_out = None + self._old_err = None + self._io_buffer = io.BytesIO() + + # Start thread to collect stats and logs at intervals + self._th_collector = Thread(target=self._log_collector, daemon=True) + self._shutdown = False + self._started = False + + # open your files here + self.artifact_save_dir = artifact_save_dir + self.file_save_path = Path(artifact_save_dir, "terminal_log.txt") + self.file_save_path.parent.mkdir(parents=True, exist_ok=True) + + @classmethod + def _install_stream_patches(cls) -> None: + cls._old_out_write = sys.stdout.write + cls._old_err_write = sys.stderr.write + + def new_out_write(data) -> None: + # out to buffer + cls._old_out_write(data) + if isinstance(data, str): + data = data.encode() + for buffer in cls._buffer_registry.values(): + buffer.write(data) + + def new_err_write(data) -> None: + # err to buffer + cls._old_err_write(data) + if isinstance(data, str): + data = data.encode() + for buffer in cls._buffer_registry.values(): + buffer.write(data) + + sys.stdout.write = new_out_write + sys.stderr.write = new_err_write + + @classmethod + def _uninstall_stream_patches(cls) -> None: + sys.stdout.write = cls._old_out_write + sys.stderr.write = cls._old_err_write + + def start(self) -> None: + """Start collection.""" + if self._started: + return + self._started = True + # install the stream patches if not done yet + if not self._buffer_registry: + self._install_stream_patches() + self._buffer_registry[id(self)] = self._io_buffer + # Start thread to asynchronously collect logs + self._th_collector.start() + LOGGER.info("Termial Log Path: " + str(self.file_save_path)) + if os.getenv("SLURM_JOB_ID"): + LOGGER.info("SLURM job id: " + os.getenv("SLURM_JOB_ID")) + + def finish(self, status) -> None: + """Stop the monitoring and close the log file.""" + if not self._started: + return + LOGGER.info( + "Stopping terminal log monitoring and saving buffered terminal outputs. Final status: " + + status.upper() + + "." + ) + self._shutdown = True + # read and store remaining buffered logs + self._store_buffered_logs() + # unregister the buffer + del self._buffer_registry[id(self)] + # uninstall stream patching if no buffer is left in the registry + if not self._buffer_registry: + self._uninstall_stream_patches() + + with self.file_save_path.open("a") as logfile: + logfile.write("\n\n") + logfile.flush() + logfile.close() + + def _log_collector(self) -> None: + """Log collecting thread body. + + Main monitoring loop, which consistently collect and log outputs. + """ + log_capture_time_counter = 0 + + while True: + if self._shutdown: + break + + time.sleep(self._log_time_interval) # in seconds + log_capture_time_counter += self._log_time_interval + + if log_capture_time_counter > self._log_capture_interval: + self._store_buffered_logs() + log_capture_time_counter = 0 + + def _store_buffered_logs(self) -> None: + _buffer_size = self._io_buffer.tell() + if not _buffer_size: + return + self._io_buffer.seek(0) + # read and reset the buffer + data = self._io_buffer.read(_buffer_size) + self._io_buffer.seek(0) + # handle the buffered data and store + # split lines and keep \n at the end of each line + lines = [e + b"\n" for e in data.split(b"\n") if e] + + ansi_csi_re = re.compile(b"\001?\033\\[((?:\\d|;)*)([a-dA-D])\002?") + + def _handle_csi(line): + # removes the cursor up and down symbols from the line + # skip tqdm status bar updates ending with "curser up" but the last one in buffer to save space + def _remove_csi(line): + return re.sub(ansi_csi_re, b"", line) + + for match in ansi_csi_re.finditer(line): + arg, command = match.groups() + arg = int(arg.decode()) if arg else 1 + if command == b"A" and (b"0%" not in line and not self._shutdown): # cursor up + # only keep x*10% status updates from tqmd status bars that end with a cursor up + # always keep shutdown commands + line = b"" + return _remove_csi(line) + + line = None + with self.file_save_path.open("a") as logfile: + for line in lines: + # handle cursor up and down symbols + line = _handle_csi(line) + # handle each line for carriage returns + line = line.rsplit(b"\r")[-1] + logfile.write(line.decode()) + + logfile.flush() + self.experiment.log_artifact(self.run_id, str(self.file_save_path)) + + +class AIFSMLflowLogger(MLFlowLogger): + """A custom MLflow logger that logs terminal output.""" + + def __init__( + self, + experiment_name: str = "lightning_logs", + run_name: Optional[str] = None, + tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), + tags: Optional[dict[str, Any]] = None, + save_dir: Optional[str] = "./mlruns", + log_model: Literal[True, False, "all"] = False, + prefix: str = "", + resumed: Optional[bool] = False, + forked: Optional[bool] = False, + run_id: Optional[str] = None, + offline: Optional[bool] = False, + # artifact_location: Optional[str] = None, + # avoid passing any artifact location otherwise it would mess up the offline logging of artifacts + ) -> None: + if offline: + # OFFLINE - When we run offline we can pass a save_dir pointing to a local path + tracking_uri = None + + else: + # ONLINE - When we pass a tracking_uri to mlflow then it will ignore the + # saving dir and save all artifacts/metrics to the remote server database + save_dir = None + + self._resumed = resumed + self._forked = forked + + super().__init__( + experiment_name=experiment_name, + run_name=run_name, + tracking_uri=tracking_uri, + tags=tags, + save_dir=save_dir, + log_model=log_model, + prefix=prefix, + run_id=run_id, + ) + + @rank_zero_only + def log_system_metrics(self) -> None: + """Log system metrics (CPU, GPU, etc).""" + import mlflow + from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor + + mlflow.enable_system_metrics_logging() + system_monitor = SystemMetricsMonitor( + self.run_id, + resume_logging=self.run_id is not None, + ) + global run_id_to_system_metrics_monitor + run_id_to_system_metrics_monitor = {} + run_id_to_system_metrics_monitor[self.run_id] = system_monitor + system_monitor.start() + + @rank_zero_only + def log_terminal_output(self, artifact_save_dir="") -> None: + """Log terminal logs to MLflow.""" + # path for logging terminal logs + # for now the 'terminal_logs' file is kept in the same folder as the plots + artifact_save_dir = Path(artifact_save_dir, self.run_id, "plots") + + log_monitor = LogsMonitor( + artifact_save_dir, + self.experiment, + self.run_id, + ) + global run_id_to_log_monitor + run_id_to_log_monitor = {} + run_id_to_log_monitor[self.run_id] = log_monitor + log_monitor.start() + + def _clean_params(self, params): + """Clean up params to avoid issues with mlflow. + + Too many logged params will make the server take longer to render the + experiment. + """ + prefixes_to_remove = ["hardware", "data", "dataloader", "model", "training", "diagnostics", "metadata.config"] + keys_to_remove = [key for key in params if any(key.startswith(prefix) for prefix in prefixes_to_remove)] + for key in keys_to_remove: + del params[key] + return params + + @rank_zero_only + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + """Overwrite the log_hyperparams method to flatten config params using '.'.""" + params = _convert_params(params) + params = _flatten_dict(params, delimiter=".") # Flatten dict with '.' to not break API queries + params = self._clean_params(params) + + from mlflow.entities import Param + + # Truncate parameter values to 250 characters. + # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 + params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] + + for idx in range(0, len(params_list), 100): + self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) + + @rank_zero_only + def finalize(self, status: str = "success") -> None: + # finalize logging and system metrics monitor + + if run_id_to_system_metrics_monitor: + run_id_to_system_metrics_monitor[self.run_id].finish() + if run_id_to_log_monitor: + run_id_to_log_monitor[self.run_id].finish(status) + + super().finalize(status) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py new file mode 100644 index 00000000..e6eb9bf3 --- /dev/null +++ b/src/anemoi/training/diagnostics/plots.py @@ -0,0 +1,500 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Optional + +import matplotlib.pyplot as plt +import matplotlib.style as mplstyle +import numpy as np +from matplotlib.colors import BoundaryNorm +from matplotlib.colors import ListedColormap +from matplotlib.colors import TwoSlopeNorm +from matplotlib.figure import Figure +from pyshtools.expand import SHGLQ +from pyshtools.expand import SHExpandGLQ +from scipy.interpolate import griddata + +from anemoi.training.diagnostics.maps import Coastlines +from anemoi.training.diagnostics.maps import EquirectangularProjection + +LOGGER = logging.getLogger(__name__) + + +continents = Coastlines() + + +def init_plot_settings() -> None: + """Initialize matplotlib plot settings.""" + small_font_size = 8 + medium_font_size = 10 + + mplstyle.use("fast") + plt.rcParams["path.simplify_threshold"] = 0.9 + + plt.rc("font", size=small_font_size) # controls default text sizes + plt.rc("axes", titlesize=small_font_size) # fontsize of the axes title + plt.rc("axes", labelsize=medium_font_size) # fontsize of the x and y labels + plt.rc("xtick", labelsize=small_font_size) # fontsize of the tick labels + plt.rc("ytick", labelsize=small_font_size) # fontsize of the tick labels + plt.rc("legend", fontsize=small_font_size) # legend fontsize + plt.rc("figure", titlesize=small_font_size) # fontsize of the figure title + + +def _hide_axes_ticks(ax) -> None: + """Hide x/y-axis ticks. + + Parameters + ---------- + ax : matplotlib.axes + Axes object handle + """ + plt.setp(ax.get_xticklabels(), visible=False) + plt.setp(ax.get_yticklabels(), visible=False) + ax.tick_params(axis="both", which="both", length=0) + + +def plot_loss( + x: np.ndarray, +) -> Figure: + """Plots data for one multilevel sample. + + Parameters + ---------- + x : np.ndarray + Data for Plotting of shape (npred,) + + Returns + ------- + Figure + The figure object handle. + """ + fig, ax = plt.subplots(1, 1, figsize=(4, 3)) + colors = [] + for c in "krbgym": + colors.extend([c] * 13) + colors.extend(["c"] * 12) + ax.bar(np.arange(x.size), x, color=colors, log=1) + + return fig + + +def plot_power_spectrum( + parameters: dict[str, int], + latlons: np.ndarray, + x: np.ndarray, + y_true: np.ndarray, + y_pred: np.ndarray, +) -> Figure: + """Plots power spectrum. + + NB: this can be very slow for large data arrays + call it as infrequently as possible! + + Parameters + ---------- + parameters : dict[str, int] + Dictionary of variable names and indices + latlons : np.ndarray + lat/lon coordinates array, shape (lat*lon, 2) + x : np.ndarray + Input data of shape (lat*lon, nvar*level) + y_true : np.ndarray + Expected data of shape (lat*lon, nvar*level) + y_pred : np.ndarray + Predicted data of shape (lat*lon, nvar*level) + + Returns + ------- + Figure + The figure object handle. + """ + n_plots_x, n_plots_y = len(parameters), 1 + + figsize = (n_plots_y * 4, n_plots_x * 3) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + + pc = EquirectangularProjection() + lat, lon = latlons[:, 0], latlons[:, 1] + pc_lon, pc_lat = pc(lon, lat) + pc_lon = np.array(pc_lon) + pc_lat = np.array(pc_lat) + # Calculate delta_lon and delta_lat on the projected grid + delta_lon = abs(np.diff(pc_lon)) + non_zero_delta_lon = delta_lon[delta_lon != 0] + delta_lat = abs(np.diff(pc_lat)) + non_zero_delta_lat = delta_lat[delta_lat != 0] + + # Define a regular grid for interpolation + n_pix_lon = int(np.floor(abs(pc_lon.max() - pc_lon.min()) / abs(np.min(non_zero_delta_lon)))) # around 400 for O96 + n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) # around 192 for O96 + regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon) + regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat) + grid_pc_lon, grid_pc_lat = np.meshgrid(regular_pc_lon, regular_pc_lat) + + for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): + yt = y_true[..., variable_idx].squeeze() + yp = y_pred[..., variable_idx].squeeze() + if output_only: + xt = x[..., variable_idx].squeeze() * int(output_only) + yt_i = griddata((pc_lon, pc_lat), (yt - xt), (grid_pc_lon, grid_pc_lat), method="cubic", fill_value=0.0) + yp_i = griddata((pc_lon, pc_lat), (yp - xt), (grid_pc_lon, grid_pc_lat), method="cubic", fill_value=0.0) + else: + yt_i = griddata((pc_lon, pc_lat), yt, (grid_pc_lon, grid_pc_lat), method="cubic", fill_value=0.0) + yp_i = griddata((pc_lon, pc_lat), yp, (grid_pc_lon, grid_pc_lat), method="cubic", fill_value=0.0) + + amplitude_t = np.array(compute_spectra(yt_i)) + amplitude_p = np.array(compute_spectra(yp_i)) + + ax[plot_idx].loglog( + np.arange(1, amplitude_t.shape[0]), amplitude_t[1 : (amplitude_t.shape[0])], label="Truth (ERA5)" + ) + ax[plot_idx].loglog( + np.arange(1, amplitude_p.shape[0]), amplitude_p[1 : (amplitude_p.shape[0])], label="Predicted" + ) + + ax[plot_idx].legend() + ax[plot_idx].set_title(variable_name) + + ax[plot_idx].set_xlabel("$k$") + ax[plot_idx].set_ylabel("$P(k)$") + ax[plot_idx].set_aspect("auto", adjustable=None) + fig.tight_layout() + return fig + + +def compute_spectra(field: np.ndarray) -> np.ndarray: + """Compute spectral variability of a field by wavenumber. + + Parameters + ---------- + field : np.ndarray + lat lon field to calculate the spectra of + + Returns + ------- + np.ndarray + spectra of field by wavenumber + """ + field = np.array(field) + + # compute real and imaginary parts of power spectra of field + lmax = field.shape[0] - 1 # maximum degree of expansion + zero_w = SHGLQ(lmax) + coeffs_field = SHExpandGLQ(field, w=zero_w[1], zero=zero_w[0]) + + # Re**2 + Im**2 + coeff_amp = coeffs_field[0, :, :] ** 2 + coeffs_field[1, :, :] ** 2 + + # sum over meridional direction + spectra = np.sum(coeff_amp, axis=0) + + return spectra + + +def plot_histogram( + parameters: dict[str, int], + x: np.ndarray, + y_true: np.ndarray, + y_pred: np.ndarray, +) -> Figure: + """Plots histogram. + + NB: this can be very slow for large data arrays + call it as infrequently as possible! + + Parameters + ---------- + parameters : dict[str, int] + Dictionary of variable names and indices + x : np.ndarray + Input data of shape (lat*lon, nvar*level) + y_true : np.ndarray + Expected data of shape (lat*lon, nvar*level) + y_pred : np.ndarray + Predicted data of shape (lat*lon, nvar*level) + + Returns + ------- + Figure + The figure object handle. + """ + n_plots_x, n_plots_y = len(parameters), 1 + + figsize = (n_plots_y * 4, n_plots_x * 3) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + + for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): + yt = y_true[..., variable_idx].squeeze() + yp = y_pred[..., variable_idx].squeeze() + + # Calculate the histogram + if output_only: + xt = x[..., variable_idx].squeeze() * int(output_only) + hist_yt, bins_yt = np.histogram((yt - xt), bins=100) + hist_yp, bins_yp = np.histogram((yp - xt), bins=100) + else: + hist_yt, bins_yt = np.histogram(yt, bins=100) + hist_yp, bins_yp = np.histogram(yp, bins=100) + + # Visualization trick for tp + if variable_name == "tp" or variable_name == "cp": + hist_yt = hist_yt * bins_yt[:-1] + hist_yp = hist_yp * bins_yp[:-1] + # Plot the modified histogram + ax[plot_idx].bar(bins_yt[:-1], hist_yt, width=np.diff(bins_yt), color="blue", alpha=0.7, label="Truth (ERA5)") + ax[plot_idx].bar(bins_yp[:-1], hist_yp, width=np.diff(bins_yp), color="red", alpha=0.7, label="AIFS") + + ax[plot_idx].set_title(variable_name) + ax[plot_idx].set_xlabel(variable_name) + ax[plot_idx].set_ylabel("Density") + ax[plot_idx].legend() + ax[plot_idx].set_aspect("auto", adjustable=None) + + fig.tight_layout() + return fig + + +def plot_predicted_multilevel_flat_sample( + parameters: dict[str, int], + n_plots_per_sample: int, + latlons: np.ndarray, + clevels: float, + cmap_precip: str, + x: np.ndarray, + y_true: np.ndarray, + y_pred: np.ndarray, +) -> Figure: + """Plots data for one multilevel latlon-"flat" sample. + + NB: this can be very slow for large data arrays + call it as infrequently as possible! + + Parameters + ---------- + parameters : dict[str, int] + Dictionary of variable names and indices + n_plots_per_sample : int + Number of plots per sample + latlons : np.ndarray + lat/lon coordinates array, shape (lat*lon, 2) + clevels : float + Accumulation levels used for precipitation related plots + cmap_precip: str + Colors used for each accumulation level + x : np.ndarray + Input data of shape (lat*lon, nvar*level) + y_true : np.ndarray + Expected data of shape (lat*lon, nvar*level) + y_pred : np.ndarray + Predicted data of shape (lat*lon, nvar*level) + + Returns + ------- + Figure + The figure object handle. + """ + n_plots_x, n_plots_y = len(parameters), n_plots_per_sample + + figsize = (n_plots_y * 4, n_plots_x * 3) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + + pc = EquirectangularProjection() + lat, lon = latlons[:, 0], latlons[:, 1] + pc_lon, pc_lat = pc(lon, lat) + + for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): + xt = x[..., variable_idx].squeeze() * int(output_only) + yt = y_true[..., variable_idx].squeeze() + yp = y_pred[..., variable_idx].squeeze() + if n_plots_x > 1: + plot_flat_sample(fig, ax[plot_idx, :], pc_lon, pc_lat, xt, yt, yp, variable_name, clevels, cmap_precip) + else: + plot_flat_sample(fig, ax, pc_lon, pc_lat, xt, yt, yp, variable_name, clevels, cmap_precip) + + return fig + + +def plot_flat_sample( + fig, + ax, + lon: np.ndarray, + lat: np.ndarray, + input_: np.ndarray, + truth: np.ndarray, + pred: np.ndarray, + vname: str, + clevels: float, + cmap_precip: str, +) -> None: + """Plot a "flat" 1D sample. + + Data on non-rectangular (reduced Gaussian) grids. + + Parameters + ---------- + fig : _type_ + Figure object handle + ax : matplotlib.axes + Axis object handle + lon : np.ndarray + longitude coordinates array, shape (lon,) + lat : np.ndarray + latitude coordinates array, shape (lat,) + input_ : np.ndarray + Input data of shape (lat*lon,) + truth : np.ndarray + Expected data of shape (lat*lon,) + pred : np.ndarray + Predicted data of shape (lat*lon,) + vname : str + Variable name + clevels : float + Accumulation levels used for precipitation related plots + cmap_precip: str + Colors used for each accumulation level + """ + if vname == "tp" or vname == "cp": + # Create a custom colormap for precipitation + nws_precip_colors = cmap_precip + precip_colormap = ListedColormap(nws_precip_colors) + + # Defining the actual precipitation accumulation levels in mm + cummulation_lvls = clevels + norm = BoundaryNorm(cummulation_lvls, len(cummulation_lvls) + 1) + + # converting to mm from m + truth = truth * 1000.0 + pred = pred * 1000.0 + scatter_plot(fig, ax[1], lon, lat, truth, cmap=precip_colormap, norm=norm, title=f"{vname} target") + scatter_plot(fig, ax[2], lon, lat, pred, cmap=precip_colormap, norm=norm, title=f"{vname} pred") + scatter_plot( + fig, ax[3], lon, lat, truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err" + ) + else: + scatter_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target") + scatter_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred") + scatter_plot( + fig, ax[3], lon, lat, truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err" + ) + + if sum(input_) != 0: + scatter_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input") + scatter_plot( + fig, + ax[4], + lon, + lat, + pred - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} increment [pred - input]", + ) + scatter_plot( + fig, + ax[5], + lon, + lat, + truth - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} persist err", + ) + else: + ax[0].axis("off") + ax[4].axis("off") + ax[5].axis("off") + + +def scatter_plot( + fig, + ax, + lon: np.array, + lat: np.array, + data: np.array, + cmap: str = "viridis", + norm: Optional[str] = None, + title: Optional[str] = None, +) -> None: + """Lat-lon scatter plot: can work with arbitrary grids. + + Parameters + ---------- + fig : _type_ + Figure object handle + ax : matplotlib.axes + Axis object handle + lon : np.ndarray + longitude coordinates array, shape (lon,) + lat : np.ndarray + latitude coordinates array, shape (lat,) + data : _type_ + Data to plot + cmap : str, optional + Colormap string from matplotlib, by default "viridis" + norm : str, optional + Normalization string from matplotlib, by default None + title : str, optional + Title for plot, by default None + + """ + psc = ax.scatter( + lon, + lat, + c=data, + cmap=cmap, + s=1, + alpha=1.0, + norm=norm, + rasterized=True, + ) + ax.set_xlim((-np.pi, np.pi)) + ax.set_ylim((-np.pi / 2, np.pi / 2)) + + continents.plot_continents(ax) + + if title is not None: + ax.set_title(title) + + ax.set_aspect("auto", adjustable=None) + _hide_axes_ticks(ax) + fig.colorbar(psc, ax=ax) + + +def plot_graph_features( + latlons: np.ndarray, + features: np.ndarray, +) -> Figure: + """Plot trainable graph features. + + Parameters + ---------- + latlons : np.ndarray + Latitudes and longitudes + features : np.ndarray + Trainable Features + + Returns + ------- + Figure + Figure object handle + """ + nplots = features.shape[-1] + figsize = (nplots * 4, 3) + fig, ax = plt.subplots(1, nplots, figsize=figsize) + + lat, lon = latlons[:, 0], latlons[:, 1] + + pc = EquirectangularProjection() + pc_lon, pc_lat = pc(lon, lat) + + for i in range(nplots): + ax_ = ax[i] if nplots > 1 else ax + scatter_plot(fig, ax_, pc_lon, pc_lat, features[..., i]) + + return fig diff --git a/src/anemoi/training/diagnostics/profilers.py b/src/anemoi/training/diagnostics/profilers.py new file mode 100644 index 00000000..975daedb --- /dev/null +++ b/src/anemoi/training/diagnostics/profilers.py @@ -0,0 +1,511 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import csv +import logging +import os +import re +from pathlib import Path +from typing import Any +from typing import Optional + +import memray +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import wandb +from memray import FileFormat +from memray import FileReader +from memray.reporters.table import TableReporter +from omegaconf import DictConfig +from pytorch_lightning.callbacks import TQDMProgressBar +from pytorch_lightning.profilers import Profiler +from pytorch_lightning.profilers import SimpleProfiler +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.types import STEP_OUTPUT + +import anemoi.training + +LOGGER = logging.getLogger(__name__) + +PROFILER_ACTIONS = [ + r"\[Strategy]\w+\.batch_to_device", + r"\[Strategy]\w+\.backward", + r"\[Strategy]\w+\.validation_step", + r"\[Strategy]\w+\.batch_to_device", + "run_training_epoch", + "run_training_batch", + r"\[_EvaluationLoop\]\.\w+", + r"\[_TrainingEpochLoop\]\.\w+", + r"\[LightningDataModule]\w+\.train_dataloader", + r"\[LightningDataModule]\w+\.val_dataloader", + r"\[LightningDataModule]\w+\.state_dict", + r"\[LightningDataModule]\w+\.setup", + r"\[LightningDataModule]\w+\.prepare_data", + r"\[LightningDataModule]\w+\.teardown", + r"\[LightningModule]\w+\.optimizer_step", + r"\[LightningModule]\w+\.configure_gradient_clipping", + r"\[LightningModule]\w+\.on_validation_model_eval", + r"\[LightningModule]\w+\.optimizer_zero_grad", + r"\[LightningModule]\w+\.transfer_batch_to_device", + r"\[LightningModule]\w+\.on_validation_model_train", + r"\[LightningModule]\w+\.configure_optimizers", + r"\[LightningModule]\w+\.lr_scheduler_step", + r"\[LightningModule]\w+\.configure_sharded_model", + r"\[LightningModule]\w+\.setup", + r"\[LightningModule]\w+\.prepare_data", + r"\[Callback\](.*Plot*)", + r"\[Callback\](.*Checkpoint*)", +] + +GPU_METRICS_DICT = { + "GPU device utilization (%)": "gpu", + "GPU memory use (%)": "memory", + "GPU memory allocated (%)": "memoryAllocated", + "GPU memory allocated (GB)": "memoryAllocatedBytes", +} + + +def get_wandb_metrics(run_id_path: str) -> (pd.DataFrame, dict): + """Fetches system metrics and metadata from a W&B run.""" + run = wandb.Api().run(run_id_path) + system_metrics = run.history(stream="events") + metadata_dict = run.metadata + system_metrics = system_metrics.dropna() + return system_metrics, metadata_dict + + +def summarize_gpu_metrics(df: pd.DataFrame) -> dict[str, float]: + """Given the System Metrics DataFrame, summarized the GPU metrics. + + - gpu.{gpu_index}.memory - GPU memory utilization in percent for each GPU + - gpu.{gpu_index}.memoryAllocated - GPU memory allocated as a percentage of the total available memory for each GPU + - gpu.{gpu_index}.memoryAllocatedBytes - GPU memory allocated in bytes for each GPU + - gpu.{gpu_index}.gpu - GPU utilization in percent for each GPU + """ + average_metric = {} + col_names = df.columns + for gpu_metric_name, gpu_metric in GPU_METRICS_DICT.items(): + pattern = rf"system.gpu.\d.{gpu_metric}$" + sub_gpu_cols = [string for string in col_names if re.match(pattern, string)] + metrics_per_gpu = df[sub_gpu_cols].mean(axis=0) + if gpu_metric == "memoryAllocatedBytes": + metrics_per_gpu = metrics_per_gpu * 1e-9 + average_metric[gpu_metric_name] = metrics_per_gpu.mean() + # Just add metrics per gpu to the report if we have more than 1 GPU + if metrics_per_gpu.shape[0] > 1: + metrics_per_gpu.index = [" " + index for index in metrics_per_gpu.index] + average_metric.update(dict(metrics_per_gpu)) + return average_metric + + +def summarize_wandb_system_metrics(run_id_path: str) -> dict[str, float]: + r"""Summarizes the System metrics from a W&B run. + + Some of the metrics included are: + - cpu.{}.cpu_percent - CPU usage of the system on a per-core basis. + - system.memory - Represents the total system memory usage as a percentage of the total available memory. + - system.cpu - Percentage of CPU usage by the process, normalized by the number of available CPUs + - system.disk.\\.usageGB - (Represents the total system disk usage in gigabytes (GB)) + - system.proc.memory.percent - Indicates the memory usage of the process as a percentage of the total available memory + + More information about W&B system metrics can be found here: + https://docs.wandb.ai/guides/app/features/system-metrics + """ + system_metrics_df, metadata_dict = get_wandb_metrics(run_id_path) + + col_names = system_metrics_df.columns + system_metrics = {} + + n_cpus = metadata_dict["cpu_count"] + cpu_cols = list(filter(lambda k: "cpu." in k, col_names)) + system_metrics["avg CPU usage (%)"] = (system_metrics_df[cpu_cols].sum(axis=1) / n_cpus).mean() + + system_metrics_gpu = summarize_gpu_metrics(system_metrics_df) + system_metrics.update(system_metrics_gpu) + + system_metrics["avg Memory usage (%)"] = system_metrics_df["system.memory"].mean() + system_metrics["avg Disk usage (GB)"] = system_metrics_df["system.disk.\\.usageGB"].mean() + system_metrics["avg Disk usage (%)"] = system_metrics_df["system.disk.\\.usagePercent"].mean() + + system_metrics["execution time (sec)"] = system_metrics_df["_runtime"].iloc[-1] # in seconds + return system_metrics + + +class BenchmarkProfiler(Profiler): + """Custom PyTorch Lightning profiler for benchmarking. + + Parameters + ---------- + config : DictConfig + Configuration object. + + Attributes + ---------- + dirpath : Path + Path to the profiler directory. + benchmark_filename : Path + Path to the benchmark profiler file. + time_profiler : SimpleProfiler + Simple profiler for time measurements. + pid : int + Process ID. + memfile_name : Path + Path to the memory profiler file. + memory_profiler : memray.Tracker + Memory profiler. + """ + + def __init__(self, config: DictConfig) -> None: + super().__init__(config) + + self.config = config + self.dirpath = Path(self.config.hardware.paths.profiler) + self.dirpath.mkdir(parents=True, exist_ok=True) + + self.benchmark_filename = Path(self.dirpath, "aifs-benchmark-profiler.csv") + + self._create_profilers() + + @rank_zero_only + def _create_output_file(self) -> None: + """Creates the output file to aggregate the memory profiling results.""" + fields = ["category", "size (MiB)", "function", "group", "pid"] + with self.benchmark_filename.open("w") as f: + writer = csv.writer(f) + writer.writerow(fields) + + def _create_profilers(self) -> None: + """Creates profilers for time and memory measurements.""" + self.time_profiler = SimpleProfiler( + dirpath=self.dirpath, + ) + self.pid = os.getpid() + + self.memfile_name = Path(self.dirpath, f"aifs-benchmark-mem-profiler_{self.pid}.bin") + self.memory_profiler = memray.Tracker(self.memfile_name, file_format=FileFormat.AGGREGATED_ALLOCATIONS) + self._create_output_file() + + def start(self, action_name: str) -> None: + """Starts recording time for a specific action. + + Parameters + ---------- + action_name : str + Name of the action. + """ + self.time_profiler.start(action_name) + + def stop(self, action_name: str) -> None: + """Stops recording time for a specific action. + + Parameters + ---------- + action_name : str + Name of the action. + """ + self.time_profiler.stop(action_name) + + def _trim_time_report(self, recorded_actions: dict) -> dict[str, float]: + all_actions_names = recorded_actions.keys() + df = pd.DataFrame({"Strings": all_actions_names}) + combined_pattern = "|".join(PROFILER_ACTIONS) + filtered_df = df[df["Strings"].str.contains(combined_pattern, regex=True, na=False)] + trimmed_actions_names = filtered_df["Strings"].tolist() + cleaned_recorded_actions = {key: recorded_actions[key] for key in trimmed_actions_names} + return cleaned_recorded_actions + + def get_time_profiler_df(self, precision: int = 5) -> pd.DataFrame: + """Retrieves a DataFrame with time profiling information. + + Parameters + ---------- + precision : int + Precision for rounding, by default 5 + + Returns + ------- + pd.DataFrame + DataFrame with time profiling information. + """ + self.time_profiler.recorded_durations = self._trim_time_report( + recorded_actions=self.time_profiler.recorded_durations + ) + time_df = pd.DataFrame(self.time_profiler.recorded_durations.items()) + time_df[2] = time_df[1].apply(len) + time_df[3] = time_df[1].apply(np.mean) + time_df[1] = time_df[1].apply(sum) + time_df.columns = ["name", "total_time", "n_calls", "avg_time"] + + def replace_function(value): + # Replace 'apple' with 'fruit' + value = re.sub(r"\{.*?\}", "", value) # Remove anything between brackets + return value + + time_df.to_csv(Path(self.config.hardware.paths.profiler, "time_profiler_no_replace.csv")) + time_df["name"] = time_df["name"].apply(replace_function) + pattern = r"\[(.*?)\]|(.*)" + time_df["category"] = time_df["name"].str.extract(pattern, expand=False)[0].fillna(time_df["name"]) + + pattern = re.compile(r"\[Callback\](.*?)\.") + # Apply the regular expression to the column + callbacks_subcategories = "*Callback_" + time_df[time_df["category"] == "Callback"]["name"].str.extract(pattern) + indexer = time_df[time_df["category"] == "Callback"].index + time_df.loc[indexer, "category"] = callbacks_subcategories[0].tolist() + time_df.to_csv(Path(self.config.hardware.paths.profiler, "time_profiler_complete.csv")) + + # Check if 'Callback' is present in the 'category' column + time_df["is_callback"] = time_df["category"].str.contains("Callback", case=False) + + # Group by the 'is_callback' column and apply groupby operation only on rows with 'Callback' in 'category' + grouped_data = ( + time_df[time_df["is_callback"]] + .groupby("category") + .agg({"n_calls": np.sum, "avg_time": np.sum, "total_time": np.sum}) + .reset_index() + ) + grouped_data["name"] = grouped_data["category"] + + time_df = pd.concat([time_df[~time_df["is_callback"]], grouped_data]) + time_df = time_df.drop("is_callback", axis=1) + time_df = time_df.round(precision) + time_df = time_df.sort_values(by="category", ascending=False) + return time_df + + def _generate_memray_df(self) -> pd.DataFrame: + """Generates dataframe from memray tracking results. + + For each node/process we convert the tracking results to a dataframe just + keeping the high watermark allocations. + + Returns + ------- + pd.DataFrame + Memory profiler data. + """ + self.memory_profiler.__exit__(None, None, None) + memfile_tracking = FileReader(self.memfile_name) + memory_allocations = list(memfile_tracking.get_high_watermark_allocation_records()) + table = TableReporter.from_snapshot(memory_allocations, memory_records=[], native_traces=False) + df = pd.DataFrame(table.data) + memfile_tracking.close() + return df + + def _aggregate_per_category(self, df: pd.DataFrame) -> pd.DataFrame: + """Aggregates memory profiling information per category. + + Each stack_trace tracked by memray is separated into parts + - first part points to the path of the library - referred as category + - second part is the exact name of the function of this script + + Since we can have traces coming from the same script but referring + to multiple functions in that script, we aggregate those and in the + 'function' entry we just keep the function that has a higher memory + consumption. + """ + pattern = r"^(.*?) at (.*?)\.py" + new_cols = df.loc[:, "stack_trace"].str.extract(pattern) + df = df.assign(function=new_cols[0], category=new_cols[1]) + df = df.drop("stack_trace", axis=1) + df_agg = df.groupby("category").apply( + lambda x: pd.Series( + { + "size (MiB)": x["size (MiB)"].sum(), + "function": x.loc[x["size (MiB)"].idxmax()]["function"], + }, + ), + ) + df_agg.reset_index(inplace=True) + return df_agg + + def _trim_memray_df(self, memray_df: pd.DataFrame, precision: int = 5, n_items: int = 10) -> pd.DataFrame: + """Trims and processes the memray DataFrame. + + Necessary since Memray tracks memory allocations across different files. + + all the script, we group those allocations in two categories: + - aifs-operations: coming from functions included in this repository + - general-operations: coming from functions from other python libraries + + Parameters + ---------- + memray_df : pd.DataFrame + Input DataFrame from memray. + precision : int, optional + Precision for rounding, by default 5 + n_items : int, optional + Number of top memory-consuming items to include, by default 10 + + Returns + ------- + pd.DataFrame + Compiled dataframe from memray data. + """ + cleaned_memray_df = memray_df.drop("tid", axis=1) + cleaned_memray_df = cleaned_memray_df.drop("allocator", axis=1) + + # For readibility, we cut the paths to just display the relevant package info + module_path = anemoi.training.__path__[0].replace("aifs-mono/aifs", "") + env_path = pl.__path__[0].replace("pytorch_lightning", "") + base_env_path = pl.__path__[0].replace("/site-packages/pytorch_lightning", "") + + cleaned_memray_df["stack_trace"] = cleaned_memray_df["stack_trace"].apply(lambda x: x.replace(module_path, "")) + cleaned_memray_df["stack_trace"] = cleaned_memray_df["stack_trace"].apply(lambda x: x.replace(env_path, "")) + cleaned_memray_df["stack_trace"] = cleaned_memray_df["stack_trace"].apply( + lambda x: x.replace(base_env_path, "") + ) + + cleaned_memray_df["size (MiB)"] = cleaned_memray_df["size"] * 9.5367e-7 + cleaned_memray_df.sort_values("size (MiB)", ascending=False, inplace=True) + cleaned_memray_df = cleaned_memray_df.drop("size", axis=1) + + top_most_memory_consuming_df = cleaned_memray_df[~cleaned_memray_df["stack_trace"].str.contains("aifs")].head( + n_items + ) + top_most_memory_consuming_df = self._aggregate_per_category(top_most_memory_consuming_df) + + aifs_memray = cleaned_memray_df[cleaned_memray_df["stack_trace"].str.contains("aifs")] + aifs_memray = self._aggregate_per_category(aifs_memray) + + aifs_memray["group"] = "aifs-operations" + top_most_memory_consuming_df["group"] = "general-operations" + + merged_memory_df = pd.concat([top_most_memory_consuming_df, aifs_memray]) + merged_memory_df = merged_memory_df.round(precision) + return merged_memory_df + + def teardown(self, stage: Optional[str]) -> None: + """Clean up before closing the profiler. + + Before closing the profiler, performs the cleanup operations and writes memray + data to the common benchmark file. + """ + memray_df = self._generate_memray_df() + cleaned_memray_df = self._trim_memray_df(memray_df) + + cleaned_memray_df["pid"] = self.pid + cleaned_memray_df.to_csv(self.benchmark_filename, mode="a", index=False, header=False) + + @rank_zero_only + def get_memory_profiler_df(self) -> pd.DataFrame: + """Retrieves the memory profiler data as a DataFrame. + + Aggregates the results coming from multiple nodes/processes. + + Returns + ------- + pd.DataFrame + Memory profiler data. + """ + mem_df = pd.read_csv(self.benchmark_filename) + return ( + mem_df.groupby(["category", "group", "function"]) + .apply( + lambda x: pd.Series( + { + "size (MiB)": x["size (MiB)"].mean(), + "pid": len(set(x["pid"])), + }, + ), + ) + .reset_index() + .sort_values("size (MiB)", ascending=False) + ) + + def __del__(self) -> None: + self.teardown(stage=self._stage) + self.memfile_name.unlink() + + +class ProfilerProgressBar(TQDMProgressBar): + """Custom PyTorch Lightning progress bar with profiling functionality. + + Attributes + ---------- + validation_rates : list[float] + List to store validation rates (it/s). + training_rates : list[float] + List to store training rates (it/s). + """ + + def __init__(self, config) -> None: + super().__init__() + self.validation_rates = [] + self.training_rates = [] + + def _extract_rate(self, pbar) -> float: + """Extracts the iteration rate from the progress bar. + + Parameters + ---------- + pbar : tqdm + The progress bar. + + Returns + ------- + float + The iteration rate. + """ + return (pbar.format_dict["n"] - pbar.format_dict["initial"]) / pbar.format_dict["elapsed"] + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + """Appends the rate from the progress bar to the list of 'training_rates'.""" + batch_idx + 1 + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + if self.train_progress_bar.format_dict["n"] != 0: + self.training_rates.append(self._extract_rate(self.train_progress_bar)) + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Append rate from the progress bar to the list of 'validation_rates'.""" + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self.val_progress_bar.format_dict["n"] != 0: + self.validation_rates.append(self._extract_rate(self.val_progress_bar)) + + @rank_zero_only + def summarize_metrics(self, config) -> dict[str, float]: + """Summarizes and returns speed metrics based on training and validation rates. + + Parameters + ---------- + config : Config + The configuration object. + + Returns + ------- + dict + A dictionary containing speed metrics. + """ + speed_metrics = {} + + batch_size_tr = config.dataloader.batch_size.training + batch_size_val = config.dataloader.batch_size.validation + + training_rates_array = np.array(self.training_rates) + speed_metrics["training_avg_throughput"] = training_rates_array.mean() + speed_metrics["training_avg_throughput_per_sample"] = training_rates_array.mean() / batch_size_tr + + validation_rates_array = np.array(self.validation_rates) + speed_metrics["validation_avg_throughput"] = validation_rates_array.mean() + speed_metrics["validation_avg_throughput_per_sample"] = validation_rates_array.mean() / batch_size_val + + return speed_metrics From 1e704be6ea229a617fc6023ba69344eebeeacf86 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 21 May 2024 16:07:47 +0000 Subject: [PATCH 5/5] refactor: Initial Distributed Implementation Allows Distributed Data Parallel and Distributed Model Parallel training Co-authored-by: Jesper Dramsch Co-authored-by: Simon Lang Co-authored-by: Matthew Chantry Co-authored-by: Mihai Alexe --- .../training/strategy/ddp_group_strategy.py | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 src/anemoi/training/strategy/ddp_group_strategy.py diff --git a/src/anemoi/training/strategy/ddp_group_strategy.py b/src/anemoi/training/strategy/ddp_group_strategy.py new file mode 100644 index 00000000..ec707d5f --- /dev/null +++ b/src/anemoi/training/strategy/ddp_group_strategy.py @@ -0,0 +1,138 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os + +import numpy as np +import pytorch_lightning as pl +import torch +from anemoi.utils.seeding import get_base_seed +from lightning_fabric.utilities.optimizer import _optimizers_to_device +from pytorch_lightning.overrides.distributed import _sync_module_states +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.trainer.states import TrainerFn + +LOGGER = logging.getLogger(__name__) + + +class DDPGroupStrategy(DDPStrategy): + """Distributed Data Parallel strategy with group communication.""" + + def __init__(self, num_gpus_per_model: int, **kwargs) -> None: + super().__init__(**kwargs) + self.model_comm_group_size = num_gpus_per_model + + def setup(self, trainer: pl.Trainer) -> None: + assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" + self.accelerator.setup(trainer) + + # determine the model groups that work together: + + assert self.world_size % self.model_comm_group_size == 0, ( + f"Total number of GPUs ({self.world_size}) must be divisible by the number of GPUs " + f"per model ({self.model_comm_group_size})." + ) + + model_comm_group_ranks = np.split( + np.arange(self.world_size, dtype=int), int(self.world_size / self.model_comm_group_size) + ) + model_comm_groups = [ + torch.distributed.new_group(x) for x in model_comm_group_ranks + ] # every rank has to create all of these + + model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group( + self.model_comm_group_size + ) + model_comm_group = model_comm_groups[model_comm_group_id] + self.model.set_model_comm_group(model_comm_group) + LOGGER.debug( + "Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s", + self.global_rank, + str(model_comm_group_nr), + model_comm_group_id, + model_comm_group_rank, + str(model_comm_group_ranks[model_comm_group_id]), + ) + + # register hooks for correct gradient reduction + self.register_parameter_hooks() + + # move the model to the correct device + self.model_to_device() + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = trainer.state.fn + + if trainer_fn == TrainerFn.FITTING and self._layer_sync: + assert self.model is not None, "Model is not initialized for distributed strategy" + self.model = self._layer_sync.apply(self.model) + + self.setup_precision_plugin() + + if trainer_fn == TrainerFn.FITTING: + # do not wrap with DDP if not fitting as there's no gradients to reduce + self.configure_ddp() + + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(trainer) + _optimizers_to_device(self.optimizers, self.root_device) + + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + + if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): + self._enable_model_averaging() + else: + # we need to manually synchronize the module's states since we aren't using the DDP wrapper + assert self.model is not None, "Model is not initialized for distributed strategy" + _sync_module_states(self.model) + + # seed ranks + self.seed_rnd(model_comm_group_id) + + def get_my_model_comm_group(self, num_gpus_per_model): + """Determine tasks that work together and from a model group.""" + model_comm_groups = np.arange(0, self.world_size, dtype=np.int32) + model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model) + + model_comm_group_id = None + for i, model_comm_group in enumerate(model_comm_groups): + if self.global_rank in model_comm_group: + model_comm_group_id = i + model_comm_group_nr = model_comm_group + model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0] + return model_comm_group_id, model_comm_group_nr, model_comm_group_rank + + def seed_rnd(self, model_comm_group_id: int) -> None: + """Seed the random number generators for the rank.""" + base_seed = get_base_seed() + initial_seed = base_seed * (model_comm_group_id + 1) + rnd_seed = pl.seed_everything(initial_seed) # note: workers are seeded independently in dataloader + np_rng = np.random.default_rng(rnd_seed) + sanity_rnd = (torch.rand(1), np_rng.random()) + LOGGER.debug( + "Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, running with random seed: %d, sanity rnd: %s", + int(os.environ.get("SLURM_PROCID", "0")), + model_comm_group_id, + base_seed, + initial_seed, + rnd_seed, + sanity_rnd, + ) + + def register_parameter_hooks(self) -> None: + """Register parameter hooks for gradient reduction. + + Here, we rescale parameters that only see a subset of the input on each rank + -> these are still divided by the total number of GPUs in DDP as if each rank would see a full set of inputs + note: the trainable parameters are added before the split across GPUs and are therefore not rescaled. + """ + for name, param in self.model.named_parameters(): + if param.requires_grad is True and "trainable" not in name: + param.register_hook(lambda grad: grad * float(self.model_comm_group_size))