diff --git a/torchrec/metrics/auprc.py b/torchrec/metrics/auprc.py index ed99417d2..d43c74fe0 100644 --- a/torchrec/metrics/auprc.py +++ b/torchrec/metrics/auprc.py @@ -66,8 +66,9 @@ def _compute_auprc_helper( recall = torch.cat([recall, recall.new_zeros(1)]) # If recalls are NaNs, set NaNs to 1.0s. + # 1.0 / 0.0 is inf and 0.0 / 0.0 is NaN. We need to fix both if torch.isnan(recall[0]): - recall = torch.nan_to_num(recall, 1.0) + recall = torch.nan_to_num(recall, 1.0, 1.0, 0.0) auprc = _riemann_integral(recall, precision) return auprc