Skip to content

Commit

Permalink
top_k None
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 3, 2023
1 parent 9f714e4 commit 4e43ff4
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _auroc_compute(
num_classes = class_observed.sum()
if num_classes == 1:
raise ValueError("Found 1 non-empty class in `multiclass` AUROC calculation")
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)
fpr, tpr, _ = roc(preds, target, "binary", num_classes, pos_label, sample_weights)

# calculate standard roc auc score
if max_fpr is None or max_fpr == 1:
Expand Down
11 changes: 5 additions & 6 deletions src/torchmetrics/functional/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _average_precision_compute(
target = target.flatten()
num_classes = 1

precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes=num_classes, pos_label=pos_label)
if average == "weighted":
if preds.ndim == target.ndim and target.ndim > 1:
weights = target.sum(dim=0).float()
Expand All @@ -460,7 +460,7 @@ def _average_precision_compute(
weights = weights / torch.sum(weights)
else:
weights = None
return _average_precision_compute_with_precision_recall(precision, recall, num_classes, average, weights)
return _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=average, weights=weights)


def _average_precision_compute_with_precision_recall(
Expand Down Expand Up @@ -488,10 +488,9 @@ def _average_precision_compute_with_precision_recall(
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes)
>>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes)
>>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None)
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes=5)
>>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes=num_classes)
>>> _average_precision_compute_with_precision_recall(precision, recall, num_classes=num_classes, average=None)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,6 @@ def _precision_recall_curve_compute_multi_class(
preds: Tensor,
target: Tensor,
num_classes: int,
sample_weights: Optional[Sequence] = None,
) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
"""Computes precision-recall pairs for multiclass inputs."""

Expand All @@ -864,14 +863,11 @@ def _precision_recall_curve_compute_multi_class(
preds=preds_cls,
target=target,
num_classes=1,
pos_label=cls,
sample_weights=sample_weights,
)
if target.ndim > 1:
prc_args.update(
dict(
target=target[:, cls],
pos_label=1,
target=target[:, cls]
)
)
res = precision_recall_curve(**prc_args)
Expand Down Expand Up @@ -911,9 +907,8 @@ def _precision_recall_curve_compute(
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes)
>>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes)
>>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes=4)
>>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes=num_classes)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
Expand All @@ -927,8 +922,8 @@ def _precision_recall_curve_compute(
if num_classes == 1:
if pos_label is None:
pos_label = 1
return _precision_recall_curve_compute_single_class(preds, target, pos_label, sample_weights)
return _precision_recall_curve_compute_multi_class(preds, target, num_classes, sample_weights)
return _precision_recall_curve_compute_single_class(preds, target, pos_label=pos_label, sample_weights=sample_weights)
return _precision_recall_curve_compute_multi_class(preds, target, num_classes=num_classes)


def precision_recall_curve(
Expand All @@ -942,7 +937,7 @@ def precision_recall_curve(
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
r"""Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values
evaluated at different thresholds, such that the tradeoff between the two values can been seen.
evaluated at different thresholds, such that the tradeoff between the two values can be seen.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of
Expand Down
22 changes: 8 additions & 14 deletions src/torchmetrics/functional/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def _roc_compute_multi_class(
preds: Tensor,
target: Tensor,
num_classes: int,
sample_weights: Optional[Sequence] = None,
) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
"""Computes Receiver Operating Characteristic for multi class inputs.
Expand All @@ -491,16 +490,13 @@ def _roc_compute_multi_class(
for cls in range(num_classes):
if preds.shape == target.shape:
target_cls = target[:, cls]
pos_label = 1
else:
target_cls = target
pos_label = cls
res = roc(
task="multiclass",
preds=preds[:, cls],
target=target_cls,
num_classes=1,
pos_label=pos_label,
sample_weights=sample_weights,
)
fpr.append(res[0])
tpr.append(res[1])
Expand All @@ -514,7 +510,6 @@ def _roc_compute(
target: Tensor,
num_classes: int,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Computes Receiver Operating Characteristic based on the number of classes.
Expand All @@ -538,9 +533,8 @@ def _roc_compute(
... [0.05, 0.05, 0.75, 0.05],
... [0.05, 0.05, 0.05, 0.75]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 4
>>> preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes)
>>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes)
>>> preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes=4)
>>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes=4)
>>> fpr
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])]
>>> tpr
Expand All @@ -556,8 +550,8 @@ def _roc_compute(
if num_classes == 1 and preds.ndim == 1: # binary
if pos_label is None:
pos_label = 1
return _roc_compute_single_class(preds, target, pos_label, sample_weights)
return _roc_compute_multi_class(preds, target, num_classes, sample_weights)
return _roc_compute_single_class(preds, target, pos_label)
return _roc_compute_multi_class(preds, target, num_classes)


def roc(
Expand Down Expand Up @@ -624,13 +618,13 @@ def roc(
tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])]
"""
if task == "binary":
return binary_roc(preds, target, thresholds, ignore_index, validate_args)
return binary_roc(preds, target, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
if task == "multiclass":
assert isinstance(num_classes, int)
return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args)
return multiclass_roc(preds, target, num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
if task == "multilabel":
assert isinstance(num_labels, int)
return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args)
return multilabel_roc(preds, target, num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def _stat_scores_update(
reduce: Optional[str] = "micro",
mdmc_reduce: Optional[str] = None,
num_classes: Optional[int] = None,
top_k: Optional[int] = 1,
top_k: Optional[int] = None,
threshold: float = 0.5,
multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -1005,12 +1005,12 @@ def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tens
Example:
>>> preds = torch.tensor([1, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3)
>>> tp, fp, tn, fn = _stat_scores_update(preds, target, top_k=1, reduce='macro', num_classes=3)
>>> _stat_scores_compute(tp, fp, tn, fn)
tensor([[0, 1, 2, 1, 1],
[1, 1, 1, 1, 2],
[1, 0, 3, 0, 1]])
>>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro')
>>> tp, fp, tn, fn = _stat_scores_update(preds, target, top_k=1, reduce='micro')
>>> _stat_scores_compute(tp, fp, tn, fn)
tensor([2, 2, 6, 2, 4])
"""
Expand Down

0 comments on commit 4e43ff4

Please sign in to comment.