Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor node attributes #64

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you!
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)

### Changed
- Bugfixes for CI
Expand Down
42 changes: 41 additions & 1 deletion src/anemoi/models/layers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import Tensor
from torch import nn
from torch_geometric.data import HeteroData


class TrainableTensor(nn.Module):
Expand All @@ -35,8 +36,47 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None:
def forward(self, x: Tensor, batch_size: int) -> Tensor:
latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)]
if self.trainable is not None:
latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size))
latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size))
return torch.cat(
latent,
dim=-1, # feature dimension
)


class NamedNodesAttributes(torch.nn.Module):
"""Named Node Attributes Module."""

def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None:
"""Initialize NamedNodesAttributes."""
super().__init__()

self.num_trainable_params = num_trainable_params
self.register_fixed_attributes(graph_data)

self.trainable_tensors = nn.ModuleDict()
for nodes_name in self.nodes_names:
self.register_coordinates(nodes_name, graph_data[nodes_name].x)
self.register_tensor(nodes_name)

def register_fixed_attributes(self, graph_data: HeteroData) -> None:
"""Register fixed attributes."""
self.nodes_names = list(graph_data.node_types)
self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names}
self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names}
self.attr_ndims = {
nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names
}

def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None:
"""Register coordinates."""
sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def register_tensor(self, name: str) -> None:
"""Register a trainable tensor."""
self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params)

def forward(self, name: str, batch_size: int) -> Tensor:
"""Forward pass."""
latlons = getattr(self, f"latlons_{name}")
return self.trainable_tensors[name](latlons, batch_size)
65 changes: 14 additions & 51 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.graph import NamedNodesAttributes

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,42 +55,33 @@ def __init__(

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = model_config.training.multistep_input

self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon of nodes
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.data_indices = data_indices

self.multi_step = model_config.training.multistep_input
self.num_channels = model_config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)

input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]

# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden],
hidden_dim=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
src_grid_size=self._data_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Decoder hidden -> data
Expand All @@ -101,8 +92,8 @@ def __init__(
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._data_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
Expand Down Expand Up @@ -132,34 +123,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
self._internal_output_idx,
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes

self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _register_latlon(self, name: str, nodes: str) -> None:
"""Register lat/lon buffers.

Parameters
----------
name : str
Name to store the lat-lon coordinates of the nodes.
nodes : str
Name of nodes to map
"""
coords = self._graph_data[nodes].x
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
self.trainable_hidden = TrainableTensor(
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size
)

def _run_mapper(
self,
mapper: nn.Module,
Expand Down Expand Up @@ -209,12 +172,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.trainable_data(self.latlons_data, batch_size=batch_size),
self.node_attributes(self._graph_name_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size)
x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size)

# get shard shapes
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)
Expand Down
Loading