Skip to content

Commit

Permalink
skip doctests on missing package
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 12, 2024
1 parent 9090b36 commit 4c35de9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/torchmetrics/clustering/cluster_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
from torchmetrics.functional.classification import multiclass_confusion_matrix
from torchmetrics.functional.clustering.cluster_accuracy import _cluster_accuracy_compute
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_LINEAR_ASSIGNMENT_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ClusterAccuracy.plot"]

if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
__doctest_skip__ = ["ClusterAccuracy"]


class ClusterAccuracy(Metric):
r"""Compute `Cluster Accuracy`_ between predicted and target clusters.
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/clustering/cluster_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

from torchmetrics.functional.classification import multiclass_confusion_matrix
from torchmetrics.functional.clustering.utils import check_cluster_labels
from torchmetrics.utilities.imports import _TORCH_LINEAR_ASSIGNMENT_AVAILABLE

if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
__doctest_skip__ = ["cluster_accuracy"]


def _cluster_accuracy_compute(confmat: Tensor) -> Tensor:
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
_SCIPI_AVAILABLE = RequirementCache("scipy")
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")
_TORCH_LINEAR_ASSIGNMENT_AVAILABLE = RequirementCache("torch_linear_assignment")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None

0 comments on commit 4c35de9

Please sign in to comment.