LLM alignment is essentially an online learning and decision making problem where the agent (e.g., the LLM policy with an optional built-in reward model) interacts with the environment (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the Explore & Exploit setting or minimizing anytime regret in the Best Arm Identification setting.
In our paper, we formalize LLM alignment as a contextual dueling bandit (CDB) problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.
The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat 🌾 is developed as part of this research initiative.
Using the CDB framework, existing LLM alignment paradigms can be summarized as follows:
For more details, please check out our paper!
Below is an example to align a 1-B Pythia
SFT Model on the tl;dr
dataset using online SimPO
with PairRM
as the preference oracle:
Warning
Aligning with PairRM
provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the examples.
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
--wb-run-name 1b_pairrm_simpo_online
This example completes in less than two hours on two A100-40G GPUs!
To run an offline SimPO
baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the sync-params-every
argument:
python -m oat.experiment.main \
--gpus 2 \
--collocate \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
- --sync-params-every 1 \
+ --sync-params-every 9999 \ # any number > total gradient step (50000//128=390)
--rollout-batch-size-per-device 64 \
--pi-buffer-maxlen-per-device 64 \
--train-batch-size-per-device 8 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_offline
Finally, we run SEA SimPO
(with
python -m oat.experiment.main \
- --gpus 2 \
+ --gpus 4 \
--dap-algo SimPO \
--beta 2 \
--preference-oracle pairrm \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--output_key pythia-1b-reference \
--sync-params-every 1 \
- --rollout-batch-size-per-device 64 \
- --pi-buffer-maxlen-per-device 64 \
- --train-batch-size-per-device 8 \
+ --rollout-batch-size-per-device 32 \
+ --pi-buffer-maxlen-per-device 32 \
+ --train-batch-size-per-device 1 \
+ --learn-rm \
+ --exp-method EnnBAITS \
+ --num_samples 10 \
--use-wb \
- --wb-run-name 1b_pairrm_simpo_online
+ --wb-run-name 1b_pairrm_simpo_sea
Check out this tutorial for more examples covering:
- Various direct optimizers, including DPO, IPO, and SLiC.
- Different modes of preference oracles, such as remote reward models and GPT-as-a-judge.
- Additional LLM exploration algorithms, e.g., APL, XPO, and EE4LLM.