Skip to content

Commit

Permalink
📉 Optimize GRPO memory usage by redefining per_device_batch_size as…
Browse files Browse the repository at this point in the history
… generations per device (huggingface#2776)

* Distribute

* fix some logic errors

* fix and document RepeatRandomSampler

* comment

* doc clarification

* fix type hint

* more readable

* fix eval

* fix tests

* roll back to distribute generation

* improve comment [ci skip]

* fix slice

* catch for eval batch size as well; fix completion_ids in vllm

* log completions

* Revert "log completions"

This reverts commit 1e4af8f.

* Before the first training step, the model has no optimizer: fix ds3
  • Loading branch information
qgallouedec authored Feb 6, 2025
1 parent 724acb9 commit cf97133
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 68 deletions.
2 changes: 0 additions & 2 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ This example demonstrates how to train a model using the GRPO method. We train a
></iframe>
Below is the script to train the model.
Note that the input tensor for the forward pass has a size of `num_generations * per_device_train_batch_size` because GRPO generates `num_generations` completions for each prompt in the batch. Adjusting these values appropriately can help prevent OOM errors.
Consequently, the effective train batch size is `num_generations * per_device_train_batch_size * gradient_accumulation_steps`.

```python
# train_grpo.py
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_grpo(self):
from trl.cli import main

with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory
command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 3 --max_completion_length 32 --report_to none"
command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 4 --max_completion_length 32 --report_to none"
with patch("sys.argv", command.split(" ")):
main()

Expand Down
26 changes: 13 additions & 13 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_training(self, config_name):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -78,8 +78,8 @@ def test_training_with_eval(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_eval_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
eval_strategy="steps",
Expand All @@ -106,7 +106,7 @@ def test_training_peft(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_training_different_reward_model(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -185,7 +185,7 @@ def reward_func(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -221,7 +221,7 @@ def reward_func(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -260,7 +260,7 @@ def reward_func2(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -295,7 +295,7 @@ def reward_func(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -334,7 +334,7 @@ def reward_func(completions, some_values, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_training_vllm(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_training_torch_compile(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
torch_compile=True,
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_training_with_sync_ref_model(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
sync_ref_model=True,
Expand Down
4 changes: 4 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def setup_chat_format(

def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
Expand Down Expand Up @@ -164,6 +166,8 @@ def iter_params(module, recurse=False):

def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
Expand Down
30 changes: 6 additions & 24 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class GRPOConfig(TrainingArguments):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample.
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
must be divisible by this value.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Expand Down Expand Up @@ -83,11 +84,6 @@ class GRPOConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
per_device_train_batch_size (`int`, *optional*, defaults to `1`):
Number of prompts sampled per device for training. The actual batch passed into the model will be this
value multiplied by `num_generations`.
gradient_accumulation_steps (`int`, *optional*, defaults to `8`):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -132,7 +128,10 @@ class GRPOConfig(TrainingArguments):
)
num_generations: Optional[int] = field(
default=8,
metadata={"help": "Number of generations to sample."},
metadata={
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
"must be divisible by this value."
},
)
temperature: Optional[float] = field(
default=0.9,
Expand Down Expand Up @@ -202,23 +201,6 @@ class GRPOConfig(TrainingArguments):
"`transformers.TrainingArguments`."
},
)
# GRPO generates multiple completions per prompt, increasing memory usage.
# To accommodate this, the per-device train batch size is decreased (overriden from the parent class),
# and the number gradient accumulation steps is increased to maintain the effective batch size.
per_device_train_batch_size: int = field(
default=1,
metadata={
"help": "Number of prompts sampled per device for training. The actual batch passed into the model will "
"be this value multiplied by `num_generations`."
},
)
gradient_accumulation_steps: int = field(
default=8,
metadata={
"help": "Number of updates steps to accumulate the gradients for, before performing a backward/update "
"pass."
},
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
Expand Down
Loading

0 comments on commit cf97133

Please sign in to comment.