diff --git a/src/anomalib/metrics/optimal_f1.py b/src/anomalib/metrics/optimal_f1.py index c7f983c696..d9d4537973 100644 --- a/src/anomalib/metrics/optimal_f1.py +++ b/src/anomalib/metrics/optimal_f1.py @@ -7,7 +7,8 @@ import torch from torchmetrics import Metric -from torchmetrics.classification import BinaryPrecisionRecallCurve + +from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve logger = logging.getLogger(__name__) diff --git a/tests/unit/metrics/test_optimal_f1.py b/tests/unit/metrics/test_optimal_f1.py index 4f6aaa0126..8dcece255d 100644 --- a/tests/unit/metrics/test_optimal_f1.py +++ b/tests/unit/metrics/test_optimal_f1.py @@ -42,4 +42,10 @@ def test_optimal_f1_raw() -> None: metric.update(preds, labels) assert metric.compute() == 1.0 - assert metric.threshold == 0.5 + assert metric.threshold == 0.0 + + metric.reset() + preds = torch.tensor([-0.5, 0.0, 1.0, 2.0, -0.1]) + metric.update(preds, labels) + assert metric.compute() == torch.tensor(1.0) + assert metric.threshold == -0.1