This repo contains code to reproduce our work "Accelerating Direct Preference Optimization with Prefix Sharing", NeurIPS-FITML Workshop, 2024. Arxiv Link
Each DPO training example consists of a shared prompt, a "chosen" response, and a "rejected" response. Instead of computing the shared prompt twice, we combine the prompt and pair of responses into a single sequence, using a custom attention mask to share the prefix across the responses. This approach lets us speed up DPO training while being numerically identical to standard DPO training.
To implement the custom attention mask, we use PyTorch's FlexAttention and leverage its block sparsity to skip empty blocks of the attention mask.
Our method works best when the prompt prefixes are much longer than the responses, such as for tasks like multiturn chat or summarization. For instance, on the Anthropic HH-RLHF multiturn dataset, FlexAttention w/ prefix sharing & sequence packing is faster than FlashAttention-3 w/ sequence packing by a factor of 1.41×. Nevertheless, even for the UltraFeedback dataset, where the responses are much longer than the prompt prefixes, we get a speedup of 1.17×.
For more performance improvement statistics across other datasets, see our paper or the Speedups section.
Installation instructions:
- Python 3.10+
- CUDA 12.1 or above
- PyTorch 2.5.0
pip install -r requirements.txt
Our trainer is based on the DPO Trainer from 🤗TRL, and thus you can run training as you would for a typical DPO run. We further enable prefix sharing and sequence packing with the flags --attn_implementation flex_attention
, --prefix_sharing
, --enable_packing
, and --packing_length ...
. Some example commands are provided below for the Capybara dataset. We've implemented support for the Mistral and Llama 3 models.
For running training for meta-llama/Meta-Llama-3.1-8B-Instruct
with prefix sharing, use --attn_implementation flex_attention
and --prefix_sharing
:
accelerate launch --config_file 'configs/zero3.yaml' train_dpo.py \
--dataset_name='argilla/distilabel-capybara-dpo-7k-binarized' \
--model_name_or_path='meta-llama/Meta-Llama-3.1-8B-Instruct' \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--beta 0.1 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--warmup_ratio 0.1 \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--output_dir capybara_no_packing \
--max_prompt_length 2180 \
--max_length 2842 \
--gradient_checkpointing True \
--save_strategy no \
--dataset_train_split train \
--num_train_epochs 1 \
--dataloader_num_workers 4 \
--dataset_num_proc 8 \
--attn_implementation flex_attention \
--prefix_sharing
Optionally, add reporting flags --report_to wandb --run_name <my_run>
for WandB, etc
NOTE: The current training config assumes training on a 8xA100 or a 8xH100 node. Tweak as needed for your setup.
For prefix sharing and sequence packing, use --attn_implementation flex_attention
, --prefix_sharing
, --enable_packing
, and --packing_length {PACKING_LENGTH}
:
Run:
accelerate launch --config_file 'configs/zero3.yaml' train_dpo.py \
--dataset_name='argilla/distilabel-capybara-dpo-7k-binarized' \
--model_name_or_path='meta-llama/Meta-Llama-3.1-8B-Instruct' \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--beta 0.1 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--warmup_ratio 0.1 \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--output_dir capybara_packing \
--max_prompt_length 2180 \
--max_length 2842 \
--gradient_checkpointing True \
--save_strategy no \
--dataset_train_split train \
--num_train_epochs 1 \
--dataloader_num_workers 4 \
--dataset_num_proc 8 \
--attn_implementation flex_attention \
--prefix_sharing \
--enable_packing \
--packing_length 7936
NOTE: We calculate the "packing length" based on statistics for prefix shared inputs. In the case of Capybara, we choose a packing length of 3968 (calculated as
To reproduce the results from our paper for Capybara, you use run: bash benchmark/run_capybara.sh
Note: These experiments were run on an 8xH100 setup, but we expect similar improvements for A100s.
Below we show prefix sharing's speedups without sample packing. Overall, our approach does best when the dataset has a high prefix-to-completion ratio and a high overall sequence length.
Dataset | Median Overall Length |
Median Prefix to Completion Length Ratio |
FlashAttn-3 (samples/sec) |
Flex Attn (samples/sec) |
Flex Attn + Prefix Sharing (samp/sec, (speedup over FA3 & Flex)) |
---|---|---|---|---|---|
Capybara | 1160 | 1.59 | 8.38 | 7.75 | 11.90 (1.42×, 1.54×) |
HH-RLHF | 186 | 2.15 | 33.71 | 30.25 | 36.11 (1.07×, 1.19×) |
MetaMath-DPO | 872 | 3.91 | 13.86 | 13.02 | 19.13 (1.38×, 1.47×) |
TLDR | 416 | 11.14 | 31.43 | 29.53 | 35.36 (1.12×, 1.20×) |
Tulu-Helpsteer | 775 | 6.34 | 14.83 | 13.93 | 21.75 (1.47×, 1.56×) |
Ultrafeedback | 409 | 0.42 | 18.40 | 17.31 | 20.46 (1.11×, 1.18×) |
When combined with sample packing, prefix sharing has more consistent speedups for datasets with shorter overall sequence lengths (HH-RLHF: 1.07× => 1.41×, TLDR: 1.12× => 1.35×). Sequence packing is also generally much more efficient than the non-packing implementation (can sometimes lead to up to 5× speedups).
Dataset | Median Overall Length |
Median Prefix to Completion Length Ratio |
FlashAttn-3 (samples/sec) |
Flex Attn (samples/sec) |
Flex Attn + Prefix Sharing (samp/sec, (speedup over FA3 & Flex)) |
---|---|---|---|---|---|
Capybara | 1160 | 1.59 | 17.89 | 17.63 | 23.89 (1.34×, 1.36×) |
HH-RLHF | 186 | 2.15 | 109.77 | 104.99 | 155.04 (1.41×, 1.48×) |
MetaMath-DPO | 872 | 3.91 | 24.21 | 23.83 | 38.07 (1.57×, 1.60×) |
TLDR | 416 | 11.14 | 44.11 | 43.22 | 59.76 (1.35×, 1.38×) |
Tulu-Helpsteer | 775 | 6.34 | 29.85 | 28.98 | 44.10 (1.48×, 1.52×) |
Ultrafeedback | 409 | 0.42 | 45.46 | 44.13 | 53.21 (1.17×, 1.21×) |
data
: Data processing and dataloading related files, including chat templating, sampler for sequence packing, patches for dataloading, etc
modeling
: Contains custom attention masks and model patches to use flex attention instead of flash attention.
train_dpo.py
: Main entrypoint.
trainer.py
: Contains a modified DPOTrainer
to work with prefix shared inputs (with optional sequence packing). We mainly customize the preprocessing (to form prefix shared inputs), the forward pass (with a custom block mask per input for Flex Attention) and the log probability computation.
Our code is based off of 🤗TRL. The packing sampler implementation is from Axolotl, which is a more distributed training friendly version of the Multipack Sampler.
We also make use of FlexAttention for the custom attention mask implementation.
If you find our work useful, please cite
@inproceedings{
wang2024accelerating,
title={Accelerating Direct Preference Optimization with Prefix Sharing},
author={Franklin Wang and Sumanth Hegde},
booktitle={NeurIPS 2024 Workshop on Fine-Tuning in Modern Machine Learning: Principles and Scalability},
year={2024},
url={https://openreview.net/forum?id=d4dRhZiTdm}
}