From c996881ec2f6cc23e540ef9ebefd87a8e5e5387d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 4 Jun 2024 16:09:51 -0400 Subject: [PATCH] add support for rpo_alpha (#1681) * add support for rpo_alpha * Add smoke test for dpo + nll loss --- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 13 +++++- .../config/models/input/v0_4_1/__init__.py | 1 + tests/e2e/test_dpo.py | 45 +++++++++++++++++++ 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 557c5293fd..b5114bbf62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl==0.8.6 +trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9 zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0551ddbc0e..b88cc42210 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer +from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -238,6 +238,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ +@dataclass +class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ + + @dataclass class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): """ @@ -1608,7 +1615,9 @@ def build_training_arguments(self, total_num_steps): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha - training_args_cls = AxolotlTrainingArguments + training_args_cls = AxolotlDPOConfig + if self.cfg.rpo_alpha is not None: + training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha if self.cfg.rl == "orpo": training_args_cls = AxolotlORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f363ebfdce..240a816edc 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -619,6 +619,7 @@ class Config: neftune_noise_alpha: Optional[float] = None orpo_alpha: Optional[float] = None + rpo_alpha: Optional[float] = None kto_desirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 5d2522bdfd..5f03e6bc1b 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -70,6 +70,51 @@ def test_dpo_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir + def test_dpo_nll_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "dpo", + "rpo_alpha": 0.5, + "datasets": [ + { + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir def test_kto_pair_lora(self, temp_dir): # pylint: disable=duplicate-code