Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-romero committed Feb 7, 2025
1 parent abf31e1 commit 3839df6
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def logprobs_from_logits(logits, labels):
output = logprobs_from_logits_flash_attn(logits, labels)
output = output.view(*batch_dim)
else:
output = logprobs_from_logits_naive(logits, labels)
output = logprobs_from_logits_v2(logits, labels)
return output


Expand All @@ -75,15 +75,14 @@ def logprobs_from_logits_naive(logits, labels):
return logpy


def logprobs_of_labels_v2(logits: torch.FloatTensor, labels):
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
"""
if logits.dtype in [torch.float32, torch.float64]:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack(
[torch.logsumexp(l, dim=-1) for l in logits] # loop to reduce peak mem consumption
)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
Expand Down

0 comments on commit 3839df6

Please sign in to comment.