From a57f380332e7021755d0a36cf79406b4423cf361 Mon Sep 17 00:00:00 2001 From: David Brandfonbrener Date: Sun, 24 Mar 2024 15:10:17 -0400 Subject: [PATCH] Fix mean reduction in cross entropy loss The mean reduction should reduce the s_loss to a scalar. Also, I'm not sure why division was being used here instead of multiplication by the mask, but I changed it to multiplication. --- olmo/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 1494a1b49..4454786e3 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -105,7 +105,7 @@ def cross_entropy_loss( z_squared = logits.logsumexp(-1).pow(2) if reduction == "mean": - z_squared = z_squared / (labels != ignore_index).mean() + z_squared = (z_squared * (labels != ignore_index)).mean() elif reduction == "sum": z_squared = (z_squared * (labels != ignore_index)).sum()