Skip to content

Commit c996464

Browse files
zhangtemplarfacebook-github-bot
authored andcommitted
Fix NaN handling in AUPRC metric calculation
Summary: Improved the NaN handling logic in the AUPRC (Area Under Precision-Recall Curve) metric calculation to correctly handle edge cases where division by zero occurs. The changes address NaN values that arise from 0.0/0.0 divisions in both recall and precision calculations: **Recall NaN handling:** - NaNs occur on the right side of the recall tensor due to cumsum starting from the left (num_fp) but being flipped - Changed NaN replacement value from 1.0 to 0.0 to match the 0.0 value appended on the right side - Removed the conditional check since we should always handle NaNs consistently **Precision NaN handling:** - Added explicit NaN handling for precision tensor (previously missing) - NaNs in precision occur on the right side similar to recall - Replace NaNs with 0.0 (not 1.0 as stated in comment) to prevent NaN propagation in _riemann_integral - This prevents the entire AUPRC result from becoming NaN when any precision element is NaN Added detailed comments explaining the root cause of NaN values and the rationale for the chosen replacement values. Differential Revision: D86464670
1 parent 8f3ed1a commit c996464

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torchrec/metrics/auprc.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,16 @@ def _compute_auprc_helper(
6565
precision = torch.cat([precision, precision.new_ones(1)])
6666
recall = torch.cat([recall, recall.new_zeros(1)])
6767

68-
# If recalls are NaNs, set NaNs to 1.0s.
69-
if torch.isnan(recall[0]):
70-
recall = torch.nan_to_num(recall, 1.0)
68+
# nan happens with 0.0 / 0.0. For recall's case, this could happen from its right side:
69+
# num_fp is a cumsum and thus 0.0 starts from its left side. But given recall has a flip,
70+
# then those 0.0 goes to right side and thus nan.
71+
# If recalls are NaNs, set NaNs to 0.0s, as append a 0.0 on its right side above.
72+
recall = torch.nan_to_num(recall, 0.0)
73+
74+
# similar as recall, precision's nan would happen from its right side.
75+
# since we append 1.0 on its right side above, we replace nan by 1.0.
76+
# If any element in precision is Nan, _riemann_integral will return NaN.
77+
precision = torch.nan_to_num(precision, 0.0)
7178

7279
auprc = _riemann_integral(recall, precision)
7380
return auprc

0 commit comments

Comments
 (0)