diff --git a/CHANGELOG.md b/CHANGELOG.md index 2600ee1526e..ff0dd083d3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084)) ### Changed diff --git a/docs/source/links.rst b/docs/source/links.rst index 3685eb8ae1a..8c8dc593fe3 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -164,3 +164,4 @@ .. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html .. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index .. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score +.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index e8973ca226b..428c961ec1d 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -266,13 +266,15 @@ def __init__( ) if validate_args: _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) - self.average = average + self.average = average # type: ignore[assignment] self.validate_args = validate_args def compute(self) -> Tensor: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat - return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds) + return _multiclass_auroc_compute( + state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type] + ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index ae26d212a51..569c00b73d8 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -264,13 +264,15 @@ def __init__( ) if validate_args: _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) - self.average = average + self.average = average # type: ignore[assignment] self.validate_args = validate_args def compute(self) -> Tensor: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat - return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds) + return _multiclass_average_precision_compute( + state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type] + ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 25f4025f52f..9996cfd683e 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -103,6 +103,8 @@ class BinaryPrecisionRecallCurve(Metric): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -266,6 +268,15 @@ class MulticlassPrecisionRecallCurve(Metric): - If set to a 1D `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -314,15 +325,17 @@ def __init__( self, num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) self.num_classes = num_classes + self.average = average self.ignore_index = ignore_index self.validate_args = validate_args @@ -344,9 +357,11 @@ def update(self, preds: Tensor, target: Tensor) -> None: if self.validate_args: _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target, _ = _multiclass_precision_recall_curve_format( - preds, target, self.num_classes, self.thresholds, self.ignore_index + preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average + ) + state = _multiclass_precision_recall_curve_update( + preds, target, self.num_classes, self.thresholds, self.average ) - state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds) if isinstance(state, Tensor): self.confmat += state else: @@ -356,7 +371,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat - return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds) + return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, @@ -456,6 +471,8 @@ class MultilabelPrecisionRecallCurve(Metric): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index a76c3ebc02a..a391cd2046f 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -87,6 +87,8 @@ class BinaryROC(BinaryPrecisionRecallCurve): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -229,6 +231,15 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -276,7 +287,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat - return _multiclass_roc_compute(state, self.num_classes, self.thresholds) + return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, @@ -381,6 +392,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index c9b576703fe..64958267737 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -20,7 +20,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _safe_divide, interp from torchmetrics.utilities.data import _bincount, _cumsum from torchmetrics.utilities.enums import ClassificationTask @@ -363,6 +363,7 @@ def _multiclass_precision_recall_curve_arg_validation( num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ) -> None: """Validate non tensor input. @@ -373,6 +374,8 @@ def _multiclass_precision_recall_curve_arg_validation( """ if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if average not in (None, "micro", "macro"): + raise ValueError(f"Expected argument `average` to be one of None, 'micro' or 'macro', but got {average}") _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -423,6 +426,7 @@ def _multiclass_precision_recall_curve_format( num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. @@ -443,6 +447,10 @@ def _multiclass_precision_recall_curve_format( if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) + if average == "micro": + preds = preds.flatten() + target = torch.nn.functional.one_hot(target, num_classes=num_classes).flatten() + thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -452,6 +460,7 @@ def _multiclass_precision_recall_curve_update( target: Tensor, num_classes: int, thresholds: Optional[Tensor], + average: Optional[Literal["micro", "macro"]] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. @@ -461,6 +470,8 @@ def _multiclass_precision_recall_curve_update( """ if thresholds is None: return preds, target + if average == "micro": + return _binary_precision_recall_curve_update(preds, target, thresholds) if preds.numel() * num_classes <= 1_000_000: update_fn = _multiclass_precision_recall_curve_update_vectorized else: @@ -520,6 +531,7 @@ def _multiclass_precision_recall_curve_compute( state: Union[Tensor, Tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], + average: Optional[Literal["micro", "macro"]] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the final pr-curve. @@ -527,6 +539,9 @@ def _multiclass_precision_recall_curve_compute( original input, then we dynamically compute the binary classification curve in an iterative way. """ + if average == "micro": + return _binary_precision_recall_curve_compute(state, thresholds) + if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -535,15 +550,37 @@ def _multiclass_precision_recall_curve_compute( recall = _safe_divide(tps, tps + fns) precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)]) recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)]) - return precision.T, recall.T, thresholds - - precision_list, recall_list, threshold_list = [], [], [] - for i in range(num_classes): - res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i) - precision_list.append(res[0]) - recall_list.append(res[1]) - threshold_list.append(res[2]) - return precision_list, recall_list, threshold_list + precision = precision.T + recall = recall.T + thres = thresholds + tensor_state = True + else: + precision_list, recall_list, thres_list = [], [], [] + for i in range(num_classes): + res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i) + precision_list.append(res[0]) + recall_list.append(res[1]) + thres_list.append(res[2]) + tensor_state = False + + if average == "macro": + thres = thres.repeat(num_classes) if tensor_state else torch.cat(thres_list, 0) + thres = thres.sort().values + mean_precision = precision.flatten() if tensor_state else torch.cat(precision_list, 0) + mean_precision = mean_precision.sort().values + mean_recall = torch.zeros_like(mean_precision) + for i in range(num_classes): + mean_recall += interp( + mean_precision, + precision[i] if tensor_state else precision_list[i], + recall[i] if tensor_state else recall_list[i], + ) + mean_recall /= num_classes + return mean_precision, mean_recall, thres + + if tensor_state: + return precision, recall, thres + return precision_list, recall_list, thres_list def multiclass_precision_recall_curve( @@ -551,6 +588,7 @@ def multiclass_precision_recall_curve( target: Tensor, num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -590,6 +628,13 @@ def multiclass_precision_recall_curve( - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. @@ -643,13 +688,18 @@ def multiclass_precision_recall_curve( """ if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, target, num_classes, thresholds, ignore_index + preds, + target, + num_classes, + thresholds, + ignore_index, + average, ) - state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) - return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average) + return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds, average) def _multilabel_precision_recall_curve_arg_validation( @@ -892,6 +942,7 @@ def precision_recall_curve( thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -940,7 +991,9 @@ def precision_recall_curve( if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args) + return multiclass_precision_recall_curve( + preds, target, num_classes, thresholds, average, ignore_index, validate_args + ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 65d2c16dc87..d61b920aa9b 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -33,7 +33,7 @@ _multilabel_precision_recall_curve_update, ) from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _safe_divide, interp from torchmetrics.utilities.enums import ClassificationTask @@ -163,7 +163,11 @@ def _multiclass_roc_compute( state: Union[Tensor, Tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], + average: Optional[Literal["micro", "macro"]] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if average == "micro": + return _binary_roc_compute(state, thresholds, pos_label=1) + if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -172,14 +176,32 @@ def _multiclass_roc_compute( tpr = _safe_divide(tps, tps + fns).flip(0).T fpr = _safe_divide(fps, fps + tns).flip(0).T thres = thresholds.flip(0) + tensor_state = True else: - fpr, tpr, thres = [], [], [] # type: ignore[assignment] + fpr_list, tpr_list, thres_list = [], [], [] for i in range(num_classes): res = _binary_roc_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i) - fpr.append(res[0]) - tpr.append(res[1]) - thres.append(res[2]) - return fpr, tpr, thres + fpr_list.append(res[0]) + tpr_list.append(res[1]) + thres_list.append(res[2]) + tensor_state = False + + if average == "macro": + thres = thres.repeat(num_classes) if tensor_state else torch.cat(thres_list, dim=0) + thres = thres.sort(descending=True).values + mean_fpr = fpr.flatten() if tensor_state else torch.cat(fpr_list, dim=0) + mean_fpr = mean_fpr.sort().values + mean_tpr = torch.zeros_like(mean_fpr) + for i in range(num_classes): + mean_tpr += interp( + mean_fpr, fpr[i] if tensor_state else fpr_list[i], tpr[i] if tensor_state else tpr_list[i] + ) + mean_tpr /= num_classes + return mean_fpr, mean_tpr, thres + + if tensor_state: + return fpr, tpr, thres + return fpr_list, tpr_list, thres_list def multiclass_roc( @@ -187,6 +209,7 @@ def multiclass_roc( target: Tensor, num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -229,6 +252,13 @@ def multiclass_roc( - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. @@ -282,13 +312,18 @@ def multiclass_roc( """ if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, target, num_classes, thresholds, ignore_index + preds, + target, + num_classes, + thresholds, + ignore_index, + average, ) - state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) - return _multiclass_roc_compute(state, num_classes, thresholds) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average) + return _multiclass_roc_compute(state, num_classes, thresholds, average) def _multilabel_roc_compute( @@ -440,6 +475,7 @@ def roc( thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -506,7 +542,7 @@ def roc( if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) + return multiclass_roc(preds, target, num_classes, thresholds, average, ignore_index, validate_args) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 0ae6de91cbd..c8cb48a8cdb 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -131,3 +131,29 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: """ x, y = _auc_format_inputs(x, y) return _auc_compute(x, y, reorder=reorder) + + +def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: + """One-dimensional linear interpolation for monotonically increasing sample points. + + Returns the one-dimensional piecewise linear interpolant to a function with + given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. + + Adjusted version of this https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 + + Args: + x: the :math:`x`-coordinates at which to evaluate the interpolated values. + xp: the :math:`x`-coordinates of the data points, must be increasing. + fp: the :math:`y`-coordinates of the data points, same length as `xp`. + + Returns: + the interpolated values, same size as `x`. + + """ + m = _safe_divide(fp[1:] - fp[:-1], xp[1:] - xp[:-1]) + b = fp[:-1] - (m * xp[:-1]) + + indices = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 + indices = torch.clamp(indices, 0, len(m) - 1) + + return m[indices] * x + b[indices] diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 834c5be5161..7167c9711bb 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -280,6 +280,25 @@ def test_multiclass_error_on_wrong_dtypes(self, inputs): with pytest.raises(ValueError, match="Expected `preds` to be a float tensor, but got.*"): multiclass_precision_recall_curve(preds[0].long(), target[0], num_classes=NUM_CLASSES) + @pytest.mark.parametrize("average", ["macro", "micro"]) + @pytest.mark.parametrize("thresholds", [None, 100]) + def test_multiclass_average(self, inputs, average, thresholds): + """Test that the average argument works as expected.""" + preds, target = inputs + output = multiclass_precision_recall_curve( + preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=average + ) + assert all(isinstance(o, torch.Tensor) for o in output) + none_output = multiclass_precision_recall_curve( + preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=None + ) + if average == "macro": + assert len(output[0]) == len(none_output[0][0]) * NUM_CLASSES + assert len(output[1]) == len(none_output[1][0]) * NUM_CLASSES + assert ( + len(output[2]) == (len(none_output[2][0]) if thresholds is None else len(none_output[2])) * NUM_CLASSES + ) + def _sklearn_precision_recall_curve_multilabel(preds, target, ignore_index=None): precision, recall, thresholds = [], [], [] diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 41f5f1d9253..4829b078ec3 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -251,6 +251,21 @@ def test_multiclass_roc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(r1[i], r2[i]) assert torch.allclose(t1[i], t2) + @pytest.mark.parametrize("average", ["macro", "micro"]) + @pytest.mark.parametrize("thresholds", [None, 100]) + def test_multiclass_average(self, inputs, average, thresholds): + """Test that the average argument works as expected.""" + preds, target = inputs + output = multiclass_roc(preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=average) + assert all(isinstance(o, torch.Tensor) for o in output) + none_output = multiclass_roc(preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=None) + if average == "macro": + assert len(output[0]) == len(none_output[0][0]) * NUM_CLASSES + assert len(output[1]) == len(none_output[1][0]) * NUM_CLASSES + assert ( + len(output[2]) == (len(none_output[2][0]) if thresholds is None else len(none_output[2])) * NUM_CLASSES + ) + def _sklearn_roc_multilabel(preds, target, ignore_index=None): fpr, tpr, thresholds = [], [], []