Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor chunked preference functions and distillation base class #491

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
18 changes: 10 additions & 8 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
def preference_loss_fn(
chosen_logps_chunk, rejected_logps_chunk, full_target, beta=0.1
):
"""
Paper: https://arxiv.org/pdf/2401.08417

Expand All @@ -26,14 +28,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
- D: Dataset of preferences

Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
chosen_logps_chunk (torch.Tensor): Avg log probabilities of chosen tokens in the chunk. Shape: (batch_size,).
rejected_logps_chunk (torch.Tensor): Avg log probabilities of rejected tokens in the chunk. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder is it full_target or actually target_chunk?

From the fused function, we are feeding into target_chunk

        def fused_fwd_bwd(
            input_chunk, target_chunk, ref_input_chunk, preference_labels_chunk
        ):
            """
            Fused forward and backward pass for a chunk of input and target.
            """
            if bias is not None:
                return torch.func.grad_and_value(
                    compute_loss, argnums=(0, 1, 3), has_aux=True
                )(
                    input_chunk,
                    weight,
                    target_chunk,
                    bias,
                    ref_input_chunk=ref_input_chunk,
                    preference_labels=preference_labels_chunk,
                )

beta (float): Weight for the CPO loss.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
return loss
logits_chunk = beta * (chosen_logps_chunk - rejected_logps_chunk)
loss_chunk = F.logsigmoid(logits_chunk).sum() / (full_target.shape[0] // 2)
return loss_chunk

@staticmethod
def forward(
Expand Down
38 changes: 21 additions & 17 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(
chosen_logps,
rejected_logps,
chosen_logps_chunk,
rejected_logps_chunk,
full_target,
ref_chosen_logps=None,
ref_rejected_logps=None,
ref_chosen_logps_chunk=None,
ref_rejected_logps_chunk=None,
beta=0.1,
):
"""
Expand All @@ -32,25 +32,29 @@ def preference_loss_fn(
- E: Expected value over the dataset

Args:
chosen_logps: Log probabilities of chosen tokens (batch_size,)
rejected_logps: Log probabilities of rejected tokens (batch_size,)
chosen_logps_chunk: Log probabilities of chosen tokens in the chunk (batch_size,)
rejected_logps_chunk: Log probabilities of rejected tokens in the chunk (batch_size,)
full_target: Non chunked full target tensor
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
ref_chosen_logps_chunk: Reference log probs of chosen tokens in the chunk (batch_size,)
ref_rejected_logps_chunk: Reference log probs of rejected tokens in the chunk (batch_size,)
beta: Weight for the direct preference loss
"""

if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
if ref_rejected_logps is None:
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
if ref_chosen_logps_chunk is None:
ref_chosen_logps_chunk = torch.tensor(0.0, device=chosen_logps_chunk.device)
if ref_rejected_logps_chunk is None:
ref_rejected_logps_chunk = torch.tensor(
0.0, device=rejected_logps_chunk.device
)

chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps
chosen_logratios_chunk = chosen_logps_chunk - ref_chosen_logps_chunk
rejected_logratios_chunk = rejected_logps_chunk - ref_rejected_logps_chunk

logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
return loss
logits_diff_chunk = beta * (chosen_logratios_chunk - rejected_logratios_chunk)
loss_chunk = -F.logsigmoid(logits_diff_chunk).sum() / (
full_target.shape[0] // 2
)
return loss_chunk

@staticmethod
def forward(
Expand Down
Loading
Loading