Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions torchrec/metrics/auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@ def _compute_auprc_helper(
precision = torch.cat([precision, precision.new_ones(1)])
recall = torch.cat([recall, recall.new_zeros(1)])

# If recalls are NaNs, set NaNs to 1.0s.
if torch.isnan(recall[0]):
recall = torch.nan_to_num(recall, 1.0)
# nan happens with 0.0 / 0.0. For recall's case, this could happen from its right side:
# num_fp is a cumsum and thus 0.0 starts from its left side. But given recall has a flip,
# then those 0.0 goes to right side and thus nan.
# If recalls are NaNs, set NaNs to 0.0s, as append a 0.0 on its right side above.
recall = torch.nan_to_num(recall, 0.0)

# similar as recall, precision's nan would happen from its right side.
# since we append 1.0 on its right side above, we replace nan by 1.0.
# If any element in precision is Nan, _riemann_integral will return NaN.
precision = torch.nan_to_num(precision, 1.0)

auprc = _riemann_integral(recall, precision)
return auprc
Expand Down
Loading