Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify input shape expectation in classification for samplewise reduction #2119

Merged
merged 5 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class BinaryAccuracy(BinaryStatScores):
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.

Additional dimension ``...`` will be flattened into the batch dimension.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand Down Expand Up @@ -176,6 +177,9 @@ class MulticlassAccuracy(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -325,6 +329,9 @@ class MultilabelAccuracy(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ class MulticlassExactMatch(Metric):
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mcem`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:

- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of labels
multidim_average:
Expand Down Expand Up @@ -206,14 +208,16 @@ class MultilabelExactMatch(Metric):
sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mlem`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:

- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
20 changes: 18 additions & 2 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class BinaryFBetaScore(BinaryStatScores):
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` consisting of
a scalar value per sample.

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand Down Expand Up @@ -202,7 +205,6 @@ class MulticlassFBetaScore(MulticlassStatScores):
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mcfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -218,6 +220,9 @@ class MulticlassFBetaScore(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
num_classes: Integer specifying the number of classes
Expand Down Expand Up @@ -382,7 +387,6 @@ class MultilabelFBetaScore(MultilabelStatScores):
per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mlfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -398,6 +402,9 @@ class MultilabelFBetaScore(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
num_labels: Integer specifying the number of labels
Expand Down Expand Up @@ -566,6 +573,9 @@ class BinaryF1Score(BinaryFBetaScore):
- If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -706,6 +716,9 @@ class MulticlassF1Score(MulticlassFBetaScore):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
preds: Tensor with predictions
target: Tensor with true labels
Expand Down Expand Up @@ -876,6 +889,9 @@ class MultilabelF1Score(MultilabelFBetaScore):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)```

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class BinaryHammingDistance(BinaryStatScores):
- If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a
scalar value per sample.

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -171,7 +174,6 @@ class MulticlassHammingDistance(MulticlassStatScores):
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mchd`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -187,6 +189,9 @@ class MulticlassHammingDistance(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -324,7 +329,6 @@ class MultilabelHammingDistance(MultilabelStatScores):
``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mlhd`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -340,6 +344,9 @@ class MultilabelHammingDistance(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
18 changes: 18 additions & 0 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class BinaryPrecision(BinaryStatScores):
value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a
scalar value per sample.

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -187,6 +190,9 @@ class MulticlassPrecision(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -340,6 +346,9 @@ class MultilabelPrecision(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down Expand Up @@ -479,6 +488,9 @@ class BinaryRecall(BinaryStatScores):
value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of
a scalar value per sample.

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -608,6 +620,9 @@ class MulticlassRecall(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -760,6 +775,9 @@ class MultilabelRecall(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
10 changes: 9 additions & 1 deletion src/torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class BinarySpecificity(BinaryStatScores):
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value
per sample.

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -174,6 +177,9 @@ class MulticlassSpecificity(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -307,7 +313,6 @@ class MultilabelSpecificity(MultilabelStatScores):
per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``


As output to ``forward`` and ``compute`` the metric returns the following output:

- ``mls`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
Expand All @@ -323,6 +328,9 @@ class MultilabelSpecificity(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
43 changes: 29 additions & 14 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ class BinaryStatScores(_AbstractStatScores):
to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape
depends on the ``multidim_average`` parameter:

- If ``multidim_average`` is set to ``global``, the shape will be ``(5,)``
- If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)``
- If ``multidim_average`` is set to ``global``, the shape will be ``(5,)``
- If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand Down Expand Up @@ -208,12 +211,18 @@ class MulticlassStatScores(_AbstractStatScores):
to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape
depends on ``average`` and ``multidim_average`` parameters:

- If ``multidim_average`` is set to ``global``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``
- If ``multidim_average`` is set to ``samplewise``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``
- If ``multidim_average`` is set to ``global``:

- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``

- If ``multidim_average`` is set to ``samplewise``:

- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_classes: Integer specifying the number of classes
Expand Down Expand Up @@ -352,12 +361,18 @@ class MultilabelStatScores(_AbstractStatScores):
to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape
depends on ``average`` and ``multidim_average`` parameters:

- If ``multidim_average`` is set to ``global``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``
- If ``multidim_average`` is set to ``samplewise``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``
- If ``multidim_average`` is set to ``global``:

- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``

- If ``multidim_average`` is set to ``samplewise``:

- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``

If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.

Args:
num_labels: Integer specifying the number of labels
Expand Down
Loading