-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update trainer for easier handling of accumulate, compile fixes, and proper reporting #34511
Changes from 11 commits
ec49756
4cdee53
fb8070f
9c6ed74
aab5467
c56ffe6
4a8a2a3
18abcb5
2f41eb7
27eaadc
6fcb0b5
43f6a2f
238c985
93f36e8
c633219
c452194
7739461
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,7 +233,6 @@ | |
from accelerate.utils import ( | ||
DistributedDataParallelKwargs, | ||
DistributedType, | ||
GradientAccumulationPlugin, | ||
load_fsdp_model, | ||
load_fsdp_optimizer, | ||
save_fsdp_model, | ||
|
@@ -2445,7 +2444,7 @@ def _inner_training_loop( | |
update_step += 1 | ||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder | ||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) | ||
for inputs in batch_samples: | ||
for i, inputs in enumerate(batch_samples): | ||
step += 1 | ||
total_batched_samples += 1 | ||
is_last_step_and_steps_less_than_grad_acc = ( | ||
|
@@ -2491,7 +2490,13 @@ def _inner_training_loop( | |
if step % args.gradient_accumulation_steps == 0: | ||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | ||
|
||
with self.accelerator.accumulate(model): | ||
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training | ||
context = ( | ||
functools.partial(self.accelerator.no_sync, model=model) | ||
if i == len(batch_samples) - 1 | ||
else contextlib.nullcontext | ||
) | ||
with context(): | ||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch) | ||
|
||
if ( | ||
|
@@ -3643,15 +3648,13 @@ def training_step( | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | ||
scaled_loss.backward() | ||
else: | ||
if num_items_in_batch is not None: | ||
if self.compute_loss_func or self.model_accepts_loss_kwargs: | ||
loss *= self.args.gradient_accumulation_steps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused that the loss is no longer multiplied by the gradient accumulation steps here, because the loss has been multiplied by the data parallel size in https://github.com/huggingface/transformers/pull/34511/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR3702-R3703 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, it should be solved in #35207 |
||
# Average tokens across devices is orthogonal to gradient accumulation | ||
if self.args.average_tokens_across_devices: | ||
loss *= self.args.world_size | ||
# Average tokens across devices is orthogonal to gradient accumulation | ||
if num_items_in_batch is not None and self.args.average_tokens_across_devices: | ||
loss *= self.args.world_size | ||
self.accelerator.backward(loss, **kwargs) | ||
|
||
return loss.detach() / self.args.gradient_accumulation_steps | ||
if num_items_in_batch is None: | ||
return loss.detach() / self.args.gradient_accumulation_steps | ||
return loss.detach() | ||
|
||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||
""" | ||
|
@@ -4953,24 +4956,21 @@ def _add_sm_patterns_to_gitignore(self) -> None: | |
self.repo.git_push() | ||
|
||
def create_accelerator_and_postprocess(self): | ||
# We explicitly don't rely on the `Accelerator` to do gradient accumulation | ||
grad_acc_kwargs = {} | ||
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: | ||
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs | ||
|
||
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs | ||
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: | ||
# raise because we do not know which setting is intended. | ||
raise ValueError( | ||
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" | ||
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." | ||
) | ||
elif "num_steps" not in grad_acc_kwargs: | ||
# take the gradient_accumulation_steps setting from TrainingArguments. | ||
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps | ||
|
||
grad_acc_kwargs["sync_with_dataloader"] = False | ||
|
||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) | ||
if "num_steps" in grad_acc_kwargs: | ||
if self.args.gradient_accumulation_steps > 1: | ||
# raise because we do not know which setting is intended. | ||
raise ValueError( | ||
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" | ||
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." | ||
) | ||
else: | ||
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] | ||
|
||
accelerator_config = self.args.accelerator_config.to_dict() | ||
|
||
|
@@ -5001,7 +5001,6 @@ def create_accelerator_and_postprocess(self): | |
|
||
args = { | ||
"deepspeed_plugin": self.args.deepspeed_plugin, | ||
"gradient_accumulation_plugin": gradient_accumulation_plugin, | ||
} | ||
if is_accelerate_available("0.28.0"): | ||
args["dataloader_config"] = dataloader_config | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For an explanation on what we have going on here @Rocketknight1 , during DDP we use
model.no_sync()
to only communicate across all GPUs during the next step outside it (so we speed up training when not needed when doing gradient accumulation).accelerator.no_sync()
is the lower-levelaccumulate()
API which makes that op backed-independent (so on a single GPU it just doesnullcontext
)