Skip to content

Commit

Permalink
Refactor and add PPO for math reasoning (#25)
Browse files Browse the repository at this point in the history
* huge refactor to make structure clearer and more extendable

* sync

* fix

* update docs

* bump version

* update logo

* minor

* minor

* fix images

* minor
  • Loading branch information
lkevinzc authored Jan 26, 2025
1 parent 37becae commit b966394
Show file tree
Hide file tree
Showing 51 changed files with 3,014 additions and 707 deletions.
131 changes: 25 additions & 106 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/logo.png" width=90% alt="OAT" />
<img src="./docs/new_logo.png" width=90% alt="OAT" />
</p>

[![PyPI - Version](https://img.shields.io/pypi/v/oat-llm.svg)](https://pypi.org/project/oat-llm)
Expand All @@ -11,41 +11,28 @@

---

## Updates
* 26/01/2025: We support reinforcement learning with verifiable rewards (RLVR) for math reasoning.

## Introduction

Oat 🌾 is a simple yet efficient system for running online LLM alignment algorithms. Its key features include:
Oat 🌾 is a simple yet efficient framework for running **online** LLM alignment algorithms. Its key features include:

* **High Efficiency**: Oat implements a distributed *Actor-Learner-Oracle* architecture, with each component being optimized using state-of-the-art tools:
* `Actor`: Utilizes [vLLM](https://github.com/vllm-project/vllm) for accelerated online response sampling.
* `Learner`: Leverages [DeepSpeed](https://github.com/microsoft/DeepSpeed) ZeRO strategies to enhance memory efficiency.
* `Oracle`: Hosted by [Mosec](https://github.com/mosecorg/mosec) as a remote service, supporting dynamic batching, data parallelism and pipeline parallelism.
* `Oracle`: Model-based oracle by [Mosec](https://github.com/mosecorg/mosec) as a remote service, supporting dynamic batching, data parallelism and pipeline parallelism.
* **Simplified Workflow**: Oat simplifies the experimental pipeline of LLM alignment. With an `Oracle` served online, we can flexibly query it for preference data labeling as well as anytime model evaluation. All you need is to launch experiments and monitor real-time learning curves (e.g., win rate) on wandb (see [reproduced results](https://wandb.ai/lkevinzc/oat-llm)) — no need for manual training, checkpointing and loading for evaluation.
* **Oracle Simulation**: Oat provides simulated preference oracles in various modes.
* **Oracle Simulation**: Oat provides a diverse set of oracles to simulate preference/reward/verification feedback.
* Verifiable rewards supported using rule-based functions.
* Lightweight reward models run within the actor's process, enabling quick testing on as few as two GPUs.
* Larger and more capable reward models can be served remotely, harnessing additional compute and memory resources.
* LLM-as-a-judge is supported via querying OpenAI API for model-based pairwise ranking.
* **Ease of Use**: Oat's modular structure allows researchers to easily inherit and modify existing classes, enabling rapid prototyping and experimentation with new algorithms.
* **Cutting-Edge Algorithms**: Oat implements state-of-the-art LLM exploration (active alignment) algorithms, including [SEA](https://arxiv.org/abs/2411.01493), APL and XPO, along with popular direct optimizers such as DPO and SimPO, fostering innovation and fair benchmarking.

## LLM alignment as contextual dueling bandits

LLM alignment is essentially an online learning and decision making problem where the **agent** (e.g., the LLM policy with an optional built-in reward model) interacts with the **environment** (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the *Explore & Exploit* setting or minimizing anytime regret in the *Best Arm Identification* setting.

In our [paper](https://arxiv.org/abs/2411.01493), we formalize LLM alignment as a **contextual dueling bandit (CDB)** problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.

<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e0da719024bdc16fb4a993a8405e15cb0cf2b53a/interface.png" width=80%/>
</p>

The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat 🌾 is developed as part of this research initiative.

Using the CDB framework, existing LLM alignment paradigms can be summarized as follows:

<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/acbb25a20dd6c1e7619539b0fa449076ade2f873/compare.png" width=95%/>
</p>

For more details, please check out our [paper](https://arxiv.org/abs/2411.01493)!
* **Cutting-Edge Algorithms**: Oat implements state-of-the-art online algorithms, fostering innovation and fair benchmarking.
* PPO (online RL) for math reasoning.
* Online DPO/SimPO/IPO for online preference learning.
* Online exploration (active alignment) algorithms, including [SEA](https://arxiv.org/abs/2411.01493), APL and XPO.

## Installation
In a python environment with supported versions (`>=3.8, <=3.10`), you could install oat via PyPI:
Expand All @@ -59,86 +46,9 @@ cd oat
pip install vllm==0.6.2 && pip install -e .
```

## Usage
Below is an example to align a `1-B Pythia` SFT Model on the `tl;dr` dataset using `online SimPO` with `PairRM` as the preference oracle:

> [!WARNING]
> Aligning with `PairRM` provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the [examples](./examples/README.md#preference-oracles).
```diff
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
--wb-run-name 1b_pairrm_simpo_online
```
This example completes in **less than two hours on two A100-40G GPUs**!

To run an `offline SimPO` baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the `sync-params-every` argument:
```diff
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
- --sync-params-every 1 \
+ --sync-params-every 9999 \ # any number > total gradient step (50000//128=390)
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_offline
```

Finally, we run `SEA SimPO` (with $\gamma=1$, see [here](https://arxiv.org/pdf/2411.01493#page=7.60) for the meaning of $\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling.
```diff
python -m oat.experiment.main \
- --gpus 2 \
+ --gpus 4 \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
- --rollout-batch-size-per-device 64 \
- --pi-buffer-maxlen-per-device 64 \
- --train-batch-size-per-device 8 \
+ --rollout-batch-size-per-device 32 \
+ --pi-buffer-maxlen-per-device 32 \
+ --train-batch-size-per-device 1 \
+ --learn-rm \
+ --exp-method EnnBAITS \
+ --num_samples 10 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_sea
```

<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/example_result.png" width=55%/>
</p>

Check out this [tutorial](./examples/) for more examples covering:
* Various direct optimizers, including DPO, IPO, and SLiC.
* Different modes of preference oracles, such as remote reward models and GPT-as-a-judge.
* Additional LLM exploration algorithms, e.g., APL, XPO, and EE4LLM.
## Usage
* [Improving math reasoning with PPO](./docs/reasoning_examples.md).
* [Online preference learning with active exploration](./docs/alignment_as_cdb.md).

## Benchmarking
The benchmarking compares oat with the online DPO implementation from [huggingface/trl](https://huggingface.co/docs/trl/main/en/online_dpo_trainer). Below, we outline the configurations used for oat and present the benchmarking results. Notably, oat 🌾 achieves up to **2.5x** computational efficiency compared to trl 🤗.
Expand All @@ -154,7 +64,16 @@ The benchmarking compares oat with the online DPO implementation from [huggingfa
Please refer to [Appendix C of our paper](https://arxiv.org/pdf/2411.01493#page=17.64) for a detailed discussion of the benchmarking methods and results.

## Citation
If you find this work useful for your research, please consider citing
If you find this codebase useful for your research, please consider citing
```
@misc{liu2025oat,
author = {Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},
title = {OAT: A research-friendly framework for LLM online alignment},
howpublished = {[https://github.com/sail-sg/oat](https://github.com/sail-sg/oat)},
year = {2025}
}
```

```
@article{
liu2024sea,
Expand Down
101 changes: 101 additions & 0 deletions docs/alignment_as_cdb.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
## LLM alignment as contextual dueling bandits

LLM alignment is essentially an online learning and decision making problem where the **agent** (e.g., the LLM policy with an optional built-in reward model) interacts with the **environment** (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the *Explore & Exploit* setting or minimizing anytime regret in the *Best Arm Identification* setting.

In our [paper](https://arxiv.org/abs/2411.01493), we formalize LLM alignment as a **contextual dueling bandit (CDB)** problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.

<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e0da719024bdc16fb4a993a8405e15cb0cf2b53a/interface.png" width=80%/>
</p>

The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat 🌾 is developed as part of this research initiative.

Using the CDB framework, existing LLM alignment paradigms can be summarized as follows:

<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/acbb25a20dd6c1e7619539b0fa449076ade2f873/compare.png" width=95%/>
</p>

For more details, please check out our [paper](https://arxiv.org/abs/2411.01493)!


## Examples
Below is an example to align a `1-B Pythia` SFT Model on the `tl;dr` dataset using `online SimPO` with `PairRM` as the preference oracle:

> [!WARNING]
> Aligning with `PairRM` provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the [examples](./examples/README.md#preference-oracles).
```diff
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
--wb-run-name 1b_pairrm_simpo_online
```
This example completes in **less than two hours on two A100-40G GPUs**!

To run an `offline SimPO` baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the `sync-params-every` argument:
```diff
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
- --sync-params-every 1 \
+ --sync-params-every 9999 \ # any number > total gradient step (50000//128=390)
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_offline
```

Finally, we run `SEA SimPO` (with $\gamma=1$, see [here](https://arxiv.org/pdf/2411.01493#page=7.60) for the meaning of $\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling.
```diff
python -m oat.experiment.main \
- --gpus 2 \
+ --gpus 4 \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
- --rollout-batch-size-per-device 64 \
- --pi-buffer-maxlen-per-device 64 \
- --train-batch-size-per-device 8 \
+ --rollout-batch-size-per-device 32 \
+ --pi-buffer-maxlen-per-device 32 \
+ --train-batch-size-per-device 1 \
+ --learn-rm \
+ --exp-method EnnBAITS \
+ --num_samples 10 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_sea
```

<p align="center">
<img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/example_result.png" width=55%/>
</p>

Check out this [tutorial](./preference_learning.md) for more examples covering:
* Various direct optimizers, including DPO, IPO, and SLiC.
* Different modes of preference oracles, such as remote reward models and GPT-as-a-judge.
* Additional LLM exploration algorithms, e.g., APL, XPO, and EE4LLM.
Binary file added docs/new_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/ppo_temperature.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
64 changes: 64 additions & 0 deletions docs/reasoning_examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
We can easily extend oat 🌾 by running RL with rule-based rewards (result verification) to improve language model's reasoning capability. Below we show an example to run PPO on GSM8K, which improves the test score significantly **from 40.6% to 55.7% with pure RL**!

```
python -m oat.experiment.run_ppo \
--gpus 8 \
--collocate \
--gradient-checkpointing \
--flash-attn \
--bf16 \
--rnd-seed \
--learning_rate 0.000001 \
--critic_learning_rate 0.000001 \
--lr_scheduler polynomial \
--kl_penalty_coef 0 \
--num_ppo_epochs 2 \
--beta 0.0001 \
--non_stop_penalty 0 \
--oracle_type reward \
--oracle gsm8k \
--pretrain lkevinzc/rho-1b-sft-GSM8K \
--apply-chat-template \
--zero-stage 2 \
--prompt_data lkevinzc/gsm8k \
--max-train 9999999 \
--num_prompt_epoch 6 \
--prompt_max_length 1000 \
--sync_params_every 1 \
--num_samples 8 \
--max_step_adjustment 16 \
--critic_max_step_adjustment 16 \
--temperature 1.2 \
--top_p 0.9 \
--generate_max_length 1024 \
--save_steps -1 \
--input_key question \
--output_key final_answer \
--train_split train \
--train_batch_size 64 \
--train_batch_size_per_device 8 \
--mini_train_batch_size_per_device 8 \
--rollout_batch_size 64 \
--rollout_batch_size_per_device 8 \
--pi_buffer_maxlen_per_device 64 \
--max_eval 1500 \
--eval_batch_size 200 \
--eval_generate_max_length 1024 \
--eval_steps 50 \
--eval_temperature 0.35 \
--eval_top_p 0.9 \
--eval_n 16 \
--eval_input_key question \
--eval_output_key final_answer \
--use-wb \
--wb-run-name rho-1b-ppo-gsm8k
```
The experiment finishes in about 6 hours with 8 A100 GPUs.

With this, we ablated how *Boltzmann exploration* (controlled by the sampling temperature) affects the learning efficiency:

<p align="center">
<img src="./ppo_temperature.png" width=55%/>
</p>

We look forward future studies on other efficient exploration strategies to enhance LLM reasoning.
2 changes: 1 addition & 1 deletion oat/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Version."""
__version__ = "0.0.5"
__version__ = "0.0.6"
18 changes: 18 additions & 0 deletions oat/actors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from oat.actors.preference import PreferenceActor
from oat.actors.reward import RewardActor

__all__ = ["PreferenceActor", "RewardActor"]
Loading

0 comments on commit b966394

Please sign in to comment.