From 3d3fdc4f8cc250796ad45eeef5c296fcb4fa78ad Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Sun, 11 Aug 2024 15:36:38 -0400 Subject: [PATCH] add comments in scattermoe (#5) --- kernel_hyperdrive/__init__.py | 9 +- .../scattermoe/torch_implementation.py | 88 +++-- .../triton_implementation/__init__.py | 4 +- .../triton_implementation/kernel.py | 329 +----------------- .../scattermoe/triton_implementation/ops.py | 313 +++++++++++++++++ kernel_hyperdrive/utils.h | 9 - kernel_hyperdrive/utils.py | 1 + kernel_hyperdrive/vector_addition/__init__.py | 6 +- .../cuda_implementation/__init__.py | 10 +- .../cuda_implementation/vector_addition.cpp | 12 +- .../vector_addition/torch_implementation.py | 6 - .../triton_implementation/__init__.py | 20 +- tests/scattermoe_test.py | 41 +-- tests/vector_addition_test.py | 2 + 14 files changed, 435 insertions(+), 415 deletions(-) create mode 100644 kernel_hyperdrive/scattermoe/triton_implementation/ops.py delete mode 100644 kernel_hyperdrive/utils.h diff --git a/kernel_hyperdrive/__init__.py b/kernel_hyperdrive/__init__.py index a1f86f6c..b7c7429c 100644 --- a/kernel_hyperdrive/__init__.py +++ b/kernel_hyperdrive/__init__.py @@ -1,13 +1,6 @@ from .scattermoe import MoE_Torch, MoE_Triton from .utils import compile_helpers -from .vector_addition import ( - VectorAddition_CUDA, - VectorAddition_Torch, - VectorAddition_Triton, - vector_addition_cuda, - vector_addition_torch, - vector_addition_triton, -) +from .vector_addition import vector_addition_cuda, vector_addition_torch, vector_addition_triton compile_helpers() diff --git a/kernel_hyperdrive/scattermoe/torch_implementation.py b/kernel_hyperdrive/scattermoe/torch_implementation.py index a6cf7a64..41ceb7db 100644 --- a/kernel_hyperdrive/scattermoe/torch_implementation.py +++ b/kernel_hyperdrive/scattermoe/torch_implementation.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.profiler import record_function class Experts_Torch(nn.Module): @@ -25,13 +26,23 @@ def __init__( self.reset_parameters() - def forward(self, input: torch.Tensor, num_experts_per_token: torch.Tensor) -> torch.Tensor: - input = input.split(num_experts_per_token.tolist(), dim=0) + def forward( + self, + input: torch.Tensor | tuple[torch.Tensor], + expert_frequency: torch.Tensor, + return_list: bool, + ) -> torch.Tensor | list[torch.Tensor]: + if isinstance(input, torch.Tensor): + input = input.split(expert_frequency.tolist(), dim=0) + input = [ F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i]) for i in range(self.num_experts) ] - input = torch.cat(input, dim=0) + + if not return_list: + input = torch.cat(input) + return input def extra_repr(self): @@ -89,64 +100,97 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: original_shape = hidden_states.shape + # hidden_states -> (batch_size, query_length, hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size) + # hidden_states -> (total_q, hidden_size) router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) - hidden_states = self._compute_experts(hidden_states, router_weights, selected_experts) + # router_logits -> (total_q, num_experts) + # router_weights -> (total_q, top_k) + # selected_experts -> (total_q, top_k) + + hidden_states = self._compute_experts(hidden_states, router_weights, selected_experts) hidden_states = hidden_states.view(original_shape) + # hidden_states -> (batch_size, query_length, hidden_size) + return hidden_states, router_logits + @record_function("MoE_Torch:_compute_routing_weights") def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: # hidden_states -> (total_q, hidden_size) router_logits = self.gate(hidden_states) # router_logits -> (total_q, num_experts) router_weights, selected_experts = self._get_topk(router_logits) - router_weights = F.softmax(router_weights.float(), dim=-1) - # we cast back to the input dtype + # router_weights -> (total_q, top_k) + # selected_experts -> (total_q, top_k) + + router_weights = F.softmax(router_weights.float(), dim=-1) router_weights = router_weights.type_as(hidden_states) return router_logits, router_weights, selected_experts + @record_function("MoE_Torch:_compute_experts") def _compute_experts( self, hidden_states: torch.Tensor, router_weights: torch.Tensor, selected_experts: torch.Tensor ) -> torch.Tensor: total_q = hidden_states.shape[0] - batch_index, batch_gates, num_experts_per_token = self._compute_expert_assignment( - router_weights, selected_experts - ) + # hidden_states -> (total_q, hidden_size) + # router_weights -> (total_q, top_k) + # selected_experts -> (total_q, top_k) - expert_inputs = hidden_states[batch_index] + fan_in_index, batch_gates, expert_frequency = self._compute_expert_assignment(router_weights, selected_experts) - hidden_states = self.c_fc(expert_inputs, num_experts_per_token) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states, num_experts_per_token) + # fan_in_index -> (total_q * top_k) + # batch_gates -> (total_q * top_k) + # expert_frequency -> (num_experts) - hidden_states = hidden_states * batch_gates.unsqueeze(-1) # [:, None] + hidden_states = hidden_states[fan_in_index] + + # hidden_states -> (total_q * top_k, hidden_size) + + hidden_states = self.c_fc(hidden_states, expert_frequency, return_list=True) + # hidden_states -> num_experts x (?, hidden_size) + hidden_states = [self.act(i) for i in hidden_states] + # hidden_states -> num_experts x (?, intermediate_size) + hidden_states = self.c_proj(hidden_states, expert_frequency, return_list=False) + # hidden_states -> (total_q * top_k, hidden_size) + + hidden_states = hidden_states * batch_gates.unsqueeze(-1) zeros = torch.zeros((total_q, self.hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) - hidden_states = zeros.index_add(0, batch_index, hidden_states) + hidden_states = zeros.index_add(0, fan_in_index, hidden_states) + + # hidden_states -> (total_q, hidden_size) return hidden_states + @record_function("MoE_Torch:_compute_expert_assignment") def _compute_expert_assignment( self, router_weights: torch.Tensor, selected_experts: torch.Tensor ) -> tuple[torch.Tensor]: + # router_weights -> (total_q, top_k) + # selected_experts -> (total_q, top_k) selected_experts = selected_experts.flatten() + # selected_experts -> (total_q * top_k) - num_experts_per_token = selected_experts.bincount(minlength=self.num_experts) + expert_frequency = selected_experts.bincount(minlength=self.num_experts) + # expert_frequency -> (num_experts) - # sort and group input tokens according to expert assignment - _, index_sorted_experts = selected_experts.sort(0) # [num_tokens * top_k] - batch_index = index_sorted_experts // self.top_k # [num_tokens * top_k] + index_sorted_experts = selected_experts.argsort() + # index_sorted_experts -> (total_q * top_k) + fan_in_index = index_sorted_experts // self.top_k + # fan_in_index -> (total_q * top_k) # gather the gate values for grouped input tokens - router_weights = router_weights.flatten() # [num_tokens * top_k] - batch_gates = router_weights[index_sorted_experts] # [num_tokens * top_k] + router_weights = router_weights.flatten() + # router_weights -> (total_q * top_k) + batch_gates = router_weights[index_sorted_experts] + # batch_gates -> (total_q * top_k) - return batch_index, batch_gates, num_experts_per_token + return fan_in_index, batch_gates, expert_frequency def _get_topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.top_k == 1: diff --git a/kernel_hyperdrive/scattermoe/triton_implementation/__init__.py b/kernel_hyperdrive/scattermoe/triton_implementation/__init__.py index 8232cee5..f91e4ae9 100644 --- a/kernel_hyperdrive/scattermoe/triton_implementation/__init__.py +++ b/kernel_hyperdrive/scattermoe/triton_implementation/__init__.py @@ -4,7 +4,7 @@ import torch.nn as nn from ..torch_implementation import Experts_Torch, MoE_Torch -from .kernel import flatten_and_sort, padded_block_indices, scattered_experts +from .ops import padded_block_indices, scattered_experts class Experts_Triton(Experts_Torch): @@ -78,7 +78,7 @@ def _compute_experts( self, hidden_states: torch.Tensor, router_weights: torch.Tensor, selected_experts: torch.Tensor ) -> torch.Tensor: with torch.no_grad(): - sorted_expert_idxs, sorted_scattered_idxs = flatten_and_sort(selected_experts) + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(selected_experts.flatten()) padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) hidden_states = self.c_fc( diff --git a/kernel_hyperdrive/scattermoe/triton_implementation/kernel.py b/kernel_hyperdrive/scattermoe/triton_implementation/kernel.py index 3d3eb122..6a64203d 100644 --- a/kernel_hyperdrive/scattermoe/triton_implementation/kernel.py +++ b/kernel_hyperdrive/scattermoe/triton_implementation/kernel.py @@ -1,4 +1,3 @@ -import torch import triton import triton.language as tl @@ -6,33 +5,6 @@ BLOCK_M = 128 -@torch.compile -def flatten_and_sort(expert_idxs: torch.Tensor): - flattened_expert_idxs = expert_idxs.flatten() - sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) - return sorted_expert_idxs, sorted_scattered_idxs - - -@torch.compile -def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): - expert_counts = torch.bincount(sorted_experts_idxs, minlength=k) - padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 - padded_expert_block_end = padded_block_counts.cumsum(-1) - expert_boundaries_end = expert_counts.cumsum(-1) - expert_boundaries_start = expert_boundaries_end - expert_counts - padded_expert_block_start = padded_expert_block_end - padded_block_counts - - block_idxs = torch.arange( - padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device - ).unsqueeze(1) - - block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) - expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start - expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) - - return expanded_block_idxs, expert_boundaries_end - - @triton.autotune( configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], key=["M", "N", "K"], @@ -44,7 +16,7 @@ def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE } ) @triton.jit -def _scatter2scatter( +def scatter2scatter_triton_kernel( X_ptr, stride_xm, stride_xk, @@ -128,97 +100,6 @@ def _scatter2scatter( tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) -def scatter2scatter( - X, W, sorted_expert_idxs, sorted_scattered_idxs, k, padded_block_idxs, x_grouped=False, y_grouped=False, out=None -): - assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) - assert sorted_scattered_idxs.size(0) == X.size(0) * k - # Pre-kernel setup - x_dim = X.size(-1) - y_dim = W.size(-1) - L_scattered = sorted_expert_idxs.size(0) - if out is None: - O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) - else: - assert out.size(0) == L_scattered and out.size(1) == y_dim - O = out - - def grid(META): - grid_num = (padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]),) - return grid_num - - with torch.cuda.device(X.device): - _scatter2scatter[grid]( - # X_ptr, stride_xm, stride_xk, - X, - X.stride(0), - X.stride(1), - # W_ptr, stride_we, stride_wk, stride_wn, - W, - W.stride(0), - W.stride(1), - W.stride(2), - # Y_ptr, stride_ym, stride_yn, - O, - O.stride(0), - O.stride(1), - grouped_idx_ptr=sorted_scattered_idxs, - expert_idxs_ptr=sorted_expert_idxs, - block_start_idx_ptr=padded_block_idxs, - FAN_OUT=k, - M=X.size(0), - K=X.size(1), - N=O.size(1), - E=W.size(0), - BLOCK_M=BLOCK_M, - ACC_TYPE=tl.float32, - OUT_M=O.size(0), - allow_tf32=True, - x_grouped=x_grouped, - y_grouped=y_grouped, - ) - return O - - -def group_bwd_W(DY, X, expert_offsets, E): - DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) - DW = DWt.permute(0, 2, 1) - - def grid(META): - grid = ( - E * triton.cdiv(META["K"], META["BLOCK_K"]), - triton.cdiv(META["N"], META["BLOCK_N"]), - ) - return grid - - with torch.cuda.device(DY.device): - _groupXtY[grid]( - # DY_ptr, stride_dym, stride_dyk, - DY, - DY.stride(0), - DY.stride(1), - # X_ptr, stride_xm, stride_xn, - X, - X.stride(0), - X.stride(1), - # DW_ptr, stride_dwe, stride_dwk, stride_dwn, - DW, - DW.stride(0), - DW.stride(1), - DW.stride(2), - # expert_offsets_ptr, - expert_offsets, - # K: tl.constexpr, N: tl.constexpr, - M=DY.size(0), - N=DY.size(-1), - K=X.size(-1), - # ACC_TYPE: tl.constexpr, - ACC_TYPE=tl.float32, - allow_tf32=True, - ) - return DW - - @triton.autotune( configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4)], key=["M", "N", "K"], @@ -230,7 +111,7 @@ def grid(META): } ) @triton.jit -def _groupXtY( +def groupXtY_triton_kernel( DY_ptr, stride_dym, stride_dyk, @@ -307,54 +188,10 @@ def _groupXtY( tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) -def _config_grouping(): - return [ - triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), - ] - - -def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): - N = sorted_expert_idxs.size(0) - K = A.size(1) - assert A.size(0) * fan_out == N - if out is not None: - Y = out - else: - Y = torch.empty((N, K), dtype=A.dtype, device=A.device) - # print("grp init:", Y.size()) - - def grid(META): - grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),) - return grid_num - - with torch.cuda.device(A.device): - _group[grid]( - # A_ptr, stride_an, stride_ai, - A, - A.stride(0), - A.stride(1), - coeff is not None, - coeff, - fan_out, - # Y_ptr, stride_yn, stride_yk, - Y, - Y.stride(0), - Y.stride(1), - # grouped_idx_ptr, - sorted_expert_idxs, - # N: tl.constexpr, K: tl.constexpr, - N, - K, - ) - return Y - - -@triton.autotune(configs=_config_grouping(), key=["K"]) +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) @triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0}) @triton.jit -def _group( +def group_triton_kernel( src_ptr, stride_sn, stride_sk, @@ -403,161 +240,3 @@ def _group( tl.store(tgt_blk_ptrs, block, mask=mask) src_blk_ptrs += BLOCK_K * stride_sk tgt_blk_ptrs += BLOCK_K * stride_ti - - -class _ScatteredExperts(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - expert_weights, - k, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates=None, - grouped_in=False, - grouped_out=False, - ): - output = scatter2scatter( - X=x, - W=expert_weights, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - k=k, - x_grouped=grouped_in, - y_grouped=grouped_out, - ) - - if gates is None: - output_expanded = None - else: - output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) - output = torch.bmm(gates[:, None, :], output_expanded).squeeze(1) - - ctx.save_for_backward( - x, - expert_weights, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates, - output_expanded, - ) - - ctx.grouped_in = grouped_in - ctx.grouped_out = grouped_out - ctx.k = k - - return output - - @staticmethod - def backward(ctx, grad_out): - ( - x, - expert_weights, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates, - output_expanded, - ) = ctx.saved_tensors - k = ctx.k - grouped_in = ctx.grouped_in - grouped_out = ctx.grouped_out - - if gates is None: - d_gates = None - gates_flat = None - gate_fan = 1 - grouped_grad_out = None - else: - # calculate gates gradient - d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) - gates_flat = gates.flatten() - gate_fan = gates.size(1) - # print("expanded and grouping") - grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later - - if grouped_out: - grouped_grad_out = grad_out - else: - grouped_grad_out = group( - grad_out, sorted_scattered_idxs, fan_out=gate_fan, coeff=gates_flat, out=grouped_grad_out - ) - - if grouped_in: - grouped_x = x - d_expanded_input = None - else: - grouped_x = group(x, sorted_scattered_idxs, fan_out=k) - d_expanded_input = grouped_x - - d_weights = group_bwd_W( - DY=grouped_grad_out, X=grouped_x, expert_offsets=expert_offsets, E=expert_weights.size(0) - ) - - d_expanded_input = scatter2scatter( - X=grouped_grad_out, - x_grouped=True, - W=expert_weights.permute(0, 2, 1), - padded_block_idxs=padded_block_idxs, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - k=1, - y_grouped=grouped_in, - out=d_expanded_input, # Reuse grouped_x buffer - ) - - if k == 1: - d_input = d_expanded_input - else: - d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) - - # print("backward end.") - return ( - # x, expert_weights, k, - d_input, - d_weights, - None, - # sorted_expert_idxs, sorted_scattered_idxs, - None, - None, - # padded_block_idxs, expert_offsets, - None, - None, - # gates - d_gates, - None, - None, - ) - - -def scattered_experts( - inputs, - expert_weights, - k, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates=None, - grouped_in=False, - grouped_out=False, -): - return _ScatteredExperts.apply( - inputs, - expert_weights, - k, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates, - grouped_in, - grouped_out, - ) diff --git a/kernel_hyperdrive/scattermoe/triton_implementation/ops.py b/kernel_hyperdrive/scattermoe/triton_implementation/ops.py new file mode 100644 index 00000000..ad98d63b --- /dev/null +++ b/kernel_hyperdrive/scattermoe/triton_implementation/ops.py @@ -0,0 +1,313 @@ +import torch +import triton +import triton.language as tl + +from .kernel import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel + + +BLOCK_M = 128 + + +@torch.compile +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): + expert_counts = torch.bincount(sorted_experts_idxs, minlength=k) + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + + block_idxs = torch.arange( + padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device + ).unsqueeze(1) + + block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) + expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + + return expanded_block_idxs, expert_boundaries_end + + +def scatter2scatter( + X, W, sorted_expert_idxs, sorted_scattered_idxs, k, padded_block_idxs, x_grouped=False, y_grouped=False, out=None +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + def grid(META): + grid_num = (padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]),) + return grid_num + + with torch.cuda.device(X.device): + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + O, + O.stride(0), + O.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=k, + M=X.size(0), + K=X.size(1), + N=O.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + OUT_M=O.size(0), + allow_tf32=True, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + return O + + +def group_bwd_W(DY, X, expert_offsets, E): + DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) + DW = DWt.permute(0, 2, 1) + + def grid(META): + grid = ( + E * triton.cdiv(META["K"], META["BLOCK_K"]), + triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid + + with torch.cuda.device(DY.device): + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + M=DY.size(0), + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=True, + ) + return DW + + +def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + if out is not None: + Y = out + else: + Y = torch.empty((N, K), dtype=A.dtype, device=A.device) + # print("grp init:", Y.size()) + + def grid(META): + grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),) + return grid_num + + with torch.cuda.device(A.device): + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + Y, + Y.stride(0), + Y.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + return Y + + +class _ScatteredExperts(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + output = scatter2scatter( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + k=k, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + + if gates is None: + output_expanded = None + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates[:, None, :], output_expanded).squeeze(1) + + ctx.save_for_backward( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + ) + + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + + return output + + @staticmethod + def backward(ctx, grad_out): + ( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + + if gates is None: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # print("expanded and grouping") + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = group( + grad_out, sorted_scattered_idxs, fan_out=gate_fan, coeff=gates_flat, out=grouped_grad_out + ) + + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x + + d_weights = group_bwd_W( + DY=grouped_grad_out, X=grouped_x, expert_offsets=expert_offsets, E=expert_weights.size(0) + ) + + d_expanded_input = scatter2scatter( + X=grouped_grad_out, + x_grouped=True, + W=expert_weights.permute(0, 2, 1), + padded_block_idxs=padded_block_idxs, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input, # Reuse grouped_x buffer + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) + + # print("backward end.") + return ( + # x, expert_weights, k, + d_input, + d_weights, + None, + # sorted_expert_idxs, sorted_scattered_idxs, + None, + None, + # padded_block_idxs, expert_offsets, + None, + None, + # gates + d_gates, + None, + None, + ) + + +def scattered_experts( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, +): + return _ScatteredExperts.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + ) diff --git a/kernel_hyperdrive/utils.h b/kernel_hyperdrive/utils.h deleted file mode 100644 index 18c387a0..00000000 --- a/kernel_hyperdrive/utils.h +++ /dev/null @@ -1,9 +0,0 @@ -#include - -// C++ interface -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " is not on CUDA device") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " is not a contiguous tensor") - -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x); diff --git a/kernel_hyperdrive/utils.py b/kernel_hyperdrive/utils.py index 7422dfad..2d35fbd2 100644 --- a/kernel_hyperdrive/utils.py +++ b/kernel_hyperdrive/utils.py @@ -12,5 +12,6 @@ def compile_helpers() -> None: ], with_cuda=True, extra_cflags=["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"], + build_directory="build", verbose=True, ) diff --git a/kernel_hyperdrive/vector_addition/__init__.py b/kernel_hyperdrive/vector_addition/__init__.py index 0a59e537..2a493c05 100644 --- a/kernel_hyperdrive/vector_addition/__init__.py +++ b/kernel_hyperdrive/vector_addition/__init__.py @@ -1,3 +1,3 @@ -from .cuda_implementation import VectorAddition_CUDA, vector_addition_cuda -from .torch_implementation import VectorAddition_Torch, vector_addition_torch -from .triton_implementation import VectorAddition_Triton, vector_addition_triton +from .cuda_implementation import vector_addition_cuda +from .torch_implementation import vector_addition_torch +from .triton_implementation import vector_addition_triton diff --git a/kernel_hyperdrive/vector_addition/cuda_implementation/__init__.py b/kernel_hyperdrive/vector_addition/cuda_implementation/__init__.py index eb3dea9d..4a1abf66 100644 --- a/kernel_hyperdrive/vector_addition/cuda_implementation/__init__.py +++ b/kernel_hyperdrive/vector_addition/cuda_implementation/__init__.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn class _VectorAddition_CUDA(torch.autograd.Function): @@ -14,10 +13,13 @@ def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor return output_grad, output_grad +# this registers the kernel with PyTorch to make it work with torch.compile +@torch.library.custom_op("khd::vector_addition_cuda", mutates_args=()) def vector_addition_cuda(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return _VectorAddition_CUDA.apply(x, y) -class VectorAddition_CUDA(nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return vector_addition_cuda(x, y) +# this tells torch.compile the output shape given the input shape +@vector_addition_cuda.register_fake +def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) diff --git a/kernel_hyperdrive/vector_addition/cuda_implementation/vector_addition.cpp b/kernel_hyperdrive/vector_addition/cuda_implementation/vector_addition.cpp index 75c00c14..70854ade 100644 --- a/kernel_hyperdrive/vector_addition/cuda_implementation/vector_addition.cpp +++ b/kernel_hyperdrive/vector_addition/cuda_implementation/vector_addition.cpp @@ -1,16 +1,18 @@ #include -#include "../../utils.h" // CUDA kernel declarations torch::Tensor vector_addition_forward_kernel_launcher(torch::Tensor x, torch::Tensor y, const int BLOCK_SIZE); torch::Tensor vector_addition_forward(torch::Tensor x, torch::Tensor y) { - CHECK_INPUT(x); - CHECK_INPUT(y); + TORCH_CHECK(x.device().is_cuda(), "tensor x is not on GPU") + TORCH_CHECK(y.device().is_cuda(), "tensor y is not on GPU") - TORCH_CHECK(x.dim() == 1, "tensor should be 1 dimensional") - TORCH_CHECK(y.dim() == 1, "tensor should be 1 dimensional") + TORCH_CHECK(x.is_contiguous(), "tensor x is not a contiguous") + TORCH_CHECK(y.is_contiguous(), "tensor y is not a contiguous") + + TORCH_CHECK(x.dim() == 1, "tensor x should be 1 dimensional") + TORCH_CHECK(y.dim() == 1, "tensor y should be 1 dimensional") TORCH_CHECK(x.numel() == y.numel(), "both tensors should have same number of elements"); TORCH_CHECK(x.type() == y.type(), "both tensors should have same dtype"); diff --git a/kernel_hyperdrive/vector_addition/torch_implementation.py b/kernel_hyperdrive/vector_addition/torch_implementation.py index 2fa914a9..e7e0464f 100644 --- a/kernel_hyperdrive/vector_addition/torch_implementation.py +++ b/kernel_hyperdrive/vector_addition/torch_implementation.py @@ -1,11 +1,5 @@ import torch -import torch.nn as nn def vector_addition_torch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - - -class VectorAddition_Torch(nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return vector_addition_torch(x, y) diff --git a/kernel_hyperdrive/vector_addition/triton_implementation/__init__.py b/kernel_hyperdrive/vector_addition/triton_implementation/__init__.py index 7e673219..8489e1d6 100644 --- a/kernel_hyperdrive/vector_addition/triton_implementation/__init__.py +++ b/kernel_hyperdrive/vector_addition/triton_implementation/__init__.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn import triton from .kernel import vector_addition_forward_triton_kernel @@ -8,13 +7,17 @@ class _VectorAddition_Triton(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - assert x.dim() == 1 + assert x.is_cuda, "tensor x is not on GPU" + assert y.is_cuda, "tensor y is not on GPU" - assert x.is_cuda - assert y.is_cuda + assert x.is_contiguous(), "tensor x is not a contiguous" + assert y.is_contiguous(), "tensor y is not a contiguous" - assert x.is_contiguous() - assert y.is_contiguous() + assert x.dim() == 1, "tensor x should be 1 dimensional" + assert y.dim() == 1, "tensor y should be 1 dimensional" + + assert x.numel() == y.numel(), "both tensors should have same number of elements" + assert x.type() == y.type(), "both tensors should have same dtype" output = torch.empty_like(x) @@ -32,8 +35,3 @@ def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor def vector_addition_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return _VectorAddition_Triton.apply(x, y) - - -class VectorAddition_Triton(nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return vector_addition_triton(x, y) diff --git a/tests/scattermoe_test.py b/tests/scattermoe_test.py index 47996b3c..79cdce94 100644 --- a/tests/scattermoe_test.py +++ b/tests/scattermoe_test.py @@ -62,27 +62,28 @@ def _test_scattermoe( f"skipping test since number of experts per token ({num_experts_per_tok}) is more than number of experts ({num_experts})" ) - moe = module_class( - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation_function=self.get_activation_function(is_glu=is_glu), - is_glu=is_glu, - add_bias=False, - std=0.02, - ).to(dtype=dtype, device=device) + with torch.device(device): + moe = module_class( + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation_function=self.get_activation_function(is_glu=is_glu), + is_glu=is_glu, + add_bias=False, + std=0.02, + ).to(dtype=dtype) - moe_torch = MoE_Torch( - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation_function=self.get_activation_function(is_glu=is_glu), - is_glu=is_glu, - add_bias=False, - std=0.02, - ).to(dtype=dtype, device=device) + moe_torch = MoE_Torch( + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation_function=self.get_activation_function(is_glu=is_glu), + is_glu=is_glu, + add_bias=False, + std=0.02, + ).to(dtype=dtype) moe_torch.load_state_dict(moe.state_dict()) diff --git a/tests/vector_addition_test.py b/tests/vector_addition_test.py index 5e303a04..20f95456 100644 --- a/tests/vector_addition_test.py +++ b/tests/vector_addition_test.py @@ -16,6 +16,7 @@ class VectorAdditionTest(TestCommons): ) def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype) -> None: self._test_vector_addition(size, device, dtype, vector_addition_cuda) + self._test_vector_addition(size, device, dtype, torch.compile(vector_addition_cuda)) @parameterized.expand( TestCommons.make_args_matrix( @@ -24,6 +25,7 @@ def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torc ) def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype) -> None: self._test_vector_addition(size, device, dtype, vector_addition_triton) + self._test_vector_addition(size, device, dtype, torch.compile(vector_addition_triton)) def _test_vector_addition(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None: x = torch.randn(size, device=device, dtype=dtype)