Skip to content

Commit

Permalink
assert to make sure flash-attn>=2.4.3
Browse files Browse the repository at this point in the history
  • Loading branch information
bebetterest committed Feb 5, 2025
1 parent ddfc04e commit cc68a04
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def logprobs_from_logits(logits, labels):


def logprobs_from_logits_flash_attn(logits, labels):
output = -cross_entropy_loss(logits, labels)[0]
return output
output = cross_entropy_loss(logits, labels)
assert isinstance(output, tuple), "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
return -output[0]


def logprobs_from_logits_naive(logits, labels):
Expand Down

0 comments on commit cc68a04

Please sign in to comment.