Skip to content

Commit

Permalink
fix small typos
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanlop committed Oct 11, 2024
1 parent 6e6b6b6 commit 38e41f8
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions trl/trainer/cgpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,15 @@ def _get_batch_logprobs(
return logprobs[:, context_length - 1 : -1].sum(1)

def crpg_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
bs = inputs["bs"]
mini_bs = self.local_genscore_mini_batch_size
context_length = inputs["context_length"]
prompt_completion_ids = inputs["prompt_completion_ids"]
full_bs = prompt_completion_ids.shape[0]
prompt_completion_mask = inputs["prompt_completion_mask"]
judgements = inputs["judgements"]
rewards = inputs["rewards"]
completion_logprobs = inputs["completion_logprobs"]
bs = inputs["bs"]
mini_bs = self.local_genscore_mini_batch_size
full_bs = prompt_completion_ids.shape[0]

with torch.no_grad():
_, baseline_rewards, _ = get_reward(
Expand Down Expand Up @@ -439,7 +439,6 @@ def codpo_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> tor
def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
prompt_baseline_ids = inputs["input_ids"]
prompt_baseline_mask = inputs["attention_mask"]

rewards = inputs["rewards"]
judgements = inputs["judgements"]
prompt_completion_ids = inputs["prompt_completion_ids"]
Expand All @@ -448,6 +447,7 @@ def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> to

bs = inputs["bs"]
mini_bs = self.local_genscore_mini_batch_size
full_bs = prompt_completion_ids.shape[0]
context_length = inputs["context_length"]

# get baseline rewards
Expand All @@ -466,7 +466,7 @@ def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> to

# compute kl div regularizer
ref_logprobss = []
for i in range(0, prompt_completion_ids.shape[0], mini_bs):
for i in range(0, full_bs, mini_bs):
mini_batch_ids = prompt_completion_ids[i : i + mini_bs]
mini_batch_mask = prompt_completion_mask[i : i + mini_bs]
with torch.no_grad():
Expand Down Expand Up @@ -501,8 +501,7 @@ def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> to
no_positive_completion_mask = best_rewards == 0
if no_positive_completion_mask.any():
# only compute the baseline logprobs if we need to (ie some prompts do not have any positive samples)
# the baseline calibrated reward is always equal to 0.5 ie sigmoid(baseline_rewards - baselines_rewards)
# Check that the baseline satisfy all constraints: judgements and kl
# check that the baseline satisfy all constraints: judgements and kl
baseline_judgements = torch.zeros_like(best_rewards, dtype=torch.bool)
prompts_text = [text for i, text in enumerate(inputs["prompts_text"]) if no_positive_completion_mask[i]]
baseline_completions_text = [
Expand Down Expand Up @@ -543,7 +542,7 @@ def crraft_optimization(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> to
# compute loss as done in eqn (18) of the CGPO paper: https://huggingface.co/papers/2409.20370
losses = -best_completion_logprobs * best_rewards
# simulate skipping samples instead of using .mean()
loss = losses.sum() / ((rewards != 0).sum() + self.epsilon)
loss = losses.sum() / ((best_rewards != 0).sum() + self.epsilon)

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Expand Down Expand Up @@ -593,7 +592,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,

rewards = []
for i in range(0, prompt_completion_ids.shape[0], self.local_genscore_mini_batch_size):
# Compute rewards on a mini-batch of size `bs`, instead of the full batch (`bs` * self.k)
mini_batch_prompt_completion_ids = prompt_completion_ids[i : i + self.local_genscore_mini_batch_size]
with torch.no_grad():
_, mini_batch_rewards, _ = get_reward(
Expand Down

0 comments on commit 38e41f8

Please sign in to comment.