-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Process-supervised RM Trainer #2127
base: main
Are you sure you want to change the base?
Conversation
This is awesome @gaetanlop ! Would you like some early feedback on the PR or would you prefer I wait a bit until it's more polished? |
Hey @lewtun, thank you for the message. Currently, the only files that are more or less ready are Implementing a PRMs seems to be pretty straighforward, it seems to be a token classification task where only prediction for the last token of each step gets assigned a label and other tokens are ignored during loss calculation. If the dataset isn’t pre-tokenized, I assume it should contain the following columns:
Are you aware of an HF dataset to train PRMs for the example file? Also, how can I add a new subset to the Thanks again for your time! |
PR ready for review. I have changed the naming conventions that I used before Tests: I created a dummy_dataset but we should add a subset to trl-internal-testing/zen as done in other scripts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the very clean PR @gaetanlop - this looks great! I've left some minor suggestions regarding the structure, but aside from that and having a smallish dataset in the right format we can sanity check that the accuracy goes up, loss goes down etc I think this is quite close to being ready
Full training: | ||
python examples/scripts/stepwise_reward_modeling.py \ | ||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||
--dataset_name trl-lib/PLACEHOLDER \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about picking a subset from PRM800k to test everything works?
You could create a subset in the expected format and then we can merge it with trl-lib/zen
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made two pull requests to trl-lib/zen
(https://huggingface.co/datasets/trl-lib/zen/discussions/3) to add the subsets to trl-lib.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking at this @lewtun. Seems like |
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
hi,a good job! when will this be merged? |
@gaetanlop #2148 is merged, let's move on to this one now. Are you still interested in contributing? |
### Stepwise preference | ||
|
||
A stepwise preference dataset is similar to an unpaired preference dataset but instead of having a single `"completion"` and `"label"`, it includes a `"completion"` column that splits the completion into a list of steps and a `"labels"` column indicating whether each step is correct or not. | ||
|
||
```python | ||
steps_preference_example = {"prompt": "The sky is", "completion": [", let me think...", "blue."], "labels": [False, True]} | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
### Stepwise preference | |
A stepwise preference dataset is similar to an unpaired preference dataset but instead of having a single `"completion"` and `"label"`, it includes a `"completion"` column that splits the completion into a list of steps and a `"labels"` column indicating whether each step is correct or not. | |
```python | |
steps_preference_example = {"prompt": "The sky is", "completion": [", let me think...", "blue."], "labels": [False, True]} | |
``` |
Remove in favour of "Stepwise supervision"
@@ -0,0 +1,54 @@ | |||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this file in favour of https://github.com/huggingface/trl/blob/main/examples/datasets/prm800k.py
if type(args) is not StepwiseRewardConfig: | ||
raise ValueError(f"args should be an instance of `StepwiseRewardConfig` but got {type(args)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if type(args) is not StepwiseRewardConfig: | |
raise ValueError(f"args should be an instance of `StepwiseRewardConfig` but got {type(args)}") |
@article{uesato2022solving, | ||
title={Solving math word problems with process-and outcome-based feedback}, | ||
author={Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, | ||
journal={arXiv preprint arXiv:2211.14275}, | ||
year={2022} | ||
}""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@article{uesato2022solving, | |
title={Solving math word problems with process-and outcome-based feedback}, | |
author={Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, | |
journal={arXiv preprint arXiv:2211.14275}, | |
year={2022} | |
}""" | |
@article{uesato2022solving, | |
title = {Solving Math Word Problems With Process- and Outcome-Based Feedback}, | |
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, | |
year = 2022, | |
journal = {arXiv preprint arXiv:2211.14275} | |
}""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating @gaetanlop and apologies for the slow review on this one 🙈 ! Overall this is looking really good and with some minor changes I think it's close to being ready
|
||
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions. | ||
|
||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun) and [Quentin Gallouédec](https://huggingface.co/qgallouedec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to remove me since you did all the work on the implementation side :)
|
||
## Overview | ||
|
||
Process-supervised Reward Models (PRMs) were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit since we don't need the acronym:
Process-supervised Reward Models (PRMs) were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins. | |
Stepwise or process reward models were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins. |
|
||
The [`StepwiseRewardTrainer`] is a wrapper around the [`Trainer`] class. It needs two parameters to be set via the [`StepwiseRewardConfig`] namely: | ||
* `max_length`: controls the maximum length of the sequences where a sequence is composed of the prompt and the concatenation of each completion steps. | ||
* `step_separator`: indicate the separator used to separate each step of the reasoning process. By default, it is set to `"n"`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be on new lines?
* `step_separator`: indicate the separator used to separate each step of the reasoning process. By default, it is set to `"n"`. | |
* `step_separator`: indicates the separator used to separate each step of the reasoning process. By default, it is set to `"\n"`. |
"prompt": [ | ||
"Hi, how are you?", | ||
], | ||
"completion": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"completion": [ | |
"completions": [ |
|
||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B-Instruct", num_labels=2) | ||
|
||
train_dataset = Dataset.from_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about using a math example like the one here? 76dbb1a#diff-9401f539a830b066fdca010e21b44ba7b439404436e3ed18c5dbea9dff582bf5R83-R88
I personally find this a bit easier to follow
|
||
## Expected dataset format | ||
|
||
The dataset should be formatted as a [Name to find](dataset_formats#[Name to find]) which implies that the dataset should contain the following columns: `prompt`, `completion` and `labels` where `completion` contains a list of reasoning steps and `labels` a list of booleans indicating the correctness of each step. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataset should be formatted as a [Name to find](dataset_formats#[Name to find]) which implies that the dataset should contain the following columns: `prompt`, `completion` and `labels` where `completion` contains a list of reasoning steps and `labels` a list of booleans indicating the correctness of each step. | |
The dataset should be formatted as a [Stepwise Supervision](dataset_formats#stepwise-supervision) dataset, which implies that it should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step. |
Full training: | ||
python examples/scripts/stepwise_reward_modeling.py \ | ||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||
--dataset_name trl-lib/openai-prm800k-15k \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--dataset_name trl-lib/openai-prm800k-15k \ | |
--dataset_name trl-lib/prm800k \ |
LoRA: | ||
python examples/scripts/stepwise_reward_modeling.py \ | ||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||
--dataset_name trl-lib/openai-prm800k-15k \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--dataset_name trl-lib/openai-prm800k-15k \ | |
--dataset_name trl-lib/prm800k \ |
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True | ||
) | ||
model = AutoModelForTokenClassification.from_pretrained( | ||
model_config.model_name_or_path, num_labels=3, trust_remote_code=model_config.trust_remote_code, **model_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the new format, shouldn't this be just two labels?
model_config.model_name_or_path, num_labels=3, trust_remote_code=model_config.trust_remote_code, **model_kwargs | |
model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs |
--max_length 2048 | ||
|
||
LoRA: | ||
python examples/scripts/stepwise_reward_modeling.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have some compute, can you share some WandB logs from running these scripts? Otherwise I can run them myself :)
What does this PR do?
Adding support for process-supervised reward training to TRL as requested in #2110 .
List of papers using PRMs: [1], [2], [3], [4]...
Fixes # (issue)
#2110
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
@lewtun @kashif