From 5bcd60ceaf6ee244740b4eeb0e2151d96e22120e Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ Date: Sun, 29 Sep 2024 15:28:46 -0400 Subject: [PATCH] formatting --- tests/test_dpo_trainer.py | 43 ++++++++++++++++++++++++++++++++++++-- trl/trainer/dpo_trainer.py | 37 +++++++++++++++++++------------- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a2fde2e3b5..6c9816e03e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -230,8 +230,6 @@ def setUp(self): ["t5", "exo_pair", True], ["gpt2", "apo_zero", True], ["t5", "apo_down", False], - ["gpt2", "wpo", False], - ["t5", "wpo", True], ] ) def test_dpo_trainer(self, name, loss_type, pre_compute): @@ -283,6 +281,47 @@ def test_dpo_trainer(self, name, loss_type, pre_compute): if param.sum() != 0: assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + def test_dpo_trainer_with_weighting(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type="sigmoid", + precompute_ref_log_probs=False, + use_weighting=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + @parameterized.expand( [ [None, "Test when rpo_alpha is set to None"], diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 9cabd6219f..8c32c26e1f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1316,7 +1316,7 @@ def get_batch_logps( labels: torch.LongTensor, label_pad_token_id: int = -100, is_encoder_decoder: bool = False, - use_weighting: bool = False + use_weighting: bool = False, ) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]: """Compute the log probabilities of the given labels under the given logits. @@ -1326,7 +1326,7 @@ def get_batch_logps( label_pad_token_id: The label pad token id. is_encoder_decoder: Whether the model is an encoder-decoder model. use_weighting: Whether to apply weighting as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. - + Returns A Tuple of three tensors of shape ((batch_size,), (batch_size,), Optional[(batch_size,)]) containing: - The sum of log probabilities of the given labels under the given logits. @@ -1347,17 +1347,17 @@ def get_batch_logps( labels[labels == label_pad_token_id] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - + all_logps = (per_token_logps * loss_mask).sum(-1) - + all_weights = None if use_weighting: # eqn (2) of the WPO paper: https://huggingface.co/papers/2406.11827 probs = F.softmax(logits, dim=-1) - weights_adjustment_factor = torch.log((probs ** 2).sum(-1)) + weights_adjustment_factor = torch.log((probs**2).sum(-1)) per_token_logps_adjusted = per_token_logps - weights_adjustment_factor all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) - + return all_logps, loss_mask.sum(-1), all_weights def concatenated_forward( @@ -1432,13 +1432,12 @@ def cross_entropy_loss(logits, labels): if self.loss_type == "ipo": all_logps = all_logps / size_completion - - policy_weights = None + + policy_weights = None if self.args.use_weighting: chosen_weights = all_weights[:len_chosen] rejected_weights = all_weights[len_chosen:] policy_weights = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) - chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] @@ -1447,9 +1446,17 @@ def cross_entropy_loss(logits, labels): rejected_logits = all_logits[len_chosen:] if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights, outputs.aux_loss) + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + nll_loss, + policy_weights, + outputs.aux_loss, + ) - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, policy_weights, nll_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights) def get_batch_loss_metrics( self, @@ -1503,7 +1510,7 @@ def get_batch_loss_metrics( if self.args.rpo_alpha is not None: # RPO loss from V3 of the paper: losses = losses + policy_nll_loss * self.args.rpo_alpha - + if self.args.use_weighting: losses = losses * policy_weights @@ -1730,7 +1737,8 @@ def create_model_card( else: base_model = None - citation = textwrap.dedent("""\ + citation = textwrap.dedent( + """\ @inproceedings{rafailov2023direct, title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, @@ -1738,7 +1746,8 @@ def create_model_card( booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, - }""") + }""" + ) model_card = generate_model_card( base_model=base_model,