Skip to content

Commit

Permalink
Clarify language about one-vs-rest for classification metrics (#2051)
Browse files Browse the repository at this point in the history
* docs improve

* suggestions

(cherry picked from commit d4ab66c)
  • Loading branch information
SkafteNicki authored and Borda committed Sep 9, 2023
1 parent b2cc42b commit 39c40e3
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve):
multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5
corresponds to random guessing.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric. By default the reported metric is then the average over all classes, but this behavior can be changed
by setting the ``average`` argument.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve):
where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
equivalent to the area under the precision-recall curve (AUPRC).
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric. By default the reported metric is then the average over all classes, but this behavior can be changed
by setting the ``average`` argument.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class MulticlassPrecisionRecallCurve(Metric):
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the
tradeoff between the two values can been seen.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/recall_fixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve):
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for
a given precision level.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at
different thresholds, such that the tradeoff between the two values can be seen.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/specificity_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class MulticlassSpecificityAtSensitivity(MulticlassPrecisionRecallCurve):
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
find the specificity for a given sensitivity level.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by
this metric.
Accepts the following input tensors:
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
Expand Down

0 comments on commit 39c40e3

Please sign in to comment.