From 31b7820aade7cefa06e3d4160dcff8a602a14850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 18 Oct 2024 21:02:24 +0200 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=94=80=20Rename=20`get=5Fbatch=5Fsamp?= =?UTF-8?q?le`=20and=20add=20`num=5Fitems=5Fin=5Fbatch`=20to=20`compute=5F?= =?UTF-8?q?loss`=20(#2246)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/bco_trainer.py | 5 +++-- trl/trainer/cpo_trainer.py | 5 +++-- trl/trainer/dpo_trainer.py | 5 +++-- trl/trainer/gkd_trainer.py | 8 +++++--- trl/trainer/kto_trainer.py | 5 +++-- trl/trainer/nash_md_trainer.py | 4 +++- trl/trainer/online_dpo_trainer.py | 4 +++- trl/trainer/orpo_trainer.py | 5 +++-- trl/trainer/reward_trainer.py | 1 + trl/trainer/xpo_trainer.py | 4 +++- 10 files changed, 30 insertions(+), 16 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 91461a9b0d..c6ce2d4902 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1260,6 +1260,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -1290,7 +1291,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1407,7 +1408,7 @@ def evaluation_loop( "prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]), "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5847cb182b..5e74fdaceb 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -828,6 +828,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -847,7 +848,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -938,7 +939,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8b9843cd91..082a627ce0 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1547,6 +1547,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() with compute_loss_context_manager: @@ -1561,7 +1562,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1672,7 +1673,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) self.log( { diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 1b7c77557d..49e93e269b 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -215,7 +215,7 @@ def generalized_jsd_loss( else: return jsd - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], @@ -273,7 +273,9 @@ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=No return generated_tokens, new_attention_mask, new_labels - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: """ Perform a training step for the Generalized Knowledge Distillation (GKD) model. @@ -298,7 +300,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels - loss = super().training_step(model, inputs) + loss = super().training_step(model, inputs, num_items_in_batch) return loss def _prepare_deepspeed(self, model: PreTrainedModelWrapper): diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index ab9ba87e41..7f32424812 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1234,6 +1234,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -1264,7 +1265,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1383,7 +1384,7 @@ def evaluation_loop( "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index db0c3046b3..73aab7899a 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -328,7 +328,9 @@ def gather_mean(tensor): self.stats["beta"].append(self.beta) self.stats["mixture_coef"].append(self.mixture_coef) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index ffc407b57d..c480c61fc5 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -366,7 +366,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None return self.accelerator.prepare(eval_dataloader) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input. diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 4edbf9b1a5..123f935208 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -844,6 +844,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -866,7 +867,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -957,7 +958,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 787c6cbd54..0ebdee68b4 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -266,6 +266,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_reward_data_collator: warnings.warn( diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 0255e6206f..a154875821 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -377,7 +377,9 @@ def gather_mean(tensor): self.stats["alpha"].append(self.alpha) self.stats["beta"].append(self.beta) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input From 92f6d246d39e339396353bc0c54f7c541775d2c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 21 Oct 2024 12:47:33 +0200 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=8F=97=EF=B8=8F=20Refactor=20DPO=20da?= =?UTF-8?q?ta=20processing=20(#2209)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * in progress * refactor concatenated_inputs and concatenated_forward * progress * further modif * padding side * eos prompt enc dec * prompt_padding_side * drop prompt apdding side collator * working on decoder only * dpo trainer * Fix loss_mask type conversion bug * bad attention mask * try to get the same tokens as main * fix loss mask * fix unused col * added comment * raise error when paddind token not set * remove private method tests * initial vlm support * make it work for paligemma * minor test updates * style * improve readibility * improve doc * style * flush left and truncate * flush left in the code * fix empty_cols and make max_length optional * always add eos token * minor changes and doc * style * fix docstring * preference collator in doc * fix doc * optional max_completion_length * Investigating CI failing * style * just dpo trainer test * just idefics * paligemma * llava * test cli * dataset in test * all tests * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py Co-authored-by: lewtun * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py * reference to ref * rich descriptions * fix logits reporting * fix truncation * remove chat template from dpo_vlm * `get_batch_sample` -> `generate_from_model[_and_ref]` * add `num_items_in_batch=None` * `num_items_in_batch` in `training_step` * Fix return type hint * test tokenize row * fix test --------- Co-authored-by: lewtun --- docs/source/dpo_trainer.mdx | 4 + examples/scripts/dpo_vlm.py | 12 - tests/test_cli.py | 2 +- tests/test_dpo_trainer.py | 260 ++++---- trl/trainer/dpo_trainer.py | 1148 +++++++++++++++-------------------- 5 files changed, 597 insertions(+), 829 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 728e0e39e3..0b5020dbad 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -276,3 +276,7 @@ dpo_trainer = DPOTrainer( ## DPOConfig [[autodoc]] DPOConfig + +## PreferenceCollator + +[[autodoc]] trainer.dpo_trainer.PreferenceCollator \ No newline at end of file diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 5c7cf4ba56..08f5687afe 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -27,7 +27,6 @@ """ import torch -from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForVision2Seq, AutoProcessor @@ -106,17 +105,6 @@ ################ dataset = load_dataset(script_args.dataset_name) - def process(row): - row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False) - row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False) - row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False) - return row - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().local_main_process_first(): - dataset = dataset.map(process, num_proc=training_args.dataset_num_proc) - ################ # Training ################ diff --git a/tests/test_cli.py b/tests/test_cli.py index b01c3c1852..2fc0e6eb03 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -32,7 +32,7 @@ def test_sft_cli(): def test_dpo_cli(): try: subprocess.run( - "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-lib/ultrafeedback_binarized --learning_rate 1e-4 --lr_scheduler_type cosine", + "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine", shell=True, check=True, ) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6dff6a3892..a090f60aed 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -14,6 +14,7 @@ import tempfile import unittest +from unittest.mock import MagicMock import numpy as np import pytest @@ -27,172 +28,129 @@ AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, + PreTrainedTokenizerBase, ) from transformers.testing_utils import require_bitsandbytes, require_peft from trl import DPOConfig, DPOTrainer, FDivergenceType -from trl.trainer.dpo_trainer import _build_tokenized_answer, _truncate_tokens from .testing_utils import require_no_wandb -class TestBuildTokenizedAnswer(unittest.TestCase): +class TestTokenizeRow(unittest.TestCase): def setUp(self): - self.tokenizer = AutoTokenizer.from_pretrained("gpt2") - self.tokenizer.pad_token = self.tokenizer.eos_token - - def test_basic_functionality(self): - prompt = "Hello, how are you?" - answer = "I'm doing well, thank you!" - - result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer) - - self.assertIn("prompt_input_ids", result) - self.assertIn("prompt_attention_mask", result) - self.assertIn("input_ids", result) - self.assertIn("attention_mask", result) - - self.assertEqual(len(result["prompt_input_ids"]), len(result["prompt_attention_mask"])) - self.assertEqual(len(result["input_ids"]), len(result["attention_mask"])) - - decoded_prompt = self.tokenizer.decode(result["prompt_input_ids"]) - self.assertTrue(prompt in decoded_prompt) - - decoded_answer = self.tokenizer.decode(result["input_ids"]) - self.assertTrue(answer in decoded_answer) - - def test_with_processor(self): - def mock_processor(text, images=None, add_special_tokens=True): - return {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])} - - prompt = "Describe this image:" - answer = "A beautiful sunset over the ocean." - - result = _build_tokenized_answer(prompt, answer, processor=mock_processor) - - self.assertIn("prompt_input_ids", result) - self.assertIn("prompt_attention_mask", result) - self.assertIn("input_ids", result) - self.assertIn("attention_mask", result) - - self.assertEqual(result["prompt_input_ids"], [1, 2, 3]) - self.assertEqual(result["prompt_attention_mask"], [1, 1, 1]) - - def test_token_merging(self): - prompt = "The quick brown" - answer = " fox jumps over the lazy dog." - - result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer) - - full_text = prompt + answer - full_tokenized = self.tokenizer(full_text, add_special_tokens=False) - - self.assertEqual(result["prompt_input_ids"] + result["input_ids"], full_tokenized["input_ids"]) + # Set up the mock tokenizer with specific behaviors + self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.tokenizer.bos_token_id = 0 + self.tokenizer.eos_token_id = 2 + + # Define mock return values for the tokenizer's 'input_ids' for the different text inputs + self.tokenizer.return_value = { + "input_ids": {"The sky is": [464, 6766, 318], " blue": [4171], " green": [4077]} + } - def test_vision_model(self): - def mock_vision_processor(text, images=None, add_special_tokens=True): - return { - "input_ids": torch.tensor([[1, 2, 3]]), - "attention_mask": torch.tensor([[1, 1, 1]]), - "pixel_values": torch.rand(1, 3, 224, 224), - "pixel_attention_mask": torch.ones(1, 224, 224), + # Define tokenizer behavior when called + def mock_tokenizer_call(text, add_special_tokens): + token_map = { + "The sky is": {"input_ids": [464, 6766, 318]}, + " blue": {"input_ids": [4171]}, + " green": {"input_ids": [4077]}, } + return token_map[text] - prompt = "Describe this image:" - answer = "A cat sitting on a windowsill." + self.tokenizer.side_effect = mock_tokenizer_call - result = _build_tokenized_answer(prompt, answer, processor=mock_vision_processor) + def test_tokenize_row_no_truncation_no_special_tokens(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} - self.assertIn("prompt_pixel_values", result) - self.assertIn("prompt_pixel_attention_mask", result) - self.assertTrue(torch.is_tensor(result["prompt_pixel_values"])) - self.assertTrue(torch.is_tensor(result["prompt_pixel_attention_mask"])) + # Call the method with no truncation and no special tokens + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=None, + max_completion_length=None, + add_special_tokens=False, + ) + # Assert the correct output without truncation or special tokens + self.assertEqual( + result, + { + "prompt_input_ids": [464, 6766, 318], + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + }, + ) -class TestTruncateTokens(unittest.TestCase): - def setUp(self): - with tempfile.TemporaryDirectory() as tmp_dir: - self.training_args = DPOConfig( - max_length=20, max_prompt_length=10, truncation_mode="keep_start", output_dir=tmp_dir - ) + def test_tokenize_row_with_truncation(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with truncation + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=2, + max_completion_length=1, + add_special_tokens=False, + ) - def test_truncate_tokens(self): - chosen_tokens = [ - { - "prompt_input_ids": list(range(15)), - "prompt_attention_mask": [1] * 15, - "input_ids": list(range(10)), - "attention_mask": [1] * 10, - } - ] - rejected_tokens = [ - { - "prompt_input_ids": list(range(15)), - "prompt_attention_mask": [1] * 15, - "input_ids": list(range(12)), - "attention_mask": [1] * 12, - } - ] - prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}] - - _truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.training_args) - - # Check if prompt is truncated correctly - self.assertEqual(len(chosen_tokens[0]["prompt_input_ids"]), 10) - self.assertEqual(len(chosen_tokens[0]["prompt_attention_mask"]), 10) - self.assertEqual(len(rejected_tokens[0]["prompt_input_ids"]), 10) - self.assertEqual(len(rejected_tokens[0]["prompt_attention_mask"]), 10) - self.assertEqual(len(prompt_tokens[0]["prompt_input_ids"]), 10) - self.assertEqual(len(prompt_tokens[0]["prompt_attention_mask"]), 10) - - # Check if responses are truncated correctly - self.assertEqual(len(chosen_tokens[0]["input_ids"]), 10) - self.assertEqual(len(chosen_tokens[0]["attention_mask"]), 10) - self.assertEqual(len(rejected_tokens[0]["input_ids"]), 10) - self.assertEqual(len(rejected_tokens[0]["attention_mask"]), 10) - - def test_truncation_mode_keep_end(self): - self.training_args.truncation_mode = "keep_end" - chosen_tokens = [ - { - "prompt_input_ids": list(range(15)), - "prompt_attention_mask": [1] * 15, - "input_ids": list(range(15, 25)), - "attention_mask": [1] * 10, - } - ] - rejected_tokens = [ + # Assert the correct output with truncation applied + self.assertEqual( + result, { - "prompt_input_ids": list(range(15)), - "prompt_attention_mask": [1] * 15, - "input_ids": list(range(15, 28)), - "attention_mask": [1] * 13, - } - ] - prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}] - - _truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.training_args) + "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + }, + ) - # Check if prompt is truncated correctly from the end - self.assertEqual(prompt_tokens[0]["prompt_input_ids"], list(range(5, 15))) - self.assertEqual(prompt_tokens[0]["prompt_attention_mask"], [1] * 10) + def test_tokenize_row_with_special_tokens(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with special tokens + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=None, + max_completion_length=None, + add_special_tokens=True, + ) - # Check if chosen tokens are truncated correctly - self.assertEqual(chosen_tokens[0]["prompt_input_ids"], list(range(5, 15))) - self.assertEqual(chosen_tokens[0]["prompt_attention_mask"], [1] * 10) - self.assertEqual(chosen_tokens[0]["input_ids"], list(range(15, 25))) - self.assertEqual(chosen_tokens[0]["attention_mask"], [1] * 10) + # Assert the correct output with special tokens added + self.assertEqual( + result, + { + "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + }, + ) - # Check if rejected tokens are truncated correctly - self.assertEqual(rejected_tokens[0]["prompt_input_ids"], list(range(5, 15))) - self.assertEqual(rejected_tokens[0]["prompt_attention_mask"], [1] * 10) - self.assertEqual(rejected_tokens[0]["input_ids"], list(range(15, 25))) - self.assertEqual(rejected_tokens[0]["attention_mask"], [1] * 10) + def test_tokenize_row_with_truncation_and_special_tokens(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with both truncation and special tokens + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=4, + max_completion_length=1, + add_special_tokens=True, + ) - def test_invalid_truncation_mode(self): - self.training_args.truncation_mode = "invalid_mode" - with self.assertRaises(ValueError): - _truncate_tokens([], [], [], self.training_args) + # Assert the correct output with both truncation and special tokens + self.assertEqual( + result, + { + "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + }, + ) class DPOTrainerTester(unittest.TestCase): @@ -461,9 +419,9 @@ def test_dpo_trainer_padding_token_is_none(self): with self.assertRaisesRegex( ValueError, - expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token." - r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)" - r" before calling the trainer.", + expected_regex=r"Can't find `pad_token_id` in the `processing_class`. " + r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) " + r"before instantiating the trainer.", ): trainer = DPOTrainer( model=self.model, @@ -498,9 +456,9 @@ def test_dpo_trainer_w_dataset_num_proc(self): with self.assertRaisesRegex( ValueError, - expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token." - r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)" - r" before calling the trainer.", + expected_regex=r"Can't find `pad_token_id` in the `processing_class`. " + r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) " + r"before instantiating the trainer.", ): trainer = DPOTrainer( model=self.model, @@ -1139,7 +1097,7 @@ def test_vdpo_trainer(self, model_id): output_dir=tmp_dir, per_device_train_batch_size=2, max_length=512, - max_prompt_length=128, + max_prompt_length=512, remove_unused_columns=False, report_to="none", ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 082a627ce0..03061bcb46 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import inspect import os import random @@ -20,6 +21,7 @@ from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch @@ -42,6 +44,7 @@ Trainer, is_wandb_available, ) +from transformers.data.data_collator import DataCollatorMixin from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput @@ -53,13 +56,11 @@ from .callbacks import SyncRefModelCallback from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType from .utils import ( - DPODataCollatorWithPadding, RunningMoments, - add_bos_token_if_needed, - add_eos_token_if_needed, cap_exp, disable_dropout_in_model, generate_model_card, + pad, pad_to_length, peft_module_casting_to_bf16, ) @@ -76,266 +77,73 @@ import deepspeed -def _tokenize( - features: Dict[str, List], - tokenizer: PreTrainedTokenizerBase, - args: DPOConfig, - processor: Optional[Callable] = None, - model: Optional[PreTrainedModel] = None, -) -> Dict[str, List]: - """ - Tokenizes and processes a batch of input features using the provided tokenizer and processor. - """ - batch = defaultdict(list) - - if model is None: - prompt = features["prompt"] - images = features.get("images", [None] * len(features["prompt"])) - - prompt_tokens = _process_prompt(prompt, processor, tokenizer, images) - chosen_tokens = _process_answer(prompt, features["chosen"], processor, tokenizer, images) - rejected_tokens = _process_answer(prompt, features["rejected"], processor, tokenizer, images) - - prompt_len_input_ids = _adjust_prompt_length(prompt_tokens, chosen_tokens, rejected_tokens) - - prompt_tokens, chosen_tokens, rejected_tokens = _add_special_tokens( - tokenizer, prompt_len_input_ids, prompt_tokens, chosen_tokens, rejected_tokens - ) - - _truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, args) - - _build_sequence_tokens(batch, chosen_tokens, args, "chosen") - _build_sequence_tokens(batch, rejected_tokens, args, "rejected") - - _append_prompt_tokens_to_batch(batch, prompt_tokens) - - else: - _tokenize_encoder_decoder(batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args) - - return dict(batch) - - -def _process_prompt( - prompts: List[str], processor: Optional[Callable], tokenizer: PreTrainedTokenizerBase, images: List[Optional[Any]] -) -> List[Dict[str, List[int]]]: +@dataclass +class PreferenceCollator(DataCollatorMixin): """ - Processes a list of prompts by tokenizing them, optionally using a processor for additional processing. - """ - if processor: - processor_kwargs = ( - {"add_special_tokens": False} if "add_special_tokens" in inspect.signature(processor).parameters else {} - ) - prompt_tokens = [] - for prompt, image in zip(prompts, images): - tokens = processor(images=image, text=prompt, **processor_kwargs) - tokens = {k: v[0] for k, v in tokens.items()} - if not isinstance(tokens["input_ids"], list): - tokens["input_ids"] = tokens["input_ids"].tolist() - tokens["attention_mask"] = tokens["attention_mask"].tolist() - prompt_tokens.append(tokens) - else: - prompt_tokens = [tokenizer(prompt, add_special_tokens=False) for prompt in prompts] - return [{f"prompt_{k}": v for k, v in tokens.items()} for tokens in prompt_tokens] - - -def _process_answer( - prompts: List[str], - answers: List[str], - processor: Optional[Callable], - tokenizer: PreTrainedTokenizerBase, - images: List[Optional[Any]], -) -> List[Dict[str, Any]]: - return [ - _build_tokenized_answer(prompt, answer, image, processor=processor, tokenizer=tokenizer) - for prompt, answer, image in zip(prompts, answers, images) - ] - - -def _adjust_prompt_length( - prompt_tokens: List[Dict[str, List[int]]], - chosen_tokens: List[Dict[str, List[int]]], - rejected_tokens: List[Dict[str, List[int]]], -) -> List[int]: - prompt_len_input_ids = [] - for p_tokens, c_tokens, r_tokens in zip(prompt_tokens, chosen_tokens, rejected_tokens): - c_len = len(c_tokens["prompt_input_ids"]) - r_len = len(r_tokens["prompt_input_ids"]) - min_len = min(c_len, r_len) - - for k, v in p_tokens.items(): - p_tokens[k] = v[:min_len] - - num_diff_tokens = sum([a != b for a, b in zip(c_tokens["prompt_input_ids"], r_tokens["prompt_input_ids"])]) - num_diff_len = abs(c_len - r_len) - if num_diff_tokens > 1 or num_diff_len > 1: - raise ValueError( - "Chosen and rejected prompt_input_ids might only differ on the last token due to tokenizer merge ops." - ) - prompt_len_input_ids.append(min_len) - return prompt_len_input_ids - - -def _add_special_tokens( - tokenizer: PreTrainedTokenizerBase, - prompt_len_input_ids: List[int], - prompt_tokens: List[Dict[str, List[int]]], - chosen_tokens: List[Dict[str, List[int]]], - rejected_tokens: List[Dict[str, List[int]]], -) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]], List[Dict[str, List[int]]]]: - for i in range(len(prompt_tokens)): - prompt_tokens[i], chosen_tokens[i], rejected_tokens[i] = add_bos_token_if_needed( - tokenizer.bos_token_id, - prompt_len_input_ids[i], - prompt_tokens[i], - len(chosen_tokens[i]["prompt_input_ids"]), - chosen_tokens[i], - len(rejected_tokens[i]["prompt_input_ids"]), - rejected_tokens[i], - ) - - chosen_tokens[i], rejected_tokens[i] = add_eos_token_if_needed( - tokenizer.eos_token_id, chosen_tokens[i], rejected_tokens[i] - ) - return prompt_tokens, chosen_tokens, rejected_tokens - + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they + are not all of the same length. -def _truncate_tokens( - chosen_tokens: List[Dict[str, List[int]]], - rejected_tokens: List[Dict[str, List[int]]], - prompt_tokens: List[Dict[str, List[int]]], - args: DPOConfig, -) -> None: - """ - Truncates the tokens in chosen, rejected, and prompt sequences to ensure they fit within the maximum length constraints. - """ - if args.truncation_mode not in ["keep_start", "keep_end"]: - raise ValueError(f"Invalid truncation mode: {args.truncation_mode}") - - for c_tokens, r_tokens, p_tokens in zip(chosen_tokens, rejected_tokens, prompt_tokens): - longer_response_length = max(len(c_tokens["input_ids"]), len(r_tokens["input_ids"])) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [c_tokens, r_tokens, p_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > args.max_length: - if args.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][: args.max_prompt_length] - elif args.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][-args.max_prompt_length :] - - # if that's still too long, truncate the response from the end - for answer_tokens in [c_tokens, r_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > args.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][: args.max_length - args.max_prompt_length] - - -def _build_sequence_tokens( - batch: Dict[str, List[int]], tokens: List[Dict[str, List[int]]], args: DPOConfig, prefix: str -) -> None: - for token in tokens: - sequence_tokens = {f"{prefix}_{k}": token[f"prompt_{k}"] + token[k] for k in ["input_ids", "attention_mask"]} - sequence_tokens[f"{prefix}_labels"] = sequence_tokens[f"{prefix}_input_ids"][:] - sequence_tokens[f"{prefix}_labels"][: len(token["prompt_input_ids"])] = [args.label_pad_token_id] * len( - token["prompt_input_ids"] - ) - for k, v in sequence_tokens.items(): - batch[k].append(v) - - -def _append_prompt_tokens_to_batch(batch: Dict[str, List[int]], prompt_tokens: List[Dict[str, List[int]]]) -> None: - for p_tokens in prompt_tokens: - for k, v in p_tokens.items(): - batch[k].append(v) - - -def _tokenize_encoder_decoder( - batch: Dict[str, List[int]], - tokenizer: PreTrainedTokenizerBase, - prompt: List[str], - chosen: List[str], - rejected: List[str], - args: DPOConfig, -) -> None: - chosen_tokens = tokenizer(chosen, truncation=True, max_length=args.max_completion_length, add_special_tokens=True) - rejected_tokens = tokenizer( - rejected, truncation=True, max_length=args.max_completion_length, add_special_tokens=True - ) - prompt_tokens = tokenizer(prompt, truncation=True, max_length=args.max_prompt_length, add_special_tokens=True) - - batch["chosen_labels"] = chosen_tokens["input_ids"] - batch["rejected_labels"] = rejected_tokens["input_ids"] - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] - - -def _build_tokenized_answer( - prompt: str, - answer: str, - images: Optional[List[Any]] = None, - processor: Optional[Callable] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, -) -> Dict[str, Any]: - """ - Build tokenized response, handling vision models and different tokenizers. - """ - - def tokenize(text, images=None): - if processor: - processor_kwargs = ( - {"add_special_tokens": False} - if "add_special_tokens" in inspect.signature(processor).parameters - else {} - ) - tokenized = processor(images=images, text=text, **processor_kwargs) - tokenized = {k: v[0] for k, v in tokenized.items()} - if not isinstance(tokenized["input_ids"], list): - tokenized["input_ids"] = tokenized["input_ids"].tolist() - tokenized["attention_mask"] = tokenized["attention_mask"].tolist() - else: - tokenized = tokenizer(text, add_special_tokens=False) - return tokenized - - full_tokenized = tokenize(prompt + answer, images) - prompt_tokenized = tokenize(prompt, images) - - prompt_input_ids = prompt_tokenized["input_ids"] - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] - - if len(full_tokenized["input_ids"]) != len(prompt_input_ids + answer_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = len(prompt_input_ids) - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - return_dict = { - "prompt_input_ids": prompt_input_ids, - "prompt_attention_mask": prompt_attention_mask, - "input_ids": answer_input_ids, - "attention_mask": answer_attention_mask, + Args: + pad_token_id (`int`): + Token ID to use for padding. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl import PreferenceCollator + >>> collator = PreferenceCollator(pad_token_id=0) + >>> examples = [ + ... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, + ... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]} + ... ] + >>> collator(examples) + {'prompt_input_ids': tensor([[1, 2, 3], + [0, 7, 8]]), + 'prompt_attention_mask': tensor([[1, 1, 1], + [0, 1, 1]]), + 'chosen_input_ids': tensor([[ 4, 5], + [ 9, 10]]), + 'chosen_attention_mask': tensor([[1, 1], + [1, 1]]), + 'rejected_input_ids': tensor([[ 6, 0, 0], + [11, 12, 13]]), + 'rejected_attention_mask': tensor([[1, 0, 0], + [1, 1, 1]]) } - if "pixel_values" in full_tokenized: - return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"] - if "pixel_attention_mask" in full_tokenized: - return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"] + ``` + """ - return return_dict + pad_token_id: int + return_tensors: str = "pt" + + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + # Convert to tensor + prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] + prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] + chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] + chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids] + rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids] + if "pixel_values" in examples[0]: + pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] + if "pixel_attention_mask" in examples[0]: + pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] + + # Pad + output = {} + output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") + output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") + output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id) + output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0) + output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id) + output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0) + if "pixel_values" in examples[0]: + output["pixel_values"] = pad(pixel_values, padding_value=0.0) + if "pixel_attention_mask" in examples[0]: + output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) + + return output class DPOTrainer(Trainer): @@ -618,12 +426,6 @@ def make_inputs_require_grad(module, input, output): ) self.is_vision_model = False - if self.is_vision_model: - self.processor = processing_class - self.processing_class = self.processor.tokenizer # tokenizer is actually a processor at this point - else: - self.processing_class = processing_class - self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) if model_adapter_name is not None: warnings.warn( @@ -668,51 +470,47 @@ def make_inputs_require_grad(module, input, output): "You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.max_length = max_length - if args.max_length is None: - warnings.warn( - "`max_length` is not set in the DPOConfig's init" - " it will default to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - args.max_length = 512 if max_prompt_length is not None: warnings.warn( "You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.max_prompt_length = max_prompt_length - if args.max_prompt_length is None: - warnings.warn( - "`max_prompt_length` is not set in the DPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - args.max_prompt_length = 128 if max_target_length is not None: warnings.warn( "You passed `max_target_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.max_completion_length = max_target_length - if args.max_completion_length is None and self.is_encoder_decoder: - warnings.warn( - "When using an encoder decoder architecture, you should set `max_completion_length` in the DPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - args.max_completion_length = 128 if label_pad_token_id != -100: warnings.warn( "You passed `label_pad_token_id` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.label_pad_token_id = label_pad_token_id - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=self.processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, + + if padding_value is not None: + warnings.warn( + "You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) + args.padding_value = padding_value + + if args.padding_value is not None: + self.padding_value = args.padding_value + else: + if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None: + self.padding_value = processing_class.pad_token_id + elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None: + self.padding_value = processing_class.tokenizer.pad_token_id + else: + raise ValueError( + "Can't find `pad_token_id` in the `processing_class`. " + "Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`) " + "before instantiating the trainer." + ) + + if data_collator is None: + data_collator = PreferenceCollator(pad_token_id=self.padding_value) if not disable_dropout: warnings.warn( @@ -727,12 +525,6 @@ def make_inputs_require_grad(module, input, output): self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id - if padding_value is not None: - warnings.warn( - "You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." - ) - args.padding_value = padding_value - self.padding_value = args.padding_value if padding_value is not None else self.processing_class.pad_token_id self.max_prompt_length = args.max_prompt_length if truncation_mode != "keep_end": warnings.warn( @@ -801,38 +593,45 @@ def make_inputs_require_grad(module, input, output): # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): # Extract the prompt if needed, and apply the chat template if needed - train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) train_dataset = train_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" + ) + train_dataset = train_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to train dataset", ) if eval_dataset is not None: - eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" + ) eval_dataset = eval_dataset.map( maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc, + desc="Applying chat template to eval dataset", ) # tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models) fn_kwargs = { - "tokenizer": self.processing_class, - "args": args, - "processor": self.processor if self.is_vision_model else None, - "model": model if self.is_encoder_decoder else None, + "processing_class": processing_class, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) + "add_special_tokens": self.is_encoder_decoder, } train_dataset = train_dataset.map( - _tokenize, + self.tokenize_row if not self.is_vision_model else self.process_row, fn_kwargs=fn_kwargs, - batched=True, num_proc=self.dataset_num_proc, writer_batch_size=10, desc="Tokenizing train dataset", ) if eval_dataset is not None: eval_dataset = eval_dataset.map( - _tokenize, + self.tokenize_row if not self.is_vision_model else self.process_row, fn_kwargs=fn_kwargs, - batched=True, num_proc=self.dataset_num_proc, writer_batch_size=10, desc="Tokenizing eval dataset", @@ -893,6 +692,107 @@ def make_inputs_require_grad(module, input, output): if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) + @staticmethod + def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): + """ + Tokenize a row of the dataset. + + Args: + features (`Dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. + processing_class (`PreTrainedTokenizerBase`): + Processing class used to process the data. + max_prompt_length (`int` or `None`): + Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + add_special_tokens (`bool`): + Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, + the prompt sequence will have a bos token prepended and an eos token appended. In any case, the + completion sequences will have an eos token appended. + + Returns: + `Dict[str, List[int]]`: + Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and + `"rejected_input_ids". + + Example: + ```python + >>> from transformers import GPT2Tokenizer + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + >>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False) + {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} + ``` + """ + tokenizer = processing_class # the processing class is a tokenizer + prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] + rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if max_completion_length is not None: + chosen_input_ids = chosen_input_ids[:max_completion_length] + rejected_input_ids = rejected_input_ids[:max_completion_length] + + return { + "prompt_input_ids": prompt_input_ids, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + } + + @staticmethod + def process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): + """ + Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. + """ + processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor + processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) + + prompt_input_ids = processed_features["input_ids"][0] + pixel_values = processed_features["pixel_values"][0] + chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] + rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if max_completion_length is not None: + chosen_input_ids = chosen_input_ids[:max_completion_length] + rejected_input_ids = rejected_input_ids[:max_completion_length] + + output = { + "prompt_input_ids": prompt_input_ids, + "pixel_values": pixel_values, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + } + + if "pixel_attention_mask" in processed_features: + output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] + + return output + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin @@ -930,16 +830,7 @@ def _set_signature_columns_if_needed(self): # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override. if self._signature_columns is None: - self._signature_columns = [ - "chosen_input_ids", - "chosen_attention_mask", - "chosen_labels", - "rejected_input_ids", - "rejected_attention_mask", - "rejected_labels", - "prompt_input_ids", - "prompt_attention_mask", - ] + self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids"] def get_train_dataloader(self) -> DataLoader: """ @@ -960,28 +851,26 @@ def get_train_dataloader(self) -> DataLoader: # prepare dataloader data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - reference_chosen_logps = [] - reference_rejected_logps = [] + ref_chosen_logps = [] + ref_rejected_logps = [] for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) - reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( - (reference_chosen_logp, reference_rejected_logp) + ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) + ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( + (ref_chosen_logp, ref_rejected_logp) ) - reference_chosen_logps.append(reference_chosen_logp.cpu()) - reference_rejected_logps.append(reference_rejected_logp.cpu()) + ref_chosen_logps.append(ref_chosen_logp.cpu()) + ref_rejected_logps.append(ref_rejected_logp.cpu()) # Unnecessary cache clearing to avoid OOM torch.cuda.empty_cache() self.accelerator.free_memory() - all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() - all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() + all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() + self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) self.train_dataset = self.train_dataset.add_column( - name="reference_chosen_logps", column=all_reference_chosen_logps - ) - self.train_dataset = self.train_dataset.add_column( - name="reference_rejected_logps", column=all_reference_rejected_logps + name="ref_rejected_logps", column=all_ref_rejected_logps ) self._precomputed_train_ref_log_probs = True @@ -1015,25 +904,23 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa # prepare dataloader data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - reference_chosen_logps = [] - reference_rejected_logps = [] + ref_chosen_logps = [] + ref_rejected_logps = [] for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) - reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( - (reference_chosen_logp, reference_rejected_logp) + ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) + ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( + (ref_chosen_logp, ref_rejected_logp) ) - reference_chosen_logps.append(reference_chosen_logp.cpu()) - reference_rejected_logps.append(reference_rejected_logp.cpu()) + ref_chosen_logps.append(ref_chosen_logp.cpu()) + ref_rejected_logps.append(ref_rejected_logp.cpu()) - all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() - all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() + all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() - eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps) - eval_dataset = eval_dataset.add_column( - name="reference_rejected_logps", column=all_reference_rejected_logps - ) + eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) + eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) - # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs + # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs if self.eval_dataset is not None: self.eval_dataset = eval_dataset self._precomputed_eval_ref_log_probs = True @@ -1052,121 +939,116 @@ def null_ref_context(self): if self.ref_adapter_name: self.model.set_adapter(self.model_adapter_name or "default") - def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: + def compute_ref_log_probs(self, batch: Dict[str, torch.LongTensor]) -> Dict: """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" compte_ref_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() - - # compute reference logps with torch.no_grad(), compte_ref_context_manager: if self.ref_model is None: with self.null_ref_context(): - reference_chosen_logps, reference_rejected_logps = self.concatenated_forward( - self.model, padded_batch - )[:2] + ref_model_output = self.concatenated_forward(self.model, batch) else: - reference_chosen_logps, reference_rejected_logps = self.concatenated_forward( - self.ref_model, padded_batch - )[:2] - - return reference_chosen_logps, reference_rejected_logps + ref_model_output = self.concatenated_forward(self.ref_model, batch) + return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] @staticmethod def concatenated_inputs( - batch: Dict[str, Union[List, torch.LongTensor]], - is_encoder_decoder: bool = False, - is_vision_model: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - device: Optional[torch.device] = None, + batch: Dict[str, Union[List, torch.LongTensor]], padding_value: int ) -> Dict[str, torch.LongTensor]: - """Concatenate the chosen and rejected inputs into a single tensor. + """ + Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt + and completion sequences. Args: - batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). - is_encoder_decoder: Whether the model is an encoder-decoder model. - label_pad_token_id: The label pad token id. - padding_value: The padding value to use for the concatenated inputs_ids. - device: The device for the concatenated inputs. + batch (`Dict[str, Union[List, torch.LongTensor]]`): + A batch of input data. The batch must contain the following keys: + + - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input IDs. + - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen completion input IDs. + - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected completion input IDs. + - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. + - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. + + padding_value (`int`): + The padding value to use for the concatenated completion sequences (`chosen_input_ids` and + `rejected_input_ids`). Returns: - A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + `Dict[str, torch.LongTensor]`: A dictionary containing: + + - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. + - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. + - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, prompt_length)`. + - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. + - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. + - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if `"prompt_pixel_attention_mask"` are present. + + Notes: + The completion input IDs and attention masks are padded to the maximum completion length of the chosen + or rejected sequences. """ - concatenated_batch = {} + output = {} - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) - else: - max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("rejected", "concatenated") - concatenated_batch[concatenated_key] = torch.cat( - ( - concatenated_batch[concatenated_key], - pad_to_length(batch[k], max_length, pad_value=pad_value), - ), - dim=0, - ).to(device=device) - - if is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1).to(device=device) - ) + # For the prompt, the input_ids are the same for both the chosen and rejected responses + output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) + output["prompt_attention_mask"] = torch.cat( + [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 + ) + if "pixel_values" in batch: + output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) - if is_vision_model: - concatenated_batch["pixel_values"] = torch.cat( - [batch["prompt_pixel_values"], batch["prompt_pixel_values"]], dim=0 + if "pixel_attention_mask" in batch: + output["pixel_attention_mask"] = torch.cat( + [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 ) - if "prompt_pixel_attention_mask" in batch: - concatenated_batch["pixel_attention_mask"] = torch.cat( - [batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0 - ) - return concatenated_batch + + # Concatenate the chosen and rejected completions + max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + output["completion_input_ids"] = torch.cat( + ( + pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), + pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), + ), + ) + output["completion_attention_mask"] = torch.cat( + ( + pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), + pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), + ), + ) + + return output def dpo_loss( self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - reference_chosen_logps: torch.FloatTensor, - reference_rejected_logps: torch.FloatTensor, + chosen_logps: torch.FloatTensor, + rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute the DPO loss for a batch of policy and reference model log probabilities. + """ + Compute the DPO loss for a batch of policy and reference model log probabilities. Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) - reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + chosen_logps (`torch.FloatTensor`): + Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. + rejected_logps (`torch.FloatTensor`): + Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. + ref_chosen_logps (`torch.FloatTensor`): + Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. + ref_rejected_logps (`torch.FloatTensor`): + Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO loss for each example in the batch. - The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The `chosen_rewards` and `rejected_rewards` tensors contain the rewards for the chosen and rejected + responses, respectively. """ - chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - ( - not self.reference_free - ) * reference_chosen_logps.to(self.accelerator.device) - rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - ( - not self.reference_free - ) * reference_rejected_logps.to(self.accelerator.device) + device = self.accelerator.device + + # Get the log ratios for the chosen and rejected responses + chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) + rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: # The alpha-divergence formula: (1 - u^-alpha) / alpha @@ -1180,15 +1062,15 @@ def dpo_loss( alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef else: - pi_logratios = policy_chosen_logps - policy_rejected_logps + logratios = chosen_logps - rejected_logps if self.reference_free: - ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device) + ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) else: - ref_logratios = reference_chosen_logps - reference_rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps - pi_logratios = pi_logratios.to(self.accelerator.device) + logratios = logratios.to(self.accelerator.device) ref_logratios = ref_logratios.to(self.accelerator.device) - logits = pi_logratios - ref_logratios + logits = logratios - ref_logratios if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: # The js-divergence formula: log(2 * u / (1 + u)) @@ -1200,18 +1082,20 @@ def dpo_loss( logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. - # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and - # calculates a conservative DPO loss. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the + # labels and calculates a conservative DPO loss. if self.loss_type == "sigmoid": losses = ( -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) + elif self.loss_type == "robust": losses = ( -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + F.logsigmoid(-self.beta * logits) * self.label_smoothing ) / (1 - 2 * self.label_smoothing) + elif self.loss_type == "exo_pair": # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 import math @@ -1221,61 +1105,61 @@ def dpo_loss( losses = (self.beta * logits).sigmoid() * ( F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) + elif self.loss_type == "hinge": losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. losses = (logits - 1 / (2 * self.beta)) ** 2 - elif self.loss_type == "bco_pair": - chosen_logratios = policy_chosen_logps - reference_chosen_logps - rejected_logratios = policy_rejected_logps - reference_rejected_logps + elif self.loss_type == "bco_pair": + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps chosen_rewards = self.beta * chosen_logratios rejected_rewards = self.beta * rejected_logratios rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() self.running.update(rewards) delta = self.running.mean - losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( -(self.beta * rejected_logratios - delta) ) - elif self.loss_type == "sppo_hard": - # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is set to 1 for the winner and 0 for the loser. - a = policy_chosen_logps - reference_chosen_logps - b = policy_rejected_logps - reference_rejected_logps + elif self.loss_type == "sppo_hard": + # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, + # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. + # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is + # set to 1 for the winner and 0 for the loser. + a = chosen_logps - ref_chosen_logps + b = rejected_logps - ref_rejected_logps losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 + elif self.loss_type == "nca_pair": - chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.beta - rejected_rewards = (policy_rejected_logps - reference_rejected_logps) * self.beta + chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta + rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta losses = ( -F.logsigmoid(chosen_rewards) - 0.5 * F.logsigmoid(-chosen_rewards) - 0.5 * F.logsigmoid(-rejected_rewards) ) - elif self.loss_type == "aot_pair": - chosen_logratios = policy_chosen_logps - reference_chosen_logps - rejected_logratios = policy_rejected_logps - reference_rejected_logps + elif self.loss_type == "aot_pair": + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) - delta = chosen_logratios_sorted - rejected_logratios_sorted - losses = ( -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * delta) * self.label_smoothing ) elif self.loss_type == "aot": - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - pi_logratios_sorted, _ = torch.sort(pi_logratios, dim=0) + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logratios_sorted, _ = torch.sort(logratios, dim=0) ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) - - delta = pi_logratios_sorted - ref_logratios_sorted - + delta = logratios_sorted - ref_logratios_sorted losses = ( -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * delta) * self.label_smoothing @@ -1284,190 +1168,148 @@ def dpo_loss( elif self.loss_type == "apo_zero": # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) # Use this loss when you believe the chosen outputs are better than your model's default output - losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood - losses = losses_chosen + losses_rejected elif self.loss_type == "apo_down": # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are worse than your model's default output - - losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood - losses_rejected = 1 - F.sigmoid( - self.beta * (chosen_logratios - rejected_logratios) - ) # Decrease rejected likelihood more - + # Use this loss when you believe the chosen outputs are worse than your model's default output. + # Decrease chosen likelihood and decrease rejected likelihood more + losses_chosen = F.sigmoid(self.beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) losses = losses_chosen + losses_rejected else: raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']" + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " + "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']" ) - chosen_rewards = ( - self.beta - * ( - policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device) - ).detach() - ) - rejected_rewards = ( - self.beta - * ( - policy_rejected_logps.to(self.accelerator.device) - - reference_rejected_logps.to(self.accelerator.device) - ).detach() - ) + chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() + rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() return losses, chosen_rewards, rejected_rewards - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - use_weighting: bool = False, - ) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) - label_pad_token_id: The label pad token id. - is_encoder_decoder: Whether the model is an encoder-decoder model. - use_weighting: Whether to apply weighting as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. - - Returns - A Tuple of three tensors of shape ((batch_size,), (batch_size,), Optional[(batch_size,)]) containing: - - The sum of log probabilities of the given labels under the given logits. - - The number of non-masked tokens. - - The wpo weighting (if use_weighting is True, otherwise None). - """ - if logits.shape[:-1] != labels.shape: - raise ValueError( - f"Logits (batch and sequence length dim) {logits.shape[:-1]} and labels must have the same shape {labels.shape}." - ) - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == label_pad_token_id] = 0 - - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - - all_logps = (per_token_logps * loss_mask).sum(-1) - - all_weights = None - if use_weighting: - # eqn (2) of the WPO paper: https://huggingface.co/papers/2406.11827 - probs = F.softmax(logits, dim=-1) - weights_adjustment_factor = torch.log((probs**2).sum(-1)) - per_token_logps_adjusted = per_token_logps - weights_adjustment_factor - all_weights = ((per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)).detach() - - return all_logps, loss_mask.sum(-1), all_weights - - def concatenated_forward( - self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]): """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ - concatenated_batch = self.concatenated_inputs( - batch, - is_encoder_decoder=self.is_encoder_decoder, - is_vision_model=self.is_vision_model, - label_pad_token_id=self.label_pad_token_id, - padding_value=self.padding_value, - device=self.accelerator.device, - ) - len_chosen = batch["chosen_labels"].shape[0] + num_examples = batch["prompt_input_ids"].shape[0] - model_kwargs = {} - - if self.is_encoder_decoder: - model_kwargs["labels"] = concatenated_batch["concatenated_labels"] - - if self.is_vision_model: - model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] - if "pixel_attention_mask" in concatenated_batch: - model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) + model_kwargs = {} if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]: - # for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens) - seq_len = concatenated_batch["concatenated_labels"].shape[1] - all_logits = all_logits[:, -seq_len:] - - all_logps, size_completion, all_weights = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - # average_log_prob=self.loss_type == "ipo", - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - use_weighting=self.use_weighting, - ) + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + + prompt_input_ids = concatenated_batch["prompt_input_ids"] + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_input_ids = concatenated_batch["completion_input_ids"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + if self.is_encoder_decoder: + labels = completion_input_ids + labels[completion_attention_mask == 0] = self.label_pad_token_id + outputs = model( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + labels=labels, # we need the labels for the logits to be returned + **model_kwargs, + ) + logits = outputs.logits + loss_mask = completion_attention_mask.bool() + else: + # Concatenate the prompt and completion inputs + input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) + attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id) - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - labels = concatenated_batch["concatenated_labels"].clone() - nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) + loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) + + # Get the first column idx that is all zeros and remove every column after that + empty_cols = torch.sum(attention_mask, dim=0) == 0 + first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) + 1 + input_ids = input_ids[:, : first_empty_col - 1] + attention_mask = attention_mask[:, : first_empty_col - 1] + loss_mask = loss_mask[:, : first_empty_col - 1] + + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + loss_mask = loss_mask[:, : self.args.max_length] + + outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs) + + # Offset the logits by one to align with the labels + logits = outputs.logits[:, :-1, :] + labels = input_ids[:, 1:].clone() + loss_mask = loss_mask[:, 1:].bool() + + if logits.shape[:2] != labels.shape[:2]: + # for llava, the returned logits include the image tokens (placed before the text tokens) + seq_len = labels.shape[1] + logits = logits[:, -seq_len:] + + # Compute the log probabilities of the labels + labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps[~loss_mask] = 0 + all_logps = per_token_logps.sum(-1) - if self.loss_type == "ipo": - all_logps = all_logps / size_completion + output = {} - policy_weights = None if self.use_weighting: - chosen_weights = all_weights[:len_chosen] - rejected_weights = all_weights[len_chosen:] - policy_weights = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) + with torch.no_grad(): + # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + logprobs = F.log_softmax(logits, dim=-1) + weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_weights = all_weights[:num_examples] + rejected_weights = all_weights[num_examples:] + output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] + if self.args.rpo_alpha is not None: + # Only use the chosen logits for the RPO loss + chosen_logits = logits[:num_examples] + chosen_labels = labels[:num_examples] + + # Compute the log probabilities of the labels + output["nll_loss"] = F.cross_entropy( + torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 + ) - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + if self.loss_type == "ipo": + all_logps = all_logps / loss_mask.sum(-1) + + output["chosen_logps"] = all_logps[:num_examples] + output["rejected_logps"] = all_logps[num_examples:] + output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean() + output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean() if self.aux_loss_enabled: - return ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - nll_loss, - policy_weights, - outputs.aux_loss, - ) + output["aux_loss"] = outputs.aux_loss - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights) + return output def get_batch_loss_metrics( self, @@ -1478,67 +1320,42 @@ def get_batch_loss_metrics( """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} - forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - policy_weights, - ) = forward_output[:6] - if self.aux_loss_enabled: - aux_loss = forward_output[6] + model_output = self.concatenated_forward(model, batch) - # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model - if ( - "reference_chosen_logps" in batch - and "reference_rejected_logps" in batch - and (self.precompute_ref_log_probs or self.args.rpo_alpha is not None) - ): - reference_chosen_logps = batch["reference_chosen_logps"] - reference_rejected_logps = batch["reference_rejected_logps"] + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model + if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: + ref_chosen_logps = batch["ref_chosen_logps"] + ref_rejected_logps = batch["ref_rejected_logps"] else: - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - reference_chosen_logps, reference_rejected_logps = self.concatenated_forward( - self.model, batch - )[:2] - else: - reference_chosen_logps, reference_rejected_logps = self.concatenated_forward( - self.ref_model, batch - )[:2] + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, + model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps ) reward_accuracies = (chosen_rewards > rejected_rewards).float() if self.args.rpo_alpha is not None: - # RPO loss from V3 of the paper: - losses = losses + policy_nll_loss * self.args.rpo_alpha + losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper if self.use_weighting: - losses = losses * policy_weights + losses = losses * model_output["policy_weights"] + + if self.aux_loss_enabled: + losses = losses + self.aux_loss_coef * model_output["aux_loss"] prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = model_output["chosen_logps"].detach().mean().cpu() + metrics[f"{prefix}logps/rejected"] = model_output["rejected_logps"].detach().mean().cpu() + metrics[f"{prefix}logits/chosen"] = model_output["mean_chosen_logits"].detach().cpu() + metrics[f"{prefix}logits/rejected"] = model_output["mean_rejected_logits"].detach().cpu() if self.args.rpo_alpha is not None: - metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() - + metrics[f"{prefix}nll_loss"] = model_output["nll_loss"].detach().mean().cpu() if self.aux_loss_enabled: - return losses.mean() + self.aux_loss_coef * aux_loss, metrics + metrics[f"{prefix}aux_loss"] = model_output["aux_loss"].detach().cpu() return losses.mean(), metrics @@ -1559,7 +1376,8 @@ def compute_loss( self.store_metrics(metrics, train_eval="train") if return_outputs: - return (loss, metrics) + return loss, metrics + return loss def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: @@ -1578,13 +1396,13 @@ def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) pad_token_id=self.processing_class.pad_token_id, ) - # if reference_output in batch use that otherwise use the reference model - if "reference_output" in batch: - reference_output = batch["reference_output"] + # if ref_output in batch use that otherwise use the reference model + if "ref_output" in batch: + ref_output = batch["ref_output"] else: if self.ref_model is None: with self.null_ref_context(): - reference_output = self.model.generate( + ref_output = self.model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, @@ -1592,7 +1410,7 @@ def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) pad_token_id=self.processing_class.pad_token_id, ) else: - reference_output = self.ref_model.generate( + ref_output = self.ref_model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, @@ -1603,10 +1421,10 @@ def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) - reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) + ref_output = pad_to_length(ref_output, self.max_length, self.processing_class.pad_token_id) + ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) - return policy_output_decoded, reference_output_decoded + return policy_output_decoded, ref_output_decoded def prediction_step( self, @@ -1630,7 +1448,7 @@ def prediction_step( self.store_metrics(metrics, train_eval="eval") if prediction_loss_only: - return (loss.detach(), None, None) + return loss.detach(), None, None # logits for the chosen and rejected samples from model logits_dict = { From 84dab850f6b93f23b6e12e5e288a3a6aeccf0045 Mon Sep 17 00:00:00 2001 From: Cameron Chen Date: Mon, 21 Oct 2024 11:06:19 -0400 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=A7=BD=20Fix=20typo=20in=20dataset=20?= =?UTF-8?q?format=20doc=20(#2259)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit doc update --- docs/source/dataset_formats.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index cc92ec0ff1..d9f75973ef 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -589,7 +589,7 @@ dataset = dataset.remove_columns(["chosen", "rejected"]) ### From explicit to implicit prompt preference dataset -To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt. +To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt. ```python from datasets import Dataset