Skip to content

Commit

Permalink
Add ORPO example and e2e test (#1572)
Browse files Browse the repository at this point in the history
* add example for mistral orpo

* sample_packing: false for orpo

* go to load_dataset (since load_rl_datasets require a transfom_fn, which only dpo uses currently)
  • Loading branch information
tokestermw authored Apr 27, 2024
1 parent 68601ec commit 98c25e1
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ venv/
ENV/
env.bak/
venv.bak/
venv3.10/

# Spyder project settings
.spyderproject
Expand Down
2 changes: 1 addition & 1 deletion docs/rlhf.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
type: chat_template.argilla
```
#### Using local dataset files
Expand Down
82 changes: 82 additions & 0 deletions examples/mistral/mistral-qlora-orpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./mistral-qlora-orpo-out

adapter: qlora
lora_model_dir:

sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
47 changes: 47 additions & 0 deletions tests/e2e/test_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,50 @@ def test_ipo_lora(self, temp_dir):

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

@with_temp_dir
def test_orpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "orpo",
"orpo_alpha": 0.1,
"remove_unused_columns": False,
"chat_template": "chatml",
"datasets": [
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

0 comments on commit 98c25e1

Please sign in to comment.