Skip to content

Commit

Permalink
refactor: Initial data implementation
Browse files Browse the repository at this point in the history
Dynamic data handling and indexing

Co-authored-by: Florian Pinault <[email protected]>
Co-authored-by: Baudouin Raoult <[email protected]>
Co-authored-by: Matthew Chantry <[email protected]>
Co-authored-by: mihai.alexe <[email protected]>
Co-authored-by: Simon Lang <[email protected]>
  • Loading branch information
6 people committed May 15, 2024
1 parent c052d19 commit e6c4225
Show file tree
Hide file tree
Showing 7 changed files with 934 additions and 0 deletions.
74 changes: 74 additions & 0 deletions src/anemoi/models/data/data_indices/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import operator

import yaml
from omegaconf import OmegaConf

from anemoi.models.data.data_indices.index import BaseIndex
from anemoi.models.data.data_indices.index import DataIndex
from anemoi.models.data.data_indices.index import ModelIndex
from anemoi.models.data.data_indices.tensor import BaseTensorIndex
from anemoi.models.data.data_indices.tensor import InputTensorIndex
from anemoi.models.data.data_indices.tensor import OutputTensorIndex


class IndexCollection:
"""Collection of data and model indices."""

def __init__(self, config, name_to_index) -> None:
self.config = OmegaConf.to_container(config, resolve=True)

self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True)
self.diagnostic = (
[] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True)
)

assert set(self.diagnostic).isdisjoint(self.forcing), (
f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ",
"Please drop them at a dataset-level to exclude them from the training data.",
)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
name_to_index_model_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic)
}
name_to_index_model_output = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing)
}

self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index)
self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output)

def __repr__(self) -> str:
return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})"

def __eq__(self, other):
if not isinstance(other, IndexCollection):
# don't attempt to compare against unrelated types
return NotImplemented

return self.model == other.model and self.data == other.data

def __getitem__(self, key):
return getattr(self, key)

def todict(self):
return {
"data": self.data.todict(),
"model": self.model.todict(),
}

@staticmethod
def representer(dumper, data):
return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data))


for cls in [BaseTensorIndex, InputTensorIndex, OutputTensorIndex, BaseIndex, DataIndex, ModelIndex, IndexCollection]:
yaml.add_representer(cls, cls.representer)
93 changes: 93 additions & 0 deletions src/anemoi/models/data/data_indices/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from anemoi.models.data.data_indices.tensor import InputTensorIndex
from anemoi.models.data.data_indices.tensor import OutputTensorIndex


class BaseIndex:
"""Base class for data and model indices."""

def __init__(self) -> None:
self.input = NotImplementedError
self.output = NotImplementedError

def __eq__(self, other):
if not isinstance(other, BaseIndex):
# don't attempt to compare against unrelated types
return NotImplemented

return self.input == other.input and self.output == other.output

def __repr__(self) -> str:
return f"{self.__class__.__name__}(input={self.input}, output={self.output})"

def __getitem__(self, key):
return getattr(self, key)

def todict(self):
return {
"input": self.input.todict(),
"output": self.output.todict(),
}

@staticmethod
def representer(dumper, data):
return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data))


class DataIndex(BaseIndex):
"""Indexing for data variables."""

def __init__(self, diagnostic, forcing, name_to_index) -> None:
self._diagnostic = diagnostic
self._forcing = forcing
self._name_to_index = name_to_index
self.input = InputTensorIndex(
includes=forcing,
excludes=diagnostic,
name_to_index=name_to_index,
)

self.output = OutputTensorIndex(
includes=diagnostic,
excludes=forcing,
name_to_index=name_to_index,
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, name_to_index={self._name_to_index})"


class ModelIndex(BaseIndex):
"""Indexing for model variables."""

def __init__(self, diagnostic, forcing, name_to_index_model_input, name_to_index_model_output) -> None:
self._diagnostic = diagnostic
self._forcing = forcing
self._name_to_index_model_input = name_to_index_model_input
self._name_to_index_model_output = name_to_index_model_output
self.input = InputTensorIndex(
includes=forcing,
excludes=[],
name_to_index=name_to_index_model_input,
)

self.output = OutputTensorIndex(
includes=diagnostic,
excludes=[],
name_to_index=name_to_index_model_output,
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, "
f"name_to_index_model_input={self._name_to_index_model_input}, "
f"name_to_index_model_output={self._name_to_index_model_output})"
)
114 changes: 114 additions & 0 deletions src/anemoi/models/data/data_indices/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import torch


class BaseTensorIndex:
"""Indexing for variables in index as Tensor."""

def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None:
"""Initialize indexing tensors from includes and excludes using name_to_index.
Parameters
----------
includes : list[str]
Variables to include in the indexing that are exclusive to this indexing.
e.g. Forcing variables for the input indexing, diagnostic variables for the output indexing
excludes : list[str]
Variables to exclude from the indexing.
e.g. Diagnostic variables for the input indexing, forcing variables for the output indexing
name_to_index : dict[str, int]
Dictionary mapping variable names to their index in the Tensor.
"""
self.includes = includes
self.excludes = excludes
self.name_to_index = name_to_index

assert set(self.excludes).issubset(
self.name_to_index.keys(),
), f"Data indexing has invalid entries {[var for var in self.excludes if var not in self.name_to_index]}, not in dataset."
assert set(self.includes).issubset(
self.name_to_index.keys(),
), f"Data indexing has invalid entries {[var for var in self.includes if var not in self.name_to_index]}, not in dataset."

self.full = self._build_idx_from_excludes()
self._only = self._build_idx_from_includes()
self._removed = self._build_idx_from_includes(self.excludes)
self.prognostic = self._build_idx_prognostic()
self.diagnostic = NotImplementedError
self.forcing = NotImplementedError

def __len__(self) -> int:
return len(self.full)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(includes={self.includes}, excludes={self.excludes}, name_to_index={self.name_to_index})"

def __eq__(self, other):
if not isinstance(other, BaseTensorIndex):
# don't attempt to compare against unrelated types
return NotImplemented

return (
torch.allclose(self.full, other.full)
and torch.allclose(self._only, other._only)
and torch.allclose(self._removed, other._removed)
and torch.allclose(self.prognostic, other.prognostic)
and torch.allclose(self.diagnostic, other.diagnostic)
and torch.allclose(self.forcing, other.forcing)
and self.includes == other.includes
and self.excludes == other.excludes
)

def __getitem__(self, key):
return getattr(self, key)

def todict(self):
return {
"full": self.full,
"prognostic": self.prognostic,
"diagnostic": self.diagnostic,
"forcing": self.forcing,
}

@staticmethod
def representer(dumper, data):
return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data))

def _build_idx_from_excludes(self, excludes=None) -> "torch.Tensor[int]":
if excludes is None:
excludes = self.excludes
return torch.Tensor(sorted(i for name, i in self.name_to_index.items() if name not in excludes)).to(torch.int)

def _build_idx_from_includes(self, includes=None) -> "torch.Tensor[int]":
if includes is None:
includes = self.includes
return torch.Tensor(sorted(self.name_to_index[name] for name in includes)).to(torch.int)

def _build_idx_prognostic(self) -> "torch.Tensor[int]":
return self._build_idx_from_excludes(self.includes + self.excludes)


class InputTensorIndex(BaseTensorIndex):
"""Indexing for input variables."""

def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None:
super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index)
self.forcing = self._only
self.diagnostic = self._removed


class OutputTensorIndex(BaseTensorIndex):
"""Indexing for output variables."""

def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None:
super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index)
self.forcing = self._removed
self.diagnostic = self._only
Loading

0 comments on commit e6c4225

Please sign in to comment.