-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Initial model implementation
Encompasses encoder, processor and decoder structure. Co-authored-by: Jesper Dramsch <[email protected]> Co-authored-by: Matthew Chantry <[email protected]> Co-authored-by: Simon Lang <[email protected]> Co-authored-by: mihai.alexe <[email protected]>
- Loading branch information
1 parent
e812346
commit c052d19
Showing
2 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
import logging | ||
from typing import Optional | ||
|
||
import einops | ||
import torch | ||
from hydra.utils import instantiate | ||
from torch import Tensor | ||
from torch import nn | ||
from torch.distributed.distributed_c10d import ProcessGroup | ||
from torch.utils.checkpoint import checkpoint | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.models.distributed.helpers import get_shape_shards | ||
from anemoi.models.layers.graph import TrainableTensor | ||
from anemoi.models.utils.config import DotConfig | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class AnemoiModelEncProcDec(nn.Module): | ||
"""Message passing graph neural network.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
config: DotConfig, | ||
data_indices: dict, | ||
graph_data: HeteroData, | ||
) -> None: | ||
"""Initializes the graph neural network. | ||
Parameters | ||
---------- | ||
config : DictConfig | ||
Job configuration | ||
graph_data : HeteroData | ||
Graph definition | ||
""" | ||
super().__init__() | ||
|
||
self._graph_data = graph_data | ||
self._graph_name_data = config.graph.data | ||
self._graph_name_hidden = config.graph.hidden | ||
|
||
# Calculate shapes and indices | ||
self.num_input_channels = len(data_indices.model.input) | ||
self.num_output_channels = len(data_indices.model.output) | ||
self._internal_input_idx = data_indices.model.input.prognostic | ||
self._internal_output_idx = data_indices.model.output.prognostic | ||
|
||
assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len( | ||
data_indices.model.output.diagnostic | ||
), ( | ||
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding " | ||
f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})", | ||
) | ||
assert len(self._internal_input_idx) == len( | ||
self._internal_output_idx, | ||
), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" | ||
|
||
self.multi_step = config.training.multistep_input | ||
|
||
# Define Sizes of different tensors | ||
self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[ | ||
0 | ||
] | ||
self._hidden_grid_size = self._graph_data[ | ||
(self._graph_name_hidden, "to", self._graph_name_hidden) | ||
].hcoords_rad.shape[0] | ||
|
||
self.trainable_data_size = config.model.trainable_parameters.data | ||
self.trainable_hidden_size = config.model.trainable_parameters.hidden | ||
|
||
# Create trainable tensors | ||
self._create_trainable_attributes() | ||
|
||
# Register lat/lon | ||
self._register_latlon("data", self._graph_name_data) | ||
self._register_latlon("hidden", self._graph_name_hidden) | ||
|
||
self.num_channels = config.model.num_channels | ||
|
||
input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size | ||
|
||
# Encoder data -> hidden | ||
self.encoder = instantiate( | ||
config.model.encoder, | ||
in_channels_src=input_dim, | ||
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, | ||
hidden_dim=self.num_channels, | ||
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], | ||
src_grid_size=self._data_grid_size, | ||
dst_grid_size=self._hidden_grid_size, | ||
) | ||
|
||
# Processor hidden -> hidden | ||
self.processor = instantiate( | ||
config.model.processor, | ||
num_channels=self.num_channels, | ||
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], | ||
src_grid_size=self._hidden_grid_size, | ||
dst_grid_size=self._hidden_grid_size, | ||
) | ||
|
||
# Decoder hidden -> data | ||
self.decoder = instantiate( | ||
config.model.decoder, | ||
in_channels_src=self.num_channels, | ||
in_channels_dst=input_dim, | ||
hidden_dim=self.num_channels, | ||
out_channels_dst=self.num_output_channels, | ||
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], | ||
src_grid_size=self._hidden_grid_size, | ||
dst_grid_size=self._data_grid_size, | ||
) | ||
|
||
def _register_latlon(self, name: str, key: str) -> None: | ||
"""Register lat/lon buffers. | ||
Parameters | ||
---------- | ||
name : str | ||
Name of grid to map | ||
key : str | ||
Key of the grid | ||
""" | ||
self.register_buffer( | ||
f"latlons_{name}", | ||
torch.cat( | ||
[ | ||
torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), | ||
torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), | ||
], | ||
dim=-1, | ||
), | ||
persistent=True, | ||
) | ||
|
||
def _create_trainable_attributes(self) -> None: | ||
"""Create all trainable attributes.""" | ||
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) | ||
self.trainable_hidden = TrainableTensor( | ||
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size | ||
) | ||
|
||
def _run_mapper( | ||
self, | ||
mapper: nn.Module, | ||
data: tuple[Tensor], | ||
batch_size: int, | ||
shard_shapes: tuple[tuple[int, int], tuple[int, int]], | ||
model_comm_group: Optional[ProcessGroup] = None, | ||
use_reentrant: bool = False, | ||
) -> Tensor: | ||
"""Run mapper with activation checkpoint. | ||
Parameters | ||
---------- | ||
mapper : nn.Module | ||
Which processor to use | ||
data : tuple[Tensor] | ||
tuple of data to pass in | ||
batch_size: int, | ||
Batch size | ||
shard_shapes : tuple[tuple[int, int], tuple[int, int]] | ||
Shard shapes for the data | ||
model_comm_group : ProcessGroup | ||
model communication group, specifies which GPUs work together | ||
in one model instance | ||
use_reentrant : bool, optional | ||
Use reentrant, by default False | ||
Returns | ||
------- | ||
Tensor | ||
Mapped data | ||
""" | ||
return checkpoint( | ||
mapper, | ||
data, | ||
batch_size=batch_size, | ||
shard_shapes=shard_shapes, | ||
model_comm_group=model_comm_group, | ||
use_reentrant=use_reentrant, | ||
) | ||
|
||
def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: | ||
batch_size = x.shape[0] | ||
ensemble_size = x.shape[2] | ||
|
||
# add data positional info (lat/lon) | ||
x_data_latent = torch.cat( | ||
( | ||
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), | ||
self.trainable_data(self.latlons_data, batch_size=batch_size), | ||
), | ||
dim=-1, # feature dimension | ||
) | ||
|
||
x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) | ||
|
||
# get shard shapes | ||
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) | ||
shard_shapes_hidden = get_shape_shards(x_hidden_latent, 0, model_comm_group) | ||
|
||
# Run encoder | ||
x_data_latent, x_latent = self._run_mapper( | ||
self.encoder, | ||
(x_data_latent, x_hidden_latent), | ||
batch_size=batch_size, | ||
shard_shapes=(shard_shapes_data, shard_shapes_hidden), | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
x_latent_proc = self.processor( | ||
x_latent, | ||
batch_size=batch_size, | ||
shard_shapes=shard_shapes_hidden, | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# add skip connection (hidden -> hidden) | ||
x_latent_proc = x_latent_proc + x_latent | ||
|
||
# Run decoder | ||
x_out = self._run_mapper( | ||
self.decoder, | ||
(x_latent_proc, x_data_latent), | ||
batch_size=batch_size, | ||
shard_shapes=(shard_shapes_hidden, shard_shapes_data), | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
x_out = ( | ||
einops.rearrange( | ||
x_out, | ||
"(batch ensemble grid) vars -> batch ensemble grid vars", | ||
batch=batch_size, | ||
ensemble=ensemble_size, | ||
) | ||
.to(dtype=x.dtype) | ||
.clone() | ||
) | ||
|
||
# residual connection (just for the prognostic variables) | ||
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] | ||
return x_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
|
||
class DotConfig(dict): | ||
"""A Config dictionary that allows access to its keys as attributes.""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
for a in args: | ||
self.update(a) | ||
self.update(kwargs) | ||
|
||
def __getattr__(self, name): | ||
if name in self: | ||
x = self[name] | ||
if isinstance(x, dict): | ||
return DotConfig(x) | ||
return x | ||
raise AttributeError(name) |