Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] 28 Make models switchable through the config #45

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Keep it human-readable, your future self will thank you!
- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)

### Removed

Expand Down
11 changes: 7 additions & 4 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec
from anemoi.models.preprocessing import Processors


Expand Down Expand Up @@ -73,9 +72,13 @@ def _build_model(self) -> None:
self.pre_processors = Processors(processors)
self.post_processors = Processors(processors, inverse=True)

# Instantiate the model (Can be generalised to other models in the future, here we use AnemoiModelEncProcDec)
self.model = AnemoiModelEncProcDec(
config=self.config, data_indices=self.data_indices, graph_data=self.graph_data
# Instantiate the model
self.model = instantiate(
self.config.model.model,
model_config=self.config,
data_indices=self.data_indices,
graph_data=self.graph_data,
_recursive_=False, # Disables recursive instantiation by Hydra
)

# Use the forward method of the model directly
Expand Down
24 changes: 12 additions & 12 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ class AnemoiModelEncProcDec(nn.Module):
def __init__(
self,
*,
config: DotDict,
model_config: DotDict,
data_indices: dict,
graph_data: HeteroData,
) -> None:
"""Initializes the graph neural network.

Parameters
----------
config : DotDict
Job configuration
model_config : DotDict
Model configuration
data_indices : dict
Data indices
graph_data : HeteroData
Expand All @@ -50,15 +50,15 @@ def __init__(
super().__init__()

self._graph_data = graph_data
self._graph_name_data = config.graph.data
self._graph_name_hidden = config.graph.hidden
self._graph_name_data = model_config.graph.data
self._graph_name_hidden = model_config.graph.hidden

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = config.training.multistep_input
self.multi_step = model_config.training.multistep_input

self._define_tensor_sizes(config)
self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()
Expand All @@ -69,13 +69,13 @@ def __init__(

self.data_indices = data_indices

self.num_channels = config.model.num_channels
self.num_channels = model_config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size

# Encoder data -> hidden
self.encoder = instantiate(
config.model.encoder,
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
hidden_dim=self.num_channels,
Expand All @@ -86,7 +86,7 @@ def __init__(

# Processor hidden -> hidden
self.processor = instantiate(
config.model.processor,
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
Expand All @@ -95,7 +95,7 @@ def __init__(

# Decoder hidden -> data
self.decoder = instantiate(
config.model.decoder,
model_config.model.decoder,
in_channels_src=self.num_channels,
in_channels_dst=input_dim,
hidden_dim=self.num_channels,
Expand All @@ -109,7 +109,7 @@ def __init__(
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index)
for cfg in getattr(config.model, "bounding", [])
for cfg in getattr(model_config.model, "bounding", [])
]
)

Expand Down
Loading