Skip to content

Commit

Permalink
Fix mean reduction in cross entropy loss
Browse files Browse the repository at this point in the history
The mean reduction should reduce the s_loss to a scalar. Also, I'm not sure why division was being used here instead of multiplication by the mask, but I changed it to multiplication.
  • Loading branch information
davidbrandfonbrener authored Mar 24, 2024
1 parent 8949bd8 commit a57f380
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def cross_entropy_loss(

z_squared = logits.logsumexp(-1).pow(2)
if reduction == "mean":
z_squared = z_squared / (labels != ignore_index).mean()
z_squared = (z_squared * (labels != ignore_index)).mean()
elif reduction == "sum":
z_squared = (z_squared * (labels != ignore_index)).sum()

Expand Down

0 comments on commit a57f380

Please sign in to comment.