-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor and add PPO for math reasoning (#25)
* huge refactor to make structure clearer and more extendable * sync * fix * update docs * bump version * update logo * minor * minor * fix images * minor
- Loading branch information
Showing
51 changed files
with
3,014 additions
and
707 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
## LLM alignment as contextual dueling bandits | ||
|
||
LLM alignment is essentially an online learning and decision making problem where the **agent** (e.g., the LLM policy with an optional built-in reward model) interacts with the **environment** (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the *Explore & Exploit* setting or minimizing anytime regret in the *Best Arm Identification* setting. | ||
|
||
In our [paper](https://arxiv.org/abs/2411.01493), we formalize LLM alignment as a **contextual dueling bandit (CDB)** problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling. | ||
|
||
<p align="center"> | ||
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e0da719024bdc16fb4a993a8405e15cb0cf2b53a/interface.png" width=80%/> | ||
</p> | ||
|
||
The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat 🌾 is developed as part of this research initiative. | ||
|
||
Using the CDB framework, existing LLM alignment paradigms can be summarized as follows: | ||
|
||
<p align="center"> | ||
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/acbb25a20dd6c1e7619539b0fa449076ade2f873/compare.png" width=95%/> | ||
</p> | ||
|
||
For more details, please check out our [paper](https://arxiv.org/abs/2411.01493)! | ||
|
||
|
||
## Examples | ||
Below is an example to align a `1-B Pythia` SFT Model on the `tl;dr` dataset using `online SimPO` with `PairRM` as the preference oracle: | ||
|
||
> [!WARNING] | ||
> Aligning with `PairRM` provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the [examples](./examples/README.md#preference-oracles). | ||
```diff | ||
python -m oat.experiment.main \ | ||
--gpus 2 \ | ||
--collocate \ | ||
--dap-algo SimPO \ | ||
--beta 2 \ | ||
--preference-oracle pairrm \ | ||
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \ | ||
--prompt-data lkevinzc/tldr-with-sft-reference \ | ||
--output_key pythia-1b-reference \ | ||
--sync-params-every 1 \ | ||
--rollout-batch-size-per-device 64 \ | ||
--pi-buffer-maxlen-per-device 64 \ | ||
--train-batch-size-per-device 8 \ | ||
--use-wb \ | ||
--wb-run-name 1b_pairrm_simpo_online | ||
``` | ||
This example completes in **less than two hours on two A100-40G GPUs**! | ||
|
||
To run an `offline SimPO` baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the `sync-params-every` argument: | ||
```diff | ||
python -m oat.experiment.main \ | ||
--gpus 2 \ | ||
--collocate \ | ||
--dap-algo SimPO \ | ||
--beta 2 \ | ||
--preference-oracle pairrm \ | ||
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \ | ||
--prompt-data lkevinzc/tldr-with-sft-reference \ | ||
--output_key pythia-1b-reference \ | ||
- --sync-params-every 1 \ | ||
+ --sync-params-every 9999 \ # any number > total gradient step (50000//128=390) | ||
--rollout-batch-size-per-device 64 \ | ||
--pi-buffer-maxlen-per-device 64 \ | ||
--train-batch-size-per-device 8 \ | ||
--use-wb \ | ||
- --wb-run-name 1b_pairrm_simpo_online | ||
+ --wb-run-name 1b_pairrm_simpo_offline | ||
``` | ||
|
||
Finally, we run `SEA SimPO` (with $\gamma=1$, see [here](https://arxiv.org/pdf/2411.01493#page=7.60) for the meaning of $\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling. | ||
```diff | ||
python -m oat.experiment.main \ | ||
- --gpus 2 \ | ||
+ --gpus 4 \ | ||
--dap-algo SimPO \ | ||
--beta 2 \ | ||
--preference-oracle pairrm \ | ||
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \ | ||
--prompt-data lkevinzc/tldr-with-sft-reference \ | ||
--output_key pythia-1b-reference \ | ||
--sync-params-every 1 \ | ||
- --rollout-batch-size-per-device 64 \ | ||
- --pi-buffer-maxlen-per-device 64 \ | ||
- --train-batch-size-per-device 8 \ | ||
+ --rollout-batch-size-per-device 32 \ | ||
+ --pi-buffer-maxlen-per-device 32 \ | ||
+ --train-batch-size-per-device 1 \ | ||
+ --learn-rm \ | ||
+ --exp-method EnnBAITS \ | ||
+ --num_samples 10 \ | ||
--use-wb \ | ||
- --wb-run-name 1b_pairrm_simpo_online | ||
+ --wb-run-name 1b_pairrm_simpo_sea | ||
``` | ||
|
||
<p align="center"> | ||
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/example_result.png" width=55%/> | ||
</p> | ||
|
||
Check out this [tutorial](./preference_learning.md) for more examples covering: | ||
* Various direct optimizers, including DPO, IPO, and SLiC. | ||
* Different modes of preference oracles, such as remote reward models and GPT-as-a-judge. | ||
* Additional LLM exploration algorithms, e.g., APL, XPO, and EE4LLM. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
We can easily extend oat 🌾 by running RL with rule-based rewards (result verification) to improve language model's reasoning capability. Below we show an example to run PPO on GSM8K, which improves the test score significantly **from 40.6% to 55.7% with pure RL**! | ||
|
||
``` | ||
python -m oat.experiment.run_ppo \ | ||
--gpus 8 \ | ||
--collocate \ | ||
--gradient-checkpointing \ | ||
--flash-attn \ | ||
--bf16 \ | ||
--rnd-seed \ | ||
--learning_rate 0.000001 \ | ||
--critic_learning_rate 0.000001 \ | ||
--lr_scheduler polynomial \ | ||
--kl_penalty_coef 0 \ | ||
--num_ppo_epochs 2 \ | ||
--beta 0.0001 \ | ||
--non_stop_penalty 0 \ | ||
--oracle_type reward \ | ||
--oracle gsm8k \ | ||
--pretrain lkevinzc/rho-1b-sft-GSM8K \ | ||
--apply-chat-template \ | ||
--zero-stage 2 \ | ||
--prompt_data lkevinzc/gsm8k \ | ||
--max-train 9999999 \ | ||
--num_prompt_epoch 6 \ | ||
--prompt_max_length 1000 \ | ||
--sync_params_every 1 \ | ||
--num_samples 8 \ | ||
--max_step_adjustment 16 \ | ||
--critic_max_step_adjustment 16 \ | ||
--temperature 1.2 \ | ||
--top_p 0.9 \ | ||
--generate_max_length 1024 \ | ||
--save_steps -1 \ | ||
--input_key question \ | ||
--output_key final_answer \ | ||
--train_split train \ | ||
--train_batch_size 64 \ | ||
--train_batch_size_per_device 8 \ | ||
--mini_train_batch_size_per_device 8 \ | ||
--rollout_batch_size 64 \ | ||
--rollout_batch_size_per_device 8 \ | ||
--pi_buffer_maxlen_per_device 64 \ | ||
--max_eval 1500 \ | ||
--eval_batch_size 200 \ | ||
--eval_generate_max_length 1024 \ | ||
--eval_steps 50 \ | ||
--eval_temperature 0.35 \ | ||
--eval_top_p 0.9 \ | ||
--eval_n 16 \ | ||
--eval_input_key question \ | ||
--eval_output_key final_answer \ | ||
--use-wb \ | ||
--wb-run-name rho-1b-ppo-gsm8k | ||
``` | ||
The experiment finishes in about 6 hours with 8 A100 GPUs. | ||
|
||
With this, we ablated how *Boltzmann exploration* (controlled by the sampling temperature) affects the learning efficiency: | ||
|
||
<p align="center"> | ||
<img src="./ppo_temperature.png" width=55%/> | ||
</p> | ||
|
||
We look forward future studies on other efficient exploration strategies to enhance LLM reasoning. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2024 Garena Online Private Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from oat.actors.preference import PreferenceActor | ||
from oat.actors.reward import RewardActor | ||
|
||
__all__ = ["PreferenceActor", "RewardActor"] |
Oops, something went wrong.