Skip to content

Commit

Permalink
style: dict in config files for defining the variables to be remapped…
Browse files Browse the repository at this point in the history
…. structure and additional assert in index collection.
  • Loading branch information
sahahner committed Sep 9, 2024
1 parent 167c1ee commit 6b02507
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#

import operator
from itertools import chain

import yaml
from omegaconf import OmegaConf
Expand All @@ -26,18 +25,14 @@ class IndexCollection:

def __init__(self, config, name_to_index) -> None:
self.config = OmegaConf.to_container(config, resolve=True)

self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True)
self.diagnostic = (
[] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True)
)
# config.data.remapped is a list of diccionaries: every remapper is one entry of the list
self.remapped = (
[]
if config.data.remapped is None
else dict(
chain.from_iterable(d.items() for d in OmegaConf.to_container(config.data.remapped, resolve=True))
)
dict() if config.data.remapped is None else OmegaConf.to_container(config.data.remapped, resolve=True)
)
self.forcing_remapped = self.forcing.copy()

Expand All @@ -48,16 +43,20 @@ def __init__(self, config, name_to_index) -> None:
assert set(self.remapped).isdisjoint(self.diagnostic), (
"Remapped variable overlap with diagnostic variables. Not implemented.",
)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
name_to_index_internal_data_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.remapped)
}
assert set(self.remapped).issubset(self.name_to_index), (
"Remapping a variable that does not exist in the dataset. Check for typos: ",
f"{set(self.remapped).difference(self.name_to_index)}",
)
name_to_index_model_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic)
}
name_to_index_model_output = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing)
}
# remove remapped variables from internal data and model indices
name_to_index_internal_data_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.remapped)
}
name_to_index_internal_model_input = {
name: i for i, name in enumerate(key for key in name_to_index_model_input if key not in self.remapped)
}
Expand Down

0 comments on commit 6b02507

Please sign in to comment.