Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop committed Oct 8, 2024
1 parent 0ac178f commit 63d214e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/test_cgpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ def test_cgpo_trainer_wrong_rlhf_optimizer(self):
report_to="none",
)

def test_cgpo_trainer_no_kl_threshold(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaisesRegex(
ValueError,
expected_regex="Training without setting the KL divergence threshold is not supported.",
):
CGPOConfig(
output_dir=tmp_dir,
rlhf_optimizer="crraft",
k=4,
kl_threshold=None,
temperature=0.9,
max_new_tokens=4,
per_device_train_batch_size=4,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
report_to="none",
)

@parameterized.expand(["crraft", "crpg", "codpo"])
def test_cgpo_trainer_with_missing_eos_penalty(self, rlhf_optimizer):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -436,7 +458,7 @@ def test_cgpo_trainer_without_providing_ref_model_with_lora(self):

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/cgpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ def __post_init__(self):
raise ValueError(
f"Invalid value for rlhf_optimizer: {self.rlhf_optimizer}. Must be one of 'crraft', 'codpo', or 'crpg'."
)

if self.kl_threshold is None:
raise ValueError("Training without setting the KL divergence threshold is not supported.")

0 comments on commit 63d214e

Please sign in to comment.