diff --git a/CHANGELOG.md b/CHANGELOG.md index e26ef438ddb..50e9482bb51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084)) +- Added error if `NoTrainInceptionV3` is being initialized without `torch-fidelity` not being installed ([#2143](https://github.com/Lightning-AI/torchmetrics/pull/2143)) + + ### Changed - Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089)) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index e8ea1582ff5..2ba787cf017 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -55,6 +55,12 @@ def __init__( features_list: List[str], feature_extractor_weights_path: Optional[str] = None, ) -> None: + if not _TORCH_FIDELITY_AVAILABLE: + raise ModuleNotFoundError( + "NoTrainInceptionV3 module requires that `Torch-fidelity` is installed." + " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." + ) + super().__init__(name, features_list, feature_extractor_weights_path) # put into evaluation mode self.eval() diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 684803afbf5..2fe3529f4c6 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -19,12 +19,21 @@ import torch from torch.nn import Module from torch.utils.data import Dataset -from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image.fid import FrechetInceptionDistance, NoTrainInceptionV3 from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE, _TORCH_GREATER_EQUAL_1_9 torch.manual_seed(42) +@pytest.mark.skipif(_TORCH_FIDELITY_AVAILABLE, reason="test only works if torch-fidelity is not installed") +def test_no_train_network_missing_torch_fidelity(): + """Assert that NoTrainInceptionV3 raises an error if torch-fidelity is not installed.""" + with pytest.raises( + ModuleNotFoundError, match="NoTrainInceptionV3 module requires that `Torch-fidelity` is installed.*" + ): + NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"]) + + @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_9, reason="test requires torch>=1.9") @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") def test_no_train():