From 79c83215a5748e574105ba510e0ad66068af7456 Mon Sep 17 00:00:00 2001 From: Jesper Dramsch Date: Mon, 23 Sep 2024 16:52:14 +0200 Subject: [PATCH] [Feature] 28 Make models switchable through the config (#45) * feat: make model instantiateable * docs: instantiation explained in changelog * refactor: rename model config object * fix: rename config to model_config * fix: mark non-recursive * docs: changelog --- CHANGELOG.md | 1 + src/anemoi/models/interface/__init__.py | 11 +++++---- .../models/encoder_processor_decoder.py | 24 +++++++++---------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7811ef3..5678486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ Keep it human-readable, your future self will thank you! - Update CI to inherit from common infrastructue reusable workflows - run downstream-ci only when src and tests folders have changed - New error messages for wrongs graphs. +- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45) ### Removed diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 626940f..aba62a2 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -14,7 +14,6 @@ from hydra.utils import instantiate from torch_geometric.data import HeteroData -from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec from anemoi.models.preprocessing import Processors @@ -73,9 +72,13 @@ def _build_model(self) -> None: self.pre_processors = Processors(processors) self.post_processors = Processors(processors, inverse=True) - # Instantiate the model (Can be generalised to other models in the future, here we use AnemoiModelEncProcDec) - self.model = AnemoiModelEncProcDec( - config=self.config, data_indices=self.data_indices, graph_data=self.graph_data + # Instantiate the model + self.model = instantiate( + self.config.model.model, + model_config=self.config, + data_indices=self.data_indices, + graph_data=self.graph_data, + _recursive_=False, # Disables recursive instantiation by Hydra ) # Use the forward method of the model directly diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index b043b0c..aa7e8bb 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -32,7 +32,7 @@ class AnemoiModelEncProcDec(nn.Module): def __init__( self, *, - config: DotDict, + model_config: DotDict, data_indices: dict, graph_data: HeteroData, ) -> None: @@ -40,8 +40,8 @@ def __init__( Parameters ---------- - config : DotDict - Job configuration + model_config : DotDict + Model configuration data_indices : dict Data indices graph_data : HeteroData @@ -50,15 +50,15 @@ def __init__( super().__init__() self._graph_data = graph_data - self._graph_name_data = config.graph.data - self._graph_name_hidden = config.graph.hidden + self._graph_name_data = model_config.graph.data + self._graph_name_hidden = model_config.graph.hidden self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - self.multi_step = config.training.multistep_input + self.multi_step = model_config.training.multistep_input - self._define_tensor_sizes(config) + self._define_tensor_sizes(model_config) # Create trainable tensors self._create_trainable_attributes() @@ -69,13 +69,13 @@ def __init__( self.data_indices = data_indices - self.num_channels = config.model.num_channels + self.num_channels = model_config.model.num_channels input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size # Encoder data -> hidden self.encoder = instantiate( - config.model.encoder, + model_config.model.encoder, in_channels_src=input_dim, in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, hidden_dim=self.num_channels, @@ -86,7 +86,7 @@ def __init__( # Processor hidden -> hidden self.processor = instantiate( - config.model.processor, + model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], src_grid_size=self._hidden_grid_size, @@ -95,7 +95,7 @@ def __init__( # Decoder hidden -> data self.decoder = instantiate( - config.model.decoder, + model_config.model.decoder, in_channels_src=self.num_channels, in_channels_dst=input_dim, hidden_dim=self.num_channels, @@ -109,7 +109,7 @@ def __init__( self.boundings = nn.ModuleList( [ instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index) - for cfg in getattr(config.model, "bounding", []) + for cfg in getattr(model_config.model, "bounding", []) ] )