diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 3c5e1efd86..66fdc6b57b 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -42,6 +42,8 @@
title: ORPO
- local: ppo_trainer
title: PPO
+ - local: prm_trainer
+ title: PRM
- local: reward_trainer
title: Reward
- local: rloo_trainer
diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx
index a6dbf11d04..c8ab321506 100644
--- a/docs/source/dataset_formats.mdx
+++ b/docs/source/dataset_formats.mdx
@@ -266,6 +266,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
+| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
diff --git a/docs/source/prm_trainer.mdx b/docs/source/prm_trainer.mdx
new file mode 100644
index 0000000000..012b8ec071
--- /dev/null
+++ b/docs/source/prm_trainer.mdx
@@ -0,0 +1,123 @@
+# PRM Trainer
+
+
+
+PRM Trainer is an experimental API which is subject to change at any time.
+
+
+
+## Overview
+
+Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
+
+The abstract from the paper is the following:
+
+> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
+
+This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
+
+
+## Quick start
+
+This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
+
+
+
+Below is the script to train the model:
+
+```python
+# train_prm.py
+from datasets import load_dataset
+from trl import PRMConfig, PRMTrainer
+from transformers import AutoModelForTokenClassification, AutoTokenizer
+
+model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
+train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
+
+training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
+trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
+trainer.train()
+```
+
+Execute the script using the following command:
+
+```bash
+accelerate launch train_prm.py
+```
+
+Distributed across 8 GPUs, the training takes approximately 1 hour.
+
+To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
+
+
+```python
+from datasets import load_dataset
+from transformers import pipeline
+
+pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
+dataset = load_dataset("trl-lib/math_shepherd")
+example = {
+ "prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
+ "completions": [
+ "Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
+ "Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
+ "Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
+ ],
+ "labels": [True, False, False],
+}
+
+
+separator = "\n" # It's important to use the same separator as the one used during training
+
+for idx in range(1, len(example["completions"]) + 1):
+ steps = example["completions"][0:idx]
+ text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
+ pred_entity = pipe(text)[-1]["entity"]
+ pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
+ label = example["labels"][idx - 1]
+ print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
+```
+
+```text
+Step 1 Predicted: True Label: True
+Step 2 Predicted: False Label: False
+Step 3 Predicted: False Label: False
+```
+
+It's a win!
+
+## Expected dataset type
+
+PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
+The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
+
+The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.
+
+## Example script
+
+We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)
+
+To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:
+
+```bash
+accelerate launch examples/scripts/prm.py \
+ --model_name_or_path Qwen/Qwen2-0.5B \
+ --dataset_name trl-lib/math_shepherd \
+ --num_train_epochs 1 \
+ --logging_steps 25 \
+ --output_dir Qwen2-0.5B-Reward-Math-Sheperd
+```
+
+## PRMTrainer
+
+[[autodoc]] PRMTrainer
+
+## PRMConfig
+
+[[autodoc]] PRMConfig
diff --git a/examples/datasets/math_shepherd.py b/examples/datasets/math_shepherd.py
new file mode 100644
index 0000000000..c09e745ad5
--- /dev/null
+++ b/examples/datasets/math_shepherd.py
@@ -0,0 +1,131 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from dataclasses import dataclass
+from itertools import chain
+from typing import Optional
+
+from datasets import load_dataset
+from transformers import HfArgumentParser
+
+
+@dataclass
+class ScriptArguments:
+ r"""
+ Arguments for the script.
+
+ Args:
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether to push the dataset to the Hugging Face Hub.
+ repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
+ Hugging Face repository ID to push the dataset to.
+ dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
+ Number of workers to use for dataset processing.
+ """
+
+ push_to_hub: bool = False
+ repo_id: str = "trl-lib/math_shepherd"
+ dataset_num_proc: Optional[int] = None
+
+
+def process_example(example):
+ # Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
+ inputs = example["input"].replace("ки", "ⶻ")
+
+ # Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
+ indexes = [m.start() for m in re.finditer("ⶻ", inputs)]
+
+ # Sanity that all indexes are either "+" or "-"
+ assert all(example["label"][idx] in ["+", "-"] for idx in indexes)
+
+ # Get the labels
+ labels = [example["label"][idx] == "+" for idx in indexes]
+
+ # Split the inputs into steps (caution, the first step is missing here, it is the prompt)
+ steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]
+
+ # Remove the last step (single ⶻ)
+ steps = steps[:-1]
+
+ # Get the prompt (first part) and completions (rest)
+ prompt = steps[0]
+ completions = steps[1:]
+
+ # Remove the heading "ⶻ" and the final whitespace from the completions
+ assert all(completion.startswith("ⶻ") for completion in completions)
+ completions = [completion[1:].strip() for completion in completions]
+
+ # At this point, we need to retrieve the first step from the prompt.
+ # First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
+ if prompt.startswith(
+ (
+ "Mr. Rocky",
+ "Parker",
+ "What is the smallest positive",
+ " The Myth",
+ "Let $\\mathbf{a}$",
+ "Find the arithmetic",
+ "Determine an ordered pair",
+ "Determine the ordered pair",
+ "At the Quill and Scroll stationery",
+ "Round to the nearest",
+ r"Calculate $\sqrt{10p}",
+ r"Simplify $\sqrt{28x}",
+ )
+ ):
+ # Some spotted datasets errors where there is an annotation in the prompt: we remove it
+ labels = labels[1:]
+
+ # Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
+ # (less common) "?".
+ elif "Step 1:" in prompt:
+ prompt, first_step = prompt.split("Step 1:")
+ first_step = "Step 1:" + first_step
+ completions = [first_step.strip()] + completions
+ elif "step 1:" in prompt:
+ prompt, first_step = prompt.split("step 1:")
+ first_step = "step 1:" + first_step
+ completions = [first_step.strip()] + completions
+ elif "?" in prompt:
+ prompt, first_step = prompt.split("?")
+ prompt = prompt + "?"
+ completions = [first_step.strip()] + completions
+ else:
+ raise ValueError(f"Prompt can't be processed: {prompt}")
+
+ # Strip the prompt
+ prompt = prompt.strip()
+
+ # Sanity check that the length of the completions is the same as the length of the labels
+ assert len(completions) == len(labels)
+
+ return {"prompt": prompt, "completions": completions, "labels": labels}
+
+
+if __name__ == "__main__":
+ parser = HfArgumentParser(ScriptArguments)
+ script_args = parser.parse_args_into_dataclasses()[0]
+
+ dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")
+
+ dataset = dataset.map(
+ process_example,
+ remove_columns=["input", "label", "task"],
+ num_proc=script_args.dataset_num_proc,
+ )
+ dataset = dataset.train_test_split(test_size=0.05, seed=42)
+
+ if script_args.push_to_hub:
+ dataset.push_to_hub(script_args.repo_id)
diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py
new file mode 100644
index 0000000000..ba7f9ce415
--- /dev/null
+++ b/examples/scripts/prm.py
@@ -0,0 +1,130 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Full training:
+python examples/scripts/prm.py \
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
+ --dataset_name trl-lib/prm800k \
+ --output_dir Qwen2-0.5B-Reward \
+ --per_device_train_batch_size 8 \
+ --num_train_epochs 1 \
+ --gradient_checkpointing True \
+ --learning_rate 1.0e-5 \
+ --logging_steps 25 \
+ --eval_strategy steps \
+ --eval_steps 50
+
+LoRA:
+python examples/scripts/prm.py \
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
+ --dataset_name trl-lib/prm800k \
+ --output_dir Qwen2-0.5B-Reward-LoRA \
+ --per_device_train_batch_size 8 \
+ --num_train_epochs 1 \
+ --gradient_checkpointing True \
+ --learning_rate 1.0e-4 \
+ --logging_steps 25 \
+ --eval_strategy steps \
+ --eval_steps 50
+ --use_peft \
+ --lora_r 32 \
+ --lora_alpha 16
+"""
+
+import warnings
+
+import torch
+from datasets import load_dataset
+from transformers import AutoModelForTokenClassification, AutoTokenizer, HfArgumentParser
+
+from trl import (
+ ModelConfig,
+ PRMConfig,
+ PRMTrainer,
+ ScriptArguments,
+ get_kbit_device_map,
+ get_peft_config,
+ get_quantization_config,
+)
+
+
+if __name__ == "__main__":
+ parser = HfArgumentParser((ScriptArguments, PRMConfig, ModelConfig))
+ script_args, training_args, model_config = parser.parse_args_into_dataclasses()
+ training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
+
+ ################
+ # Model & Tokenizer
+ ################
+ torch_dtype = (
+ model_config.torch_dtype
+ if model_config.torch_dtype in ["auto", None]
+ else getattr(torch, model_config.torch_dtype)
+ )
+ quantization_config = get_quantization_config(model_config)
+ model_kwargs = dict(
+ revision=model_config.model_revision,
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
+ quantization_config=quantization_config,
+ use_cache=False if training_args.gradient_checkpointing else True,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
+ )
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs
+ )
+ # Align padding tokens between tokenizer and model
+ model.config.pad_token_id = tokenizer.pad_token_id
+
+ if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS":
+ warnings.warn(
+ "You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs"
+ " Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT.",
+ UserWarning,
+ )
+
+ ##############
+ # Load dataset
+ ##############
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
+
+ dataset = dataset.filter(lambda x: len(x["completions"]) > 0)
+
+ ##########
+ # Training
+ ##########
+ trainer = PRMTrainer(
+ model=model,
+ processing_class=tokenizer,
+ args=training_args,
+ train_dataset=dataset[script_args.dataset_train_split],
+ eval_dataset=dataset[script_args.dataset_test_split],
+ peft_config=get_peft_config(model_config),
+ )
+ trainer.train()
+
+ ############################
+ # Save model and push to Hub
+ ############################
+ trainer.save_model(training_args.output_dir)
+ metrics = trainer.evaluate()
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Save and push to hub
+ trainer.save_model(training_args.output_dir)
+ if training_args.push_to_hub:
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
diff --git a/tests/test_judges.py b/tests/test_judges.py
index 4789d3bb3a..0f8b83d881 100644
--- a/tests/test_judges.py
+++ b/tests/test_judges.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import time
import unittest
diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py
new file mode 100644
index 0000000000..4f2c1c21c1
--- /dev/null
+++ b/tests/test_prm_trainer.py
@@ -0,0 +1,329 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+import unittest
+from unittest.mock import MagicMock
+
+import torch
+from datasets import Dataset, load_dataset
+from parameterized import parameterized
+from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase
+from transformers.testing_utils import require_peft
+from transformers.utils import is_peft_available
+
+from trl import PRMConfig, PRMTrainer
+
+
+if is_peft_available():
+ from peft import LoraConfig, TaskType
+
+
+class TestTokenizeRow(unittest.TestCase):
+ def setUp(self):
+ # Set up the mock tokenizer with specific behaviors
+ self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
+ self.tokenizer.bos_token_id = 0
+ self.tokenizer.eos_token_id = 2
+
+ def mock_encode(text, add_special_tokens):
+ token_map = {
+ "Which number is larger, 9.8 or 9.11?": [465, 6766, 318, 298],
+ "11 is greater than 8.": [4, 322, 12],
+ "Hence, 9.11 > 9.8.": [4995, 11, 22],
+ "\n": [1030],
+ "\n\n": [1030, 1030],
+ }
+
+ return token_map[text]
+
+ def mock_tokenizer_call(text, add_special_tokens):
+ return {"input_ids": mock_encode(text, add_special_tokens)}
+
+ self.tokenizer.encode.side_effect = mock_encode
+ self.tokenizer.side_effect = mock_tokenizer_call
+
+ def test_tokenize_row_no_truncation(self):
+ # Define the input features
+ features = {
+ "prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
+ "labels": [True, False],
+ }
+
+ # Call the method with no truncation
+ result = PRMTrainer.tokenize_row(
+ features=features,
+ tokenizer=self.tokenizer,
+ step_separator="\n",
+ max_length=None,
+ max_completion_length=None,
+ train_on_last_step_only=False,
+ is_eval=False,
+ )
+
+ self.assertEqual(
+ result,
+ {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
+ },
+ )
+
+ def test_tokenize_row_train_on_last_step_only(self):
+ # Define the input features
+ features = {
+ "prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
+ "labels": [True, False],
+ }
+
+ result = PRMTrainer.tokenize_row(
+ features=features,
+ tokenizer=self.tokenizer,
+ step_separator="\n",
+ max_length=None,
+ max_completion_length=None,
+ train_on_last_step_only=True,
+ is_eval=False,
+ )
+
+ self.assertEqual(
+ result,
+ {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
+ },
+ )
+
+ def test_tokenize_row_completion_truncation(self):
+ # Define the input features
+ features = {
+ "prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
+ "labels": [True, False],
+ }
+
+ # Call the method with truncation on the completion
+ result = PRMTrainer.tokenize_row(
+ features=features,
+ tokenizer=self.tokenizer,
+ step_separator="\n",
+ max_length=None,
+ max_completion_length=6,
+ train_on_last_step_only=False,
+ is_eval=False,
+ )
+
+ self.assertEqual(
+ result,
+ {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100],
+ },
+ )
+
+ def test_tokenize_row_prompt_completion_truncation(self):
+ # Define the input features
+ features = {
+ "prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
+ "labels": [True, False],
+ }
+
+ # Call the method with truncation on the prompt and completion
+ result = PRMTrainer.tokenize_row(
+ features=features,
+ tokenizer=self.tokenizer,
+ step_separator="\n",
+ max_length=9,
+ max_completion_length=None,
+ train_on_last_step_only=False,
+ is_eval=False,
+ )
+
+ self.assertEqual(
+ result,
+ {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1],
+ },
+ )
+
+ def test_tokenize_row_multi_token_separator(self):
+ # Define the input features
+ features = {
+ "prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
+ "labels": [True, False],
+ }
+
+ # Call the method using multiple tokens as step_separator
+ result = PRMTrainer.tokenize_row(
+ features=features,
+ tokenizer=self.tokenizer,
+ step_separator="\n\n",
+ max_length=None,
+ max_completion_length=None,
+ train_on_last_step_only=False,
+ is_eval=False,
+ )
+
+ self.assertEqual(
+ result,
+ {
+ "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030],
+ "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0],
+ },
+ )
+
+
+class PRMTrainerTester(unittest.TestCase):
+ def setUp(self):
+ model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
+ self.model = AutoModelForTokenClassification.from_pretrained(model_id)
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+ @parameterized.expand([True, False])
+ def test_train_full(self, train_on_last_step_only):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
+ training_args = PRMConfig(
+ output_dir=tmp_dir,
+ report_to="none",
+ train_on_last_step_only=train_on_last_step_only,
+ )
+ trainer = PRMTrainer(
+ model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+ )
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ # check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ # check the params have changed - ignore 0 biases
+ if param.sum() != 0:
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+
+ def test_train_full_pretokenized(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = Dataset.from_dict(
+ {
+ "labels": [
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1],
+ [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1, -100, -100, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1],
+ [-100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 1],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1],
+ [-100, -100, -100, -100, -100, -100, 0],
+ [-100, -100, -100, -100, -100, -100, -100, -100, 1],
+ [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
+ ],
+ "input_ids": [
+ [46518, 374, 2664, 1091, 11, 1077, 752, 1744, 1112, 198, 27261, 13, 198],
+ [98923, 374, 2664, 1091, 11, 315, 3308, 11, 198, 17995, 13, 198, 1576, 31273, 12850, 13, 198],
+ [16374, 374, 2664, 1091, 1112, 1077, 594, 2506, 432, 6770, 11, 198, 6351, 13, 198],
+ [31137, 374, 2664, 1091, 979, 4362, 11, 198, 16965, 13, 198],
+ [31019, 374, 2664, 1091, 304, 3793, 315, 5944, 11, 198, 24034, 13, 198],
+ [98491, 374, 2664, 1091, 1112, 5310, 369, 91494, 13, 198],
+ [4418, 2897, 14579, 5310, 979, 3800, 1349, 432, 13, 198],
+ [20366, 5048, 7629, 944, 3281, 3322, 11, 7241, 1112, 198, 807, 1795, 279, 5601, 13, 198],
+ [15802, 14976, 487, 33327, 1045, 31787, 63443, 11, 198, 52400, 13, 198],
+ [13877, 1265, 2581, 1494, 49394, 11, 198, 7241, 20975, 91681, 13, 198],
+ [641, 279, 3579, 315, 71768, 11, 25066, 279, 61361, 311, 7942, 13, 198],
+ [7039, 374, 2664, 1091, 2937, 13, 198],
+ [26155, 374, 3545, 2664, 1091, 34933, 26537, 13, 198],
+ [2679, 279, 8129, 374, 4135, 311, 10339, 11, 432, 2578, 387, 264, 1661, 2884, 13, 198],
+ ],
+ }
+ )
+
+ training_args = PRMConfig(output_dir=tmp_dir, report_to="none")
+ trainer = PRMTrainer(
+ model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ # check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ # check the params have changed - ignore 0 biases
+ if param.sum() != 0:
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+
+ @require_peft
+ def test_train_lora(self):
+ peft_config = LoraConfig(
+ task_type=TaskType.TOKEN_CLS,
+ inference_mode=False,
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.1,
+ )
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
+ training_args = PRMConfig(output_dir=tmp_dir, max_steps=3, report_to="none")
+ trainer = PRMTrainer(
+ model=self.model,
+ args=training_args,
+ processing_class=self.tokenizer,
+ train_dataset=dummy_dataset,
+ peft_config=peft_config,
+ )
+ previous_trainable_params = {}
+ previous_non_trainable_params = {}
+
+ # due to a change in the way the modules to save are dealt in PEFT.
+ trainable_params_name = ["lora", "modules_to_save"]
+
+ # check gradients are not None
+ for n, param in trainer.model.named_parameters():
+ if any(t in n for t in trainable_params_name):
+ previous_trainable_params[n] = param.clone()
+ else:
+ previous_non_trainable_params[n] = param.clone()
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+
+ # check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+
+ # check the non trainable params have not changed
+ for n, param in previous_non_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+
+ def test_tags(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
+ training_args = PRMConfig(output_dir=tmp_dir, report_to="none")
+ trainer = PRMTrainer(
+ model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+ )
+ self.assertEqual(trainer.model.model_tags, trainer._tag_names)
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index d4466a0404..d977719765 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -17,12 +17,11 @@
import torch
from datasets import Dataset, load_dataset
-from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
-from trl.trainer import compute_accuracy
from trl.trainer.reward_trainer import _tokenize
@@ -37,11 +36,6 @@ def setUp(self):
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.model.config.pad_token_id = self.tokenizer.pad_token_id
- def test_accuracy_metrics(self):
- dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0]))
- accuracy = compute_accuracy(dummy_eval_predictions)
- self.assertEqual(accuracy["accuracy"], 0.5)
-
def test_preprocessing_conversational(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 210eae4306..a1cabcfc19 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -14,13 +14,15 @@
import unittest
+import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
-from trl.trainer.model_config import ModelConfig
+from trl import ModelConfig
+from trl.trainer import compute_accuracy
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
@@ -332,3 +334,70 @@ def test_single_batch_generation(self):
self.assertGreater(max_length_query, context_length)
self.assertEqual(query_responses.shape, (bs, max_length_query))
self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size))
+
+
+class TestComputeAccuracy(unittest.TestCase):
+ def test_token_classification_task(self):
+ eval_pred = (
+ np.array(
+ [
+ [[0.1, 0.9], [0.8, 0.2]], # Batch 1
+ [[0.3, 0.7], [0.6, 0.4]], # Batch 2
+ ]
+ ),
+ np.array([[0, 1], [1, 0]]),
+ )
+ expected_accuracy = 0.5 # 2 matches, 2 mismatches
+ result = compute_accuracy(eval_pred)
+ self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+
+ def test_token_classification_task_with_ignored_tokens_0(self):
+ eval_pred = (
+ np.array(
+ [
+ [[0.1, 0.9], [0.8, 0.2]], # Batch 1
+ [[0.3, 0.7], [0.6, 0.4]], # Batch 2
+ ]
+ ),
+ np.array([[1, 0], [1, -100]]),
+ )
+ expected_accuracy = 1.0 # All non-ignored tokens match
+ result = compute_accuracy(eval_pred)
+ self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+
+ def test_token_classification_task_with_ignored_tokens_1(self):
+ eval_pred = (
+ np.array(
+ [
+ [[0.1, 0.9], [0.8, 0.2]], # Batch 1
+ [[0.3, 0.7], [0.6, 0.4]], # Batch 2
+ ]
+ ),
+ np.array([[1, 1], [0, -100]]),
+ )
+ expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored
+ result = compute_accuracy(eval_pred)
+ self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+
+ def test_rewards_comparison_task(self):
+ eval_pred = (
+ np.array(
+ [
+ [0.9, 0.1], # Batch 1
+ [0.6, 0.4], # Batch 2
+ [0.5, 0.5], # Batch 3 (equal)
+ ]
+ ),
+ np.array([0, 1, 1]),
+ )
+ expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored)
+
+ with self.assertWarns(UserWarning) as cm:
+ result = compute_accuracy(eval_pred)
+
+ self.assertAlmostEqual(result["accuracy"], expected_accuracy)
+ expected_warning = (
+ "There are 1 out of 3 instances where the predictions for both options are equal. "
+ "These instances are ignored in the accuracy computation."
+ )
+ self.assertEqual(str(cm.warning), expected_warning)
diff --git a/trl/__init__.py b/trl/__init__.py
index 8976eb2603..b05f75cd8b 100644
--- a/trl/__init__.py
+++ b/trl/__init__.py
@@ -82,6 +82,8 @@
"PairRMJudge",
"PPOConfig",
"PPOTrainer",
+ "PRMConfig",
+ "PRMTrainer",
"RewardConfig",
"RewardTrainer",
"RLOOConfig",
@@ -172,6 +174,8 @@
PairRMJudge,
PPOConfig,
PPOTrainer,
+ PRMConfig,
+ PRMTrainer,
RewardConfig,
RewardTrainer,
RLOOConfig,
diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py
index e5599756f7..85a2e4d57c 100644
--- a/trl/trainer/__init__.py
+++ b/trl/trainer/__init__.py
@@ -62,6 +62,8 @@
"ppo_trainer": ["PPOTrainer"],
"ppov2_config": ["PPOv2Config"],
"ppov2_trainer": ["PPOv2Trainer"],
+ "prm_config": ["PRMConfig"],
+ "prm_trainer": ["PRMTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"rloo_config": ["RLOOConfig"],
@@ -130,6 +132,8 @@
from .orpo_trainer import ORPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
+ from .prm_config import PRMConfig
+ from .prm_trainer import PRMTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .rloo_config import RLOOConfig
diff --git a/trl/trainer/prm_config.py b/trl/trainer/prm_config.py
new file mode 100644
index 0000000000..4558084572
--- /dev/null
+++ b/trl/trainer/prm_config.py
@@ -0,0 +1,51 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional
+
+from transformers import TrainingArguments
+
+
+@dataclass
+class PRMConfig(TrainingArguments):
+ r"""
+ Configuration class for the [`PRMTrainer`].
+
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ learning_rate (`float`, *optional*, defaults to `1e-5`):
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
+ [`~transformers.TrainingArguments`].
+ max_length (`Optional[int]`, *optional*, defaults to `None`):
+ Maximum length of the sequences (prompt + completion) used for truncation.
+ max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
+ step_separator (`str`, *optional*, defaults to `"\n"`):
+ Separator used to separate each step of the reasoning process.
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
+ Whether to train only on the last step.
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
+ Number of processes to use for processing the dataset.
+ """
+
+ learning_rate: float = 1e-5
+ max_length: Optional[int] = None
+ max_completion_length: Optional[int] = None
+ step_separator: str = "\n"
+ train_on_last_step_only: bool = False
+ dataset_num_proc: Optional[int] = None
diff --git a/trl/trainer/prm_trainer.py b/trl/trainer/prm_trainer.py
new file mode 100644
index 0000000000..dbb3558d57
--- /dev/null
+++ b/trl/trainer/prm_trainer.py
@@ -0,0 +1,330 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import os
+import textwrap
+import warnings
+from itertools import chain
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+from accelerate import PartialState
+from datasets import Dataset, features
+from transformers import (
+ BaseImageProcessor,
+ DataCollator,
+ DataCollatorForTokenClassification,
+ FeatureExtractionMixin,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ ProcessorMixin,
+ Trainer,
+ is_wandb_available,
+)
+from transformers.trainer_callback import TrainerCallback
+from transformers.trainer_utils import EvalPrediction
+from transformers.utils import is_peft_available
+
+from .prm_config import PRMConfig
+from .utils import compute_accuracy, generate_model_card
+
+
+if is_peft_available():
+ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
+
+if is_wandb_available():
+ import wandb
+
+
+class PRMTrainer(Trainer):
+ """
+ Initialize PRMTrainer.
+
+ Args:
+ model (`transformers.PreTrainedModel`):
+ The model to train, preferably an `AutoModelForTokenClassification`.
+ args (`PRMConfig`):
+ The arguments to use for training.
+ data_collator (`transformers.DataCollator`):
+ The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
+ train_dataset (`datasets.Dataset`):
+ The dataset to use for training.
+ eval_dataset (`datasets.Dataset`):
+ The dataset to use for evaluation.
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
+ reuse the fine-tuned model.
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
+ callbacks (`list[transformers.TrainerCallback]`):
+ The callbacks to use for training.
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
+ The optimizer and scheduler to use for training.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
+ The function to use to preprocess the logits before computing the metrics.
+ peft_config (`dict`, defaults to `None`):
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
+ """
+
+ _tag_names = ["trl", "prm"]
+
+ def __init__(
+ self,
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
+ args: Optional[PRMConfig] = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
+ processing_class: Optional[
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
+ ] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
+ None,
+ None,
+ ),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ peft_config: Optional[dict] = None,
+ ):
+ if not is_peft_available() and peft_config is not None:
+ raise ValueError(
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
+ )
+ elif is_peft_available() and peft_config is not None:
+ if not isinstance(model, PeftModel):
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
+ inspect.signature(prepare_model_for_kbit_training).parameters
+ )
+
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
+
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
+ warnings.warn(
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
+ )
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
+
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
+
+ model = get_peft_model(model, peft_config)
+
+ if compute_metrics is None:
+ compute_metrics = compute_accuracy
+
+ if data_collator is None:
+ if processing_class is None:
+ raise ValueError(
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
+ )
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
+
+ if "input_ids" not in train_dataset.column_names:
+ with PartialState().local_main_process_first():
+ fn_kwargs = {
+ "tokenizer": processing_class,
+ "step_separator": args.step_separator,
+ "max_length": args.max_length,
+ "max_completion_length": args.max_completion_length,
+ "train_on_last_step_only": args.train_on_last_step_only,
+ }
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
+ train_dataset = train_dataset.map(
+ self.tokenize_row,
+ fn_kwargs=train_fn_kwargs,
+ num_proc=args.dataset_num_proc,
+ remove_columns=train_dataset.features,
+ desc="Tokenizing train dataset",
+ features=features.Features( # needed to avoid map to cast labels to bool
+ {
+ "labels": features.Sequence(features.Value("int64")),
+ "input_ids": features.Sequence(features.Value("int64")),
+ }
+ ),
+ )
+
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
+ if eval_dataset is not None:
+ eval_dataset = eval_dataset.map(
+ self.tokenize_row,
+ fn_kwargs=eval_fn_kwargs,
+ num_proc=args.dataset_num_proc,
+ remove_columns=eval_dataset.features,
+ desc="Tokenizing eval dataset",
+ features=features.Features( # needed to avoid map to cast labels to bool
+ {
+ "labels": features.Sequence(features.Value("int64")),
+ "input_ids": features.Sequence(features.Value("int64")),
+ }
+ ),
+ )
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ model_init=model_init,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+
+ # Add tags for models that have been loaded with the correct transformers version
+ if hasattr(self.model, "add_model_tags"):
+ self.model.add_model_tags(self._tag_names)
+
+ @staticmethod
+ def tokenize_row(
+ features, tokenizer, step_separator, max_length, max_completion_length, train_on_last_step_only, is_eval
+ ):
+ r"""
+ Tokenize a row of the dataset.
+
+ Args:
+ features (`dict[str, str]`):
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
+ tokenizer (`PreTrainedTokenizerBase`):
+ Tokenizer used to process the data.
+ step_separator (`str`):
+ Separator between steps in the completion.
+ max_length (`int` or `None`):
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
+ max_completion_length (`int` or `None`):
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
+ train_on_last_step_only (`bool`):
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
+ token of the completion.
+ is_eval (`bool`):
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
+
+ Returns:
+ `dict[str, list[int]]`:
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
+
+ Example:
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
+ ... "completions": ["11 is greater than 8.",
+ ... "Hence, 9.11 > 9.8."],
+ ... "labels": [True, False]}
+ >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
+ ```
+ """
+ # Tokenize the prompt and completions
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
+ completions_ids = [
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
+ ]
+ if train_on_last_step_only and not is_eval:
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
+ else:
+ labels = [int(label) for label in features["labels"]]
+
+ # Get the ID of the separator token and add it to the completions
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
+ completions_ids = [completion + separator_ids for completion in completions_ids]
+
+ # Create the label
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
+
+ # Join the completions and labels steps
+ completion_ids = list(chain(*completions_ids))
+ labels = list(chain(*labels))
+
+ if max_completion_length is not None:
+ completion_ids = completion_ids[:max_completion_length]
+ labels = labels[:max_completion_length]
+
+ if tokenizer.bos_token_id is not None:
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
+
+ input_ids = prompt_ids + completion_ids
+ labels = [-100] * len(prompt_ids) + labels
+
+ if max_length is not None:
+ input_ids = input_ids[:max_length]
+ labels = labels[:max_length]
+
+ return {"input_ids": input_ids, "labels": labels}
+
+ def create_model_card(
+ self,
+ model_name: Optional[str] = None,
+ dataset_name: Optional[str] = None,
+ tags: Union[str, list[str], None] = None,
+ ):
+ """
+ Creates a draft of a model card using the information available to the `Trainer`.
+ Args:
+ model_name (`str`, *optional*, defaults to `None`):
+ The name of the model.
+ dataset_name (`str`, *optional*, defaults to `None`):
+ The name of the dataset used for training.
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
+ Tags to be associated with the model card.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
+ base_model = self.model.config._name_or_path
+ else:
+ base_model = None
+
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
+ citation = textwrap.dedent("""\
+ @article{uesato2022solving,
+ title = {Solving Math Word Problems With Process- and Outcome-Based Feedback},
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
+ year = 2022,
+ journal = {arXiv preprint arXiv:2211.14275}
+ }""")
+
+ model_card = generate_model_card(
+ base_model=base_model,
+ model_name=model_name,
+ hub_model_id=self.hub_model_id,
+ dataset_name=dataset_name,
+ tags=tags,
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ trainer_name="PRM",
+ trainer_citation=citation,
+ paper_title="Solving math word problems with process-and outcome-based feedback",
+ )
+
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 699a447bb8..fd30bea929 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -37,6 +37,7 @@
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
+ EvalPrediction,
GenerationConfig,
PreTrainedTokenizerBase,
TrainerState,
@@ -757,18 +758,38 @@ def get_global_statistics(
return global_mean.to(device), global_var.to(device), count.item()
-def compute_accuracy(eval_pred) -> dict[str, float]:
+def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]:
predictions, labels = eval_pred
- # Here, predictions is rewards_chosen and rewards_rejected.
- # We want to see how much of the time rewards_chosen > rewards_rejected.
- equal_predictions_count = np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum()
- if equal_predictions_count > 0:
- warnings.warn(
- f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions for "
- "both options are equal. As a consequence the accuracy can be misleading.",
- UserWarning,
+ if predictions.ndim == 3:
+ # Token classification task. Shapes are (batch_size, seq_len, num_labels) and (batch_size, seq_len)
+ # Used to compute the accuracy in the prm_trainer.
+ predictions = np.argmax(predictions, axis=2)
+
+ # Flatten the predictions and labels to remove the ignored tokens.
+ predictions = np.array(
+ [p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100]
)
- predictions = np.argmax(predictions, axis=1)
+ labels = np.array([lbl for label in labels for lbl in label if lbl != -100])
+
+ else:
+ # Here, predictions is rewards_chosen and rewards_rejected. Shapes are (batch_size, 2) and (batch_size,)
+ # We want to see how much of the time rewards_chosen > rewards_rejected.
+ equal_mask = predictions[:, 0] == predictions[:, 1]
+ equal_predictions_count = int(equal_mask.sum())
+
+ if equal_predictions_count > 0:
+ warnings.warn(
+ f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions "
+ "for both options are equal. These instances are ignored in the accuracy computation.",
+ UserWarning,
+ )
+
+ # Filter out equal predictions
+ predictions = predictions[~equal_mask]
+ labels = labels[~equal_mask]
+
+ # Use the remaining predictions for accuracy calculation
+ predictions = np.argmax(predictions, axis=1)
accuracy = np.array(predictions == labels, dtype=float).mean().item()
return {"accuracy": accuracy}