Skip to content

Conversation

Copilot
Copy link

@Copilot Copilot AI commented Oct 4, 2025

Description

Fixes an issue where binary_auroc and other classification metrics return incorrect results when logits are very large (>16.7 for float32, >36.7 for float64). The sigmoid function overflows to exactly 1.0 for all such values, losing the ranking information needed for AUROC calculation.

Problem

When all logits are in a large range (e.g., 97-100), naive sigmoid application causes numerical overflow:

import torch
from torchmetrics.functional.classification import binary_auroc

preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 
                      99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 
                      99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

# Before fix: returns 0.5 (random guessing)
# After fix: returns 0.9286 (correct)
binary_auroc(preds, labels)

The issue occurs because sigmoid(x) for x > 16.7 evaluates to exactly 1.0 in float32, making all predictions indistinguishable and destroying the ranking information that AUROC depends on.

Solution

Modified normalize_logits_if_needed in src/torchmetrics/utilities/compute.py to apply numerically stable sigmoid when needed:

  • Conditional stabilization: Only applies when min(logits) > 15, indicating all values will overflow
  • Preserves ranking: Subtracts max value before sigmoid, maintaining relative ordering since sigmoid is monotonic
  • Avoids artificial ties: Does not apply stabilization to mixed-range logits (e.g., -5 to 100) where it would create spurious ties
  • Backward compatible: Normal-range logits use standard sigmoid, maintaining existing behavior

Changes

  • Updated normalize_logits_if_needed() to check both min and max values before stabilization
  • Added comprehensive regression test covering:
    • Original issue case (logits 97-100)
    • Very large logits (200+)
    • Mixed range logits (-5 to 100)

Testing

All existing tests pass:

  • ✅ 92 binary AUROC tests passed
  • ✅ 30 precision-recall curve tests passed
  • ✅ 30 stat_scores tests passed
  • ✅ New regression test with 3 cases added

Closes #XXXX

Original prompt

This section details on the original issue you should resolve

<issue_title>torchmetrics.functional.classification.binary_auroc gives wrong results when logits are large</issue_title>
<issue_description>## 🐛 Bug

torchmetrics.functional.classification.binary_auroc always gives 0.5 when all logits are large. This seems to be caused by a floating point precision error with sigmoid.

To Reproduce

Code sample
import torch
import torchmetrics.functional.classification

preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

torchmetrics.functional.classification.binary_auroc(preds, labels)

Output:

tensor(0.5000)

Expected behavior

AUROC of the above example should be 0.9286, as computed by sklearn.

import sklearn.metrics

sklearn.metrics.roc_auc_score(labels, preds)

Output:

0.9285714285714286

Environment

  • Windows 11 24H2
  • Python version 3.10.11
  • TorchMetrics version 1.4.3
  • PyTorch version 2.4.1+cu124

Additional context

This appears to be a problem of floating point precision with sigmoid at line 185 in function _binary_precision_recall_curve_format in file torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py.

I extracted all the necessary functions and made a miniature binary_auroc function that uses exactly the same algorithm (works for the above example, did not test for other examples):

def binary_auroc(
    preds: torch.Tensor,
    target: torch.Tensor,
) -> torch.Tensor:
    preds = preds.sigmoid()
    print(preds)

    desc_score_indices = torch.argsort(preds, descending=True)
    preds = preds[desc_score_indices]
    target = target[desc_score_indices]
    # print(preds, target)

    # pred typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = torch.nonzero(preds[1:] - preds[:-1], as_tuple=True)[0]
    # print(distinct_value_indices)
    threshold_idxs = torch.nn.functional.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
    # print(threshold_idxs)
    tps = torch.cumsum(target, dim=0)[threshold_idxs]
    fps = 1 + threshold_idxs - tps
    # print(tps, fps)

    # Add an extra threshold position to make sure that the curve starts at (0, 0)
    tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
    fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
    tpr = tps / tps[-1]
    fpr = fps / fps[-1]
    # print(fpr, tpr)
    return torch.trapezoid(tpr, fpr, dim=-1)

Output:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

tensor(0.5000)

preds = preds.sigmoid() converts all logits to 1 as if all the logits are the same, which is not the case. The maximum magnitude of a logit must be less than 36.74 for double or 16.64 for float32 to avoid being converted to exactly 1.

Suggested fix

It's probably a good idea to scale the raw logits before sigmoid, something like below:

preds /= torch.max(torch.abs(preds)) # scales max element to 1
preds = preds.sigmoid()

All functions that applies sigmoid to raw ogits will need such a fix.</issue_description>

Comments on the Issue (you are @copilot in this section)

@Borda > ### Suggested fix > It's probably a good idea to scale the raw logits before sigmoid, something like below: > ```py > preds /= torch.max(torch.abs(preds)) # scales max element to 1 > preds = preds.sigmoid() > ``` > All functions that applies sigmoid to raw ogits will need such a fix.

That sounds reasonable to me... @SkafteNicki your thoughts? 🤔 </comment_new>

Fixes #2819

💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.


📚 Documentation preview 📚: https://torchmetrics--3283.org.readthedocs.build/en/3283/

@Copilot Copilot AI changed the title [WIP] torchmetrics.functional.classification.binary_auroc gives wrong results when logits are large Fix sigmoid overflow for large logits causing incorrect AUROC results Oct 4, 2025
@Copilot Copilot AI requested a review from Borda October 4, 2025 21:18
Copilot finished work on behalf of Borda October 4, 2025 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torchmetrics.functional.classification.binary_auroc gives wrong results when logits are large
2 participants