From 9801b45f2bbaaf09f69e7dab8682fd3d4a500378 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 11:21:19 -0500 Subject: [PATCH] handle token/logprob shifting --- src/axolotl/core/trainers/kd/__init__.py | 27 +++++++++++++------ .../prompt_strategies/chat_template.py | 2 +- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/trainers/kd/__init__.py b/src/axolotl/core/trainers/kd/__init__.py index 6a0492ae0..ad68055c9 100644 --- a/src/axolotl/core/trainers/kd/__init__.py +++ b/src/axolotl/core/trainers/kd/__init__.py @@ -41,7 +41,12 @@ def _set_signature_columns_if_needed(self): self._signature_columns += columns_to_add def compute_loss( - self, model, inputs, return_outputs=False, num_items_in_batch=None + self, + model, + inputs, + return_outputs=False, + num_items_in_batch=None, + shift_targets=False, ): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -65,16 +70,22 @@ def compute_loss( # FIXME: account for tokenizer.padding_side student_logits = outputs["logits"][:, :seq_len, :].contiguous() - shift_logits = student_logits[..., :-1, :].contiguous() - shift_target_logprobs = target_logprobs[..., 1:, :].contiguous() - shift_target_token_ids = target_token_ids[..., 1:, :].contiguous() - shift_target_mask = target_mask[..., 1:, :].contiguous() + if shift_targets: + shift_logits = student_logits[..., :-1, :].contiguous() + target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() + target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() + target_mask_for_loss = target_mask[..., 1:, :].contiguous() + else: + shift_logits = student_logits.contiguous() + target_logprobs_for_loss = target_logprobs.contiguous() + target_token_ids_for_loss = target_token_ids.contiguous() + target_mask_for_loss = target_mask.contiguous() loss_kd = topk_kd_loss( shift_logits, - shift_target_token_ids, - shift_target_logprobs, - shift_target_mask, + target_token_ids_for_loss, + target_logprobs_for_loss, + target_mask_for_loss, num_items_in_batch=num_items_in_batch, kd_temperature=self.args.kd_temperature, ) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index ba524eb48..1bf9b9e43 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -502,7 +502,7 @@ def transform_logprobs(self, sample): # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - for _ in range(input_padding_len): + for _ in range(1, input_padding_len): # start at 1 since this is causal target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k)