Skip to content

Commit

Permalink
test: data module, data indices and normalizer
Browse files Browse the repository at this point in the history
Co-authored-by: Jesper Dramsch <[email protected]>

Co-authored-by: Matthew Chantry <[email protected]>

Co-authored-by: Mihai Alexe <[email protected]

Co-authored-by: Florian Pinault <[email protected]>

Co-authored-by: Baudouin Raoult <[email protected]>
  • Loading branch information
theissenhelen and b8raoult committed May 15, 2024
1 parent 19c26c8 commit eba71c7
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 0 deletions.
92 changes: 92 additions & 0 deletions tests/data/data_indices/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import torch
from omegaconf import DictConfig

from anemoi.models.data.data_indices.collection import IndexCollection


@pytest.fixture()
def data_indices():
config = DictConfig(
{
"data": {
"forcing": ["x"],
"diagnostic": ["z", "q"],
},
},
)
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4}
return IndexCollection(config=config, name_to_index=name_to_index)


def test_dataindices_init(data_indices) -> None:
assert data_indices.data.input.includes == ["x"]
assert data_indices.data.input.excludes == ["z", "q"]
assert data_indices.data.output.includes == ["z", "q"]
assert data_indices.data.output.excludes == ["x"]
assert data_indices.model.input.includes == ["x"]
assert data_indices.model.input.excludes == []
assert data_indices.model.output.includes == ["z", "q"]
assert data_indices.model.output.excludes == []
assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4}
assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4}
assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "other": 2}
assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3}


def test_dataindices_max(data_indices) -> None:
assert max(data_indices.data.input.full) == max(data_indices.data.input.name_to_index.values())
assert max(data_indices.data.output.full) == max(data_indices.data.output.name_to_index.values())
assert max(data_indices.model.input.full) == max(data_indices.model.input.name_to_index.values())
assert max(data_indices.model.output.full) == max(data_indices.model.output.name_to_index.values())


def test_dataindices_todict(data_indices) -> None:
expected_output = {
"input": {
"full": torch.Tensor([0, 1, 4]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"diagnostic": torch.Tensor([2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 4]).to(torch.int),
},
"output": {
"full": torch.Tensor([1, 2, 3, 4]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"diagnostic": torch.Tensor([2, 3]).to(torch.int),
"prognostic": torch.Tensor([1, 4]).to(torch.int),
},
}

for key in ["output", "input"]:
for subkey, value in data_indices.data.todict()[key].items():
assert subkey in expected_output[key]
assert torch.allclose(value, expected_output[key][subkey])


def test_modelindices_todict(data_indices) -> None:
expected_output = {
"input": {
"full": torch.Tensor([0, 1, 2]).to(torch.int),
"forcing": torch.Tensor([0]).to(torch.int),
"diagnostic": torch.Tensor([]).to(torch.int),
"prognostic": torch.Tensor([1, 2]).to(torch.int),
},
"output": {
"full": torch.Tensor([0, 1, 2, 3]).to(torch.int),
"forcing": torch.Tensor([]).to(torch.int),
"diagnostic": torch.Tensor([1, 2]).to(torch.int),
"prognostic": torch.Tensor([0, 3]).to(torch.int),
},
}

for key in ["output", "input"]:
for subkey, value in data_indices.model.todict()[key].items():
assert subkey in expected_output[key]
assert torch.allclose(value, expected_output[key][subkey])
168 changes: 168 additions & 0 deletions tests/data/data_indices/test_data_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import torch
import yaml

from anemoi.models.data.data_indices.collection import IndexCollection
from anemoi.models.data.data_indices.index import BaseIndex
from anemoi.models.data.data_indices.index import DataIndex
from anemoi.models.data.data_indices.tensor import BaseTensorIndex
from anemoi.models.data.data_indices.tensor import InputTensorIndex
from anemoi.models.data.data_indices.tensor import OutputTensorIndex


@pytest.mark.data_dependent()
@pytest.mark.parametrize(
("data_model", "in_out", "full_only_prognostic"),
[
(a, b, c)
for a in ["data", "model"]
for b in ["input", "output"]
for c in ["full", "forcing", "diagnostic", "prognostic"]
],
)
def test_dataindex_types(datamodule, data_model, in_out, full_only_prognostic) -> None:
assert hasattr(datamodule, "data_indices")
assert isinstance(datamodule.data_indices, IndexCollection)
data_indices = datamodule.data_indices

assert hasattr(data_indices, data_model)
assert isinstance(getattr(data_indices, data_model), BaseIndex)
assert hasattr(getattr(data_indices, data_model), in_out)
assert isinstance(getattr(getattr(data_indices, data_model), in_out), BaseTensorIndex)
assert hasattr(getattr(getattr(data_indices, data_model), in_out), full_only_prognostic)
assert isinstance(getattr(getattr(getattr(data_indices, data_model), in_out), full_only_prognostic), torch.Tensor)


@pytest.fixture()
def fake_data():
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4}
forcing = ["x", "y"]
diagnostic = ["z"]
return forcing, diagnostic, name_to_index


@pytest.fixture()
def input_tensor_index(fake_data):
forcing, diagnostic, name_to_index = fake_data
return InputTensorIndex(includes=forcing, excludes=diagnostic, name_to_index=name_to_index)


@pytest.fixture()
def output_tensor_index(fake_data):
forcing, diagnostic, name_to_index = fake_data
return OutputTensorIndex(includes=diagnostic, excludes=forcing, name_to_index=name_to_index)


def test_dataindex_init(fake_data, input_tensor_index, output_tensor_index) -> None:
forcing, diagnostic, name_to_index = fake_data
data_index = DataIndex(forcing=forcing, diagnostic=diagnostic, name_to_index=name_to_index)
assert data_index.input == input_tensor_index
assert data_index.output == output_tensor_index


def test_output_tensor_index_full(output_tensor_index) -> None:
expected_output = torch.Tensor([2, 3, 4]).to(torch.int)
assert torch.allclose(output_tensor_index.full, expected_output)


def test_output_tensor_index_only(output_tensor_index) -> None:
expected_output = torch.Tensor([2]).to(torch.int)
assert torch.allclose(output_tensor_index._only, expected_output)


def test_output_tensor_index_prognostic(output_tensor_index) -> None:
expected_output = torch.Tensor([3, 4]).to(torch.int)
assert torch.allclose(output_tensor_index.prognostic, expected_output)


def test_output_tensor_index_todict(output_tensor_index) -> None:
expected_output = {
"full": torch.Tensor([2, 3, 4]).to(torch.int),
"diagnostic": torch.Tensor([2]).to(torch.int),
"forcing": torch.Tensor([0, 1]).to(torch.int),
"prognostic": torch.Tensor([3, 4]).to(torch.int),
}
for key, value in output_tensor_index.todict().items():
assert key in expected_output
assert torch.allclose(value, expected_output[key])


def test_output_tensor_index_getattr(output_tensor_index) -> None:
assert output_tensor_index.full is not None
with pytest.raises(AttributeError):
output_tensor_index.z


def test_output_tensor_index_build_idx_from_excludes(output_tensor_index) -> None:
expected_output = torch.Tensor([2, 3, 4]).to(torch.int)
assert torch.allclose(output_tensor_index._build_idx_from_excludes(), expected_output)


def test_output_tensor_index_build_idx_from_includes(output_tensor_index) -> None:
expected_output = torch.Tensor([2]).to(torch.int)
assert torch.allclose(output_tensor_index._build_idx_from_includes(), expected_output)


def test_output_tensor_index_build_idx_prognostic(output_tensor_index) -> None:
expected_output = torch.Tensor([3, 4]).to(torch.int)
assert torch.allclose(output_tensor_index._build_idx_prognostic(), expected_output)


def test_input_tensor_index_full(input_tensor_index) -> None:
expected_output = torch.Tensor([0, 1, 3, 4]).to(torch.int)
assert torch.allclose(input_tensor_index.full, expected_output)


def test_input_tensor_index_only(input_tensor_index) -> None:
expected_output = torch.Tensor([0, 1]).to(torch.int)
assert torch.allclose(input_tensor_index._only, expected_output)


def test_input_tensor_index_prognostic(input_tensor_index) -> None:
expected_output = torch.Tensor([3, 4]).to(torch.int)
assert torch.allclose(input_tensor_index.prognostic, expected_output)


def test_input_tensor_index_todict(input_tensor_index) -> None:
expected_output = {
"full": torch.Tensor([0, 1, 3, 4]).to(torch.int),
"diagnostic": torch.Tensor([2]).to(torch.int),
"forcing": torch.Tensor([0, 1]).to(torch.int),
"prognostic": torch.Tensor([3, 4]).to(torch.int),
}
for key, value in input_tensor_index.todict().items():
assert key in expected_output
assert torch.allclose(value, expected_output[key])


def test_input_tensor_index_getattr(input_tensor_index) -> None:
assert input_tensor_index.full is not None
with pytest.raises(AttributeError):
input_tensor_index.z


def test_input_tensor_index_build_idx_from_excludes(input_tensor_index) -> None:
expected_output = torch.Tensor([0, 1, 3, 4]).to(torch.int)
assert torch.allclose(input_tensor_index._build_idx_from_excludes(), expected_output)


def test_input_tensor_index_build_idx_from_includes(input_tensor_index) -> None:
expected_output = torch.Tensor([0, 1]).to(torch.int)
assert torch.allclose(input_tensor_index._build_idx_from_includes(), expected_output)


def test_input_tensor_index_build_idx_prognostic(input_tensor_index) -> None:
expected_output = torch.Tensor([3, 4]).to(torch.int)
assert torch.allclose(input_tensor_index._build_idx_prognostic(), expected_output)


@pytest.mark.data_dependent()
def test_yaml_dump(datamodule) -> None:
yaml.dump(datamodule.data_indices)
97 changes: 97 additions & 0 deletions tests/data/test_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import torch

from anemoi.models.data.data_indices.collection import IndexCollection
from anemoi.models.data.data_indices.index import BaseIndex
from anemoi.models.data.data_indices.tensor import BaseTensorIndex


@pytest.mark.data_dependent()
def test_datamodule_datasets(datamodule) -> None:
assert hasattr(datamodule, "dataset_train")
assert hasattr(datamodule, "dataset_valid")
assert hasattr(datamodule, "dataset_test")


def test_datamodule_dataloaders(datamodule) -> None:
assert hasattr(datamodule, "train_dataloader")
assert hasattr(datamodule, "val_dataloader")
assert hasattr(datamodule, "test_dataloader")


@pytest.mark.data_dependent()
def test_datamodule_metadata(datamodule) -> None:
assert hasattr(datamodule, "metadata")
assert isinstance(datamodule.metadata, dict)


@pytest.mark.data_dependent()
def test_datamodule_statistics(datamodule) -> None:
assert hasattr(datamodule, "statistics")
assert isinstance(datamodule.statistics, dict)
assert "mean" in datamodule.statistics
assert "stdev" in datamodule.statistics
assert "minimum" in datamodule.statistics
assert "maximum" in datamodule.statistics


@pytest.mark.data_dependent()
@pytest.mark.parametrize(
("data_model", "in_out", "full_only_prognostic"),
[
(a, b, c)
for a in ["data", "model"]
for b in ["input", "output"]
for c in ["full", "forcing", "diagnostic", "prognostic"]
],
)
def test_datamodule_api(datamodule, data_model, in_out, full_only_prognostic) -> None:
assert hasattr(datamodule, "data_indices")
assert isinstance(datamodule.data_indices, IndexCollection)
assert hasattr(datamodule.data_indices, data_model)
assert isinstance(datamodule.data_indices[data_model], BaseIndex)
data_indices = getattr(datamodule.data_indices, data_model)
assert isinstance(getattr(data_indices, in_out), BaseTensorIndex)
assert hasattr(getattr(data_indices, in_out), full_only_prognostic)
assert isinstance(getattr(getattr(data_indices, in_out), full_only_prognostic), torch.Tensor)


@pytest.mark.data_dependent()
def test_datamodule_data_indices(datamodule) -> None:
# Check that different indices are split correctly
all_data = set(datamodule.data_indices.data.input.name_to_index.values())
assert (
set(datamodule.data_indices.data.input.full.numpy()).union(
datamodule.data_indices.data.input.name_to_index[v] for v in datamodule.config.data.diagnostic
)
== all_data
)
assert len(datamodule.data_indices.data.input.prognostic) <= len(datamodule.data_indices.data.input.full)
assert len(datamodule.data_indices.data.output.prognostic) <= len(datamodule.data_indices.data.output.full)
assert len(datamodule.data_indices.data.output.prognostic) == len(datamodule.data_indices.data.input.prognostic)

assert len(datamodule.data_indices.model.input.prognostic) <= len(datamodule.data_indices.model.input.full)
assert len(datamodule.data_indices.model.output.prognostic) <= len(datamodule.data_indices.model.output.full)
assert len(datamodule.data_indices.model.output.prognostic) == len(datamodule.data_indices.model.input.prognostic)


@pytest.mark.data_dependent()
def test_datamodule_batch(datamodule) -> None:
first_batch = next(iter(datamodule.train_dataloader()))
assert isinstance(first_batch, torch.Tensor)
assert first_batch.shape[-1] == len(
datamodule.data_indices.data.input.name_to_index.values(),
), "Batch should have all variables"
assert (
first_batch.shape[0] == datamodule.config.dataloader.batch_size.training
), "Batch should have correct batch size"
assert (
first_batch.shape[1] == datamodule.config.training.multistep_input + 1
), "Batch needs correct sequence length (steps + 1)"
Loading

0 comments on commit eba71c7

Please sign in to comment.