-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: data module, data indices and normalizer
Co-authored-by: Jesper Dramsch <[email protected]> Co-authored-by: Matthew Chantry <[email protected]> Co-authored-by: Mihai Alexe <[email protected] Co-authored-by: Florian Pinault <[email protected]> Co-authored-by: Baudouin Raoult <[email protected]>
- Loading branch information
1 parent
19c26c8
commit eba71c7
Showing
4 changed files
with
444 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
import torch | ||
from omegaconf import DictConfig | ||
|
||
from anemoi.models.data.data_indices.collection import IndexCollection | ||
|
||
|
||
@pytest.fixture() | ||
def data_indices(): | ||
config = DictConfig( | ||
{ | ||
"data": { | ||
"forcing": ["x"], | ||
"diagnostic": ["z", "q"], | ||
}, | ||
}, | ||
) | ||
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} | ||
return IndexCollection(config=config, name_to_index=name_to_index) | ||
|
||
|
||
def test_dataindices_init(data_indices) -> None: | ||
assert data_indices.data.input.includes == ["x"] | ||
assert data_indices.data.input.excludes == ["z", "q"] | ||
assert data_indices.data.output.includes == ["z", "q"] | ||
assert data_indices.data.output.excludes == ["x"] | ||
assert data_indices.model.input.includes == ["x"] | ||
assert data_indices.model.input.excludes == [] | ||
assert data_indices.model.output.includes == ["z", "q"] | ||
assert data_indices.model.output.excludes == [] | ||
assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} | ||
assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} | ||
assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "other": 2} | ||
assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3} | ||
|
||
|
||
def test_dataindices_max(data_indices) -> None: | ||
assert max(data_indices.data.input.full) == max(data_indices.data.input.name_to_index.values()) | ||
assert max(data_indices.data.output.full) == max(data_indices.data.output.name_to_index.values()) | ||
assert max(data_indices.model.input.full) == max(data_indices.model.input.name_to_index.values()) | ||
assert max(data_indices.model.output.full) == max(data_indices.model.output.name_to_index.values()) | ||
|
||
|
||
def test_dataindices_todict(data_indices) -> None: | ||
expected_output = { | ||
"input": { | ||
"full": torch.Tensor([0, 1, 4]).to(torch.int), | ||
"forcing": torch.Tensor([0]).to(torch.int), | ||
"diagnostic": torch.Tensor([2, 3]).to(torch.int), | ||
"prognostic": torch.Tensor([1, 4]).to(torch.int), | ||
}, | ||
"output": { | ||
"full": torch.Tensor([1, 2, 3, 4]).to(torch.int), | ||
"forcing": torch.Tensor([0]).to(torch.int), | ||
"diagnostic": torch.Tensor([2, 3]).to(torch.int), | ||
"prognostic": torch.Tensor([1, 4]).to(torch.int), | ||
}, | ||
} | ||
|
||
for key in ["output", "input"]: | ||
for subkey, value in data_indices.data.todict()[key].items(): | ||
assert subkey in expected_output[key] | ||
assert torch.allclose(value, expected_output[key][subkey]) | ||
|
||
|
||
def test_modelindices_todict(data_indices) -> None: | ||
expected_output = { | ||
"input": { | ||
"full": torch.Tensor([0, 1, 2]).to(torch.int), | ||
"forcing": torch.Tensor([0]).to(torch.int), | ||
"diagnostic": torch.Tensor([]).to(torch.int), | ||
"prognostic": torch.Tensor([1, 2]).to(torch.int), | ||
}, | ||
"output": { | ||
"full": torch.Tensor([0, 1, 2, 3]).to(torch.int), | ||
"forcing": torch.Tensor([]).to(torch.int), | ||
"diagnostic": torch.Tensor([1, 2]).to(torch.int), | ||
"prognostic": torch.Tensor([0, 3]).to(torch.int), | ||
}, | ||
} | ||
|
||
for key in ["output", "input"]: | ||
for subkey, value in data_indices.model.todict()[key].items(): | ||
assert subkey in expected_output[key] | ||
assert torch.allclose(value, expected_output[key][subkey]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
import torch | ||
import yaml | ||
|
||
from anemoi.models.data.data_indices.collection import IndexCollection | ||
from anemoi.models.data.data_indices.index import BaseIndex | ||
from anemoi.models.data.data_indices.index import DataIndex | ||
from anemoi.models.data.data_indices.tensor import BaseTensorIndex | ||
from anemoi.models.data.data_indices.tensor import InputTensorIndex | ||
from anemoi.models.data.data_indices.tensor import OutputTensorIndex | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
@pytest.mark.parametrize( | ||
("data_model", "in_out", "full_only_prognostic"), | ||
[ | ||
(a, b, c) | ||
for a in ["data", "model"] | ||
for b in ["input", "output"] | ||
for c in ["full", "forcing", "diagnostic", "prognostic"] | ||
], | ||
) | ||
def test_dataindex_types(datamodule, data_model, in_out, full_only_prognostic) -> None: | ||
assert hasattr(datamodule, "data_indices") | ||
assert isinstance(datamodule.data_indices, IndexCollection) | ||
data_indices = datamodule.data_indices | ||
|
||
assert hasattr(data_indices, data_model) | ||
assert isinstance(getattr(data_indices, data_model), BaseIndex) | ||
assert hasattr(getattr(data_indices, data_model), in_out) | ||
assert isinstance(getattr(getattr(data_indices, data_model), in_out), BaseTensorIndex) | ||
assert hasattr(getattr(getattr(data_indices, data_model), in_out), full_only_prognostic) | ||
assert isinstance(getattr(getattr(getattr(data_indices, data_model), in_out), full_only_prognostic), torch.Tensor) | ||
|
||
|
||
@pytest.fixture() | ||
def fake_data(): | ||
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} | ||
forcing = ["x", "y"] | ||
diagnostic = ["z"] | ||
return forcing, diagnostic, name_to_index | ||
|
||
|
||
@pytest.fixture() | ||
def input_tensor_index(fake_data): | ||
forcing, diagnostic, name_to_index = fake_data | ||
return InputTensorIndex(includes=forcing, excludes=diagnostic, name_to_index=name_to_index) | ||
|
||
|
||
@pytest.fixture() | ||
def output_tensor_index(fake_data): | ||
forcing, diagnostic, name_to_index = fake_data | ||
return OutputTensorIndex(includes=diagnostic, excludes=forcing, name_to_index=name_to_index) | ||
|
||
|
||
def test_dataindex_init(fake_data, input_tensor_index, output_tensor_index) -> None: | ||
forcing, diagnostic, name_to_index = fake_data | ||
data_index = DataIndex(forcing=forcing, diagnostic=diagnostic, name_to_index=name_to_index) | ||
assert data_index.input == input_tensor_index | ||
assert data_index.output == output_tensor_index | ||
|
||
|
||
def test_output_tensor_index_full(output_tensor_index) -> None: | ||
expected_output = torch.Tensor([2, 3, 4]).to(torch.int) | ||
assert torch.allclose(output_tensor_index.full, expected_output) | ||
|
||
|
||
def test_output_tensor_index_only(output_tensor_index) -> None: | ||
expected_output = torch.Tensor([2]).to(torch.int) | ||
assert torch.allclose(output_tensor_index._only, expected_output) | ||
|
||
|
||
def test_output_tensor_index_prognostic(output_tensor_index) -> None: | ||
expected_output = torch.Tensor([3, 4]).to(torch.int) | ||
assert torch.allclose(output_tensor_index.prognostic, expected_output) | ||
|
||
|
||
def test_output_tensor_index_todict(output_tensor_index) -> None: | ||
expected_output = { | ||
"full": torch.Tensor([2, 3, 4]).to(torch.int), | ||
"diagnostic": torch.Tensor([2]).to(torch.int), | ||
"forcing": torch.Tensor([0, 1]).to(torch.int), | ||
"prognostic": torch.Tensor([3, 4]).to(torch.int), | ||
} | ||
for key, value in output_tensor_index.todict().items(): | ||
assert key in expected_output | ||
assert torch.allclose(value, expected_output[key]) | ||
|
||
|
||
def test_output_tensor_index_getattr(output_tensor_index) -> None: | ||
assert output_tensor_index.full is not None | ||
with pytest.raises(AttributeError): | ||
output_tensor_index.z | ||
|
||
|
||
def test_output_tensor_index_build_idx_from_excludes(output_tensor_index) -> None: | ||
expected_output = torch.Tensor([2, 3, 4]).to(torch.int) | ||
assert torch.allclose(output_tensor_index._build_idx_from_excludes(), expected_output) | ||
|
||
|
||
def test_output_tensor_index_build_idx_from_includes(output_tensor_index) -> None: | ||
expected_output = torch.Tensor([2]).to(torch.int) | ||
assert torch.allclose(output_tensor_index._build_idx_from_includes(), expected_output) | ||
|
||
|
||
def test_output_tensor_index_build_idx_prognostic(output_tensor_index) -> None: | ||
expected_output = torch.Tensor([3, 4]).to(torch.int) | ||
assert torch.allclose(output_tensor_index._build_idx_prognostic(), expected_output) | ||
|
||
|
||
def test_input_tensor_index_full(input_tensor_index) -> None: | ||
expected_output = torch.Tensor([0, 1, 3, 4]).to(torch.int) | ||
assert torch.allclose(input_tensor_index.full, expected_output) | ||
|
||
|
||
def test_input_tensor_index_only(input_tensor_index) -> None: | ||
expected_output = torch.Tensor([0, 1]).to(torch.int) | ||
assert torch.allclose(input_tensor_index._only, expected_output) | ||
|
||
|
||
def test_input_tensor_index_prognostic(input_tensor_index) -> None: | ||
expected_output = torch.Tensor([3, 4]).to(torch.int) | ||
assert torch.allclose(input_tensor_index.prognostic, expected_output) | ||
|
||
|
||
def test_input_tensor_index_todict(input_tensor_index) -> None: | ||
expected_output = { | ||
"full": torch.Tensor([0, 1, 3, 4]).to(torch.int), | ||
"diagnostic": torch.Tensor([2]).to(torch.int), | ||
"forcing": torch.Tensor([0, 1]).to(torch.int), | ||
"prognostic": torch.Tensor([3, 4]).to(torch.int), | ||
} | ||
for key, value in input_tensor_index.todict().items(): | ||
assert key in expected_output | ||
assert torch.allclose(value, expected_output[key]) | ||
|
||
|
||
def test_input_tensor_index_getattr(input_tensor_index) -> None: | ||
assert input_tensor_index.full is not None | ||
with pytest.raises(AttributeError): | ||
input_tensor_index.z | ||
|
||
|
||
def test_input_tensor_index_build_idx_from_excludes(input_tensor_index) -> None: | ||
expected_output = torch.Tensor([0, 1, 3, 4]).to(torch.int) | ||
assert torch.allclose(input_tensor_index._build_idx_from_excludes(), expected_output) | ||
|
||
|
||
def test_input_tensor_index_build_idx_from_includes(input_tensor_index) -> None: | ||
expected_output = torch.Tensor([0, 1]).to(torch.int) | ||
assert torch.allclose(input_tensor_index._build_idx_from_includes(), expected_output) | ||
|
||
|
||
def test_input_tensor_index_build_idx_prognostic(input_tensor_index) -> None: | ||
expected_output = torch.Tensor([3, 4]).to(torch.int) | ||
assert torch.allclose(input_tensor_index._build_idx_prognostic(), expected_output) | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
def test_yaml_dump(datamodule) -> None: | ||
yaml.dump(datamodule.data_indices) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
import torch | ||
|
||
from anemoi.models.data.data_indices.collection import IndexCollection | ||
from anemoi.models.data.data_indices.index import BaseIndex | ||
from anemoi.models.data.data_indices.tensor import BaseTensorIndex | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
def test_datamodule_datasets(datamodule) -> None: | ||
assert hasattr(datamodule, "dataset_train") | ||
assert hasattr(datamodule, "dataset_valid") | ||
assert hasattr(datamodule, "dataset_test") | ||
|
||
|
||
def test_datamodule_dataloaders(datamodule) -> None: | ||
assert hasattr(datamodule, "train_dataloader") | ||
assert hasattr(datamodule, "val_dataloader") | ||
assert hasattr(datamodule, "test_dataloader") | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
def test_datamodule_metadata(datamodule) -> None: | ||
assert hasattr(datamodule, "metadata") | ||
assert isinstance(datamodule.metadata, dict) | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
def test_datamodule_statistics(datamodule) -> None: | ||
assert hasattr(datamodule, "statistics") | ||
assert isinstance(datamodule.statistics, dict) | ||
assert "mean" in datamodule.statistics | ||
assert "stdev" in datamodule.statistics | ||
assert "minimum" in datamodule.statistics | ||
assert "maximum" in datamodule.statistics | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
@pytest.mark.parametrize( | ||
("data_model", "in_out", "full_only_prognostic"), | ||
[ | ||
(a, b, c) | ||
for a in ["data", "model"] | ||
for b in ["input", "output"] | ||
for c in ["full", "forcing", "diagnostic", "prognostic"] | ||
], | ||
) | ||
def test_datamodule_api(datamodule, data_model, in_out, full_only_prognostic) -> None: | ||
assert hasattr(datamodule, "data_indices") | ||
assert isinstance(datamodule.data_indices, IndexCollection) | ||
assert hasattr(datamodule.data_indices, data_model) | ||
assert isinstance(datamodule.data_indices[data_model], BaseIndex) | ||
data_indices = getattr(datamodule.data_indices, data_model) | ||
assert isinstance(getattr(data_indices, in_out), BaseTensorIndex) | ||
assert hasattr(getattr(data_indices, in_out), full_only_prognostic) | ||
assert isinstance(getattr(getattr(data_indices, in_out), full_only_prognostic), torch.Tensor) | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
def test_datamodule_data_indices(datamodule) -> None: | ||
# Check that different indices are split correctly | ||
all_data = set(datamodule.data_indices.data.input.name_to_index.values()) | ||
assert ( | ||
set(datamodule.data_indices.data.input.full.numpy()).union( | ||
datamodule.data_indices.data.input.name_to_index[v] for v in datamodule.config.data.diagnostic | ||
) | ||
== all_data | ||
) | ||
assert len(datamodule.data_indices.data.input.prognostic) <= len(datamodule.data_indices.data.input.full) | ||
assert len(datamodule.data_indices.data.output.prognostic) <= len(datamodule.data_indices.data.output.full) | ||
assert len(datamodule.data_indices.data.output.prognostic) == len(datamodule.data_indices.data.input.prognostic) | ||
|
||
assert len(datamodule.data_indices.model.input.prognostic) <= len(datamodule.data_indices.model.input.full) | ||
assert len(datamodule.data_indices.model.output.prognostic) <= len(datamodule.data_indices.model.output.full) | ||
assert len(datamodule.data_indices.model.output.prognostic) == len(datamodule.data_indices.model.input.prognostic) | ||
|
||
|
||
@pytest.mark.data_dependent() | ||
def test_datamodule_batch(datamodule) -> None: | ||
first_batch = next(iter(datamodule.train_dataloader())) | ||
assert isinstance(first_batch, torch.Tensor) | ||
assert first_batch.shape[-1] == len( | ||
datamodule.data_indices.data.input.name_to_index.values(), | ||
), "Batch should have all variables" | ||
assert ( | ||
first_batch.shape[0] == datamodule.config.dataloader.batch_size.training | ||
), "Batch should have correct batch size" | ||
assert ( | ||
first_batch.shape[1] == datamodule.config.training.multistep_input + 1 | ||
), "Batch needs correct sequence length (steps + 1)" |
Oops, something went wrong.