diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index b22ce22756..79ccb58b68 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -168,6 +168,10 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. +### LD-DPO loss + +The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 69b7cfaf43..ed0f7a4770 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1258,6 +1258,37 @@ def dummy_compute_metrics(*args, **kwargs): self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + def test_train_with_length_desensitization(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + learning_rate=9e-1, + ld_alpha=0.5, + report_to="none", + ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_vision class DPOVisionTrainerTester(unittest.TestCase): diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 94bff43fb3..ac03f61dcd 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -131,14 +131,19 @@ class DPOConfig(TrainingArguments): Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. label_smoothing (`float`, *optional*, defaults to `0.0`): - Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and + Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. use_weighting (`bool`, *optional*, defaults to `False`): - Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. + Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). rpo_alpha (`float`, *optional*, defaults to `None`): - α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the + α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. + ld_alpha (`float` or `None`, *optional*, defaults to `None`): + α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting + of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose + part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between + `0.0` and `1.0`. discopop_tau (`float`, *optional*, defaults to `0.05`): τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. @@ -346,6 +351,14 @@ class DPOConfig(TrainingArguments): "`rpo_alpha=1.0`." }, ) + ld_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token " + "log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is " + "equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.", + }, + ) discopop_tau: float = field( default=0.05, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f01c5e3ff7..1c255ffe14 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -804,9 +804,9 @@ def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: with torch.no_grad(), compte_ref_context_manager: if self.ref_model is None: with self.null_ref_context(): - ref_model_output = self.concatenated_forward(self.model, batch) + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) else: - ref_model_output = self.concatenated_forward(self.ref_model, batch) + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] @staticmethod @@ -1066,10 +1066,22 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards - def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False + ): + """ + Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. + + Args: + model: + Model to run the forward pass on. + batch: + Batch of input data. + is_ref_model: + Whether this method is being called for the reference model. If `True`, length desensitization is not + applied. """ num_examples = batch["prompt_input_ids"].shape[0] @@ -1218,6 +1230,28 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if self.loss_type == "ipo": all_logps = all_logps / loss_mask.sum(-1) + if self.args.ld_alpha is not None and not is_ref_model: + # Compute response lengths based on loss_mask + completion_lengths = loss_mask.sum(dim=1) + + chosen_lengths = completion_lengths[:num_examples] + rejected_lengths = completion_lengths[num_examples:] + public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper + public_lengths = torch.cat([public_lengths, public_lengths], dim=0) + + seq_len = per_token_logps.size(1) + position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) + + ld_mask = position_ids < public_lengths.unsqueeze(1) + mask = position_ids < completion_lengths.unsqueeze(1) + + front_mask = (ld_mask & mask).float() + rear_mask = (~ld_mask & mask).float() + front_logps = (per_token_logps * front_mask).sum(dim=1) + rear_logps = (per_token_logps * rear_mask).sum(dim=1) + + all_logps = front_logps + self.args.ld_alpha * rear_logps + output["chosen_logps"] = all_logps[:num_examples] output["rejected_logps"] = all_logps[num_examples:]