Skip to content

Commit

Permalink
handle token/logprob shifting
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 30, 2024
1 parent 93ea657 commit 9801b45
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
27 changes: 19 additions & 8 deletions src/axolotl/core/trainers/kd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def _set_signature_columns_if_needed(self):
self._signature_columns += columns_to_add

def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Expand All @@ -65,16 +70,22 @@ def compute_loss(
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous()

shift_logits = student_logits[..., :-1, :].contiguous()
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
shift_target_mask = target_mask[..., 1:, :].contiguous()
if shift_targets:
shift_logits = student_logits[..., :-1, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()

loss_kd = topk_kd_loss(
shift_logits,
shift_target_token_ids,
shift_target_logprobs,
shift_target_mask,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def transform_logprobs(self, sample):

# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(input_padding_len):
for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
Expand Down

0 comments on commit 9801b45

Please sign in to comment.