diff --git a/CHANGELOG.md b/CHANGELOG.md index 57078fd..963c62d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you! - configurabilty of the dropout probability in the the MultiHeadSelfAttention module - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) +- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) ### Changed - Bugfixes for CI diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 71703d9..5d96e73 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -11,6 +11,7 @@ import torch from torch import Tensor from torch import nn +from torch_geometric.data import HeteroData class TrainableTensor(nn.Module): @@ -35,8 +36,47 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None: def forward(self, x: Tensor, batch_size: int) -> Tensor: latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)] if self.trainable is not None: - latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size)) + latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size)) return torch.cat( latent, dim=-1, # feature dimension ) + + +class NamedNodesAttributes(torch.nn.Module): + """Named Node Attributes Module.""" + + def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: + """Initialize NamedNodesAttributes.""" + super().__init__() + + self.num_trainable_params = num_trainable_params + self.register_fixed_attributes(graph_data) + + self.trainable_tensors = nn.ModuleDict() + for nodes_name in self.nodes_names: + self.register_coordinates(nodes_name, graph_data[nodes_name].x) + self.register_tensor(nodes_name) + + def register_fixed_attributes(self, graph_data: HeteroData) -> None: + """Register fixed attributes.""" + self.nodes_names = list(graph_data.node_types) + self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} + self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} + self.attr_ndims = { + nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names + } + + def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: + """Register coordinates.""" + sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) + self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) + + def register_tensor(self, name: str) -> None: + """Register a trainable tensor.""" + self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params) + + def forward(self, name: str, batch_size: int) -> Tensor: + """Forward pass.""" + latlons = getattr(self, f"latlons_{name}") + return self.trainable_tensors[name](latlons, batch_size) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c77db6e..592d6d4 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -21,7 +21,7 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.graph import NamedNodesAttributes LOGGER = logging.getLogger(__name__) @@ -55,33 +55,24 @@ def __init__( self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - - self.multi_step = model_config.training.multistep_input - - self._define_tensor_sizes(model_config) - - # Create trainable tensors - self._create_trainable_attributes() - - # Register lat/lon of nodes - self._register_latlon("data", self._graph_name_data) - self._register_latlon("hidden", self._graph_name_hidden) - self.data_indices = data_indices + self.multi_step = model_config.training.multistep_input self.num_channels = model_config.model.num_channels - input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] # Encoder data -> hidden self.encoder = instantiate( model_config.model.encoder, in_channels_src=input_dim, - in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden], hidden_dim=self.num_channels, sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], - src_grid_size=self._data_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Processor hidden -> hidden @@ -89,8 +80,8 @@ def __init__( model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data @@ -101,8 +92,8 @@ def __init__( hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._data_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -132,34 +123,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotDict) -> None: - self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes - self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes - - self.trainable_data_size = config.model.trainable_parameters.data - self.trainable_hidden_size = config.model.trainable_parameters.hidden - - def _register_latlon(self, name: str, nodes: str) -> None: - """Register lat/lon buffers. - - Parameters - ---------- - name : str - Name to store the lat-lon coordinates of the nodes. - nodes : str - Name of nodes to map - """ - coords = self._graph_data[nodes].x - sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - - def _create_trainable_attributes(self) -> None: - """Create all trainable attributes.""" - self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) - self.trainable_hidden = TrainableTensor( - trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size - ) - def _run_mapper( self, mapper: nn.Module, @@ -209,12 +172,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_data_latent = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - self.trainable_data(self.latlons_data, batch_size=batch_size), + self.node_attributes(self._graph_name_data, batch_size=batch_size), ), dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size) # get shard shapes shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)