-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update trainer for easier handling of accumulate, compile fixes, and proper reporting #34511
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple comments about the test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests look clean to me now, and I'm trusting you on the accelerate side of things! 😅
cc @LysandreJik / @ArthurZucker for core maintainer review
context = ( | ||
functools.partial(self.accelerator.no_sync, model=model) | ||
if i == len(batch_samples) - 1 | ||
else contextlib.nullcontext | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For an explanation on what we have going on here @Rocketknight1 , during DDP we use model.no_sync()
to only communicate across all GPUs during the next step outside it (so we speed up training when not needed when doing gradient accumulation). accelerator.no_sync()
is the lower-level accumulate()
API which makes that op backed-independent (so on a single GPU it just does nullcontext
)
@Milad335t just warning you to stop spamming or we'll have to block you 😢 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's hope this gets stabilized!
num_items_in_batch = sum( | ||
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] | ||
) | ||
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird to me that we have to use -100 here, instead of a general parameter but whit was already the case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC we use -100
for padding by default in the Trainer. I can align it to self.processor
if it exists else -100
if that's better?:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually our padding index is -100 everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay sounds good then sorry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries, it's weird for me too :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we no longer need to shift labels ["labels"][...,1:]
when getting num_items_in_batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I also think we need to shift labels before computing the num_items_in_batch
. Otherwise, the value is incorrect as the first element in labels may not be -100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks patching today!
num_items_in_batch = sum( | ||
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] | ||
) | ||
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay sounds good then sorry
…proper reporting (#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <[email protected]>
…proper reporting (#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <[email protected]>
…proper reporting (huggingface#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <[email protected]>
…proper reporting (huggingface#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <[email protected]>
…proper reporting (huggingface#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <[email protected]>
…proper reporting (huggingface#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <[email protected]>
@@ -3643,15 +3650,11 @@ def training_step( | |||
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |||
scaled_loss.backward() | |||
else: | |||
if num_items_in_batch is not None: | |||
if self.compute_loss_func or self.model_accepts_loss_kwargs: | |||
loss *= self.args.gradient_accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused that the loss is no longer multiplied by the gradient accumulation steps here, because the loss has been multiplied by the data parallel size in https://github.com/huggingface/transformers/pull/34511/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR3702-R3703
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it should be solved in #35207
What does this PR do?
Alternative to #34442
TL;DR we just need to remove
lru_cache
and everything will work fine. (and adds a test)This PR also takes the full lessons from my article and adds it to the
Trainer
for a simpler solution to the grad accum calculation (we shouldn't rely onaccelerator
from now on bc it can't handle the nuances with the grad accum fix at the highest level API, so we use a lower level version instead)Fixes #34402
Would recommend a patch after this
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @Rocketknight1