Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop committed Sep 29, 2024
1 parent a793436 commit 5bcd60c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 16 deletions.
43 changes: 41 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ def setUp(self):
["t5", "exo_pair", True],
["gpt2", "apo_zero", True],
["t5", "apo_down", False],
["gpt2", "wpo", False],
["t5", "wpo", True],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down Expand Up @@ -283,6 +281,47 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

def test_dpo_trainer_with_weighting(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
loss_type="sigmoid",
precompute_ref_log_probs=False,
use_weighting=True,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

@parameterized.expand(
[
[None, "Test when rpo_alpha is set to None"],
Expand Down
37 changes: 23 additions & 14 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ def get_batch_logps(
labels: torch.LongTensor,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
use_weighting: bool = False
use_weighting: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]:
"""Compute the log probabilities of the given labels under the given logits.
Expand All @@ -1326,7 +1326,7 @@ def get_batch_logps(
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
use_weighting: Whether to apply weighting as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
Returns
A Tuple of three tensors of shape ((batch_size,), (batch_size,), Optional[(batch_size,)]) containing:
- The sum of log probabilities of the given labels under the given logits.
Expand All @@ -1347,17 +1347,17 @@ def get_batch_logps(
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

all_logps = (per_token_logps * loss_mask).sum(-1)

all_weights = None
if use_weighting:
# eqn (2) of the WPO paper: https://huggingface.co/papers/2406.11827
probs = F.softmax(logits, dim=-1)
weights_adjustment_factor = torch.log((probs ** 2).sum(-1))
weights_adjustment_factor = torch.log((probs**2).sum(-1))
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)

return all_logps, loss_mask.sum(-1), all_weights

def concatenated_forward(
Expand Down Expand Up @@ -1432,13 +1432,12 @@ def cross_entropy_loss(logits, labels):

if self.loss_type == "ipo":
all_logps = all_logps / size_completion
policy_weights = None

policy_weights = None
if self.args.use_weighting:
chosen_weights = all_weights[:len_chosen]
rejected_weights = all_weights[len_chosen:]
policy_weights = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)


chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
Expand All @@ -1447,9 +1446,17 @@ def cross_entropy_loss(logits, labels):
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights, outputs.aux_loss)
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
nll_loss,
policy_weights,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, policy_weights, nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights)

def get_batch_loss_metrics(
self,
Expand Down Expand Up @@ -1503,7 +1510,7 @@ def get_batch_loss_metrics(
if self.args.rpo_alpha is not None:
# RPO loss from V3 of the paper:
losses = losses + policy_nll_loss * self.args.rpo_alpha

if self.args.use_weighting:
losses = losses * policy_weights

Expand Down Expand Up @@ -1730,15 +1737,17 @@ def create_model_card(
else:
base_model = None

citation = textwrap.dedent("""\
citation = textwrap.dedent(
"""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}""")
}"""
)

model_card = generate_model_card(
base_model=base_model,
Expand Down

0 comments on commit 5bcd60c

Please sign in to comment.