From 3d745a28bedb40dadf9761e265662da86da04662 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Fri, 28 Feb 2025 02:58:06 +0300 Subject: [PATCH] Custom DPO losses support (#2427) Co-authored-by: Mark Obozov --- docs/source/recipes/dpo.rst | 22 +++++++++ recipes/full_dpo_distributed.py | 54 ++++++++++---------- recipes/lora_dpo_distributed.py | 52 +++++++++++--------- recipes/lora_dpo_single_device.py | 50 ++++++++++--------- tests/torchtune/rlhf/loss/test_dpo_loss.py | 8 ++- torchtune/rlhf/__init__.py | 3 +- torchtune/rlhf/_types.py | 21 ++++++++ torchtune/rlhf/loss/dpo.py | 57 ++++++++++------------ 8 files changed, 160 insertions(+), 107 deletions(-) diff --git a/docs/source/recipes/dpo.rst b/docs/source/recipes/dpo.rst index efe2a0e126..9db18f5d90 100644 --- a/docs/source/recipes/dpo.rst +++ b/docs/source/recipes/dpo.rst @@ -59,6 +59,28 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r loss=torchtune.modules.loss.RSOLoss \ gamma=0.5 +Also, you can pass your custom loss in our recipe. Note that its `forward` method should align with the following signature: + +.. code-block:: python + + def forward(self, policy_inputs: ChosenRejectedOutputs, reference_inputs: ChosenRejectedOutputs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ... + +Here, `ChosenRejectedOutputs` is a dataclass obtained from `concatenated_forward``: + +.. code-block:: python + + @dataclass + class ChosenRejectedOutputs: + chosen_logps: torch.Tensor + rejected_logps: torch.Tensor + chosen_logits: torch.Tensor + rejected_logits: torch.Tensor + +If this is not sufficient and you need to compute additional values from the logits, you can modify `concatenated_forward` directly. To do this, use `tune cp` to copy the desired recipe, and don’t forget to use your own dataclass! + +Refer to the TRL library for reference implementations of the desired losses. In particular, you may find useful loss calculations in trainers. + For a deeper understanding of the different levers you can pull when using this recipe, see our documentation for the different PEFT training paradigms we support: diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py index e81df88b5a..df2db9a698 100644 --- a/recipes/full_dpo_distributed.py +++ b/recipes/full_dpo_distributed.py @@ -20,6 +20,7 @@ from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.rlhf import ChosenRejectedOutputs from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY from torchtune.training.lr_schedulers import get_lr from torchtune.utils import get_world_size_and_rank @@ -797,7 +798,7 @@ def concatenated_forward( model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor], activations_handling: Optional[bool] = True, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> ChosenRejectedOutputs: """ Run forward pass of the model with chosen and rejected samples concatenated. @@ -806,7 +807,7 @@ def concatenated_forward( batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. Returns: - Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits. """ concatenated_input_ids, concatenated_labels = batch concatenated_input_ids = concatenated_input_ids.to(self._device) @@ -836,7 +837,9 @@ def concatenated_forward( chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] - return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + return ChosenRejectedOutputs( + chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits + ) def train(self) -> None: """ @@ -884,36 +887,35 @@ def train(self) -> None: # batch is input_ids, labels num_tokens += torch.tensor(batch[0].numel()) - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - ) = self.concatenated_forward(self._model, batch) + policy_chosen_rejected_outputs = self.concatenated_forward( + self._model, batch + ) - policy_chosen_logits_mean = policy_chosen_logits.detach().mean() - policy_rejected_logits_mean = policy_rejected_logits.detach().mean() + policy_chosen_logits_mean = ( + policy_chosen_rejected_outputs.chosen_logits.detach().mean() + ) + policy_rejected_logits_mean = ( + policy_chosen_rejected_outputs.rejected_logits.detach().mean() + ) # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging - del policy_chosen_logits, policy_rejected_logits + del ( + policy_chosen_rejected_outputs.chosen_logits, + policy_chosen_rejected_outputs.rejected_logits, + ) with torch.no_grad(): - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - reference_chosen_logits, - reference_rejected_logits, - ) = self.concatenated_forward( + reference_chosen_rejected_outputs = self.concatenated_forward( self._ref_model, batch, activations_handling=False ) - del reference_chosen_logits, reference_rejected_logits + del ( + reference_chosen_rejected_outputs.chosen_logits, + reference_chosen_rejected_outputs.rejected_logits, + ) loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, - policy_rejected_log_probs, - reference_chosen_log_probs, - reference_rejected_log_probs, + policy_chosen_rejected_outputs, reference_chosen_rejected_outputs ) reward_accuracies = (chosen_rewards > rejected_rewards).float() @@ -936,10 +938,12 @@ def train(self) -> None: scaling_factor * reward_accuracies.mean() ) running_metrics["log_probs/chosen"] += ( - scaling_factor * policy_chosen_log_probs.detach().mean() + scaling_factor + * policy_chosen_rejected_outputs.chosen_logps.detach().mean() ) running_metrics["log_probs/rejected"] += ( - scaling_factor * policy_rejected_log_probs.detach().mean() + scaling_factor + * policy_chosen_rejected_outputs.rejected_logps.detach().mean() ) running_metrics["logits/chosen"] += ( scaling_factor * policy_chosen_logits_mean diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 3c94a84d02..720576605a 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -33,6 +33,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.rlhf import ChosenRejectedOutputs from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -614,7 +615,7 @@ def save_checkpoint( def concatenated_forward( self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> ChosenRejectedOutputs: """ Run forward pass of the model with chosen and rejected samples concatenated. @@ -623,7 +624,7 @@ def concatenated_forward( batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. Returns: - Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits. """ concatenated_input_ids, concatenated_labels = batch concatenated_input_ids = concatenated_input_ids.to(self._device) @@ -643,7 +644,9 @@ def concatenated_forward( chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] - return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + return ChosenRejectedOutputs( + chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits + ) def train(self) -> None: """ @@ -690,31 +693,30 @@ def train(self) -> None: # batch is input_ids, labels num_tokens += torch.tensor(batch[0].numel()) - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - ) = self.concatenated_forward(self._model, batch) + policy_chosen_rejected_outputs = self.concatenated_forward( + self._model, batch + ) - policy_chosen_logits_mean = policy_chosen_logits.detach().mean() - policy_rejected_logits_mean = policy_rejected_logits.detach().mean() + policy_chosen_logits_mean = ( + policy_chosen_rejected_outputs.chosen_logits.detach().mean() + ) + policy_rejected_logits_mean = ( + policy_chosen_rejected_outputs.rejected_logits.detach().mean() + ) # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging - del policy_chosen_logits, policy_rejected_logits + del ( + policy_chosen_rejected_outputs.chosen_logits, + policy_chosen_rejected_outputs.rejected_logits, + ) with torch.no_grad(), disable_adapter(self._model): - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = self.concatenated_forward(self._model, batch) + reference_chosen_rejected_outputs = self.concatenated_forward( + self._model, batch + ) loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, - policy_rejected_log_probs, - reference_chosen_log_probs, - reference_rejected_log_probs, + policy_chosen_rejected_outputs, + reference_chosen_rejected_outputs, ) reward_accuracies = (chosen_rewards > rejected_rewards).float() @@ -737,10 +739,12 @@ def train(self) -> None: scaling_factor * reward_accuracies.mean() ) running_metrics["log_probs/chosen"] += ( - scaling_factor * policy_chosen_log_probs.detach().mean() + scaling_factor + * policy_chosen_rejected_outputs.chosen_logps.detach().mean() ) running_metrics["log_probs/rejected"] += ( - scaling_factor * policy_rejected_log_probs.detach().mean() + scaling_factor + * policy_chosen_rejected_outputs.rejected_logps.detach().mean() ) running_metrics["logits/chosen"] += ( scaling_factor * policy_chosen_logits_mean diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 1416ed4076..046710048b 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -30,6 +30,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.rlhf import ChosenRejectedOutputs from tqdm import tqdm @@ -472,7 +473,7 @@ def save_checkpoint(self, epoch: int) -> None: def concatenated_forward( self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> ChosenRejectedOutputs: """ Run forward pass of the model with chosen and rejected samples concatenated. @@ -481,7 +482,7 @@ def concatenated_forward( batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. Returns: - Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits. """ concatenated_input_ids, concatenated_labels = batch concatenated_input_ids = concatenated_input_ids.to(self._device) @@ -501,7 +502,9 @@ def concatenated_forward( chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] - return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + return ChosenRejectedOutputs( + chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits + ) def train(self) -> None: """ @@ -533,31 +536,30 @@ def train(self) -> None: # batch is input_ids, labels num_tokens += batch[0].numel() - ( - policy_chosen_log_probs, - policy_rejected_log_probs, - policy_chosen_logits, - policy_rejected_logits, - ) = self.concatenated_forward(self._model, batch) + policy_chosen_rejected_outputs = self.concatenated_forward( + self._model, batch + ) - policy_chosen_logits_mean = policy_chosen_logits.detach().mean() - policy_rejected_logits_mean = policy_rejected_logits.detach().mean() + policy_chosen_logits_mean = ( + policy_chosen_rejected_outputs.chosen_logits.detach().mean() + ) + policy_rejected_logits_mean = ( + policy_chosen_rejected_outputs.rejected_logits.detach().mean() + ) # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging - del policy_chosen_logits, policy_rejected_logits + del ( + policy_chosen_rejected_outputs.chosen_logits, + policy_chosen_rejected_outputs.rejected_logits, + ) with torch.no_grad(), disable_adapter(self._model): - ( - reference_chosen_log_probs, - reference_rejected_log_probs, - _, - _, - ) = self.concatenated_forward(self._model, batch) + reference_chosen_rejected_outputs = self.concatenated_forward( + self._model, batch + ) loss, chosen_rewards, rejected_rewards = self._loss_fn( - policy_chosen_log_probs, - policy_rejected_log_probs, - reference_chosen_log_probs, - reference_rejected_log_probs, + policy_chosen_rejected_outputs, + reference_chosen_rejected_outputs, ) loss = loss.mean() @@ -596,10 +598,10 @@ def train(self) -> None: "rewards/margins": (chosen_rewards - rejected_rewards) .mean() .cpu(), - "log_probs/rejected": policy_rejected_log_probs.detach() + "log_probs/rejected": policy_chosen_rejected_outputs.rejected_logps.detach() .mean() .cpu(), - "log_probs/chosen": policy_chosen_log_probs.detach() + "log_probs/chosen": policy_chosen_rejected_outputs.chosen_logps.detach() .mean() .cpu(), "logits/rejected": policy_rejected_logits_mean.cpu(), diff --git a/tests/torchtune/rlhf/loss/test_dpo_loss.py b/tests/torchtune/rlhf/loss/test_dpo_loss.py index ab1bfefa4c..8f5bade186 100644 --- a/tests/torchtune/rlhf/loss/test_dpo_loss.py +++ b/tests/torchtune/rlhf/loss/test_dpo_loss.py @@ -6,6 +6,7 @@ import pytest import torch +from torchtune.rlhf._types import ChosenRejectedOutputs from torchtune.rlhf.loss import DPOLoss, RSOLoss @@ -39,11 +40,16 @@ def loss_inputs(self): ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1]) ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1]) - return ( + return ChosenRejectedOutputs( policy_chosen_logprobs, policy_rejected_logprobs, + torch.tensor(0), + torch.tensor(0), + ), ChosenRejectedOutputs( ref_chosen_logprobs, ref_rejected_logprobs, + torch.tensor(0), + torch.tensor(0), ) def test_dpo_loss(self, dpo_loss, loss_inputs): diff --git a/torchtune/rlhf/__init__.py b/torchtune/rlhf/__init__.py index 3589d096f3..60375a5bb7 100644 --- a/torchtune/rlhf/__init__.py +++ b/torchtune/rlhf/__init__.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from ._types import PPOStats, Trajectory +from ._types import ChosenRejectedOutputs, PPOStats, Trajectory from .rewards import ( estimate_advantages, @@ -39,4 +39,5 @@ "PPOStats", "get_batch_log_probs", "Trajectory", + "ChosenRejectedOutputs", ] diff --git a/torchtune/rlhf/_types.py b/torchtune/rlhf/_types.py index 729a4035fc..b28a438974 100644 --- a/torchtune/rlhf/_types.py +++ b/torchtune/rlhf/_types.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass from typing import NamedTuple import torch @@ -67,3 +68,23 @@ class PPOStats(NamedTuple): ratios: torch.Tensor clipfrac: torch.Tensor approx_policy_kls: torch.Tensor + + +@dataclass +class ChosenRejectedOutputs: + """ + Contains `concatenated_forward` outputs. + + Attributes: + chosen_logps (torch.Tensor): Log probabilities of the policy/reference model + for the chosen responses. Shape: (batch_size) + rejected_logps (torch.Tensor): Log probabilities of the policy/reference model + for the rejected responses. Shape: (batch_size) + chosen_logits (torch.Tensor): logits of the policy/reference model + rejected_logits (torch.Tensor): logits of the policy/reference model + """ + + chosen_logps: torch.Tensor + rejected_logps: torch.Tensor + chosen_logits: torch.Tensor + rejected_logits: torch.Tensor diff --git a/torchtune/rlhf/loss/dpo.py b/torchtune/rlhf/loss/dpo.py index c09d36a261..9d871d0c49 100644 --- a/torchtune/rlhf/loss/dpo.py +++ b/torchtune/rlhf/loss/dpo.py @@ -10,6 +10,10 @@ import torch.nn as nn import torch.nn.functional as F +from torchtune.rlhf._types import ChosenRejectedOutputs + +from torchtune.utils._logging import deprecated + class DPOLoss(nn.Module): """ @@ -46,23 +50,15 @@ def __init__( def forward( self, - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, + policy_inputs: ChosenRejectedOutputs, + reference_inputs: ChosenRejectedOutputs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the DPO loss for a batch of policy and reference model log probabilities. Args: - policy_chosen_logps (torch.Tensor): Log probabilities of the policy model - for the chosen responses. Shape: (batch_size) - policy_rejected_logps (torch.Tensor): Log probabilities of the policy model - for the rejected responses. Shape: (batch_size) - reference_chosen_logps (torch.Tensor): Log probabilities of the reference model - for the chosen responses. Shape: (batch_size) - reference_rejected_logps (torch.Tensor): Log probabilities of the reference model - for the rejected responses. Shape: (batch_size) + policy_inputs (ChosenRejectedOutputs): Policy log-probs and logits required for the calculation. + reference_inputs (ChosenRejectedOutputs): Reference log-probs and logits required for the calculation. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: @@ -71,8 +67,8 @@ def forward( - rejected_rewards: Rewards for the rejected responses. """ - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps + pi_logratios = policy_inputs.chosen_logps - policy_inputs.rejected_logps + ref_logratios = reference_inputs.chosen_logps - reference_inputs.rejected_logps logits = pi_logratios - ref_logratios @@ -85,15 +81,18 @@ def forward( ) chosen_rewards = ( - self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + self.beta + * (policy_inputs.chosen_logps - reference_inputs.chosen_logps).detach() ) rejected_rewards = ( - self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + self.beta + * (policy_inputs.rejected_logps - reference_inputs.rejected_logps).detach() ) return losses, chosen_rewards, rejected_rewards +@deprecated(msg="RSOLoss will be deprecated in an upcoming release.") class RSOLoss(nn.Module): """ Statistical Rejection Sampling Optimization (RSO) or "hinge" loss module: https://arxiv.org/abs/2309.06657. @@ -118,23 +117,15 @@ def __init__( def forward( self, - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, + policy_inputs: ChosenRejectedOutputs, + reference_inputs: ChosenRejectedOutputs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the RSO loss for a batch of policy and reference model log probabilities. Args: - policy_chosen_logps (torch.Tensor): Log probabilities of the policy model - for the chosen responses. Shape: (batch_size) - policy_rejected_logps (torch.Tensor): Log probabilities of the policy model - for the rejected responses. Shape: (batch_size) - reference_chosen_logps (torch.Tensor): Log probabilities of the reference model - for the chosen responses. Shape: (batch_size) - reference_rejected_logps (torch.Tensor): Log probabilities of the reference model - for the rejected responses. Shape: (batch_size) + policy_inputs (ChosenRejectedOutputs): Policy log-probs and logits required for the calculation. + reference_inputs (ChosenRejectedOutputs): Reference log-probs and logits required for the calculation. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of three tensors: @@ -143,18 +134,20 @@ def forward( - rejected_rewards: Rewards for the rejected responses. """ - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps + pi_logratios = policy_inputs.chosen_logps - policy_inputs.rejected_logps + ref_logratios = reference_inputs.chosen_logps - reference_inputs.rejected_logps logits = pi_logratios - ref_logratios losses = torch.relu(1 - self.gamma * logits) chosen_rewards = ( - self.gamma * (policy_chosen_logps - reference_chosen_logps).detach() + self.gamma + * (policy_inputs.chosen_logps - reference_inputs.chosen_logps).detach() ) rejected_rewards = ( - self.gamma * (policy_rejected_logps - reference_rejected_logps).detach() + self.gamma + * (policy_inputs.rejected_logps - reference_inputs.rejected_logps).detach() ) return losses, chosen_rewards, rejected_rewards