Skip to content

Commit

Permalink
fix remaining issues
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 12, 2024
1 parent ae10d90 commit b9bb1b0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/clustering/cluster_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch_linear_assignment import batch_linear_assignment

from torchmetrics.functional.classification import multiclass_confusion_matrix
from torchmetrics.functional.clustering.utils import check_cluster_labels


def _cluster_accuracy_compute(confmat: Tensor) -> Tensor:
Expand Down Expand Up @@ -48,5 +49,6 @@ def cluster_accuracy(preds: Tensor, target: Tensor, num_classes: int) -> Tensor:
tensor(1.000)
"""
check_cluster_labels(preds, target)
confmat = multiclass_confusion_matrix(preds, target, num_classes=num_classes)
return _cluster_accuracy_compute(confmat)
9 changes: 5 additions & 4 deletions tests/unittests/clustering/test_cluster_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TestAdjustedMutualInfoScore(MetricTester):
"""Test class for `AdjustedMutualInfoScore` metric."""

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_cluster_accuracy_score(self, preds, target, ddp):
def test_cluster_accuracy(self, preds, target, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp=ddp,
Expand All @@ -47,13 +47,14 @@ def test_cluster_accuracy_score(self, preds, target, ddp):
metric_args={"num_classes": NUM_CLASSES},
)

def test_cluster_accuracy_score_functional(self, preds, target):
def test_cluster_accuracy_functional(self, preds, target):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=cluster_accuracy,
reference_metric=clustering_accuracy_score,
metric_args={"num_classes": NUM_CLASSES},
)


Expand All @@ -66,8 +67,8 @@ def test_cluster_accuracy_sanity_check():
assert torch.allclose(res, torch.tensor(1.0))


def test_cluster_accuracy_score_functional_raises_invalid_task():
def test_cluster_accuracy_functional_raises_invalid_task():
"""Check that metric rejects continuous-valued inputs."""
preds, target = _float_inputs_extrinsic
with pytest.raises(ValueError, match=r"Expected *"):
cluster_accuracy(preds, target)
cluster_accuracy(preds, target, num_classes=NUM_CLASSES)

0 comments on commit b9bb1b0

Please sign in to comment.