Skip to content

Commit

Permalink
[Feature] 28 Make models switchable through the config (#45)
Browse files Browse the repository at this point in the history
* feat: make model instantiateable

* docs: instantiation explained in changelog

* refactor: rename model config object

* fix: rename config to model_config

* fix: mark non-recursive

* docs: changelog
  • Loading branch information
JesperDramsch authored and theissenhelen committed Sep 27, 2024
1 parent 2b1d2ce commit 79c8321
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Keep it human-readable, your future self will thank you!
- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)

### Removed

Expand Down
11 changes: 7 additions & 4 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec
from anemoi.models.preprocessing import Processors


Expand Down Expand Up @@ -73,9 +72,13 @@ def _build_model(self) -> None:
self.pre_processors = Processors(processors)
self.post_processors = Processors(processors, inverse=True)

# Instantiate the model (Can be generalised to other models in the future, here we use AnemoiModelEncProcDec)
self.model = AnemoiModelEncProcDec(
config=self.config, data_indices=self.data_indices, graph_data=self.graph_data
# Instantiate the model
self.model = instantiate(
self.config.model.model,
model_config=self.config,
data_indices=self.data_indices,
graph_data=self.graph_data,
_recursive_=False, # Disables recursive instantiation by Hydra
)

# Use the forward method of the model directly
Expand Down
24 changes: 12 additions & 12 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ class AnemoiModelEncProcDec(nn.Module):
def __init__(
self,
*,
config: DotDict,
model_config: DotDict,
data_indices: dict,
graph_data: HeteroData,
) -> None:
"""Initializes the graph neural network.
Parameters
----------
config : DotDict
Job configuration
model_config : DotDict
Model configuration
data_indices : dict
Data indices
graph_data : HeteroData
Expand All @@ -50,15 +50,15 @@ def __init__(
super().__init__()

self._graph_data = graph_data
self._graph_name_data = config.graph.data
self._graph_name_hidden = config.graph.hidden
self._graph_name_data = model_config.graph.data
self._graph_name_hidden = model_config.graph.hidden

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = config.training.multistep_input
self.multi_step = model_config.training.multistep_input

self._define_tensor_sizes(config)
self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()
Expand All @@ -69,13 +69,13 @@ def __init__(

self.data_indices = data_indices

self.num_channels = config.model.num_channels
self.num_channels = model_config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size

# Encoder data -> hidden
self.encoder = instantiate(
config.model.encoder,
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
hidden_dim=self.num_channels,
Expand All @@ -86,7 +86,7 @@ def __init__(

# Processor hidden -> hidden
self.processor = instantiate(
config.model.processor,
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
Expand All @@ -95,7 +95,7 @@ def __init__(

# Decoder hidden -> data
self.decoder = instantiate(
config.model.decoder,
model_config.model.decoder,
in_channels_src=self.num_channels,
in_channels_dst=input_dim,
hidden_dim=self.num_channels,
Expand All @@ -109,7 +109,7 @@ def __init__(
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index)
for cfg in getattr(config.model, "bounding", [])
for cfg in getattr(model_config.model, "bounding", [])
]
)

Expand Down

0 comments on commit 79c8321

Please sign in to comment.