From 8782ca45204a083b3150271895751fd8fe0046d6 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 14 Oct 2024 19:40:24 +0900 Subject: [PATCH] modify import (#3293) --- ignite/metrics/precision_recall_curve.py | 10 ++++++---- tests/ignite/metrics/test_precision_recall_curve.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ignite/metrics/precision_recall_curve.py b/ignite/metrics/precision_recall_curve.py index 5b9ece27545..4a8fde1fe25 100644 --- a/ignite/metrics/precision_recall_curve.py +++ b/ignite/metrics/precision_recall_curve.py @@ -8,10 +8,7 @@ def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]: - try: - from sklearn.metrics import precision_recall_curve - except ImportError: - raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.") + from sklearn.metrics import precision_recall_curve y_true = y_targets.cpu().numpy() y_pred = y_preds.cpu().numpy() @@ -83,6 +80,11 @@ def __init__( device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: + try: + from sklearn.metrics import precision_recall_curve # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scikit-learn to be installed.") + super(PrecisionRecallCurve, self).__init__( precision_recall_curve_compute_fn, # type: ignore[arg-type] output_transform=output_transform, diff --git a/tests/ignite/metrics/test_precision_recall_curve.py b/tests/ignite/metrics/test_precision_recall_curve.py index bc7770e9e2b..8c448d4a2da 100644 --- a/tests/ignite/metrics/test_precision_recall_curve.py +++ b/tests/ignite/metrics/test_precision_recall_curve.py @@ -21,7 +21,7 @@ def mock_no_sklearn(): def test_no_sklearn(mock_no_sklearn): - with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."): + with pytest.raises(ModuleNotFoundError, match=r"This module requires scikit-learn to be installed."): y = torch.tensor([1, 1]) pr_curve = PrecisionRecallCurve() pr_curve.update((y, y))