Skip to content

Commit

Permalink
refactor: Initial model implementation
Browse files Browse the repository at this point in the history
Encompasses encoder, processor and decoder structure.

Co-authored-by: Jesper Dramsch <[email protected]>
Co-authored-by: Matthew Chantry <[email protected]>
Co-authored-by: Simon Lang <[email protected]>
Co-authored-by: mihai.alexe <[email protected]>
  • Loading branch information
5 people committed May 15, 2024
1 parent e812346 commit c052d19
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 0 deletions.
256 changes: 256 additions & 0 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import logging
from typing import Optional

import einops
import torch
from hydra.utils import instantiate
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.models.distributed.helpers import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.utils.config import DotConfig

LOGGER = logging.getLogger(__name__)


class AnemoiModelEncProcDec(nn.Module):
"""Message passing graph neural network."""

def __init__(
self,
*,
config: DotConfig,
data_indices: dict,
graph_data: HeteroData,
) -> None:
"""Initializes the graph neural network.
Parameters
----------
config : DictConfig
Job configuration
graph_data : HeteroData
Graph definition
"""
super().__init__()

self._graph_data = graph_data
self._graph_name_data = config.graph.data
self._graph_name_hidden = config.graph.hidden

# Calculate shapes and indices
self.num_input_channels = len(data_indices.model.input)
self.num_output_channels = len(data_indices.model.output)
self._internal_input_idx = data_indices.model.input.prognostic
self._internal_output_idx = data_indices.model.output.prognostic

assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len(
data_indices.model.output.diagnostic
), (
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding "
f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})",
)
assert len(self._internal_input_idx) == len(
self._internal_output_idx,
), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

self.multi_step = config.training.multistep_input

# Define Sizes of different tensors
self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[
0
]
self._hidden_grid_size = self._graph_data[
(self._graph_name_hidden, "to", self._graph_name_hidden)
].hcoords_rad.shape[0]

self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.num_channels = config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size

# Encoder data -> hidden
self.encoder = instantiate(
config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
hidden_dim=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
src_grid_size=self._data_grid_size,
dst_grid_size=self._hidden_grid_size,
)

# Processor hidden -> hidden
self.processor = instantiate(
config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._hidden_grid_size,
)

# Decoder hidden -> data
self.decoder = instantiate(
config.model.decoder,
in_channels_src=self.num_channels,
in_channels_dst=input_dim,
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._data_grid_size,
)

def _register_latlon(self, name: str, key: str) -> None:
"""Register lat/lon buffers.
Parameters
----------
name : str
Name of grid to map
key : str
Key of the grid
"""
self.register_buffer(
f"latlons_{name}",
torch.cat(
[
torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]),
torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]),
],
dim=-1,
),
persistent=True,
)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
self.trainable_hidden = TrainableTensor(
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size
)

def _run_mapper(
self,
mapper: nn.Module,
data: tuple[Tensor],
batch_size: int,
shard_shapes: tuple[tuple[int, int], tuple[int, int]],
model_comm_group: Optional[ProcessGroup] = None,
use_reentrant: bool = False,
) -> Tensor:
"""Run mapper with activation checkpoint.
Parameters
----------
mapper : nn.Module
Which processor to use
data : tuple[Tensor]
tuple of data to pass in
batch_size: int,
Batch size
shard_shapes : tuple[tuple[int, int], tuple[int, int]]
Shard shapes for the data
model_comm_group : ProcessGroup
model communication group, specifies which GPUs work together
in one model instance
use_reentrant : bool, optional
Use reentrant, by default False
Returns
-------
Tensor
Mapped data
"""
return checkpoint(
mapper,
data,
batch_size=batch_size,
shard_shapes=shard_shapes,
model_comm_group=model_comm_group,
use_reentrant=use_reentrant,
)

def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor:
batch_size = x.shape[0]
ensemble_size = x.shape[2]

# add data positional info (lat/lon)
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.trainable_data(self.latlons_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size)

# get shard shapes
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)
shard_shapes_hidden = get_shape_shards(x_hidden_latent, 0, model_comm_group)

# Run encoder
x_data_latent, x_latent = self._run_mapper(
self.encoder,
(x_data_latent, x_hidden_latent),
batch_size=batch_size,
shard_shapes=(shard_shapes_data, shard_shapes_hidden),
model_comm_group=model_comm_group,
)

x_latent_proc = self.processor(
x_latent,
batch_size=batch_size,
shard_shapes=shard_shapes_hidden,
model_comm_group=model_comm_group,
)

# add skip connection (hidden -> hidden)
x_latent_proc = x_latent_proc + x_latent

# Run decoder
x_out = self._run_mapper(
self.decoder,
(x_latent_proc, x_data_latent),
batch_size=batch_size,
shard_shapes=(shard_shapes_hidden, shard_shapes_data),
model_comm_group=model_comm_group,
)

x_out = (
einops.rearrange(
x_out,
"(batch ensemble grid) vars -> batch ensemble grid vars",
batch=batch_size,
ensemble=ensemble_size,
)
.to(dtype=x.dtype)
.clone()
)

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]
return x_out
25 changes: 25 additions & 0 deletions src/anemoi/models/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


class DotConfig(dict):
"""A Config dictionary that allows access to its keys as attributes."""

def __init__(self, *args, **kwargs) -> None:
for a in args:
self.update(a)
self.update(kwargs)

def __getattr__(self, name):
if name in self:
x = self[name]
if isinstance(x, dict):
return DotConfig(x)
return x
raise AttributeError(name)

0 comments on commit c052d19

Please sign in to comment.