Skip to content

Commit

Permalink
add support for rpo_alpha (#1681)
Browse files Browse the repository at this point in the history
* add support for rpo_alpha

* Add smoke test for dpo + nll loss
  • Loading branch information
winglian authored Jun 4, 2024
1 parent 1f151c0 commit c996881
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ s3fs
gcsfs
# adlfs

trl==0.8.6
trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
zstandard==0.22.0
fastcore
13 changes: 11 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length

from axolotl.loraplus import create_loraplus_optimizer
Expand Down Expand Up @@ -238,6 +238,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""


@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""


@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
Expand Down Expand Up @@ -1608,7 +1615,9 @@ def build_training_arguments(self, total_num_steps):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

training_args_cls = AxolotlTrainingArguments
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ class Config:
neftune_noise_alpha: Optional[float] = None

orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None

kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None
Expand Down
45 changes: 45 additions & 0 deletions tests/e2e/test_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,51 @@ def test_dpo_lora(self, temp_dir):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
# pylint: disable=duplicate-code
Expand Down

0 comments on commit c996881

Please sign in to comment.