Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate num_outputs in R2 because it is no longer needed #2705

Merged
merged 7 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- update `InfoLM` class to dynamically set `higher_is_better` ([#2674](https://github.com/Lightning-AI/torchmetrics/pull/2674))


### Deprecated

- Deprecated `num_outputs` in `R2Score` ([#2705](https://github.com/Lightning-AI/torchmetrics/pull/2705))


### Removed

-
Expand Down
47 changes: 31 additions & 16 deletions src/torchmetrics/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -65,23 +65,32 @@ class R2Score(Metric):
* ``'variance_weighted'`` scores are weighted by their individual variances
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

.. warning::
Argument ``num_outputs`` in ``R2Score`` has been deprecated because it is no longer necessary and will be
removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape
of the input tensors.

Raises:
ValueError:
If ``adjusted`` parameter is not an integer larger or equal to 0.
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.

Example:
Example (single output):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)

>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
Example (multioutput):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])

Expand All @@ -100,14 +109,20 @@ class R2Score(Metric):

def __init__(
self,
num_outputs: int = 1,
num_outputs: Optional[int] = None,
adjusted: int = 0,
multioutput: str = "uniform_average",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.num_outputs = num_outputs
if num_outputs is not None:
rank_zero_warn(
"Argument `num_outputs` in `R2Score` has been deprecated because it is no longer necessary and will be"
"removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape"
"of the input tensors.",
DeprecationWarning,
)

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError("`adjusted` parameter should be an integer larger or equal to 0.")
Expand All @@ -120,19 +135,19 @@ def __init__(
)
self.multioutput = multioutput

self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("residual", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target)

self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total
self.sum_squared_error = self.sum_squared_error + sum_squared_error
self.sum_error = self.sum_error + sum_error
self.residual = self.residual + residual
self.total = self.total + total

def compute(self) -> Tensor:
"""Compute r2 score over the metric states."""
Expand Down
8 changes: 7 additions & 1 deletion tests/unittests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch
from torchmetrics.functional.regression import kl_divergence
from torchmetrics.regression import KLDivergence
from torchmetrics.regression import KLDivergence, R2Score


def test_deprecated_kl_divergence_input_order():
Expand All @@ -14,3 +14,9 @@ def test_deprecated_kl_divergence_input_order():

with pytest.deprecated_call(match="The input order and naming in metric `KLDivergence` is set to be deprecated.*"):
KLDivergence()


def test_deprecated_r2_score_num_outputs():
"""Ensure that the deprecated num_outputs argument in R2Score raises a warning."""
with pytest.deprecated_call(match="Argument `num_outputs` in `R2Score` has been deprecated"):
R2Score(num_outputs=2)
Loading