Skip to content

Commit

Permalink
Add average to curve metrics (#2084)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 4, 2023
1 parent 2387f2a commit b12a647
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 36 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@
.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ def __init__(
)
if validate_args:
_multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index)
self.average = average
self.average = average # type: ignore[assignment]
self.validate_args = validate_args

def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds)
return _multiclass_auroc_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ def __init__(
)
if validate_args:
_multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index)
self.average = average
self.average = average # type: ignore[assignment]
self.validate_args = validate_args

def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds)
return _multiclass_average_precision_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
25 changes: 21 additions & 4 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class BinaryPrecisionRecallCurve(Metric):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -266,6 +268,15 @@ class MulticlassPrecisionRecallCurve(Metric):
- If set to a 1D `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -314,15 +325,17 @@ def __init__(
self,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if validate_args:
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average)

self.num_classes = num_classes
self.average = average
self.ignore_index = ignore_index
self.validate_args = validate_args

Expand All @@ -344,9 +357,11 @@ def update(self, preds: Tensor, target: Tensor) -> None:
if self.validate_args:
_multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
preds, target, _ = _multiclass_precision_recall_curve_format(
preds, target, self.num_classes, self.thresholds, self.ignore_index
preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average
)
state = _multiclass_precision_recall_curve_update(
preds, target, self.num_classes, self.thresholds, self.average
)
state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds)
if isinstance(state, Tensor):
self.confmat += state
else:
Expand All @@ -356,7 +371,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds)
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average)

def plot(
self,
Expand Down Expand Up @@ -456,6 +471,8 @@ class MultilabelPrecisionRecallCurve(Metric):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Expand Down
15 changes: 14 additions & 1 deletion src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class BinaryROC(BinaryPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -229,6 +231,15 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -276,7 +287,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_roc_compute(state, self.num_classes, self.thresholds)
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average)

def plot(
self,
Expand Down Expand Up @@ -381,6 +392,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _safe_divide, interp
from torchmetrics.utilities.data import _bincount, _cumsum
from torchmetrics.utilities.enums import ClassificationTask

Expand Down Expand Up @@ -363,6 +363,7 @@ def _multiclass_precision_recall_curve_arg_validation(
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
average: Optional[Literal["micro", "macro"]] = None,
) -> None:
"""Validate non tensor input.
Expand All @@ -373,6 +374,8 @@ def _multiclass_precision_recall_curve_arg_validation(
"""
if not isinstance(num_classes, int) or num_classes < 2:
raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
if average not in (None, "micro", "macro"):
raise ValueError(f"Expected argument `average` to be one of None, 'micro' or 'macro', but got {average}")
_binary_precision_recall_curve_arg_validation(thresholds, ignore_index)


Expand Down Expand Up @@ -423,6 +426,7 @@ def _multiclass_precision_recall_curve_format(
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
average: Optional[Literal["micro", "macro"]] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""Convert all input to the right format.
Expand All @@ -443,6 +447,10 @@ def _multiclass_precision_recall_curve_format(
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)

if average == "micro":
preds = preds.flatten()
target = torch.nn.functional.one_hot(target, num_classes=num_classes).flatten()

thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds

Expand All @@ -452,6 +460,7 @@ def _multiclass_precision_recall_curve_update(
target: Tensor,
num_classes: int,
thresholds: Optional[Tensor],
average: Optional[Literal["micro", "macro"]] = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Return the state to calculate the pr-curve with.
Expand All @@ -461,6 +470,8 @@ def _multiclass_precision_recall_curve_update(
"""
if thresholds is None:
return preds, target
if average == "micro":
return _binary_precision_recall_curve_update(preds, target, thresholds)
if preds.numel() * num_classes <= 1_000_000:
update_fn = _multiclass_precision_recall_curve_update_vectorized
else:
Expand Down Expand Up @@ -520,13 +531,17 @@ def _multiclass_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
num_classes: int,
thresholds: Optional[Tensor],
average: Optional[Literal["micro", "macro"]] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute the final pr-curve.
If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is
original input, then we dynamically compute the binary classification curve in an iterative way.
"""
if average == "micro":
return _binary_precision_recall_curve_compute(state, thresholds)

if isinstance(state, Tensor) and thresholds is not None:
tps = state[:, :, 1, 1]
fps = state[:, :, 0, 1]
Expand All @@ -535,22 +550,45 @@ def _multiclass_precision_recall_curve_compute(
recall = _safe_divide(tps, tps + fns)
precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)])
return precision.T, recall.T, thresholds

precision_list, recall_list, threshold_list = [], [], []
for i in range(num_classes):
res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i)
precision_list.append(res[0])
recall_list.append(res[1])
threshold_list.append(res[2])
return precision_list, recall_list, threshold_list
precision = precision.T
recall = recall.T
thres = thresholds
tensor_state = True
else:
precision_list, recall_list, thres_list = [], [], []
for i in range(num_classes):
res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i)
precision_list.append(res[0])
recall_list.append(res[1])
thres_list.append(res[2])
tensor_state = False

if average == "macro":
thres = thres.repeat(num_classes) if tensor_state else torch.cat(thres_list, 0)
thres = thres.sort().values
mean_precision = precision.flatten() if tensor_state else torch.cat(precision_list, 0)
mean_precision = mean_precision.sort().values
mean_recall = torch.zeros_like(mean_precision)
for i in range(num_classes):
mean_recall += interp(
mean_precision,
precision[i] if tensor_state else precision_list[i],
recall[i] if tensor_state else recall_list[i],
)
mean_recall /= num_classes
return mean_precision, mean_recall, thres

if tensor_state:
return precision, recall, thres
return precision_list, recall_list, thres_list


def multiclass_precision_recall_curve(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
Expand Down Expand Up @@ -590,6 +628,13 @@ def multiclass_precision_recall_curve(
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Expand Down Expand Up @@ -643,13 +688,18 @@ def multiclass_precision_recall_curve(
"""
if validate_args:
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average)
_multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
preds, target, thresholds = _multiclass_precision_recall_curve_format(
preds, target, num_classes, thresholds, ignore_index
preds,
target,
num_classes,
thresholds,
ignore_index,
average,
)
state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds)
state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average)
return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds, average)


def _multilabel_precision_recall_curve_arg_validation(
Expand Down Expand Up @@ -892,6 +942,7 @@ def precision_recall_curve(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
Expand Down Expand Up @@ -940,7 +991,9 @@ def precision_recall_curve(
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args)
return multiclass_precision_recall_curve(
preds, target, num_classes, thresholds, average, ignore_index, validate_args
)
if task == ClassificationTask.MULTILABEL:
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
Expand Down
Loading

0 comments on commit b12a647

Please sign in to comment.