From 93ea657fdd5b208ce3c9ce70b83ab775171e8ccd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 10:40:05 -0500 Subject: [PATCH] remove references to triton kd for now --- src/axolotl/core/trainers/kd/__init__.py | 54 ----- src/axolotl/integrations/kd/kernels/kd.py | 249 ---------------------- 2 files changed, 303 deletions(-) delete mode 100644 src/axolotl/integrations/kd/kernels/kd.py diff --git a/src/axolotl/core/trainers/kd/__init__.py b/src/axolotl/core/trainers/kd/__init__.py index 605a52f9b3..6a0492ae05 100644 --- a/src/axolotl/core/trainers/kd/__init__.py +++ b/src/axolotl/core/trainers/kd/__init__.py @@ -40,60 +40,6 @@ def _set_signature_columns_if_needed(self): if columns_to_add: self._signature_columns += columns_to_add - # def compute_loss_w_triton( - # self, model, inputs, return_outputs=False, num_items_in_batch=None - # ): - # target_logprobs = inputs.pop("target_logprobs") - # target_token_ids = inputs.pop("target_token_ids") - # target_mask = inputs.pop("target_mask") - # - # if self.model_accepts_loss_kwargs: - # loss_kwargs = {} - # if num_items_in_batch is not None: - # loss_kwargs["num_items_in_batch"] = num_items_in_batch - # inputs = {**inputs, **loss_kwargs} - # outputs = model(**inputs) - # - # student_logits = outputs["logits"] - # # Slice or gather student logits to match teacher seq len - # # e.g.: - # teacher_seq_len = target_token_ids.shape[1] - # student_logits_for_kd = student_logits[ - # :, :teacher_seq_len, : - # ] # [B, seq_len, vocab_size] - # - # # GATHER top-K from student - # student_logits_topk = torch.gather( - # student_logits_for_kd, - # dim=-1, - # index=target_token_ids, # same shape [B, seq_len, K] - # ) - # - # # Now call the Triton-based KD loss - # kd_sum = kd_loss_triton( - # student_logits_topk, - # target_logprobs, # teacher logprobs [B, seq_len, K] - # target_mask, # mask [B, seq_len, K] - # ) - # - # # Normalize however you want - # if num_items_in_batch is not None: - # loss_kd = kd_sum / num_items_in_batch - # else: - # # or do e.g. average over valid tokens - # # quick example: - # total_valid = target_mask.sum() - # loss_kd = kd_sum / (total_valid + 1e-8) - # - # # optionally combine with CE loss - # if self.args.kd_ce_alpha > 0: - # kd_alpha = self.args.kd_alpha - # loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - # else: - # loss = loss_kd - # - # return (loss, outputs) if return_outputs else loss - def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py deleted file mode 100644 index 28a3c8e676..0000000000 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -Triton kernel for optimized kl divergence loss -""" - -import torch -import triton -import triton.language as tl - -# -------------------------------------------------------- -# Triton Kernel for forward pass -# -------------------------------------------------------- -# We'll assume: -# - B * seq_len threads in 1D dimension -# - Each thread handles K tokens (the top-K from teacher). -# - For large K, you might want a more 2D approach to keep good occupancy. -# -# Pseudocode steps inside kernel: -# 1) compute index for [batch, seq_position] -# 2) read top-K token IDs from teacher_token_ids -# 3) gather student_logits_topk -# 4) compute logsumexp for those K logits -# 5) compute student_logprobs_topk -# 6) read teacher_logprobs -# 7) compute teacher_probs = exp(teacher_logprobs) -# 8) compute partial KL = sum(teacher_probs * (teacher_logprobs - student_logprobs_topk)) -# 9) store partial KL in a buffer -# -# Later, we'll do a reduction on partial KL across all threads. -# -# NOTE: This is a reference skeleton. You must adapt indexing carefully. -# - - -@triton.jit -def kd_forward_kernel( - # student_logits after gather: [B, seq_len, K] flattened to 1D in row-major - student_logits_ptr: tl.tensor, - # teacher_logprobs: [B, seq_len, K] flattened - teacher_logprobs_ptr: tl.tensor, - # mask: [B, seq_len, K] flattened (bool or 0/1) - mask_ptr: tl.tensor, - # partial_kd: [B*seq_len] flattened buffer to store partial sums - partial_kd_ptr: tl.tensor, - B: tl.int32, # pylint: disable=invalid-name - seq_len: tl.int32, - K: tl.int32, # pylint: disable=invalid-name - BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name -): - """ - For each position in [0..B*seq_len), we: - - gather the K student logits - - compute logsumexp - - compute the KL sum = sum_{k} t_prob_k * ( t_log_k - s_logprob_k ) - - store that partial sum into partial_kd_ptr[offset]. - """ - # 1) Identify which [B*seq_len] index this block handles - pid = tl.program_id(0) - - # 2) Vector of [0..BLOCK_SIZE) local offsets - offsets = tl.arange(0, BLOCK_SIZE) - # 3) Global indices = pid * BLOCK_SIZE + offsets - idx = pid * BLOCK_SIZE + offsets - - # 4) Mask to ensure we don’t read out-of-bounds - total_positions = B * seq_len - mask_pos = idx < total_positions - - # 5) Convert a 1D `idx` => (b_idx, s_idx) - # b_idx is the batch number, s_idx is the sequence position - b_idx = idx // seq_len - s_idx = idx % seq_len - - # We'll accumulate the KL for each index in a register array - kl_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - - # ------------------------------------------------------------------------- - # First pass: find max logits over K to implement logsumexp - # ------------------------------------------------------------------------- - max_val = tl.full([BLOCK_SIZE], -1e30, dtype=tl.float32) - - # Python-level loops are allowed in Triton as long as the - # operations inside are Triton ops, not torch or Python math. - for k in range(K): - # pointer offset in the flattened [B, seq_len, K] = b_idx*(seq_len*K) + s_idx*K + k - offset_k = b_idx * (seq_len * K) + s_idx * K + k - - # load student logits, masked out-of-bounds with a large negative - # so they don't affect the max - student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30) - # update running max - max_val = tl.where(student_val > max_val, student_val, max_val) - - # ------------------------------------------------------------------------- - # Second pass: sum of exp(...) to complete logsumexp - # ------------------------------------------------------------------------- - exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for k in range(K): - offset_k = b_idx * (seq_len * K) + s_idx * K + k - student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30) - # exponent - exponent = tl.exp(student_val - max_val) - exp_sum += exponent - - # final logsumexp - logsumexp_val = max_val + tl.log(exp_sum) - - # ------------------------------------------------------------------------- - # Third pass: compute partial KL per position - # KL = sum_{k in valid} p^T_k * (teacher_logprobs_k - student_logprobs_k) - # - # - teacher_logprobs_k => t_log - # - teacher_prob_k = exp(t_log) - # - student_logprobs_k = s_val - logsumexp_val - # ------------------------------------------------------------------------- - for k in range(K): - offset_k = b_idx * (seq_len * K) + s_idx * K + k - # teacher logprobs - t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30) - # teacher prob - t_prob = tl.exp(t_log) - - # student logit - s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30) - # student logprob - s_logprob = s_val - logsumexp_val - - # local KL - kl_val = t_prob * (t_log - s_logprob) - - # also read mask to disable invalid tokens if mask is not purely sequence-based - valid_k = tl.load(mask_ptr + offset_k) - # if mask is bool => use 'valid_k != 0', if it's 0/1 => same - is_valid = valid_k > 0 - - # zero out if either this index is out-of-bounds or mask is invalid - kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0) - - # accumulate - kl_acc += kl_val - - # ------------------------------------------------------------------------- - # Store the partial KL in partial_kd_ptr for each element in idx. - # Later in Python, you can do partial_kd.sum() to get the total KL. - # ------------------------------------------------------------------------- - tl.store(partial_kd_ptr + idx, kl_acc, mask=mask_pos) - - -def kd_forward_pass_triton( - student_logits, # [B, seq_len, K] (already gathered) - teacher_logprobs, # [B, seq_len, K] - mask, # [B, seq_len, K] bool or 0/1 - BLOCK_SIZE=1024, # pylint: disable=invalid-name -): - """ - Returns total KL (float). We do the sum on the Python side. - NOTE: No normalization is done here. - You might divide by `num_items_in_batch` or # valid tokens afterward. - """ - B, seq_len, K = student_logits.shape # pylint: disable=invalid-name - # Flatten - student_logits_flat = student_logits.reshape(-1) - teacher_logprobs_flat = teacher_logprobs.reshape(-1) - mask_flat = mask.reshape(-1) - - total_positions = B * seq_len - # We'll store partial KL sums for each of the B*seq_len positions - partial_kd = torch.empty( - total_positions, dtype=student_logits.dtype, device=student_logits.device - ) - - # Grid config - grid = ((total_positions + BLOCK_SIZE - 1) // BLOCK_SIZE,) - - kd_forward_kernel[grid]( - student_logits_flat, - teacher_logprobs_flat, - mask_flat, - partial_kd, - B, - seq_len, - K, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Sum on CPU or GPU - kd_sum = partial_kd.sum() - return kd_sum - - -class _KLDivergenceTritonFn(torch.autograd.Function): - @staticmethod - def forward(ctx, student_logits, teacher_logprobs, mask): - """ - student_logits: (B, seq_len, K) - teacher_logprobs: (B, seq_len, K) - mask: (B, seq_len, K) - """ - kd_sum = kd_forward_pass_triton(student_logits, teacher_logprobs, mask) - kd_loss = kd_sum # Not normalized here. You can do that externally. - - # Save for backward - ctx.save_for_backward(student_logits, teacher_logprobs, mask) - return kd_loss - - @staticmethod - def backward(ctx, grad_output): - # We'll do naive PyTorch re-computation for gradient wrt student_logits - student_logits, teacher_logprobs, mask = ctx.saved_tensors - # grad_output is dLoss/dOut => a scalar - # Let’s compute dLoss/dStudentLogits with the same formula as your original code - - with torch.enable_grad(): - stl = student_logits.clone().detach().requires_grad_(True) - t_log = teacher_logprobs - # mask might be bool or 0/1 - # compute logsumexp - lse = torch.logsumexp(stl, dim=-1, keepdim=True) - s_logprob = stl - lse - t_prob = t_log.exp() - - # forward KL = sum_{k} p^T_k ( t_log_k - s_logprob_k ) - kl_val = t_prob * (t_log - s_logprob) - # mask out - kl_val = kl_val * mask # zero out invalid - - kd_loss = kl_val.sum() - # now compute dLoss/d stl - grad_stl = torch.autograd.grad(kd_loss, stl, grad_outputs=grad_output)[0] - - return grad_stl, None, None - - -def kd_loss_triton( - student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K - teacher_logprobs, - mask, - num_items_in_batch=None, # pylint: disable=unused-argument -): - """ - Wrapper that calls our Triton-based forward+backward for KD. - For production, you likely want to do the gather (teacher top-K) outside - or inside a separate kernel. This function expects that you've *already* - called gather on student_logits -> shape [B, seq_len, K]. - """ - return _KLDivergenceTritonFn.apply( - student_logits, - teacher_logprobs, - mask, # num_items_in_batch - )