Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update trainer for easier handling of accumulate, compile fixes, and proper reporting #34511

Merged
merged 17 commits into from
Nov 4, 2024
Merged
3 changes: 1 addition & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile
Expand Down Expand Up @@ -5014,7 +5014,6 @@ def _is_quantized_training_enabled(self):
return self.hf_quantizer.is_trainable

@property
@lru_cache
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
Expand Down
49 changes: 24 additions & 25 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -2445,7 +2444,7 @@ def _inner_training_loop(
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
for i, inputs in enumerate(batch_samples):
step += 1
total_batched_samples += 1
is_last_step_and_steps_less_than_grad_acc = (
Expand Down Expand Up @@ -2491,7 +2490,13 @@ def _inner_training_loop(
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

with self.accelerator.accumulate(model):
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
Comment on lines +2496 to +2500
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For an explanation on what we have going on here @Rocketknight1 , during DDP we use model.no_sync() to only communicate across all GPUs during the next step outside it (so we speed up training when not needed when doing gradient accumulation). accelerator.no_sync() is the lower-level accumulate() API which makes that op backed-independent (so on a single GPU it just does nullcontext)

with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)

if (
Expand Down Expand Up @@ -3643,15 +3648,13 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused that the loss is no longer multiplied by the gradient accumulation steps here, because the loss has been multiplied by the data parallel size in https://github.com/huggingface/transformers/pull/34511/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR3702-R3703

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it should be solved in #35207

# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
# Average tokens across devices is orthogonal to gradient accumulation
if num_items_in_batch is not None and self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Expand Down Expand Up @@ -4953,24 +4956,21 @@ def _add_sm_patterns_to_gitignore(self) -> None:
self.repo.git_push()

def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs

# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

grad_acc_kwargs["sync_with_dataloader"] = False

gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]

accelerator_config = self.args.accelerator_config.to_dict()

Expand Down Expand Up @@ -5001,7 +5001,6 @@ def create_accelerator_and_postprocess(self):

args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
Expand Down
43 changes: 31 additions & 12 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ def __getitem__(self, i):
return {"input_ids": self.x, "labels": self.x}


class SequenceClassificationDataset:
def __init__(self, length=64, vocab_size=100, num_labels=5):
self.length = length
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
self.labels = torch.randint(0, num_labels, (length,)).tolist()

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_ids": self.sequences[i], "label": self.labels[i]}


class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8):
self.length = length
Expand Down Expand Up @@ -1144,6 +1157,23 @@ def test_number_of_steps_in_training_with_ipex(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)

def test_torch_compile_loss_func_compatibility(self):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
per_device_train_batch_size=2,
torch_compile=True,
max_steps=1, # compile happens on the first step
)
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
trainer.train()

@require_peft
@require_bitsandbytes
def test_bnb_compile(self):
Expand Down Expand Up @@ -3676,9 +3706,6 @@ def test_accelerator_config_from_dict(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -3691,8 +3718,6 @@ def test_accelerator_config_from_yaml(self):
"even_batches": False,
"use_seedable_sampler": False,
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
Expand All @@ -3706,9 +3731,6 @@ def test_accelerator_config_from_yaml(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)

if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand Down Expand Up @@ -3754,10 +3776,7 @@ def test_accelerate_config_from_dataclass_grad_accum(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
self.assertEqual(trainer.args.gradient_accumulation_steps, 10)

def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
Expand Down
Loading