Skip to content

GRPO split generations into multiple training batches #3017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
JamesBowerXanda opened this issue Mar 6, 2025 · 15 comments
Open

GRPO split generations into multiple training batches #3017

JamesBowerXanda opened this issue Mar 6, 2025 · 15 comments

Comments

@JamesBowerXanda
Copy link

Feature request

In the GRPO training it would be useful if you could split the generations into smaller batches for the gradient calculations similar to how we split batches into multiple gradient calculations with gradient_accumulation_steps.

I am imagining the config to work something like this:

per_device_train_batch_size = 4
num_generations = 8
gradient_accumulation_steps = 4

with the condition that per_device_train_batch_size * gradient_accumulation_steps is a multiple of num_generations.

Motivation

In the GRPO algorithm the loss calculation (ignoring the KL part) is an estimation of an expectation under the current models distribution. This will have very high variance if we are limiting the sample size (number of generations) to small numbers giving us a poor estimation of the expectation and therefore making training less stable.

Currently per_device_train_batch_size must be a multiple of num_generations which can severely limit how large you can make it before hitting OOM particularly when in resource constrained environments working with long context windows. This seems like an unnecessary restriction since nothing in the algorithm stops us from splitting the gradient calculation of a generation batch into multiple smaller batches.

Your contribution

I don't think I would be able to create the PR myself unfortunately.

@qgallouedec
Copy link
Member

Let's say that you've 8 GPUs, in the limit you can have per_device_batch_size=1 and num_generations=8. And set the number of gradient accumulation steps to any value.

Currently per_device_train_batch_size must be a multiple of num_generations which can severely limit how large you can make it before

That's not exactly that. It's per_device_train_batch_size*num_devices that must be a multiple of num_generations.

While I understand the motivation, I think it's not straightforward to implement.

@JamesBowerXanda
Copy link
Author

Ah yes, sorry I forgot about number of devices. Though this doesn't change much right, we just amend my statement to

num_devices * per_device_train_batch_size * gradient_accumulation_steps must be a multiple of num_generations.

Is it complicated because currently the prepare_inputs method does both the generation and score calculation then the inputs are passed straight to the compute_loss method by the Trainer superclass?

I can see how it could cause more issues than it is worth having to fiddle with the core pipeline just for one trainer. I just thought I would bring it because I noticed how much smoother the training seemed when I was able to up the number of generations using smaller models and this seemed to be the big bottleneck to that.

@qgallouedec
Copy link
Member

Is it complicated because currently the prepare_inputs method does both the generation and score calculation then the inputs are passed straight to the compute_loss method by the Trainer superclass?

Yes that's correct

I was able to up the number of generations using smaller models and this seemed to be the big bottleneck to that.

You can increase the number of generations quite high actually. Example, if you've 8 GPUs that can handle 4 generations, you can use number of generations per prompt up to 32,

@JamesBowerXanda
Copy link
Author

Ok, I understand, thanks for your prompt responses.

Unfortunately I am most interested in using this on my personal gpu so I am not using multiple gpu clusters.

Thanks for your time, I am happy for the issue to be closed since it is not deemed feasible.

@qgallouedec
Copy link
Member

With 1 GPU, the best you can do is to set num_generations=per_device_train_batch_size, and set the gradient_accumulation_steps depending on the desired effective batch size. Example:

per_device_train_batch_size = 8
num_generations = 8
gradient_accumulation_steps = 16

To have an effective batch size of 128

@JamesBowerXanda
Copy link
Author

I understand this but it doesn't solve the issue of the loss function being an estimation based on a sample size of 8.

Image

Based on the GRPO loss formulation the expectation we estimate is conditional on the input prompt as are the advantage calculations and just increasing the gradient accumulation to 16 gives us 16 high variance estimates of the expectation rather than one low variance estimation.

I hope this makes sense. As I said before I can see why this is deemed not worth it since most large scale use cases can probably afford to just up the number of gpus. I had just hoped it would be an easier adjustment that would allow us hobbyists to stick closer to the theory of the paper.

@qgallouedec
Copy link
Member

Then you should increase num_generations. By default it's 8, but in the DeepSeek Math paper, they use 64. Of course you'll be probably limited by the compute here if you've only have 1 GPU

@qgallouedec
Copy link
Member

qgallouedec commented Mar 7, 2025

I had just hoped it would be an easier adjustment

In fact, this is tricky, as it would involve sampling, generating and calculating the advantage for the whole batch, then iterating somehow over the batch. It's not impossible, but it adds an implementation complexity that I don't think is justified.
In my experience, playing with a low num_generations gives good results.

@JamesBowerXanda
Copy link
Author

Forgive my naivety but would it not be as simple as overiding the training_step method for GRPOTrainer from the base Trainer one which is:

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)
        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

        del inputs
        if (
            self.args.torch_empty_cache_steps is not None
            and self.state.global_step % self.args.torch_empty_cache_steps == 0
        ):
            if is_torch_xpu_available():
                torch.xpu.empty_cache()
            elif is_torch_mlu_available():
                torch.mlu.empty_cache()
            elif is_torch_musa_available():
                torch.musa.empty_cache()
            elif is_torch_npu_available():
                torch.npu.empty_cache()
            elif is_torch_mps_available(min_version="2.0"):
                torch.mps.empty_cache()
            else:
                torch.cuda.empty_cache()

        kwargs = {}

        # For LOMO optimizers you need to explicitly use the learnign rate
        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            kwargs["learning_rate"] = self._get_learning_rate()

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            # Finally we need to normalize the loss for reporting
            if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
                loss = loss / self.args.gradient_accumulation_steps

            # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
            # https://github.com/huggingface/transformers/pull/35808
            if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs["scale_wrt_gas"] = False

            self.accelerator.backward(loss, **kwargs)

            return loss.detach()

to somehting like

def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)

        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)
        
        # CHANGED: Split the inputs into mini-batches
        mini_batch_size = self.args.per_device_train_batch_size * self.args.n_gpu
        mini_batch_inputs = []
        for i in range(inputs["prompt_ids"].shape[0] // mini_batch_size):
            mini_batch_inputs.append(
                {
                    key: value[i * mini_batch_size : (i + 1) * mini_batch_size] for key, value in inputs.items()
                }
            )
        losses = []

        del inputs

        # CHANGED: Iterate over the mini-batches for loss calculation and gradient backward pass
        for inputs in mini_batch_inputs:

            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

            del inputs
            if (
                self.args.torch_empty_cache_steps is not None
                and self.state.global_step % self.args.torch_empty_cache_steps == 0
            ):
                if is_torch_xpu_available():
                    torch.xpu.empty_cache()
                elif is_torch_mlu_available():
                    torch.mlu.empty_cache()
                elif is_torch_musa_available():
                    torch.musa.empty_cache()
                elif is_torch_npu_available():
                    torch.npu.empty_cache()
                elif is_torch_mps_available(min_version="2.0"):
                    torch.mps.empty_cache()
                else:
                    torch.cuda.empty_cache()

            kwargs = {}

            # For LOMO optimizers you need to explicitly use the learnign rate
            if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
                kwargs["learning_rate"] = self._get_learning_rate()

            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training

            if self.use_apex:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                # Finally we need to normalize the loss for reporting
                if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
                    loss = loss / self.args.gradient_accumulation_steps

                # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
                # https://github.com/huggingface/transformers/pull/35808
                if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
                    kwargs["scale_wrt_gas"] = False

                self.accelerator.backward(loss, **kwargs)

            # CHANGED: Append the loss to the list so that we can average it later and return the same value as before
            losses.append(loss.detach())

        # CHANGED: Average the losses and return the same value as before
        loss = torch.mean(torch.tensor(losses))

        return loss.detach()

I have added comments starting with # CHANGED: to all parts I have edited from the trainers method.

@JamesBowerXanda
Copy link
Author

Sorry, I am not trying to be a pain. As I said previously I am happy for you to close this if it is just a no go. Just thought I would offer the suggestion in case it helped.

@qgallouedec
Copy link
Member

It might work, but that's the complexity I want to avoid. Forking the repo might be the best option here. Or subclass GRPOTrainer to override the training_step method.

@JamesBowerXanda
Copy link
Author

Ok, I am happy to do that. I won't bog you down anymore on this.

@ingambe
Copy link
Contributor

ingambe commented Mar 16, 2025

Actually, being restricted on the minibatch size by the number of trajectories is very limiting.
Depending on the problem, if the variance is large or the reward is very sparse, 8 iterations will not cut it.

@jaeminSon
Copy link

If I understand correctly, per_device_train_batch_size is an integer, which means single GPU should be able to handle a backprop. H100 has roughly 80GB memory and I encountered GPU OOM with Qwen2-7B model. If I'm correct, this could be quite a constraint as bigger models cannot be run.

@jarrelscy
Copy link

Hi @JamesBowerXanda I ran into a similar thing as what you had and needed a larger generation batch size. I've implemented something which you can run using this. As mentioned above, I overwrote training_step within GRPOTrainer for this to work.

# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", 
                          logging_steps=10,
                          per_device_train_batch_size=16, # needs to be a multiple of num_generations
                          num_generations=8, # needs to be a multiple of num_generations_chunks 
                          num_generations_chunks=8)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

You can find it here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants