From 09fcf9fa899fff5743517bf98449201e9bad4f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AE=B8=E4=B8=96=E8=B6=8A?= <1531845081@qq.com> Date: Fri, 16 May 2025 20:34:18 +0800 Subject: [PATCH 1/8] LD-DPO support --- trl/trainer/dpo_config.py | 12 ++++++++++++ trl/trainer/dpo_trainer.py | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 94bff43fb3..d2a8808c72 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -139,6 +139,10 @@ class DPOConfig(TrainingArguments): α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. + ld_alpha (`float`, *optional*, defaults to `None`): + α parameter from the LD-DPO paper, which controls the verbose token logp in responses. + If `None`, no weighting is applied on the verbose part and the loss is the same as the DPO loss. + The paper recommends `ld_alpha` should be between `0.0` and `1.0` discopop_tau (`float`, *optional*, defaults to `0.05`): τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. @@ -346,6 +350,14 @@ class DPOConfig(TrainingArguments): "`rpo_alpha=1.0`." }, ) + ld_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "α parameter from the LD-DPO paper, which controls the verbose token logp in responses. If " + "`None`, no weighting is applied on the verbose part and the loss is the same as the DPO loss. " + "The paper recommends `ld_alpha` should be between `0.0` and `1.0`" + }, + ) discopop_tau: float = field( default=0.05, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f01c5e3ff7..0918ae7c18 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1218,6 +1218,28 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if self.loss_type == "ipo": all_logps = all_logps / loss_mask.sum(-1) + if self.args.ld_alpha is not None: + # Compute response lengths based on loss_mask + completion_lengths = loss_mask.sum(dim=1) + + chosen_lengths = completion_lengths[:num_examples] + rejected_lengths = completion_lengths[num_examples:] + l_p = torch.min(chosen_lengths, rejected_lengths) + l_p = torch.cat([l_p, l_p], dim=0) + + seq_len = per_token_logps.size(1) + position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) + + ld_mask = position_ids < l_p.unsqueeze(1) + mask = position_ids < completion_lengths.unsqueeze(1) + + front_mask = (ld_mask & mask).float() + rear_mask = (~ld_mask & mask).float() + front_logps = (per_token_logps * front_mask).sum(dim=1) + rear_logps = (per_token_logps * rear_mask).sum(dim=1) + + all_logps = front_logps + self.args.ld_alpha * rear_logps + output["chosen_logps"] = all_logps[:num_examples] output["rejected_logps"] = all_logps[num_examples:] From 11b575a5562bd2de91cecc9a8022dd94ee16fcfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AE=B8=E4=B8=96=E8=B6=8A?= <1531845081@qq.com> Date: Fri, 16 May 2025 21:39:17 +0800 Subject: [PATCH 2/8] add description of LD-DPO --- docs/source/dpo_trainer.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index b22ce22756..e8f8f8a93a 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -168,6 +168,9 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. +### LD-DPO loss +The [LD-DPO](https://arxiv.org/abs/2409.06411) The paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient $\alpha$. To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. From f13447457fbe2e9f6b96c4320226290eadbf614b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 May 2025 15:43:28 +0200 Subject: [PATCH 3/8] Update docs/source/dpo_trainer.md --- docs/source/dpo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index e8f8f8a93a..50eacf259f 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -169,7 +169,7 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. ### LD-DPO loss -The [LD-DPO](https://arxiv.org/abs/2409.06411) The paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient $\alpha$. To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. +The [LD-DPO](https://huggingface.co/papers/2409.06411) The paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient $\alpha$. To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. ### For Mixture of Experts Models: Enabling the auxiliary loss From d375a06b1641d8d101d7986d2a155d95b781029d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AE=B8=E4=B8=96=E8=B6=8A?= <1531845081@qq.com> Date: Tue, 27 May 2025 10:55:00 +0800 Subject: [PATCH 4/8] fix the logps computation of LD-DPO --- docs/source/dpo_trainer.md | 2 +- trl/trainer/dpo_trainer.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 50eacf259f..8e29fbc666 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -169,7 +169,7 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. ### LD-DPO loss -The [LD-DPO](https://huggingface.co/papers/2409.06411) The paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient $\alpha$. To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. +The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient $\alpha$. To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. ### For Mixture of Experts Models: Enabling the auxiliary loss diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0918ae7c18..be48b8e47f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -804,9 +804,9 @@ def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: with torch.no_grad(), compte_ref_context_manager: if self.ref_model is None: with self.null_ref_context(): - ref_model_output = self.concatenated_forward(self.model, batch) + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) else: - ref_model_output = self.concatenated_forward(self.ref_model, batch) + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] @staticmethod @@ -1066,10 +1066,17 @@ def dpo_loss( return losses, chosen_rewards, rejected_rewards - def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False + ): """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. + + Args: + model: The model to run forward pass on + batch: The batch of inputs + is_ref_model: Whether this is being called for the reference model. """ num_examples = batch["prompt_input_ids"].shape[0] @@ -1218,7 +1225,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to if self.loss_type == "ipo": all_logps = all_logps / loss_mask.sum(-1) - if self.args.ld_alpha is not None: + if self.args.ld_alpha is not None and not is_ref_model: # Compute response lengths based on loss_mask completion_lengths = loss_mask.sum(dim=1) From 2e36155466fcf951ef8a6d1f941f761a0d179e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 27 May 2025 22:23:47 +0000 Subject: [PATCH 5/8] refine doc --- trl/trainer/dpo_config.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index d2a8808c72..ac03f61dcd 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -131,18 +131,19 @@ class DPOConfig(TrainingArguments): Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. label_smoothing (`float`, *optional*, defaults to `0.0`): - Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and + Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. use_weighting (`bool`, *optional*, defaults to `False`): - Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. + Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). rpo_alpha (`float`, *optional*, defaults to `None`): - α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the + α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. - ld_alpha (`float`, *optional*, defaults to `None`): - α parameter from the LD-DPO paper, which controls the verbose token logp in responses. - If `None`, no weighting is applied on the verbose part and the loss is the same as the DPO loss. - The paper recommends `ld_alpha` should be between `0.0` and `1.0` + ld_alpha (`float` or `None`, *optional*, defaults to `None`): + α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting + of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose + part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between + `0.0` and `1.0`. discopop_tau (`float`, *optional*, defaults to `0.05`): τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. @@ -353,9 +354,9 @@ class DPOConfig(TrainingArguments): ld_alpha: Optional[float] = field( default=None, metadata={ - "help": "α parameter from the LD-DPO paper, which controls the verbose token logp in responses. If " - "`None`, no weighting is applied on the verbose part and the loss is the same as the DPO loss. " - "The paper recommends `ld_alpha` should be between `0.0` and `1.0`" + "help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token " + "log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is " + "equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.", }, ) discopop_tau: float = field( From bd7b8aabe62807f83fd9eb3a13ea158bfeff7a10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 27 May 2025 22:25:06 +0000 Subject: [PATCH 6/8] fix latex --- docs/source/dpo_trainer.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 8e29fbc666..79ccb58b68 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -169,7 +169,8 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. ### LD-DPO loss -The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient $\alpha$. To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. + +The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. ### For Mixture of Experts Models: Enabling the auxiliary loss From 4760d670a8ab4c3d8c59b35d2743aad6e5c1ee6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 27 May 2025 22:43:03 +0000 Subject: [PATCH 7/8] add a test --- tests/test_dpo_trainer.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 69b7cfaf43..ed0f7a4770 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1258,6 +1258,37 @@ def dummy_compute_metrics(*args, **kwargs): self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + def test_train_with_length_desensitization(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + learning_rate=9e-1, + ld_alpha=0.5, + report_to="none", + ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_vision class DPOVisionTrainerTester(unittest.TestCase): From a1809032e8cd706c974b806b4118377147c7f170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 27 May 2025 22:43:16 +0000 Subject: [PATCH 8/8] nits --- trl/trainer/dpo_trainer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index be48b8e47f..1c255ffe14 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1069,14 +1069,19 @@ def dpo_loss( def concatenated_forward( self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False ): - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + """ + Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. Args: - model: The model to run forward pass on - batch: The batch of inputs - is_ref_model: Whether this is being called for the reference model. + model: + Model to run the forward pass on. + batch: + Batch of input data. + is_ref_model: + Whether this method is being called for the reference model. If `True`, length desensitization is not + applied. """ num_examples = batch["prompt_input_ids"].shape[0] @@ -1231,13 +1236,13 @@ def concatenated_forward( chosen_lengths = completion_lengths[:num_examples] rejected_lengths = completion_lengths[num_examples:] - l_p = torch.min(chosen_lengths, rejected_lengths) - l_p = torch.cat([l_p, l_p], dim=0) + public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper + public_lengths = torch.cat([public_lengths, public_lengths], dim=0) seq_len = per_token_logps.size(1) position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) - ld_mask = position_ids < l_p.unsqueeze(1) + ld_mask = position_ids < public_lengths.unsqueeze(1) mask = position_ids < completion_lengths.unsqueeze(1) front_mask = (ld_mask & mask).float()