From 63d214ee97d238f5b5171a4c5d999611fa2a3141 Mon Sep 17 00:00:00 2001 From: gaetanlop Date: Mon, 7 Oct 2024 22:38:04 -0400 Subject: [PATCH] add more tests --- tests/test_cgpo_trainer.py | 24 +++++++++++++++++++++++- trl/trainer/cgpo_config.py | 3 +++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_cgpo_trainer.py b/tests/test_cgpo_trainer.py index 97907bccbb..b3b258e7a1 100644 --- a/tests/test_cgpo_trainer.py +++ b/tests/test_cgpo_trainer.py @@ -269,6 +269,28 @@ def test_cgpo_trainer_wrong_rlhf_optimizer(self): report_to="none", ) + def test_cgpo_trainer_no_kl_threshold(self): + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaisesRegex( + ValueError, + expected_regex="Training without setting the KL divergence threshold is not supported.", + ): + CGPOConfig( + output_dir=tmp_dir, + rlhf_optimizer="crraft", + k=4, + kl_threshold=None, + temperature=0.9, + max_new_tokens=4, + per_device_train_batch_size=4, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + @parameterized.expand(["crraft", "crpg", "codpo"]) def test_cgpo_trainer_with_missing_eos_penalty(self, rlhf_optimizer): with tempfile.TemporaryDirectory() as tmp_dir: @@ -436,7 +458,7 @@ def test_cgpo_trainer_without_providing_ref_model_with_lora(self): trainer.train() - assert trainer.state.log_history[-1]["train_loss"] is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) # check the params have changed for n, param in previous_trainable_params.items(): diff --git a/trl/trainer/cgpo_config.py b/trl/trainer/cgpo_config.py index e585bcb31e..8a29cbbbb2 100644 --- a/trl/trainer/cgpo_config.py +++ b/trl/trainer/cgpo_config.py @@ -75,3 +75,6 @@ def __post_init__(self): raise ValueError( f"Invalid value for rlhf_optimizer: {self.rlhf_optimizer}. Must be one of 'crraft', 'codpo', or 'crpg'." ) + + if self.kl_threshold is None: + raise ValueError("Training without setting the KL divergence threshold is not supported.")