Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,27 +2504,62 @@ def _inner_training_loop(

# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
if is_sagemaker_mp_enabled() and args.fp16:
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
if args.skip_unnecessary_grad_clip:
# Always compute grad norm first
grad_norm_context = contextlib.nullcontext
if self.is_tp_enabled:
from torch.distributed._tensor.experimental import implicit_replication

grad_norm_context = implicit_replication
with grad_norm_context():
# Compute grad norm ONLY (no clip yet)
_grad_norm = (
self.accelerator.get_grad_norm(
model.parameters(),
norm_type=2.0, # Match the norm used in clipping
)
if hasattr(self.accelerator, "get_grad_norm")
else (
torch.norm(
torch.stack(
[
torch.norm(p.grad.detach(), 2)
for p in model.parameters()
if p.grad is not None
]
),
2,
).item()
)
)
print(f"DEBUG: grad_norm={_grad_norm}, threshold={args.max_grad_norm}")
# Only clip if norm exceeds threshold
if _grad_norm > args.max_grad_norm:
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
grad_norm = _grad_norm # Always log the norm used, clipped or unclipped

elif is_sagemaker_mp_enabled() and args.fp16:
grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
else:
grad_norm_context = contextlib.nullcontext
if self.is_tp_enabled:
from torch.distributed._tensor.experimental import implicit_replication

grad_norm_context = implicit_replication
with grad_norm_context():
_grad_norm = self.accelerator.clip_grad_norm_(
grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)

# If DeepSpeed, update grad norm handling as before...
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm

self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,9 @@ class TrainingArguments:
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
skip_unnecessary_grad_clip: bool = field(
default=False, metadata={"help": "Skip gradient clipping when grad norm is below max_grad_norm threshold"}
)

num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
max_steps: int = field(
Expand Down
33 changes: 33 additions & 0 deletions tests/trainer/test_gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datasets import Dataset

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments


def test_skip_unnecessary_grad_clip(monkeypatch):
# Dummy model and data
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
data = {"text": ["hello world", "foo bar"], "label": [0, 1]}
ds = Dataset.from_dict(data)
ds = ds.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length"), batched=True)

args = TrainingArguments(
output_dir="./test_output",
skip_unnecessary_grad_clip=True, # <-- YOUR FEATURE!
max_grad_norm=1e8, # <-- Set threshold extremely high so grad norm is always below!
per_device_train_batch_size=2,
num_train_epochs=1,
)

def fake_clip_grad_norm(*args, **kwargs):
raise RuntimeError("Should not clip! Grad norm is under threshold.")

trainer = Trainer(model=model, args=args, train_dataset=ds)
monkeypatch.setattr(trainer.accelerator, "clip_grad_norm_", fake_clip_grad_norm)

# Run one training step and make sure no runtime error raised (no clipping triggered)
trainer.train()

# Check logged grad_norm value is less than threshold (since threshold is huge)
logged_norms = [entry["grad_norm"] for entry in trainer.state.log_history if "grad_norm" in entry]
assert all(norm < args.max_grad_norm for norm in logged_norms), f"Grad norm logging failed! {logged_norms}"