diff --git a/.gitignore b/.gitignore index 589440abf6..e6dfee67db 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv3.10/ # Spyder project settings .spyderproject diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 7db68915ad..b8b2bded09 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -49,7 +49,7 @@ remove_unused_columns: false chat_template: chatml datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned - type: orpo.chat_template + type: chat_template.argilla ``` #### Using local dataset files diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml new file mode 100644 index 0000000000..7727fd7485 --- /dev/null +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -0,0 +1,82 @@ +base_model: mistralai/Mistral-7B-v0.1 +model_type: MistralForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +rl: orpo +orpo_alpha: 0.1 +remove_unused_columns: false + +chat_template: chatml +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned + type: chat_template.argilla +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./mistral-qlora-orpo-out + +adapter: qlora +lora_model_dir: + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index e28df7411f..9596b1873f 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -158,3 +158,50 @@ def test_ipo_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_orpo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "orpo", + "orpo_alpha": 0.1, + "remove_unused_columns": False, + "chat_template": "chatml", + "datasets": [ + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned", + "type": "chat_template.argilla", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()