Skip to content

BIG DPO [Add FP8 support, 2D parallelism, async checkpointing] #2621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
106 changes: 106 additions & 0 deletions recipes/configs/llama4/scout_17B_16E_dpo_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Config for multi-device DPO finetuning in full_dpo_distributed.py
# using a Llama4 17Bx16E MoE model with 2D parallelism
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
#
# To launch on 8 devices, run the following command from root:
# tune run --nproc_per_node 8 full_dpo_distributed --config llama4/scout_17B_16E_dpo
#
# You can add specific overrides through the command line. For example, to use a larger bsz:
# tune run --nproc_per_node 8 full_dpo_distributed --config llama4/scout_17B_16E_dpo batch_size=8
#
# This config is designed for 8xA100 or 16xH100 machines.

output_dir: /tmp/torchtune/llama4_17Bx16E/dpo

# Modeling arguments
model:
_component_: torchtune.models.llama4.llama4_scout_17b_16e

# 2D Parallelism configuration
tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8
tensor_parallel_plan:
_component_: torchtune.models.llama4.decoder_only_tp_plan
data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP
data_parallel_replicate_dim: 1

tokenizer:
_component_: torchtune.models.llama4.llama4_transform
path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
max_seq_len: null
max_num_tiles: 16

# Base model checkpointer
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA4

# Reference model checkpointer (for DPO)
ref_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA4

resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
packed: False
seed: null
shuffle: True

# Training arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 1
gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
lr: 5e-7 # Lower learning rate for DPO
fused: False
loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.1
clip_grad_norm: 1.0

# cuda, cpu, rocm, xpu...
device: cuda

# Memory management / performance
enable_activation_checkpointing: True
enable_activation_offloading: True
fsdp_cpu_offload: False # Set to False - keeping optimizer states on GPU
fsdp_reshard_after_forward: True
compile: False # torch.compile, set to true for perf/memory improvement

# Reduced precision
dtype: bf16

# Log metrics during training
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Useful for understanding how to optimize memory and performance
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

# Float8 training support
enable_fp8_training: False
fp8_recipe_name: null
Loading
Loading