Skip to content

Commit

Permalink
📍 [GRPO] add gradient_checkpointing (#2848)
Browse files Browse the repository at this point in the history
* add gradient_checkpointing

* added a helper

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* minor refactor for better readability

* use acceelrate util

* enable_input_require_grads is in base class

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Feb 18, 2025
1 parent 15fec31 commit 9b3c5bf
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 5 deletions.
53 changes: 52 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


if is_peft_available():
from peft import LoraConfig
from peft import LoraConfig, PeftModel


class GRPOTrainerTester(unittest.TestCase):
Expand Down Expand Up @@ -133,6 +133,57 @@ def test_training_peft(self):
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")

@require_peft
def test_training_peft_with_gradient_checkpointing(self):
"""Test that training works with PEFT and gradient checkpointing enabled."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues
use_cache=False, # Required for gradient checkpointing
)

lora_config = LoraConfig(
r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
gradient_checkpointing=True, # Enable gradient checkpointing
report_to="none",
)
trainer = GRPOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
)

# Verify gradient checkpointing is enabled
self.assertIsInstance(trainer.model, PeftModel)

# Store initial parameters to check which ones change
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that only LoRA parameters have changed, base model parameters remain unchanged
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "lora" in n.lower(): # LoRA parameters should change
self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
else: # Base model parameters should not change
self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")

def test_training_different_reward_model(self):
# Use a reward model different from the model: different chat template, tokenization, etc.
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
Expand Down
36 changes: 32 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,21 +247,27 @@ def __init__(
self.beta = args.beta

if peft_config is not None:
if not is_peft_available():
raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
model = get_peft_model(model, peft_config)

# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)

# Reference model
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
elif is_peft_model(model):
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
else:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)

# Processing class
if processing_class is None:
Expand Down Expand Up @@ -488,6 +494,28 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler:
# preventing discrepancies in group formation.
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)

def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False

# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable()

gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)

if use_reentrant:
model.enable_input_require_grads()

return model

# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
Expand Down

0 comments on commit 9b3c5bf

Please sign in to comment.