diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index be45697762e2..4c3ade60f815 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2504,8 +2504,45 @@ def _inner_training_loop( # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: - if is_sagemaker_mp_enabled() and args.fp16: - _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + if args.skip_unnecessary_grad_clip: + # Always compute grad norm first + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + # Compute grad norm ONLY (no clip yet) + _grad_norm = ( + self.accelerator.get_grad_norm( + model.parameters(), + norm_type=2.0, # Match the norm used in clipping + ) + if hasattr(self.accelerator, "get_grad_norm") + else ( + torch.norm( + torch.stack( + [ + torch.norm(p.grad.detach(), 2) + for p in model.parameters() + if p.grad is not None + ] + ), + 2, + ).item() + ) + ) + print(f"DEBUG: grad_norm={_grad_norm}, threshold={args.max_grad_norm}") + # Only clip if norm exceeds threshold + if _grad_norm > args.max_grad_norm: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + grad_norm = _grad_norm # Always log the norm used, clipped or unclipped + + elif is_sagemaker_mp_enabled() and args.fp16: + grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) else: grad_norm_context = contextlib.nullcontext if self.is_tp_enabled: @@ -2513,18 +2550,16 @@ def _inner_training_loop( grad_norm_context = implicit_replication with grad_norm_context(): - _grad_norm = self.accelerator.clip_grad_norm_( + grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) + # If DeepSpeed, update grad norm handling as before... if self.accelerator.distributed_type == DistributedType.DEEPSPEED: grad_norm = model.get_global_grad_norm() - # In some cases the grad norm may not return a float if hasattr(grad_norm, "item"): grad_norm = grad_norm.item() - else: - grad_norm = _grad_norm self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 412ba01f4e17..99b03b63fba5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -828,6 +828,9 @@ class TrainingArguments: adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) + skip_unnecessary_grad_clip: bool = field( + default=False, metadata={"help": "Skip gradient clipping when grad norm is below max_grad_norm threshold"} + ) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) max_steps: int = field( diff --git a/tests/trainer/test_gradient_clipping.py b/tests/trainer/test_gradient_clipping.py new file mode 100644 index 000000000000..16402ac34030 --- /dev/null +++ b/tests/trainer/test_gradient_clipping.py @@ -0,0 +1,33 @@ +from datasets import Dataset + +from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments + + +def test_skip_unnecessary_grad_clip(monkeypatch): + # Dummy model and data + model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased") + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + data = {"text": ["hello world", "foo bar"], "label": [0, 1]} + ds = Dataset.from_dict(data) + ds = ds.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length"), batched=True) + + args = TrainingArguments( + output_dir="./test_output", + skip_unnecessary_grad_clip=True, # <-- YOUR FEATURE! + max_grad_norm=1e8, # <-- Set threshold extremely high so grad norm is always below! + per_device_train_batch_size=2, + num_train_epochs=1, + ) + + def fake_clip_grad_norm(*args, **kwargs): + raise RuntimeError("Should not clip! Grad norm is under threshold.") + + trainer = Trainer(model=model, args=args, train_dataset=ds) + monkeypatch.setattr(trainer.accelerator, "clip_grad_norm_", fake_clip_grad_norm) + + # Run one training step and make sure no runtime error raised (no clipping triggered) + trainer.train() + + # Check logged grad_norm value is less than threshold (since threshold is huge) + logged_norms = [entry["grad_norm"] for entry in trainer.state.log_history if "grad_norm" in entry] + assert all(norm < args.max_grad_norm for norm in logged_norms), f"Grad norm logging failed! {logged_norms}"