Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 3, 2023
1 parent 4e43ff4 commit 9d7a67c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def _average_precision_compute(
weights = weights / torch.sum(weights)
else:
weights = None
return _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=average, weights=weights)
return _average_precision_compute_with_precision_recall(
precision, recall, num_classes=num_classes, average=average, weights=weights
)


def _average_precision_compute_with_precision_recall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,11 +865,7 @@ def _precision_recall_curve_compute_multi_class(
num_classes=1,
)
if target.ndim > 1:
prc_args.update(
dict(
target=target[:, cls]
)
)
prc_args.update(dict(target=target[:, cls]))
res = precision_recall_curve(**prc_args)
precision.append(res[0])
recall.append(res[1])
Expand Down Expand Up @@ -922,7 +918,9 @@ def _precision_recall_curve_compute(
if num_classes == 1:
if pos_label is None:
pos_label = 1
return _precision_recall_curve_compute_single_class(preds, target, pos_label=pos_label, sample_weights=sample_weights)
return _precision_recall_curve_compute_single_class(
preds, target, pos_label=pos_label, sample_weights=sample_weights
)
return _precision_recall_curve_compute_multi_class(preds, target, num_classes=num_classes)


Expand Down
18 changes: 16 additions & 2 deletions src/torchmetrics/functional/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,24 @@ def roc(
return binary_roc(preds, target, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
if task == "multiclass":
assert isinstance(num_classes, int)
return multiclass_roc(preds, target, num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
return multiclass_roc(
preds,
target,
num_classes=num_classes,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=validate_args,
)
if task == "multilabel":
assert isinstance(num_labels, int)
return multilabel_roc(preds, target, num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
return multilabel_roc(
preds,
target,
num_labels=num_labels,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=validate_args,
)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)

0 comments on commit 9d7a67c

Please sign in to comment.