Skip to content

Commit

Permalink
Scale loss before backward (#35207)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Dec 23, 2024
1 parent f5264a8 commit 3cd3cd5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3698,10 +3698,12 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss, **kwargs)
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
loss /= self.args.gradient_accumulation_steps

self.accelerator.backward(loss, **kwargs)

return loss.detach()

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
Expand Down

0 comments on commit 3cd3cd5

Please sign in to comment.