Skip to content

Commit

Permalink
remove references to triton kd for now
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 30, 2024
1 parent f704e86 commit 93ea657
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 303 deletions.
54 changes: 0 additions & 54 deletions src/axolotl/core/trainers/kd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,60 +40,6 @@ def _set_signature_columns_if_needed(self):
if columns_to_add:
self._signature_columns += columns_to_add

# def compute_loss_w_triton(
# self, model, inputs, return_outputs=False, num_items_in_batch=None
# ):
# target_logprobs = inputs.pop("target_logprobs")
# target_token_ids = inputs.pop("target_token_ids")
# target_mask = inputs.pop("target_mask")
#
# if self.model_accepts_loss_kwargs:
# loss_kwargs = {}
# if num_items_in_batch is not None:
# loss_kwargs["num_items_in_batch"] = num_items_in_batch
# inputs = {**inputs, **loss_kwargs}
# outputs = model(**inputs)
#
# student_logits = outputs["logits"]
# # Slice or gather student logits to match teacher seq len
# # e.g.:
# teacher_seq_len = target_token_ids.shape[1]
# student_logits_for_kd = student_logits[
# :, :teacher_seq_len, :
# ] # [B, seq_len, vocab_size]
#
# # GATHER top-K from student
# student_logits_topk = torch.gather(
# student_logits_for_kd,
# dim=-1,
# index=target_token_ids, # same shape [B, seq_len, K]
# )
#
# # Now call the Triton-based KD loss
# kd_sum = kd_loss_triton(
# student_logits_topk,
# target_logprobs, # teacher logprobs [B, seq_len, K]
# target_mask, # mask [B, seq_len, K]
# )
#
# # Normalize however you want
# if num_items_in_batch is not None:
# loss_kd = kd_sum / num_items_in_batch
# else:
# # or do e.g. average over valid tokens
# # quick example:
# total_valid = target_mask.sum()
# loss_kd = kd_sum / (total_valid + 1e-8)
#
# # optionally combine with CE loss
# if self.args.kd_ce_alpha > 0:
# kd_alpha = self.args.kd_alpha
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
# else:
# loss = loss_kd
#
# return (loss, outputs) if return_outputs else loss

def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
Expand Down
249 changes: 0 additions & 249 deletions src/axolotl/integrations/kd/kernels/kd.py

This file was deleted.

0 comments on commit 93ea657

Please sign in to comment.