Skip to content

Commit

Permalink
implement length-regularized DPO (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkevinzc authored Dec 21, 2024
1 parent ed7d3b1 commit 37becae
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
2 changes: 2 additions & 0 deletions oat/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class OATArgs:
label_smoothing: float = 0
# SimPO https://arxiv.org/pdf/2405.14734.
gamma_beta_ratio: float = 0.5
# Length-Regularized DPO https://arxiv.org/pdf/2403.19159.
len_reg_alpha: float = 0.0

# Oracle.
preference_oracle: str = "pairrm"
Expand Down
22 changes: 15 additions & 7 deletions oat/learners/dap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def _init(self, args, actors) -> None:
super()._init(args, actors)

if self.algo in [DAPAlgo.DPO, DAPAlgo.LR_DPO, DAPAlgo.IPO, DAPAlgo.SLiC]:
self.loss = DPOLoss(args.beta, args.label_smoothing, dap_algo=self.algo)
self.loss = DPOLoss(
args.beta, args.label_smoothing, args.len_reg_alpha, dap_algo=self.algo
)
elif self.algo == DAPAlgo.SimPO:
self.loss = SimPOLoss(
args.beta, args.gamma_beta_ratio, args.label_smoothing
Expand Down Expand Up @@ -74,13 +76,13 @@ def learning_step(self, data):
)

else:
chosen_logps, rejected_logps, _ = self.concatenated_forward(
chosen_logps, rejected_logps, _, token_masks = self.concatenated_forward(
self.model, chosen_ids, c_mask, rejected_ids, r_mask, prompt_id_lens
)

if self.ref_model is not None:
with torch.no_grad():
reference_chosen_logps, reference_rejected_logps, _ = (
reference_chosen_logps, reference_rejected_logps, _, _ = (
self.concatenated_forward(
self.ref_model,
chosen_ids,
Expand All @@ -96,6 +98,7 @@ def learning_step(self, data):
reference_chosen_logps,
reference_rejected_logps,
loss_masks,
token_masks,
)
else:
preference_loss, chosen_reward, rejected_reward = self.loss(
Expand Down Expand Up @@ -128,7 +131,7 @@ def concatenated_forward(

if self.algo != DAPAlgo.BNF:

all_logps = self.get_batch_logps(
all_logps, token_masks = self.get_batch_logps(
all_logits,
input_ids,
att_masks,
Expand All @@ -140,7 +143,12 @@ def concatenated_forward(
rejected_logps = all_logps[chosen_ids.shape[0] :]
aux_loss = output.aux_loss if "aux_loss" in output else []

return chosen_logps, rejected_logps, aux_loss
return (
chosen_logps,
rejected_logps,
aux_loss,
token_masks,
)

else:

Expand Down Expand Up @@ -226,9 +234,9 @@ def get_batch_logps(
length = loss_masks.sum(-1)

if average_log_prob:
return (target_logps * loss_masks).sum(-1) / length
return (target_logps * loss_masks).sum(-1) / length, loss_masks
else:
return (target_logps * loss_masks).sum(-1)
return (target_logps * loss_masks).sum(-1), loss_masks

else:

Expand Down
10 changes: 10 additions & 0 deletions oat/learners/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ def __init__(
self,
beta: float,
label_smoothing: float = 0.0,
len_reg_alpha: float = 0.0,
dap_algo=DAPAlgo.DPO,
) -> None:
super().__init__()
self.beta = beta
self.label_smoothing = label_smoothing
self.dap_algo = dap_algo
self.len_reg_alpha = len_reg_alpha

def forward(
self,
Expand All @@ -44,6 +46,7 @@ def forward(
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
loss_masks: torch.Tensor,
token_masks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand All @@ -56,6 +59,13 @@ def forward(
elif self.dap_algo == DAPAlgo.SLiC:
losses = torch.relu(1 - self.beta * logits)
else:
if self.len_reg_alpha > 0:
y_length = token_masks.sum(-1)
length_diff = (
y_length[: len(y_length) // 2] - y_length[len(y_length) // 2 :]
)
# Eq. 9 https://arxiv.org/pdf/2403.19159; Length Reg in loss.
logits += self.len_reg_alpha / self.beta * length_diff
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
Expand Down

0 comments on commit 37becae

Please sign in to comment.