Skip to content

Commit

Permalink
Add SimPO+WPO+DPO norm loss (#237)
Browse files Browse the repository at this point in the history
* Add SimPO loss

* lint

* whoops, simpo needs logp avging

* first pass at WPO

* lint

* add dpo_norm as an option
  • Loading branch information
hamishivi authored Aug 14, 2024
1 parent cb40645 commit bd720d9
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 21 deletions.
25 changes: 25 additions & 0 deletions configs/train_configs/dpo/tulu_3_preview_test_simpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model_name_or_path: /model
model_revision: main
use_flash_attn: true
gradient_checkpointing: true
tokenizer_name: /model
use_slow_tokenizer: true
dataset_name: princeton-nlp/llama3-ultrafeedback-armorm
max_seq_length: 2048
preprocessing_num_workers: 16
per_device_train_batch_size: 1
gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128
learning_rate: 5.0e-7
lr_scheduler_type: linear
warmup_ratio: 0.1
weight_decay: 0.0
num_train_epochs: 1
output_dir: /output
with_tracking: true
report_to:
- wandb
logging_steps: 1
use_lora: false
dpo_loss_type: simpo
dpo_gamma_beta_ratio: 0.3
dpo_beta: 10
69 changes: 51 additions & 18 deletions open_instruct/dpo_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
DataCollatorForSeq2SeqDPO,
concatenated_forward,
dpo_loss,
simpo_loss,
wpo_loss,
)
from open_instruct.utils import (
ArgumentParserPlus,
Expand Down Expand Up @@ -317,10 +319,14 @@ def load_model():
return model

model = load_model()
if not args.use_lora:
reference_model = load_model()
# only simpo is reference model free rn
if args.dpo_loss_type != "simpo":
if not args.use_lora:
reference_model = load_model()
else:
reference_model = model
else:
reference_model = model
reference_model = None

# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
Expand Down Expand Up @@ -498,7 +504,8 @@ def load_model():
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if not args.use_lora:
# reference model may not be none with e.g. SimPO loss.
if not args.use_lora and reference_model is not None:
reference_model = prepare_deepspeed(accelerator, reference_model)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
Expand Down Expand Up @@ -583,23 +590,49 @@ def load_model():
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
# we need to average the log probs for simpo loss
average_log_prob_loss_types = ["simpo", "dpo_norm"]
average_log_prob = args.dpo_loss_type in average_log_prob_loss_types
for step, batch in enumerate(active_dataloader):
# dpo forward pass & loss
with accelerator.accumulate(model):
policy_chosen_logps, policy_rejected_logps = concatenated_forward(model, batch)
with torch.no_grad():
if args.use_lora:
with accelerator.unwrap_model(model).disable_adapter():
reference_chosen_logps, reference_rejected_logps = concatenated_forward(model, batch)
else:
reference_chosen_logps, reference_rejected_logps = concatenated_forward(reference_model, batch)
losses, _, _ = dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
beta=args.dpo_beta,
)
policy_chosen_logps, policy_rejected_logps = concatenated_forward(model, batch, average_log_prob=average_log_prob)
if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm":
with torch.no_grad():
if args.use_lora:
with accelerator.unwrap_model(model).disable_adapter():
reference_chosen_logps, reference_rejected_logps = concatenated_forward(model, batch, average_log_prob=average_log_prob)
else:
reference_chosen_logps, reference_rejected_logps = concatenated_forward(reference_model, batch, average_log_prob=average_log_prob)
losses, _, _ = dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
beta=args.dpo_beta,
label_smoothing=args.dpo_label_smoothing,
)
elif args.dpo_loss_type == "simpo":
losses, _, _ = simpo_loss(
policy_chosen_logps,
policy_rejected_logps,
beta=args.dpo_beta,
gamma_beta_ratio=args.dpo_gamma_beta_ratio,
label_smoothing=args.dpo_label_smoothing,
)
elif args.dpo_loss_type == "wpo":
losses, _, _ = wpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
beta=args.dpo_beta,
label_smoothing=args.dpo_label_smoothing,
chosen_loss_mask=batch["chosen_labels"] != -100,
rejected_loss_mask=batch["rejected_labels"] != -100,
)
else:
raise ValueError(f"Invalid dpo loss type {args.dpo_loss_type}.")
# TODO: metric logging
loss = losses.mean()
# We keep track of the loss at each logged step
Expand Down
72 changes: 69 additions & 3 deletions open_instruct/dpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def dpo_loss(
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False,
label_smoothing: float = 0.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
Expand Down Expand Up @@ -66,13 +67,78 @@ def dpo_loss(

logits = pi_logratios - ref_logratios

losses = -F.logsigmoid(beta * logits)
losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

return losses, chosen_rewards, rejected_rewards


def wpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
label_smoothing: float = 0.0,
chosen_loss_mask: torch.BoolTensor = None,
rejected_loss_mask: torch.BoolTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps

# compute average logps and use them to compute the weights
policy_chosen_logps_average = (policy_chosen_logps * chosen_loss_mask).sum(-1) / chosen_loss_mask.sum(-1)
policy_rejected_logps_average = (policy_rejected_logps * rejected_loss_mask).sum(-1) / rejected_loss_mask.sum(-1)
policy_weights = torch.clamp(torch.exp(policy_chosen_logps_average + policy_rejected_logps_average), max=1)

logits = pi_logratios - ref_logratios

losses = (
-F.logsigmoid(beta * logits) * (1 - label_smoothing) * policy_weights
- F.logsigmoid(-beta * logits) * label_smoothing * policy_weights
)

chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

return losses, chosen_rewards, rejected_rewards


# From https://github.com/princeton-nlp/SimPO/blob/main/scripts/simpo_trainer.py#L560C1-L595C56
def simpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
beta: float,
gamma_beta_ratio: float,
label_smoothing: float = 0.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the SimPO loss for a batch of policy model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the SimPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
logits = pi_logratios - gamma_beta_ratio

# sigmoid loss type from SimPO.
losses = (
-F.logsigmoid(beta * logits) * (1 - label_smoothing)
- F.logsigmoid(-beta * logits) * label_smoothing
)

chosen_rewards = beta * policy_chosen_logps.detach()
rejected_rewards = beta * policy_rejected_logps.detach()

return losses, chosen_rewards, rejected_rewards


def _get_batch_logps(
logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False
) -> torch.FloatTensor:
Expand Down Expand Up @@ -139,7 +205,7 @@ def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict


def concatenated_forward(
model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]], average_log_prob: bool = False
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
Expand All @@ -150,7 +216,7 @@ def concatenated_forward(
input_ids=concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
).logits.to(torch.float32)
all_logps = _get_batch_logps(all_logits, concatenated_batch["concatenated_labels"], average_log_prob=False)
all_logps = _get_batch_logps(all_logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob)
chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]]
rejected_logps = all_logps[batch["chosen_input_ids"].shape[0] :]
return chosen_logps, rejected_logps
Expand Down
12 changes: 12 additions & 0 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ class FlatArguments:
default=0.1,
metadata={"help": "Beta parameter for DPO loss. Default is 0.1."},
)
dpo_loss_type: str = field(
default="dpo",
metadata={"help": "Type of DPO loss to use. Options are 'dpo', 'dpo_norm', 'simpo', 'wpo'."},
)
dpo_gamma_beta_ratio: float = field(
default=0.3,
metadata={"help": "Gamma to beta ratio for SimPO loss. Default is 0.3. Not used for DPO loss."},
)
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "Label smoothing for DPO/SimPO loss. Default is 0 (no smoothing)."},
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
Expand Down

0 comments on commit bd720d9

Please sign in to comment.