Skip to content

Commit

Permalink
crpg mini batch
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop committed Oct 7, 2024
1 parent a588956 commit 4f222c4
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 25 deletions.
46 changes: 46 additions & 0 deletions tests/test_cgpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,52 @@ def test_cgpo_trainer(self, rlhf_optimizer):
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

@parameterized.expand(["crraft", "crpg", "codpo"])
def test_cgpo_trainer_with_missing_eos_penalty(self, rlhf_optimizer):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CGPOConfig(
output_dir=tmp_dir,
rlhf_optimizer=rlhf_optimizer,
k=4,
missing_eos_penalty=1.0,
kl_threshold=5.0,
temperature=0.9,
max_new_tokens=4,
per_device_train_batch_size=4,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")

trainer = CGPOTrainer(
model=self.model,
ref_model=self.ref_model,
reward_model=self.reward_model,
mixture_of_judges=self.moj,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

def test_cgpo_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CGPOConfig(
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/cgpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class CGPOConfig(TrainingArguments):
lamb (`float`, *optional*, defaults to `5.0`):
Only used when rlhf_optimizer is set to `codpo`.
Parameter controlling the importance of the regularization term added to the vanilla DPO loss.
local_generation_batch_size (`int`, *optional*, defaults to `None`):
The size of the local mini-batch used during the generation phase.
max_new_tokens (`int`, *optional*, defaults to `64`):
Maximum number of tokens to generate per completion.
max_length (`int`, *optional*, defaults to `None`):
Expand All @@ -59,6 +61,7 @@ class CGPOConfig(TrainingArguments):
kl_threshold: float = None
beta: float = 0.1
lamb: float = 5.0
local_generation_batch_size: int = None
max_new_tokens: int = 64
max_length: int = None
temperature: float = 0.9
Expand Down
91 changes: 66 additions & 25 deletions trl/trainer/cgpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def __init__(
self.beta = args.beta
self.kl_threshold = args.kl_threshold
self.lamb = args.lamb
self.local_generation_batch_size = (
args.local_generation_batch_size if args.local_generation_batch_size else args.per_device_train_batch_size
)
self._tag_names.append(args.rlhf_optimizer)

super().__init__(
Expand Down Expand Up @@ -283,27 +286,46 @@ def crpg_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torc
context_length,
)

baseline_rewards = baseline_rewards.repeat_interleave(repeats=self.k, dim=0)
inputs["baseline_rewards"] = baseline_rewards.repeat_interleave(repeats=self.k, dim=0)

rewards = inputs["rewards"]
judgements = inputs["judgements"]
total_loss = torch.tensor(0.0, device=self.model.device)
for i in range(self.k):
# simulate gradient accumulation to avoid computing OOM with a batch of size bs* self.k
mini_batch_rewards = inputs["rewards"][i * bs : (i + 1) * bs]
mini_batch_judgements = inputs["judgements"][i * bs : (i + 1) * bs]
mini_batch_prompt_completion_ids = inputs["prompt_completion_ids"][i * bs : (i + 1) * bs]
mini_batch_prompt_completion_mask = inputs["prompt_completion_mask"][i * bs : (i + 1) * bs]
mini_batch_baseline_rewards = inputs["baseline_rewards"][i * bs : (i + 1) * bs]

calibrated_rewards = torch.sigmoid(rewards - baseline_rewards)
mini_batch_calibrated_rewards = torch.sigmoid(mini_batch_rewards - mini_batch_baseline_rewards)

# compute kl_divergence
logprobs, ref_logprobs = self._get_batch_logprobs(
inputs["prompt_completion_ids"], inputs["prompt_completion_mask"], context_length
)
# compute kl_divergence
logprobs, ref_logprobs = self._get_batch_logprobs(
mini_batch_prompt_completion_ids, mini_batch_prompt_completion_mask, context_length
)

with torch.no_grad():
# kl_div is used as a regularization term here
kl_div = logprobs - ref_logprobs
kl_div_regularization = torch.clamp(1 - kl_div / self.kl_threshold, min=0)
calibrated_regularized_rewards = judgements * calibrated_rewards * kl_div_regularization
with torch.no_grad():
# kl_div is used as a regularization term here
kl_div = logprobs - ref_logprobs
kl_div_regularization = torch.clamp(1 - kl_div / self.kl_threshold, min=0)

calibrated_regularized_rewards = (
mini_batch_judgements * mini_batch_calibrated_rewards * kl_div_regularization
)

losses = -logprobs * (calibrated_regularized_rewards - calibrated_regularized_rewards.mean())
losses = -logprobs * (calibrated_regularized_rewards - calibrated_regularized_rewards.mean())

return losses.sum() / bs
loss = losses.mean() / self.k

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)

total_loss += loss

return total_loss

def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""Implementation of the Calibrated Regularized Reward Ranking Finetuning (CRRAFT) policy opttimizer."""
Expand All @@ -324,6 +346,11 @@ def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> to
inputs["context_length"],
)

if self.args.missing_eos_penalty is not None:
baseline_ids = inputs["prompt_baseline_ids"][:, context_length:]
contain_eos_token = torch.any(baseline_ids == self.tokenizer.eos_token_id, dim=-1)
baseline_rewards[~contain_eos_token] -= self.args.missing_eos_penalty

baseline_judgements = self.moj.judge(inputs["prompt"], inputs["completion"])
baseline_judgements = torch.tensor(baseline_judgements, device=self.model.device, dtype=torch.bool)

Expand Down Expand Up @@ -390,12 +417,20 @@ def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> to
losses = -logprobs * filtered_calibrated_rewards

# simulate skipping samples instead of using .mean()
return (
loss = (
losses.sum() / filtered_calibrated_rewards.sum()
if filtered_calibrated_rewards.sum() != 0
else losses.sum()
)

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)

return loss

def codpo_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
bs = inputs["bs"]
context_length = inputs["context_length"]
Expand Down Expand Up @@ -434,7 +469,15 @@ def codpo_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> tor
# eqn (14) in the paper
losses = -(F.logsigmoid(self.beta * logits) + self.lamb / chosen_length * chosen_logprobs)

return losses.mean()
loss = losses.mean()

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)

return loss

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
bs, context_length = inputs["prompt_ids"].shape
Expand All @@ -445,7 +488,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
query_responses_ids, _ = batch_generation(
unwrapped_model,
prompt_ids,
bs,
self.local_generation_batch_size,
self.tokenizer.pad_token_id,
self.generation_config,
)
Expand All @@ -466,6 +509,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,

rewards = []
for i in range(0, prompt_completion_ids.shape[0], bs):
# operate on a mini batch of size batch_size instead of size batch_size * self.k
mini_batch_prompt_completion_ids = prompt_completion_ids[i : i + bs]
with torch.no_grad():
_, mini_batch_rewards, _ = get_reward(
Expand All @@ -477,8 +521,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,

rewards.append(mini_batch_rewards)

# TO DO: Add penalty for samples that do not contain the eos token
rewards = torch.cat(rewards, dim=0)
# Completions that do not contain an eos token id are penalized.
if self.args.missing_eos_penalty is not None:
contain_eos_token = torch.any(completion_ids == self.tokenizer.eos_token_id, dim=-1)
rewards[~contain_eos_token] -= self.args.missing_eos_penalty

inputs["rewards"] = rewards
inputs["judgements"] = torch.tensor(judgements, device=self.model.device, dtype=torch.float)
Expand All @@ -499,12 +546,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
self.stats["constraints/judgements"].append(self.accelerator.gather(inputs["judgements"]).mean().item())
self.stats["constraints/rewards"].append(self.accelerator.gather(inputs["rewards"]).mean().item())

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)

return loss.detach() / self.args.gradient_accumulation_steps

# Same as Trainer.evaluate but log our metrics
Expand Down

0 comments on commit 4f222c4

Please sign in to comment.