-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Process-supervised RM Trainer #2127
base: main
Are you sure you want to change the base?
Changes from all commits
357a8c6
841f7a1
641e899
106bc0e
0163dcc
c2720d7
5034083
8818b6a
b777d1c
afa9e0a
2dd752d
613d838
b96ef4d
161f5de
93e6652
3ec4ebe
1461a61
8c4ac31
8b3fa52
8e4e159
c60bc40
614fb4e
c582464
424af34
b00e32b
d5f780a
f02056a
3ac323f
436dfd7
f4e6d4e
6947aef
e0c0648
35de0ee
c3eb08e
898f621
3a488e0
a03aed8
e77eee2
6c62c69
e8e93f1
2059c51
701241b
6bb467b
6b2bd97
2030a83
66baada
b47eea5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -250,25 +250,34 @@ stepwise_example = { | |||||
} | ||||||
``` | ||||||
|
||||||
### Stepwise preference | ||||||
|
||||||
A stepwise preference dataset is similar to an unpaired preference dataset but instead of having a single `"completion"` and `"label"`, it includes a `"completion"` column that splits the completion into a list of steps and a `"labels"` column indicating whether each step is correct or not. | ||||||
|
||||||
```python | ||||||
steps_preference_example = {"prompt": "The sky is", "completion": [", let me think...", "blue."], "labels": [False, True]} | ||||||
``` | ||||||
|
||||||
## Which dataset type to use? | ||||||
|
||||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer. | ||||||
|
||||||
| Trainer | Expected dataset type | | ||||||
| ----------------------- | ------------------------------------------------------------------------------------------------------ | | ||||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | | ||||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | | ||||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | | ||||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | | ||||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | | ||||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`PPOTrainer`] | Tokenized language modeling | | ||||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | ||||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) | | ||||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) | | ||||||
| Trainer | Expected dataset type | | ||||||
| ------------------------- | ------------------------------------------------------------------------------------------------------ | | ||||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | | ||||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | | ||||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | | ||||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | | ||||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | | ||||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | | ||||||
| [`PPOTrainer`] | Tokenized language modeling | | ||||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | ||||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) | | ||||||
| [`StepwiseRewardTrainer`] | [Stepwise preference](#stepwise-preference) | | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) | | ||||||
|
||||||
<Tip> | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,85 @@ | ||||||
# Stepwise Reward Modeling | ||||||
gaetanlop marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
## Overview | ||||||
|
||||||
Process-supervised Reward Models (PRMs) were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit since we don't need the acronym:
Suggested change
|
||||||
|
||||||
The abstract from the paper is the following: | ||||||
|
||||||
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions. | ||||||
|
||||||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun) and [Quentin Gallouédec](https://huggingface.co/qgallouedec) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to remove me since you did all the work on the implementation side :) |
||||||
|
||||||
## Usage tips | ||||||
|
||||||
The [`StepwiseRewardTrainer`] is a wrapper around the [`Trainer`] class. It needs two parameters to be set via the [`StepwiseRewardConfig`] namely: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. turbo nit:
Suggested change
|
||||||
* `max_length`: controls the maximum length of the sequences where a sequence is composed of the prompt and the concatenation of each completion steps. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||
* `step_separator`: indicate the separator used to separate each step of the reasoning process. By default, it is set to `"n"`. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be on new lines?
Suggested change
|
||||||
|
||||||
The basic API is as follows: | ||||||
|
||||||
```python | ||||||
from datasets import Dataset | ||||||
from transformers import AutoModelForTokenClassification, AutoTokenizer | ||||||
from trl import StepwiseRewardTrainer, StepwiseRewardConfig | ||||||
|
||||||
|
||||||
NUM_DUMMY_SAMPLES = 100 | ||||||
|
||||||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B-Instruct", num_labels=2) | ||||||
|
||||||
train_dataset = Dataset.from_dict( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT about using a math example like the one here? 76dbb1a#diff-9401f539a830b066fdca010e21b44ba7b439404436e3ed18c5dbea9dff582bf5R83-R88 I personally find this a bit easier to follow |
||||||
{ | ||||||
"prompt": [ | ||||||
"Hi, how are you?", | ||||||
], | ||||||
"completion": [ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
[", let me think...", "I'm great thanks."] | ||||||
] | ||||||
"labels": [ | ||||||
[False, True] | ||||||
] | ||||||
* NUM_DUMMY_SAMPLES | ||||||
} | ||||||
) | ||||||
eval_dataset = Dataset.from_dict( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here about using a simple math example |
||||||
{ | ||||||
"prompt": [ | ||||||
"What colour is the sky?", | ||||||
], | ||||||
"completion": [ | ||||||
["Hmm,", "great question...", "The sky is blue."] | ||||||
] | ||||||
"labels": [ | ||||||
[False, False, True] | ||||||
] | ||||||
* NUM_DUMMY_SAMPLES | ||||||
} | ||||||
) | ||||||
|
||||||
config = StepwiseRewardConfig(output_dir="stepwise-reward-model", per_device_train_batch_size=1, max_length=512, step_separator="\n") | ||||||
trainer = StepwiseRewardTrainer( | ||||||
model=model, | ||||||
args=training_args, | ||||||
tokenizer=tokenizer, | ||||||
train_dataset=dataset, | ||||||
) | ||||||
|
||||||
trainer.train() | ||||||
``` | ||||||
|
||||||
## Expected dataset format | ||||||
|
||||||
The dataset should be formatted as a [Name to find](dataset_formats#[Name to find]) which implies that the dataset should contain the following columns: `prompt`, `completion` and `labels` where `completion` contains a list of reasoning steps and `labels` a list of booleans indicating the correctness of each step. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
The [`StepwiseRewardTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. | ||||||
|
||||||
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids`, `attention_mask` and `labels`. | ||||||
|
||||||
## StepwiseRewardTrainer | ||||||
|
||||||
[[autodoc]] StepwiseRewardTrainer | ||||||
|
||||||
## StepwiseRewardConfig | ||||||
|
||||||
[[autodoc]] StepwiseRewardConfig |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can remove this file in favour of https://github.com/huggingface/trl/blob/main/examples/datasets/prm800k.py |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
|
||
from datasets import load_dataset | ||
from transformers import HfArgumentParser | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
r""" | ||
Arguments for the script. | ||
|
||
Args: | ||
push_to_hub (`bool`, *optional*, defaults to `False`): | ||
Whether to push the dataset to the Hugging Face Hub. | ||
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`): | ||
Hugging Face repository ID to push the dataset to. | ||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): | ||
Number of workers to use for dataset processing. | ||
""" | ||
|
||
push_to_hub: bool = False | ||
repo_id: str = "trl-lib/openai-prm800k-15k" | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
|
||
dataset = load_dataset("gaetanlop/openai-prm800k-15k-stage2-conversational") | ||
|
||
def reformat_labels(examples): | ||
# openai-prm800k labels: -1, 0, or +1 where -1 means incorrect, 0 means that it isn't incorrect but doesn't make any progress, and +1 means correct | ||
examples["labels"] = [[label + 1 for label in labels] for labels in examples["labels"]] | ||
|
||
return examples | ||
|
||
dataset = dataset.map(reformat_labels, batched=True) | ||
|
||
if script_args.push_to_hub: | ||
dataset.push_to_hub(script_args.repo_id) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,133 @@ | ||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
""" | ||||||
Full training: | ||||||
python examples/scripts/stepwise_reward_modeling.py \ | ||||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||||||
--dataset_name trl-lib/openai-prm800k-15k \ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
--output_dir Qwen2-0.5B-Reward \ | ||||||
--per_device_train_batch_size 8 \ | ||||||
--num_train_epochs 1 \ | ||||||
--gradient_checkpointing True \ | ||||||
--learning_rate 1.0e-5 \ | ||||||
--logging_steps 25 \ | ||||||
--eval_strategy steps \ | ||||||
--eval_steps 50 \ | ||||||
--max_length 2048 | ||||||
|
||||||
LoRA: | ||||||
python examples/scripts/stepwise_reward_modeling.py \ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have some compute, can you share some WandB logs from running these scripts? Otherwise I can run them myself :) |
||||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||||||
--dataset_name trl-lib/openai-prm800k-15k \ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
--output_dir Qwen2-0.5B-Reward-LoRA \ | ||||||
--per_device_train_batch_size 8 \ | ||||||
--num_train_epochs 1 \ | ||||||
--gradient_checkpointing True \ | ||||||
--learning_rate 1.0e-4 \ | ||||||
--logging_steps 25 \ | ||||||
--eval_strategy steps \ | ||||||
--eval_steps 50 \ | ||||||
--max_length 2048 \ | ||||||
--use_peft \ | ||||||
--lora_r 32 \ | ||||||
--lora_alpha 16 | ||||||
""" | ||||||
|
||||||
import warnings | ||||||
|
||||||
import torch | ||||||
from datasets import load_dataset | ||||||
from transformers import AutoModelForTokenClassification, AutoTokenizer, HfArgumentParser | ||||||
|
||||||
from trl import ( | ||||||
ModelConfig, | ||||||
StepwiseRewardConfig, | ||||||
StepwiseRewardTrainer, | ||||||
get_kbit_device_map, | ||||||
get_peft_config, | ||||||
get_quantization_config, | ||||||
setup_chat_format, | ||||||
) | ||||||
from trl.commands.cli_utils import RewardScriptArguments | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
parser = HfArgumentParser((RewardScriptArguments, StepwiseRewardConfig, ModelConfig)) | ||||||
script_args, training_args, model_config = parser.parse_args_into_dataclasses() | ||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) | ||||||
|
||||||
################ | ||||||
# Model & Tokenizer | ||||||
################ | ||||||
torch_dtype = ( | ||||||
model_config.torch_dtype | ||||||
if model_config.torch_dtype in ["auto", None] | ||||||
else getattr(torch, model_config.torch_dtype) | ||||||
) | ||||||
quantization_config = get_quantization_config(model_config) | ||||||
model_kwargs = dict( | ||||||
revision=model_config.model_revision, | ||||||
device_map=get_kbit_device_map() if quantization_config is not None else None, | ||||||
quantization_config=quantization_config, | ||||||
use_cache=False if training_args.gradient_checkpointing else True, | ||||||
) | ||||||
tokenizer = AutoTokenizer.from_pretrained( | ||||||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True | ||||||
) | ||||||
model = AutoModelForTokenClassification.from_pretrained( | ||||||
model_config.model_name_or_path, num_labels=3, trust_remote_code=model_config.trust_remote_code, **model_kwargs | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the new format, shouldn't this be just two labels?
Suggested change
|
||||||
) | ||||||
# Align padding tokens between tokenizer and model | ||||||
model.config.pad_token_id = tokenizer.pad_token_id | ||||||
|
||||||
# If post-training a base model, use ChatML as the default template | ||||||
if tokenizer.chat_template is None: | ||||||
model, tokenizer = setup_chat_format(model, tokenizer) | ||||||
gaetanlop marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS": | ||||||
warnings.warn( | ||||||
"You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs" | ||||||
" Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT." | ||||||
) | ||||||
|
||||||
############## | ||||||
# Load dataset | ||||||
############## | ||||||
dataset = load_dataset(script_args.dataset_name) | ||||||
|
||||||
########## | ||||||
# Training | ||||||
########## | ||||||
trainer = StepwiseRewardTrainer( | ||||||
model=model, | ||||||
processing_class=tokenizer, | ||||||
args=training_args, | ||||||
train_dataset=dataset[script_args.dataset_train_split], | ||||||
eval_dataset=dataset[script_args.dataset_test_split], | ||||||
peft_config=get_peft_config(model_config), | ||||||
) | ||||||
trainer.train() | ||||||
|
||||||
############################ | ||||||
# Save model and push to Hub | ||||||
############################ | ||||||
trainer.save_model(training_args.output_dir) | ||||||
metrics = trainer.evaluate() | ||||||
trainer.log_metrics("eval", metrics) | ||||||
trainer.save_metrics("eval", metrics) | ||||||
|
||||||
# Save and push to hub | ||||||
trainer.save_model(training_args.output_dir) | ||||||
if training_args.push_to_hub: | ||||||
trainer.push_to_hub(dataset_name=script_args.dataset_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove in favour of "Stepwise supervision"