Skip to content

Commit

Permalink
add comments in scattermoe (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 authored Aug 11, 2024
1 parent 8b8d539 commit 3d3fdc4
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 415 deletions.
9 changes: 1 addition & 8 deletions kernel_hyperdrive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
from .scattermoe import MoE_Torch, MoE_Triton
from .utils import compile_helpers
from .vector_addition import (
VectorAddition_CUDA,
VectorAddition_Torch,
VectorAddition_Triton,
vector_addition_cuda,
vector_addition_torch,
vector_addition_triton,
)
from .vector_addition import vector_addition_cuda, vector_addition_torch, vector_addition_triton


compile_helpers()
88 changes: 66 additions & 22 deletions kernel_hyperdrive/scattermoe/torch_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.profiler import record_function


class Experts_Torch(nn.Module):
Expand All @@ -25,13 +26,23 @@ def __init__(

self.reset_parameters()

def forward(self, input: torch.Tensor, num_experts_per_token: torch.Tensor) -> torch.Tensor:
input = input.split(num_experts_per_token.tolist(), dim=0)
def forward(
self,
input: torch.Tensor | tuple[torch.Tensor],
expert_frequency: torch.Tensor,
return_list: bool,
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(input, torch.Tensor):
input = input.split(expert_frequency.tolist(), dim=0)

input = [
F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i])
for i in range(self.num_experts)
]
input = torch.cat(input, dim=0)

if not return_list:
input = torch.cat(input)

return input

def extra_repr(self):
Expand Down Expand Up @@ -89,64 +100,97 @@ def __init__(
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
original_shape = hidden_states.shape

# hidden_states -> (batch_size, query_length, hidden_size)
hidden_states = hidden_states.view(-1, self.hidden_size)
# hidden_states -> (total_q, hidden_size)
router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states)
hidden_states = self._compute_experts(hidden_states, router_weights, selected_experts)

# router_logits -> (total_q, num_experts)
# router_weights -> (total_q, top_k)
# selected_experts -> (total_q, top_k)

hidden_states = self._compute_experts(hidden_states, router_weights, selected_experts)
hidden_states = hidden_states.view(original_shape)

# hidden_states -> (batch_size, query_length, hidden_size)

return hidden_states, router_logits

@record_function("MoE_Torch:_compute_routing_weights")
def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
# hidden_states -> (total_q, hidden_size)
router_logits = self.gate(hidden_states)
# router_logits -> (total_q, num_experts)

router_weights, selected_experts = self._get_topk(router_logits)
router_weights = F.softmax(router_weights.float(), dim=-1)

# we cast back to the input dtype
# router_weights -> (total_q, top_k)
# selected_experts -> (total_q, top_k)

router_weights = F.softmax(router_weights.float(), dim=-1)
router_weights = router_weights.type_as(hidden_states)

return router_logits, router_weights, selected_experts

@record_function("MoE_Torch:_compute_experts")
def _compute_experts(
self, hidden_states: torch.Tensor, router_weights: torch.Tensor, selected_experts: torch.Tensor
) -> torch.Tensor:
total_q = hidden_states.shape[0]

batch_index, batch_gates, num_experts_per_token = self._compute_expert_assignment(
router_weights, selected_experts
)
# hidden_states -> (total_q, hidden_size)
# router_weights -> (total_q, top_k)
# selected_experts -> (total_q, top_k)

expert_inputs = hidden_states[batch_index]
fan_in_index, batch_gates, expert_frequency = self._compute_expert_assignment(router_weights, selected_experts)

hidden_states = self.c_fc(expert_inputs, num_experts_per_token)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states, num_experts_per_token)
# fan_in_index -> (total_q * top_k)
# batch_gates -> (total_q * top_k)
# expert_frequency -> (num_experts)

hidden_states = hidden_states * batch_gates.unsqueeze(-1) # [:, None]
hidden_states = hidden_states[fan_in_index]

# hidden_states -> (total_q * top_k, hidden_size)

hidden_states = self.c_fc(hidden_states, expert_frequency, return_list=True)
# hidden_states -> num_experts x (?, hidden_size)
hidden_states = [self.act(i) for i in hidden_states]
# hidden_states -> num_experts x (?, intermediate_size)
hidden_states = self.c_proj(hidden_states, expert_frequency, return_list=False)
# hidden_states -> (total_q * top_k, hidden_size)

hidden_states = hidden_states * batch_gates.unsqueeze(-1)
zeros = torch.zeros((total_q, self.hidden_size), dtype=hidden_states.dtype, device=hidden_states.device)
hidden_states = zeros.index_add(0, batch_index, hidden_states)
hidden_states = zeros.index_add(0, fan_in_index, hidden_states)

# hidden_states -> (total_q, hidden_size)

return hidden_states

@record_function("MoE_Torch:_compute_expert_assignment")
def _compute_expert_assignment(
self, router_weights: torch.Tensor, selected_experts: torch.Tensor
) -> tuple[torch.Tensor]:
# router_weights -> (total_q, top_k)
# selected_experts -> (total_q, top_k)
selected_experts = selected_experts.flatten()
# selected_experts -> (total_q * top_k)

num_experts_per_token = selected_experts.bincount(minlength=self.num_experts)
expert_frequency = selected_experts.bincount(minlength=self.num_experts)
# expert_frequency -> (num_experts)

# sort and group input tokens according to expert assignment
_, index_sorted_experts = selected_experts.sort(0) # [num_tokens * top_k]
batch_index = index_sorted_experts // self.top_k # [num_tokens * top_k]
index_sorted_experts = selected_experts.argsort()
# index_sorted_experts -> (total_q * top_k)
fan_in_index = index_sorted_experts // self.top_k
# fan_in_index -> (total_q * top_k)

# gather the gate values for grouped input tokens
router_weights = router_weights.flatten() # [num_tokens * top_k]
batch_gates = router_weights[index_sorted_experts] # [num_tokens * top_k]
router_weights = router_weights.flatten()
# router_weights -> (total_q * top_k)
batch_gates = router_weights[index_sorted_experts]
# batch_gates -> (total_q * top_k)

return batch_index, batch_gates, num_experts_per_token
return fan_in_index, batch_gates, expert_frequency

def _get_topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.top_k == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn

from ..torch_implementation import Experts_Torch, MoE_Torch
from .kernel import flatten_and_sort, padded_block_indices, scattered_experts
from .ops import padded_block_indices, scattered_experts


class Experts_Triton(Experts_Torch):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _compute_experts(
self, hidden_states: torch.Tensor, router_weights: torch.Tensor, selected_experts: torch.Tensor
) -> torch.Tensor:
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = flatten_and_sort(selected_experts)
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(selected_experts.flatten())
padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts)

hidden_states = self.c_fc(
Expand Down
Loading

0 comments on commit 3d3fdc4

Please sign in to comment.