Skip to content

Commit

Permalink
Custom DPO losses support (#2427)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Obozov <[email protected]>
  • Loading branch information
krammnic and Mark Obozov authored Feb 27, 2025
1 parent 8bf8647 commit 3d745a2
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 107 deletions.
22 changes: 22 additions & 0 deletions docs/source/recipes/dpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r
loss=torchtune.modules.loss.RSOLoss \
gamma=0.5
Also, you can pass your custom loss in our recipe. Note that its `forward` method should align with the following signature:

.. code-block:: python
def forward(self, policy_inputs: ChosenRejectedOutputs, reference_inputs: ChosenRejectedOutputs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...
Here, `ChosenRejectedOutputs` is a dataclass obtained from `concatenated_forward``:

.. code-block:: python
@dataclass
class ChosenRejectedOutputs:
chosen_logps: torch.Tensor
rejected_logps: torch.Tensor
chosen_logits: torch.Tensor
rejected_logits: torch.Tensor
If this is not sufficient and you need to compute additional values from the logits, you can modify `concatenated_forward` directly. To do this, use `tune cp` to copy the desired recipe, and don’t forget to use your own dataclass!

Refer to the TRL library for reference implementations of the desired losses. In particular, you may find useful loss calculations in trainers.

For a deeper understanding of the different levers you can pull when using this recipe,
see our documentation for the different PEFT training paradigms we support:

Expand Down
54 changes: 29 additions & 25 deletions recipes/full_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf import ChosenRejectedOutputs
from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY
from torchtune.training.lr_schedulers import get_lr
from torchtune.utils import get_world_size_and_rank
Expand Down Expand Up @@ -797,7 +798,7 @@ def concatenated_forward(
model: nn.Module,
batch: Tuple[torch.Tensor, torch.Tensor],
activations_handling: Optional[bool] = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> ChosenRejectedOutputs:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Expand All @@ -806,7 +807,7 @@ def concatenated_forward(
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
Expand Down Expand Up @@ -836,7 +837,9 @@ def concatenated_forward(
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
return ChosenRejectedOutputs(
chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits
)

def train(self) -> None:
"""
Expand Down Expand Up @@ -884,36 +887,35 @@ def train(self) -> None:

# batch is input_ids, labels
num_tokens += torch.tensor(batch[0].numel())
(
policy_chosen_log_probs,
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
policy_chosen_rejected_outputs = self.concatenated_forward(
self._model, batch
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
policy_chosen_logits_mean = (
policy_chosen_rejected_outputs.chosen_logits.detach().mean()
)
policy_rejected_logits_mean = (
policy_chosen_rejected_outputs.rejected_logits.detach().mean()
)

# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits
del (
policy_chosen_rejected_outputs.chosen_logits,
policy_chosen_rejected_outputs.rejected_logits,
)

with torch.no_grad():
(
reference_chosen_log_probs,
reference_rejected_log_probs,
reference_chosen_logits,
reference_rejected_logits,
) = self.concatenated_forward(
reference_chosen_rejected_outputs = self.concatenated_forward(
self._ref_model, batch, activations_handling=False
)

del reference_chosen_logits, reference_rejected_logits
del (
reference_chosen_rejected_outputs.chosen_logits,
reference_chosen_rejected_outputs.rejected_logits,
)

loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
policy_chosen_rejected_outputs, reference_chosen_rejected_outputs
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

Expand All @@ -936,10 +938,12 @@ def train(self) -> None:
scaling_factor * reward_accuracies.mean()
)
running_metrics["log_probs/chosen"] += (
scaling_factor * policy_chosen_log_probs.detach().mean()
scaling_factor
* policy_chosen_rejected_outputs.chosen_logps.detach().mean()
)
running_metrics["log_probs/rejected"] += (
scaling_factor * policy_rejected_log_probs.detach().mean()
scaling_factor
* policy_chosen_rejected_outputs.rejected_logps.detach().mean()
)
running_metrics["logits/chosen"] += (
scaling_factor * policy_chosen_logits_mean
Expand Down
52 changes: 28 additions & 24 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf import ChosenRejectedOutputs
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -614,7 +615,7 @@ def save_checkpoint(

def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> ChosenRejectedOutputs:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Expand All @@ -623,7 +624,7 @@ def concatenated_forward(
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
Expand All @@ -643,7 +644,9 @@ def concatenated_forward(
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
return ChosenRejectedOutputs(
chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits
)

def train(self) -> None:
"""
Expand Down Expand Up @@ -690,31 +693,30 @@ def train(self) -> None:
# batch is input_ids, labels
num_tokens += torch.tensor(batch[0].numel())

(
policy_chosen_log_probs,
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
policy_chosen_rejected_outputs = self.concatenated_forward(
self._model, batch
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
policy_chosen_logits_mean = (
policy_chosen_rejected_outputs.chosen_logits.detach().mean()
)
policy_rejected_logits_mean = (
policy_chosen_rejected_outputs.rejected_logits.detach().mean()
)

# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits
del (
policy_chosen_rejected_outputs.chosen_logits,
policy_chosen_rejected_outputs.rejected_logits,
)

with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
reference_chosen_rejected_outputs = self.concatenated_forward(
self._model, batch
)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
policy_chosen_rejected_outputs,
reference_chosen_rejected_outputs,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

Expand All @@ -737,10 +739,12 @@ def train(self) -> None:
scaling_factor * reward_accuracies.mean()
)
running_metrics["log_probs/chosen"] += (
scaling_factor * policy_chosen_log_probs.detach().mean()
scaling_factor
* policy_chosen_rejected_outputs.chosen_logps.detach().mean()
)
running_metrics["log_probs/rejected"] += (
scaling_factor * policy_rejected_log_probs.detach().mean()
scaling_factor
* policy_chosen_rejected_outputs.rejected_logps.detach().mean()
)
running_metrics["logits/chosen"] += (
scaling_factor * policy_chosen_logits_mean
Expand Down
50 changes: 26 additions & 24 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf import ChosenRejectedOutputs

from tqdm import tqdm

Expand Down Expand Up @@ -472,7 +473,7 @@ def save_checkpoint(self, epoch: int) -> None:

def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> ChosenRejectedOutputs:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Expand All @@ -481,7 +482,7 @@ def concatenated_forward(
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
Expand All @@ -501,7 +502,9 @@ def concatenated_forward(
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
return ChosenRejectedOutputs(
chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits
)

def train(self) -> None:
"""
Expand Down Expand Up @@ -533,31 +536,30 @@ def train(self) -> None:

# batch is input_ids, labels
num_tokens += batch[0].numel()
(
policy_chosen_log_probs,
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
policy_chosen_rejected_outputs = self.concatenated_forward(
self._model, batch
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
policy_chosen_logits_mean = (
policy_chosen_rejected_outputs.chosen_logits.detach().mean()
)
policy_rejected_logits_mean = (
policy_chosen_rejected_outputs.rejected_logits.detach().mean()
)

# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits
del (
policy_chosen_rejected_outputs.chosen_logits,
policy_chosen_rejected_outputs.rejected_logits,
)

with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
reference_chosen_rejected_outputs = self.concatenated_forward(
self._model, batch
)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
policy_chosen_rejected_outputs,
reference_chosen_rejected_outputs,
)

loss = loss.mean()
Expand Down Expand Up @@ -596,10 +598,10 @@ def train(self) -> None:
"rewards/margins": (chosen_rewards - rejected_rewards)
.mean()
.cpu(),
"log_probs/rejected": policy_rejected_log_probs.detach()
"log_probs/rejected": policy_chosen_rejected_outputs.rejected_logps.detach()
.mean()
.cpu(),
"log_probs/chosen": policy_chosen_log_probs.detach()
"log_probs/chosen": policy_chosen_rejected_outputs.chosen_logps.detach()
.mean()
.cpu(),
"logits/rejected": policy_rejected_logits_mean.cpu(),
Expand Down
8 changes: 7 additions & 1 deletion tests/torchtune/rlhf/loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from torchtune.rlhf._types import ChosenRejectedOutputs
from torchtune.rlhf.loss import DPOLoss, RSOLoss


Expand Down Expand Up @@ -39,11 +40,16 @@ def loss_inputs(self):
ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1])
ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1])

return (
return ChosenRejectedOutputs(
policy_chosen_logprobs,
policy_rejected_logprobs,
torch.tensor(0),
torch.tensor(0),
), ChosenRejectedOutputs(
ref_chosen_logprobs,
ref_rejected_logprobs,
torch.tensor(0),
torch.tensor(0),
)

def test_dpo_loss(self, dpo_loss, loss_inputs):
Expand Down
3 changes: 2 additions & 1 deletion torchtune/rlhf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from ._types import PPOStats, Trajectory
from ._types import ChosenRejectedOutputs, PPOStats, Trajectory

from .rewards import (
estimate_advantages,
Expand Down Expand Up @@ -39,4 +39,5 @@
"PPOStats",
"get_batch_log_probs",
"Trajectory",
"ChosenRejectedOutputs",
]
Loading

0 comments on commit 3d745a2

Please sign in to comment.