Skip to content

Commit

Permalink
fix(models): 74 imputer inference mode (#127)
Browse files Browse the repository at this point in the history
* feature: reset nan mask outside of training

* refactor: common test combinations for different training and evaluation mode

---------

Authored by: Jesper Dramsch <[email protected]>
Co-authored-by: Harrison Cook <[email protected]>
  • Loading branch information
sahahner authored Feb 10, 2025
1 parent dfef377 commit 0a9cfa7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 55 deletions.
6 changes: 5 additions & 1 deletion models/src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()

# Initialize nan mask once
# Reset NaN locations outside of training for validation and inference.
if not self.training:
self.nan_locations = None

# Initialise mask if not cached.
if self.nan_locations is None:

# Get NaN locations
Expand Down
113 changes: 59 additions & 54 deletions models/tests/preprocessing/test_preprocessor_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ def default_constant_data():
return base, expected


fixture_combinations = (
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
)


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None:
"""Check that the imputer does not modify the input tensor when in_place=False."""
Expand All @@ -150,12 +153,7 @@ def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None:

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None:
"""Check that the imputer modifies the input tensor when in_place=True."""
Expand All @@ -169,12 +167,7 @@ def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None:

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand All @@ -186,12 +179,7 @@ def test_transform_with_nan(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan_small(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand All @@ -211,12 +199,7 @@ def test_transform_with_nan_small(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan_inference(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs in inference."""
Expand Down Expand Up @@ -244,12 +227,7 @@ def test_transform_with_nan_inference(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_noop(imputer_fixture, data_fixture, request):
"""Check that the imputer does not modify a tensor without NaNs."""
Expand All @@ -262,12 +240,7 @@ def test_transform_noop(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_inverse_transform(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly inverts the transformation."""
Expand All @@ -281,12 +254,7 @@ def test_inverse_transform(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_mask_saving(imputer_fixture, data_fixture, request):
"""Check that the imputer saves the NaN mask correctly."""
Expand All @@ -299,12 +267,7 @@ def test_mask_saving(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_loss_nan_mask(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand Down Expand Up @@ -336,3 +299,45 @@ def test_reuse_imputer(imputer_fixture, data_fixture, request):
assert torch.allclose(
transformed2, expected, equal_nan=True
), "Imputer does not reuse mask correctly on subsequent runs."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
fixture_combinations,
)
def test_inference_imputer(imputer_fixture, data_fixture, request):
"""Check that the imputer resets its mask during inference."""
x, expected = request.getfixturevalue(data_fixture)
imputer = request.getfixturevalue(imputer_fixture)

# Check training flag
assert imputer.training, "Imputer is not set to training mode."

expected_mask = torch.isnan(x)
transformed = imputer.transform(x, in_place=False)
assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly."
restored = imputer.inverse_transform(transformed, in_place=False)
assert torch.allclose(restored, x, equal_nan=True), "Inverse transform does not restore NaNs correctly."
assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."

imputer.eval()
with torch.no_grad():
x2 = x.roll(-1, dims=0)
expected2 = expected.roll(-1, dims=0)
expected_mask2 = torch.isnan(x2)

assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."

# Check training flag
assert not imputer.training, "Imputer is not set to evaluation mode."

assert not torch.allclose(x, x2, equal_nan=True), "Failed to modify the input data."
assert not torch.allclose(expected, expected2, equal_nan=True), "Failed to modify the expected data."
assert not torch.allclose(expected_mask, expected_mask2, equal_nan=True), "Failed to modify the nan mask."

transformed = imputer.transform(x2, in_place=False)
assert torch.allclose(transformed, expected2, equal_nan=True), "Transform does not handle NaNs correctly."
restored = imputer.inverse_transform(transformed, in_place=False)
assert torch.allclose(restored, x2, equal_nan=True), "Inverse transform does not restore NaNs correctly."

assert torch.equal(imputer.nan_locations, expected_mask2), "Mask not saved correctly after evaluation run."

0 comments on commit 0a9cfa7

Please sign in to comment.