From 3e316aafec7e46629b4b5b7a35d13d3762ac55c7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 5 Oct 2023 10:02:14 +0200 Subject: [PATCH 1/3] add error on missing module --- src/torchmetrics/image/fid.py | 6 ++++++ tests/unittests/image/test_fid.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index e8ea1582ff5..2ba787cf017 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -55,6 +55,12 @@ def __init__( features_list: List[str], feature_extractor_weights_path: Optional[str] = None, ) -> None: + if not _TORCH_FIDELITY_AVAILABLE: + raise ModuleNotFoundError( + "NoTrainInceptionV3 module requires that `Torch-fidelity` is installed." + " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." + ) + super().__init__(name, features_list, feature_extractor_weights_path) # put into evaluation mode self.eval() diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 684803afbf5..61e0b1dc98e 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -19,12 +19,21 @@ import torch from torch.nn import Module from torch.utils.data import Dataset -from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image.fid import FrechetInceptionDistance, NoTrainInceptionV3 from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE, _TORCH_GREATER_EQUAL_1_9 torch.manual_seed(42) +@pytest.mark.skipif(_TORCH_FIDELITY_AVAILABLE, test="test only works if torch-fidelity is not installed") +def test_no_train_network_missing_torch_fidelity(): + """Assert that NoTrainInceptionV3 raises an error if torch-fidelity is not installed.""" + with pytest.raises( + ModuleNotFoundError, match="NoTrainInceptionV3 module requires that `Torch-fidelity` is installed.*" + ): + NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"]) + + @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9") @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") def test_no_train(): From c7885f275ff1c344f3ffc12b9864f78a81e0cc95 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 5 Oct 2023 10:17:20 +0200 Subject: [PATCH 2/3] provide reason --- tests/unittests/image/test_fid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 61e0b1dc98e..2fe3529f4c6 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -25,7 +25,7 @@ torch.manual_seed(42) -@pytest.mark.skipif(_TORCH_FIDELITY_AVAILABLE, test="test only works if torch-fidelity is not installed") +@pytest.mark.skipif(_TORCH_FIDELITY_AVAILABLE, reason="test only works if torch-fidelity is not installed") def test_no_train_network_missing_torch_fidelity(): """Assert that NoTrainInceptionV3 raises an error if torch-fidelity is not installed.""" with pytest.raises( From ed08f273dddad96af9b32e6a9a1c296bac04a4cf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 5 Oct 2023 10:19:28 +0200 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff0dd083d3e..7bcf3e90631 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084)) +- Added error if `NoTrainInceptionV3` is being initialized without `torch-fidelity` not being installed ([#2143](https://github.com/Lightning-AI/torchmetrics/pull/2143)) + + ### Changed - Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089)) @@ -25,7 +28,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- ## [1.2.0] - 2023-09-22