diff --git a/models/src/anemoi/models/preprocessing/imputer.py b/models/src/anemoi/models/preprocessing/imputer.py index d0bd5896..536b1709 100644 --- a/models/src/anemoi/models/preprocessing/imputer.py +++ b/models/src/anemoi/models/preprocessing/imputer.py @@ -124,7 +124,11 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: if not in_place: x = x.clone() - # Initialize nan mask once + # Reset NaN locations outside of training for validation and inference. + if not self.training: + self.nan_locations = None + + # Initialise mask if not cached. if self.nan_locations is None: # Get NaN locations diff --git a/models/tests/preprocessing/test_preprocessor_imputer.py b/models/tests/preprocessing/test_preprocessor_imputer.py index 5d1035eb..8557b64b 100644 --- a/models/tests/preprocessing/test_preprocessor_imputer.py +++ b/models/tests/preprocessing/test_preprocessor_imputer.py @@ -130,14 +130,17 @@ def default_constant_data(): return base, expected +fixture_combinations = ( + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), +) + + @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None: """Check that the imputer does not modify the input tensor when in_place=False.""" @@ -150,12 +153,7 @@ def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None: @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None: """Check that the imputer modifies the input tensor when in_place=True.""" @@ -169,12 +167,7 @@ def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None: @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_transform_with_nan(imputer_fixture, data_fixture, request): """Check that the imputer correctly transforms a tensor with NaNs.""" @@ -186,12 +179,7 @@ def test_transform_with_nan(imputer_fixture, data_fixture, request): @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_transform_with_nan_small(imputer_fixture, data_fixture, request): """Check that the imputer correctly transforms a tensor with NaNs.""" @@ -211,12 +199,7 @@ def test_transform_with_nan_small(imputer_fixture, data_fixture, request): @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_transform_with_nan_inference(imputer_fixture, data_fixture, request): """Check that the imputer correctly transforms a tensor with NaNs in inference.""" @@ -244,12 +227,7 @@ def test_transform_with_nan_inference(imputer_fixture, data_fixture, request): @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_transform_noop(imputer_fixture, data_fixture, request): """Check that the imputer does not modify a tensor without NaNs.""" @@ -262,12 +240,7 @@ def test_transform_noop(imputer_fixture, data_fixture, request): @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_inverse_transform(imputer_fixture, data_fixture, request): """Check that the imputer correctly inverts the transformation.""" @@ -281,12 +254,7 @@ def test_inverse_transform(imputer_fixture, data_fixture, request): @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_mask_saving(imputer_fixture, data_fixture, request): """Check that the imputer saves the NaN mask correctly.""" @@ -299,12 +267,7 @@ def test_mask_saving(imputer_fixture, data_fixture, request): @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), - [ - ("default_constant_imputer", "default_constant_data"), - ("non_default_constant_imputer", "non_default_constant_data"), - ("default_input_imputer", "default_input_data"), - ("non_default_input_imputer", "non_default_input_data"), - ], + fixture_combinations, ) def test_loss_nan_mask(imputer_fixture, data_fixture, request): """Check that the imputer correctly transforms a tensor with NaNs.""" @@ -336,3 +299,45 @@ def test_reuse_imputer(imputer_fixture, data_fixture, request): assert torch.allclose( transformed2, expected, equal_nan=True ), "Imputer does not reuse mask correctly on subsequent runs." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + fixture_combinations, +) +def test_inference_imputer(imputer_fixture, data_fixture, request): + """Check that the imputer resets its mask during inference.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + + # Check training flag + assert imputer.training, "Imputer is not set to training mode." + + expected_mask = torch.isnan(x) + transformed = imputer.transform(x, in_place=False) + assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly." + restored = imputer.inverse_transform(transformed, in_place=False) + assert torch.allclose(restored, x, equal_nan=True), "Inverse transform does not restore NaNs correctly." + assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run." + + imputer.eval() + with torch.no_grad(): + x2 = x.roll(-1, dims=0) + expected2 = expected.roll(-1, dims=0) + expected_mask2 = torch.isnan(x2) + + assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run." + + # Check training flag + assert not imputer.training, "Imputer is not set to evaluation mode." + + assert not torch.allclose(x, x2, equal_nan=True), "Failed to modify the input data." + assert not torch.allclose(expected, expected2, equal_nan=True), "Failed to modify the expected data." + assert not torch.allclose(expected_mask, expected_mask2, equal_nan=True), "Failed to modify the nan mask." + + transformed = imputer.transform(x2, in_place=False) + assert torch.allclose(transformed, expected2, equal_nan=True), "Transform does not handle NaNs correctly." + restored = imputer.inverse_transform(transformed, in_place=False) + assert torch.allclose(restored, x2, equal_nan=True), "Inverse transform does not restore NaNs correctly." + + assert torch.equal(imputer.nan_locations, expected_mask2), "Mask not saved correctly after evaluation run."