Skip to content

Commit

Permalink
[DPO] add SLiC hinge loss to DPOTrainer (huggingface#866)
Browse files Browse the repository at this point in the history
* add SLiC hinge loss

* fix links

* beta when loss is hinge is reciprocal of margin

* fix tests

* fix docs

* doc strings

* fix method name

* raise error if loss_type is not correct

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Leandro von Werra <[email protected]>

* fix formatting

---------

Co-authored-by: Leandro von Werra <[email protected]>
  • Loading branch information
kashif and lvwerra authored Oct 16, 2023
1 parent eb4d2f3 commit 14b6bc6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
5 changes: 5 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ dpo_trainer.train()

Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.

## Loss function

Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
10 changes: 3 additions & 7 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,8 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand(
[
["gpt2"],
["t5"],
]
)
def test_dpo_trainer(self, name):
@parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"]])
def test_dpo_trainer(self, name, loss_type):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
Expand All @@ -107,6 +102,7 @@ def test_dpo_trainer(self, name):
model=model,
ref_model=ref_model,
beta=0.1,
loss_type=loss_type,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
Expand Down
12 changes: 11 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class DPOTrainer(Trainer):
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss or `"hinge"` loss from SLiC paper.
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
model: Union[PreTrainedModel, nn.Module] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
beta: float = 0.1,
loss_type: Literal["sigmoid", "hinge"] = "sigmoid",
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -223,6 +226,7 @@ def __init__(
self.padding_value = padding_value

self.beta = beta
self.loss_type = loss_type

self._stored_metrics = defaultdict(lambda: defaultdict(list))

Expand Down Expand Up @@ -356,7 +360,13 @@ def dpo_loss(

logits = pi_logratios - ref_logratios

losses = -F.logsigmoid(self.beta * logits)
if self.loss_type == "sigmoid":
losses = -F.logsigmoid(self.beta * logits)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']")

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

Expand Down

0 comments on commit 14b6bc6

Please sign in to comment.