diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 9645c7b1097..dcc5bfde5f7 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -460,7 +460,9 @@ def _average_precision_compute( weights = weights / torch.sum(weights) else: weights = None - return _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=average, weights=weights) + return _average_precision_compute_with_precision_recall( + precision, recall, num_classes=num_classes, average=average, weights=weights + ) def _average_precision_compute_with_precision_recall( diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 8683ee4dbfe..aa63f03b950 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -865,11 +865,7 @@ def _precision_recall_curve_compute_multi_class( num_classes=1, ) if target.ndim > 1: - prc_args.update( - dict( - target=target[:, cls] - ) - ) + prc_args.update(dict(target=target[:, cls])) res = precision_recall_curve(**prc_args) precision.append(res[0]) recall.append(res[1]) @@ -922,7 +918,9 @@ def _precision_recall_curve_compute( if num_classes == 1: if pos_label is None: pos_label = 1 - return _precision_recall_curve_compute_single_class(preds, target, pos_label=pos_label, sample_weights=sample_weights) + return _precision_recall_curve_compute_single_class( + preds, target, pos_label=pos_label, sample_weights=sample_weights + ) return _precision_recall_curve_compute_multi_class(preds, target, num_classes=num_classes) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 9d01b2babfd..5d990e91cc8 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -621,10 +621,24 @@ def roc( return binary_roc(preds, target, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) if task == "multiclass": assert isinstance(num_classes, int) - return multiclass_roc(preds, target, num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + return multiclass_roc( + preds, + target, + num_classes=num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + ) if task == "multilabel": assert isinstance(num_labels, int) - return multilabel_roc(preds, target, num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + return multilabel_roc( + preds, + target, + num_labels=num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" )