From 110d0884c75f1cbb3ae8e041bce8436fa7596054 Mon Sep 17 00:00:00 2001 From: Seungjae Jung Date: Sat, 26 Oct 2024 01:15:08 +0900 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=8F=81=20Add=20`bos=5Ftoken=5Fid`=20o?= =?UTF-8?q?nly=20if=20it=20exists=20(#2279)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: sean.jung --- trl/trainer/bco_trainer.py | 25 ++++++++++++++----------- trl/trainer/kto_trainer.py | 25 ++++++++++++++----------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index c6ce2d4902..baf43a42bb 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -223,17 +223,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, ** ) # add BOS, which affects both prompt and the full completion - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] # add EOS, which affects only the full completion if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index fa1541bb44..e6991182a7 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -218,17 +218,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, ** ) # add BOS, which affects both prompt and the full completion - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] # add EOS, which affects only the full completion if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ From ea7a1be92c1262b95a57c55a6602a9251ad4afa6 Mon Sep 17 00:00:00 2001 From: Daniil Tiapkin Date: Fri, 25 Oct 2024 18:16:02 +0200 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=A7=AE=20Fix=20the=20computation=20of?= =?UTF-8?q?=20KL=20divergence=20loss=20(#2277)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/nash_md_trainer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 7702802bd1..68b40db4b9 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -297,20 +297,19 @@ def _compute_losses( ref_logprobs_model_data, probability, ): - # Compute log probs - model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) - ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) - # reinforce score where 0.5 is a control variate - score = (probability - 0.5) * model_logprobs_model_data_sum + score = (probability - 0.5) * model_logprobs_model_data.sum(1) - # kl divergence - kl_div = model_logprobs_model_data_sum - ref_logprobs_model_data_sum + # kl divergence via reinforce + with torch.no_grad(): + log_ratio = model_logprobs_model_data - ref_logprobs_model_data + kl_div_log = log_ratio.sum(1) + kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1) # final loss - loss = self.beta * kl_div - score + loss = self.beta * kl_div_loss - score - return loss.mean(), score, kl_div + return loss.mean(), score, kl_div_log def _log_statistics( self, From e155cb8a6643c6e48eae87190ade39d20cae5a25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:40:51 +0100 Subject: [PATCH 3/5] =?UTF-8?q?=E2=9B=93=EF=B8=8F=E2=80=8D=F0=9F=92=A5=20D?= =?UTF-8?q?on't=20use=20`eval=5Fdataset`=20in=20scripts=20when=20no=20eval?= =?UTF-8?q?=20strategy=20(#2270)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/scripts/bco.py | 2 +- examples/scripts/cpo.py | 2 +- examples/scripts/dpo.py | 2 +- examples/scripts/dpo_online.py | 2 +- examples/scripts/dpo_vlm.py | 2 +- examples/scripts/gkd.py | 2 +- examples/scripts/kto.py | 2 +- examples/scripts/nash_md.py | 2 +- examples/scripts/orpo.py | 2 +- examples/scripts/ppo/ppo_tldr.py | 8 +++++--- examples/scripts/reward_modeling.py | 2 +- examples/scripts/rloo/rloo_tldr.py | 2 +- examples/scripts/sft.py | 2 +- examples/scripts/sft_vlm.py | 2 +- examples/scripts/xpo.py | 2 +- 15 files changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index ac46766c2d..d37a1fb3a9 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -151,7 +151,7 @@ def mean_pooling(model_output, attention_mask): ref_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_args), embedding_func=embedding_func, diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 34e897c557..019e9ef714 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -91,7 +91,7 @@ model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 302e42c59c..7c1114eedd 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -120,7 +120,7 @@ ref_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=peft_config, ) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 5c2c47d56c..33dea12ff2 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -121,7 +121,7 @@ judge=judge, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 08f5687afe..0781edd9e3 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -113,7 +113,7 @@ ref_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor, peft_config=peft_config, ) diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 7c37d811c5..1ceaf5b085 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -121,7 +121,7 @@ teacher_model=training_args.teacher_model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 50dbcd5f36..56205ce57e 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -99,7 +99,7 @@ ref_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_args), ) diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index ec731d0384..c492faf76f 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -129,7 +129,7 @@ judge=judge, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, ) generation_config = GenerationConfig( diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 163655e1e3..ac3f598f88 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -91,7 +91,7 @@ model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 45f813f765..fd20d6693d 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -95,7 +95,7 @@ ################ dataset = load_dataset(script_args.dataset_name) train_dataset = dataset[script_args.dataset_train_split] - eval_dataset = dataset[script_args.dataset_test_split] + eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None def prepare_dataset(dataset, tokenizer): """pre-tokenize the dataset before training; only collate during training""" @@ -118,10 +118,12 @@ def tokenize(element): # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): train_dataset = prepare_dataset(train_dataset, tokenizer) - eval_dataset = prepare_dataset(eval_dataset, tokenizer) + if eval_dataset is not None: + eval_dataset = prepare_dataset(eval_dataset, tokenizer) # filtering train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) - eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token" ################ diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 57027bc10c..40924a3ad3 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -115,7 +115,7 @@ processing_class=tokenizer, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_config), ) trainer.train() diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 2e8312272c..36f759208a 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -94,7 +94,7 @@ ################ dataset = load_dataset(script_args.dataset_name) train_dataset = dataset[script_args.dataset_train_split] - eval_dataset = dataset[script_args.dataset_test_split] + eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None def prepare_dataset(dataset, tokenizer): """pre-tokenize the dataset before training; only collate during training""" diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index b6c789b3d2..422fa89a6f 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -98,7 +98,7 @@ model=model_config.model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 361bcc0e50..49654bc408 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -118,7 +118,7 @@ def collate_fn(examples): args=training_args, data_collator=collate_fn, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor.tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 5f32c161b4..8a07e96b3d 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -114,7 +114,7 @@ judge=judge, args=training_args, train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, ) generation_config = GenerationConfig( From 0ce3b659280fc61bb1c6b18b6c35a2b1d6f30703 Mon Sep 17 00:00:00 2001 From: anch0vy Date: Mon, 28 Oct 2024 19:49:35 +0900 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=94=8C=20Fix=20type=20hint=20in=20`Lo?= =?UTF-8?q?gCompletionsCallback`=20(#2285)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update callbacks.py for fix small python type error * Update callbacks.py --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index eefe83ce96..24e33a7566 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -364,9 +364,9 @@ class LogCompletionsCallback(WandbCallback): column containing the prompts for generating completions. generation_config (`GenerationConfig`, *optional*): The generation config to use for generating completions. - num_prompts (`int`, *optional*): + num_prompts (`int` or `None`, *optional*): The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. - freq (`int`, *optional*): + freq (`int` or `None`, *optional*): The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. """ @@ -374,8 +374,8 @@ def __init__( self, trainer: Trainer, generation_config: Optional[GenerationConfig] = None, - num_prompts: int = None, - freq: int = None, + num_prompts: Optional[int] = None, + freq: Optional[int] = None, ): super().__init__() self.trainer = trainer From b2696578ce6db1749a250661b507bf8b90e14dd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:21:40 +0100 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=8D=AC=20Use=20any=20reward=20model?= =?UTF-8?q?=20for=20online=20methods=20(#2276)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor reward processing in OnlineDPOTrainer * Refactor completion decoding and reward processing * remove strip * remove warning * Add reward_tokenizer to training script * Add reward_tokenizer and reward_processing_class to OnlineDPOTrainer test * propagate to xpo and nash * style * reduce memory requirement with inference_mode * fix tests * pairrm judge llmblender * setUpClass(cls) * Add setUpClass method to TestJudges class * truncation left for reward tokenizer * don't logcompletion without eval dataset * only eval when possible --- docs/source/online_dpo_trainer.md | 9 ++-- examples/scripts/dpo.py | 8 ++-- examples/scripts/dpo_online.py | 21 ++++++--- examples/scripts/gkd.py | 12 ++++-- examples/scripts/nash_md.py | 14 +++--- examples/scripts/reward_modeling.py | 8 ++-- examples/scripts/xpo.py | 14 +++--- tests/test_judges.py | 6 +++ tests/test_online_dpo_trainer.py | 34 ++++++++++----- tests/test_trainers_args.py | 5 ++- tests/test_xpo_trainer.py | 4 +- trl/trainer/nash_md_trainer.py | 1 + trl/trainer/online_dpo_trainer.py | 66 +++++++++++++++++++++-------- trl/trainer/xpo_trainer.py | 1 + 14 files changed, 138 insertions(+), 65 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 823f2e56cc..49e40957c1 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -79,20 +79,17 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht - judge = PairRMJudge() + reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) ++ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward") trainer = OnlineDPOTrainer( ... - judge=judge, + reward_model=reward_model, ++ reward_processing_class=reward_tokenizer, + ... ) ``` - - -Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. - - - ### Encourage EOS token generation When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]: diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 7c1114eedd..ed14461725 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -126,9 +126,11 @@ ) trainer.train() - metrics = trainer.evaluate() - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) # Save and push to hub trainer.save_model(training_args.output_dir) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 33dea12ff2..58d7c4a2c0 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -93,8 +93,15 @@ trust_remote_code=model_config.trust_remote_code, **model_kwargs, ) + reward_tokenizer = AutoTokenizer.from_pretrained( + training_args.reward_model_path, + trust_remote_code=model_config.trust_remote_code, + truncation=True, + truncation_side="left", # since we judge the completion, truncating left is more appropriate + ) else: reward_model = None + reward_tokenizer = None if training_args.judge is not None: judge_cls = JUDGES[training_args.judge] @@ -123,13 +130,17 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, + reward_processing_class=reward_tokenizer, peft_config=get_peft_config(model_config), ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 1ceaf5b085..ed5d07a47e 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -46,7 +46,7 @@ from accelerate import PartialState from datasets import load_dataset -from transformers import AutoTokenizer +from transformers import AutoTokenizer, GenerationConfig from trl import ( GKDConfig, @@ -125,8 +125,14 @@ processing_class=tokenizer, peft_config=get_peft_config(model_config), ) - completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8) - trainer.add_callback(completions_callback) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index c492faf76f..b9492e6042 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -132,12 +132,14 @@ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) - # train the model + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 40924a3ad3..073016bc77 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -124,9 +124,11 @@ # Save model and push to Hub ############################ trainer.save_model(training_args.output_dir) - metrics = trainer.evaluate() - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) # Save and push to hub trainer.save_model(training_args.output_dir) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 8a07e96b3d..1c465d90e6 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -117,12 +117,14 @@ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) - # train the model + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub diff --git a/tests/test_judges.py b/tests/test_judges.py index def75066f1..0b393ff9fd 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -18,6 +18,12 @@ class TestJudges(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Initialize once to download the model. This ensures it’s downloaded before running tests, preventing issues + # where concurrent tests attempt to load the model while it’s still downloading. + PairRMJudge() + def _get_prompts_and_completions(self): prompts = ["The capital of France is", "The biggest planet in the solar system is"] completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 9a4e7680a2..9058462bed 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -20,7 +20,8 @@ from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge, is_llmblender_available +from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llmblender_available +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE if is_peft_available(): @@ -33,6 +34,9 @@ def setUp(self): self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) + self.reward_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") + self.reward_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token @@ -53,9 +57,10 @@ def test_training(self, config_name): model=self.model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, ) trainer.train() @@ -79,9 +84,10 @@ def test_training_with_ref_model(self): ref_model=self.ref_model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, ) trainer.train() @@ -103,9 +109,11 @@ def test_ref_model_is_model(self): OnlineDPOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model + reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, ) @require_peft @@ -126,9 +134,10 @@ def test_training_with_peft(self): model=self.model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, peft_config=lora_config, ) @@ -156,9 +165,10 @@ def test_training_with_peft_and_ref_model(self): ref_model=self.ref_model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, peft_config=lora_config, ) @@ -188,9 +198,10 @@ def test_training_with_peft_model_and_peft_config(self): model=model, reward_model=self.reward_model, args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, peft_config=lora_train_config, ) @@ -200,7 +211,8 @@ def test_training_with_peft_model_and_peft_config(self): self.assertIn("train_loss", trainer.state.log_history[-1]) @unittest.skipIf(not is_llmblender_available(), "llm-blender is not available") - def test_training_with_judge(self): + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training_with_judge(self, config_name): with tempfile.TemporaryDirectory() as tmp_dir: training_args = OnlineDPOConfig( output_dir=tmp_dir, @@ -210,15 +222,15 @@ def test_training_with_judge(self): eval_strategy="steps", report_to="none", ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) trainer = OnlineDPOTrainer( model=self.model, - judge=PairRMJudge(), + judge=RandomPairwiseJudge(), args=training_args, - processing_class=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, ) trainer.train() diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index c7994b5564..0e49d8dd98 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -272,12 +272,13 @@ def test_online_dpo(self, beta_list): reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") trainer = OnlineDPOTrainer( - args=training_args, - processing_class=tokenizer, model=model, ref_model=ref_model, reward_model=reward_model, + args=training_args, train_dataset=dataset, + processing_class=tokenizer, + reward_processing_class=tokenizer, ) self.assertEqual(trainer.args.max_new_tokens, 42) self.assertEqual(trainer.args.temperature, 0.5) diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index e03b855d5a..ad8050ac00 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -20,7 +20,7 @@ from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl import PairRMJudge, XPOConfig, XPOTrainer, is_llmblender_available +from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llmblender_available if is_peft_available(): @@ -171,7 +171,7 @@ def test_xpo_trainer_judge_training(self, config_name): report_to="none", ) dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - judge = PairRMJudge() + judge = RandomPairwiseJudge() trainer = XPOTrainer( model=self.model, diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 68b40db4b9..c998174765 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -122,6 +122,7 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, + reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model peft_config=peft_config, compute_metrics=compute_metrics, callbacks=callbacks, diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 21ce78d8bc..2575c2f886 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -44,7 +44,7 @@ from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging -from ..data_utils import is_conversational, maybe_apply_chat_template +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..models import create_reference_model from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge @@ -137,6 +137,7 @@ def __init__( processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, + reward_processing_class: Optional[PreTrainedTokenizerBase] = None, peft_config: Optional[Dict] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, @@ -161,6 +162,7 @@ def __init__( raise ValueError("Either `reward_model` or `judge` must be provided.") self.reward_model = reward_model + self.reward_processing_class = reward_processing_class self.judge = judge if args.missing_eos_penalty is not None and judge is not None: @@ -428,18 +430,23 @@ def training_step( ref_logprobs = torch.take_along_dim(ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs # free memory - # Get the reward from the reward model or judge: - if self.judge is not None: - completions = self.processing_class.batch_decode( - prompt_completion_ids[:, context_length:], skip_special_tokens=True - ) - completions = [completion.strip() for completion in completions] # remove the leading space + # Decode the completions, and format them if the input is conversational + device = prompt_completion_ids.device + completions_ids = prompt_completion_ids[:, context_length:] + completions = self.processing_class.batch_decode(completions_ids, skip_special_tokens=True) + if is_conversational({"prompt": prompts[0]}): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] + # Get the reward from the reward model or judge + if self.judge is not None: + # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not + # directly understandable by the judge and could alter its judgment. To avoid this and make the judge + # independent of the model's chat template, we use the raw conversation data, and apply our own chat + # template to it. if is_conversational({"prompt": prompts[0]}): - completions = [[{"role": "assistant", "content": completion}] for completion in completions] environment = jinja2.Environment() template = environment.from_string(SIMPLE_CHAT_TEMPLATE) - prompts = [template.render(messages=message) for message in prompts] + prompts = [template.render(messages=prompt) for prompt in prompts] completions = [template.render(messages=completion) for completion in completions] ranks_of_first_completion = self.judge.judge( @@ -449,16 +456,39 @@ def training_step( # convert ranks to a True/False mask: # when rank == 0, it means the first completion is the best # when rank == 1, it means the second completion is the best - mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=prompt_completion_ids.device) + mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) else: - _, scores, _ = get_reward( - self.reward_model, prompt_completion_ids, self.processing_class.pad_token_id, context_length - ) + # The reward model may not have the same chat template or tokenizer as the model, so we need to use the + # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class. + prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1] + if is_conversational({"prompt": prompts[0]}): + examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)] + examples = [apply_chat_template(example, self.reward_processing_class) for example in examples] + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + # Tokenize the prompts + prompts_ids = self.reward_processing_class( + prompts, padding=True, return_tensors="pt", padding_side="left" + )["input_ids"].to(device) + context_length = prompts_ids.shape[1] + + # Tokenize the completions + completions_ids = self.reward_processing_class( + completions, padding=True, return_tensors="pt", padding_side="right" + )["input_ids"].to(device) + + # Concatenate the prompts and completions and get the reward + prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1) + with torch.inference_mode(): + _, scores, _ = get_reward( + self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length + ) - # Filter completion. Ensure that the sample contains stop_token_id - # Completions not passing that filter will receive a lower score. - if self.args.missing_eos_penalty is not None: - scores[~contain_eos_token] -= self.args.missing_eos_penalty + # Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty # Split the scores in 2 (the prompts of the first half are the same as the second half) first_half, second_half = scores.split(num_examples) @@ -466,7 +496,7 @@ def training_step( # Get the indices of the chosen and rejected examples mask = first_half >= second_half - num_examples_range = torch.arange(num_examples, device=prompt_completion_ids.device) + num_examples_range = torch.arange(num_examples, device=device) chosen_indices = num_examples_range + (~mask * num_examples) rejected_indices = num_examples_range + (mask * num_examples) diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index d643dab651..4ec501c7f0 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -121,6 +121,7 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, + reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model peft_config=peft_config, compute_metrics=compute_metrics, callbacks=callbacks,