diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index fb3465f4010..950f8f34f9d 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -172,6 +172,11 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. By default the reported metric is then the average over all classes, but this behavior can be changed + by setting the ``average`` argument. + As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 2d29c30e05b..e579cfde173 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -171,6 +171,11 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is equivalent to the area under the precision-recall curve (AUPRC). + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. By default the reported metric is then the average over all classes, but this behavior can be changed + by setting the ``average`` argument. + As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 75117ddb956..0bbd00fa05f 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -224,6 +224,10 @@ class MulticlassPrecisionRecallCurve(Metric): The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. + As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index b97fe90a433..6f00df004af 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -180,6 +180,10 @@ class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve): This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level. + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. + As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 66370f42f34..7ae73b1fe56 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -174,6 +174,10 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen. + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. + As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor diff --git a/src/torchmetrics/classification/specificity_sensitivity.py b/src/torchmetrics/classification/specificity_sensitivity.py index 20f219df2c0..42429bf90fd 100644 --- a/src/torchmetrics/classification/specificity_sensitivity.py +++ b/src/torchmetrics/classification/specificity_sensitivity.py @@ -133,6 +133,10 @@ class MulticlassSpecificityAtSensitivity(MulticlassPrecisionRecallCurve): This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level. + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. + Accepts the following input tensors: - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each