From c052d197936eb02a38a0e9d54c498bfa0967bed5 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 15 May 2024 10:54:26 +0000 Subject: [PATCH] refactor: Initial model implementation Encompasses encoder, processor and decoder structure. Co-authored-by: Jesper Dramsch Co-authored-by: Matthew Chantry Co-authored-by: Simon Lang Co-authored-by: mihai.alexe --- .../models/encoder_processor_decoder.py | 256 ++++++++++++++++++ src/anemoi/models/utils/config.py | 25 ++ 2 files changed, 281 insertions(+) create mode 100644 src/anemoi/models/models/encoder_processor_decoder.py create mode 100644 src/anemoi/models/utils/config.py diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py new file mode 100644 index 0000000..dd3564f --- /dev/null +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -0,0 +1,256 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import einops +import torch +from hydra.utils import instantiate +from torch import Tensor +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup +from torch.utils.checkpoint import checkpoint +from torch_geometric.data import HeteroData + +from anemoi.models.distributed.helpers import get_shape_shards +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.utils.config import DotConfig + +LOGGER = logging.getLogger(__name__) + + +class AnemoiModelEncProcDec(nn.Module): + """Message passing graph neural network.""" + + def __init__( + self, + *, + config: DotConfig, + data_indices: dict, + graph_data: HeteroData, + ) -> None: + """Initializes the graph neural network. + + Parameters + ---------- + config : DictConfig + Job configuration + graph_data : HeteroData + Graph definition + """ + super().__init__() + + self._graph_data = graph_data + self._graph_name_data = config.graph.data + self._graph_name_hidden = config.graph.hidden + + # Calculate shapes and indices + self.num_input_channels = len(data_indices.model.input) + self.num_output_channels = len(data_indices.model.output) + self._internal_input_idx = data_indices.model.input.prognostic + self._internal_output_idx = data_indices.model.output.prognostic + + assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len( + data_indices.model.output.diagnostic + ), ( + f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding " + f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})", + ) + assert len(self._internal_input_idx) == len( + self._internal_output_idx, + ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" + + self.multi_step = config.training.multistep_input + + # Define Sizes of different tensors + self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[ + 0 + ] + self._hidden_grid_size = self._graph_data[ + (self._graph_name_hidden, "to", self._graph_name_hidden) + ].hcoords_rad.shape[0] + + self.trainable_data_size = config.model.trainable_parameters.data + self.trainable_hidden_size = config.model.trainable_parameters.hidden + + # Create trainable tensors + self._create_trainable_attributes() + + # Register lat/lon + self._register_latlon("data", self._graph_name_data) + self._register_latlon("hidden", self._graph_name_hidden) + + self.num_channels = config.model.num_channels + + input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + + # Encoder data -> hidden + self.encoder = instantiate( + config.model.encoder, + in_channels_src=input_dim, + in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + hidden_dim=self.num_channels, + sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], + src_grid_size=self._data_grid_size, + dst_grid_size=self._hidden_grid_size, + ) + + # Processor hidden -> hidden + self.processor = instantiate( + config.model.processor, + num_channels=self.num_channels, + sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], + src_grid_size=self._hidden_grid_size, + dst_grid_size=self._hidden_grid_size, + ) + + # Decoder hidden -> data + self.decoder = instantiate( + config.model.decoder, + in_channels_src=self.num_channels, + in_channels_dst=input_dim, + hidden_dim=self.num_channels, + out_channels_dst=self.num_output_channels, + sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], + src_grid_size=self._hidden_grid_size, + dst_grid_size=self._data_grid_size, + ) + + def _register_latlon(self, name: str, key: str) -> None: + """Register lat/lon buffers. + + Parameters + ---------- + name : str + Name of grid to map + key : str + Key of the grid + """ + self.register_buffer( + f"latlons_{name}", + torch.cat( + [ + torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), + torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), + ], + dim=-1, + ), + persistent=True, + ) + + def _create_trainable_attributes(self) -> None: + """Create all trainable attributes.""" + self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) + self.trainable_hidden = TrainableTensor( + trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size + ) + + def _run_mapper( + self, + mapper: nn.Module, + data: tuple[Tensor], + batch_size: int, + shard_shapes: tuple[tuple[int, int], tuple[int, int]], + model_comm_group: Optional[ProcessGroup] = None, + use_reentrant: bool = False, + ) -> Tensor: + """Run mapper with activation checkpoint. + + Parameters + ---------- + mapper : nn.Module + Which processor to use + data : tuple[Tensor] + tuple of data to pass in + batch_size: int, + Batch size + shard_shapes : tuple[tuple[int, int], tuple[int, int]] + Shard shapes for the data + model_comm_group : ProcessGroup + model communication group, specifies which GPUs work together + in one model instance + use_reentrant : bool, optional + Use reentrant, by default False + + Returns + ------- + Tensor + Mapped data + """ + return checkpoint( + mapper, + data, + batch_size=batch_size, + shard_shapes=shard_shapes, + model_comm_group=model_comm_group, + use_reentrant=use_reentrant, + ) + + def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: + batch_size = x.shape[0] + ensemble_size = x.shape[2] + + # add data positional info (lat/lon) + x_data_latent = torch.cat( + ( + einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), + self.trainable_data(self.latlons_data, batch_size=batch_size), + ), + dim=-1, # feature dimension + ) + + x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + + # get shard shapes + shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) + shard_shapes_hidden = get_shape_shards(x_hidden_latent, 0, model_comm_group) + + # Run encoder + x_data_latent, x_latent = self._run_mapper( + self.encoder, + (x_data_latent, x_hidden_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_data, shard_shapes_hidden), + model_comm_group=model_comm_group, + ) + + x_latent_proc = self.processor( + x_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hidden, + model_comm_group=model_comm_group, + ) + + # add skip connection (hidden -> hidden) + x_latent_proc = x_latent_proc + x_latent + + # Run decoder + x_out = self._run_mapper( + self.decoder, + (x_latent_proc, x_data_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_hidden, shard_shapes_data), + model_comm_group=model_comm_group, + ) + + x_out = ( + einops.rearrange( + x_out, + "(batch ensemble grid) vars -> batch ensemble grid vars", + batch=batch_size, + ensemble=ensemble_size, + ) + .to(dtype=x.dtype) + .clone() + ) + + # residual connection (just for the prognostic variables) + x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + return x_out diff --git a/src/anemoi/models/utils/config.py b/src/anemoi/models/utils/config.py new file mode 100644 index 0000000..a541051 --- /dev/null +++ b/src/anemoi/models/utils/config.py @@ -0,0 +1,25 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +class DotConfig(dict): + """A Config dictionary that allows access to its keys as attributes.""" + + def __init__(self, *args, **kwargs) -> None: + for a in args: + self.update(a) + self.update(kwargs) + + def __getattr__(self, name): + if name in self: + x = self[name] + if isinstance(x, dict): + return DotConfig(x) + return x + raise AttributeError(name)