Skip to content

Commit

Permalink
fix doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 12, 2024
1 parent 184d910 commit 805543d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/torchmetrics/clustering/cluster_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ClusterAccuracy(Metric):
>>> target = torch.tensor([1, 1, 0, 0])
>>> metric = ClusterAccuracy(num_classes=2)
>>> metric(preds, target)
tensor(1.0000)
tensor(1.)
"""

Expand Down Expand Up @@ -123,7 +123,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import ClusterAccuracy
>>> metric = ClusterAccuracy()
>>> metric = ClusterAccuracy(num_classes=4)
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
Expand All @@ -133,7 +133,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import ClusterAccuracy
>>> metric = ClusterAccuracy()
>>> metric = ClusterAccuracy(num_classes=4)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))))
Expand Down

0 comments on commit 805543d

Please sign in to comment.