|
1 | 1 | from typing import Dict, List, Optional, Union |
2 | 2 |
|
3 | 3 | import torch |
4 | | -import torch.nn.functional as F |
5 | 4 | import triton |
6 | 5 | import triton.language as tl |
7 | 6 |
|
@@ -216,30 +215,82 @@ def triton_masked_index_gather(output, input, start_offsets, row_indices): |
216 | 215 | return |
217 | 216 |
|
218 | 217 |
|
219 | | -@nvtx_range("[DG] act") |
220 | | -@torch.compile(dynamic=True) |
221 | | -def swiglu_fused_moe(x): |
222 | | - x, gate = x.chunk(2, dim=-1) |
223 | | - return F.silu(gate) * x |
224 | | - |
225 | | - |
226 | | -@nvtx_range("[DG] indexing") |
227 | | -@torch.compile(dynamic=True) |
228 | | -def indexing(x, mask): |
229 | | - return x[mask > 0, :].contiguous() |
| 218 | +@triton.jit |
| 219 | +def _preprocess_after_permute_kernel( |
| 220 | + expert_offsets_ptr, |
| 221 | + masked_m_ptr, |
| 222 | + token_map_ptr, |
| 223 | + TOTAL_TOKENS: tl.constexpr, |
| 224 | + NUM_EXPERTS: tl.constexpr, |
| 225 | + BLOCK_SIZE_M: tl.constexpr, |
| 226 | +): |
| 227 | + pid_x = tl.program_id(0) |
| 228 | + pid_y = tl.program_id(1) |
| 229 | + |
| 230 | + if pid_y == 0: |
| 231 | + token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 232 | + token_mask = token_offsets < TOTAL_TOKENS |
| 233 | + # get expert_id for each token in the block |
| 234 | + expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32) |
| 235 | + found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1) |
| 236 | + for i in tl.static_range(NUM_EXPERTS): |
| 237 | + boundary = tl.load(expert_offsets_ptr + i + 1) |
| 238 | + cond = (token_offsets < boundary) & ~found_mask |
| 239 | + expert_ids = tl.where(cond, i, expert_ids) |
| 240 | + found_mask = found_mask | cond |
| 241 | + tl.store(token_map_ptr + token_offsets, expert_ids, mask=token_mask) |
| 242 | + |
| 243 | + elif pid_y == 1: |
| 244 | + # get num_tokens for each expert |
| 245 | + expert_mask = pid_x < NUM_EXPERTS |
| 246 | + next_offset = tl.load(expert_offsets_ptr + pid_x + 1, |
| 247 | + mask=expert_mask, |
| 248 | + other=0) |
| 249 | + current_offset = tl.load(expert_offsets_ptr + pid_x, |
| 250 | + mask=expert_mask, |
| 251 | + other=0) |
| 252 | + tokens_per_expert = next_offset - current_offset |
| 253 | + tl.store(masked_m_ptr + pid_x, |
| 254 | + tokens_per_expert.to(tl.int32), |
| 255 | + mask=expert_mask) |
230 | 256 |
|
231 | 257 |
|
232 | 258 | @nvtx_range("[DG] preprocess_after_permute") |
233 | 259 | def preprocess_after_permute(expert_first_token_offset_tensor, |
234 | 260 | permuted_data_tensor): |
235 | | - # get tokens per expert |
236 | | - masked_m = expert_first_token_offset_tensor[ |
237 | | - 1:] - expert_first_token_offset_tensor[:-1] |
238 | | - token_to_expert_map = torch.searchsorted( |
239 | | - expert_first_token_offset_tensor[1:], |
240 | | - torch.arange(permuted_data_tensor.shape[0], device='cuda'), |
241 | | - right=True) |
242 | | - return masked_m.to(torch.int32), token_to_expert_map |
| 261 | + """ |
| 262 | + Python wrapper that launches a single fused kernel to get the token-to-expert map |
| 263 | + and the number of tokens per expert. |
| 264 | + """ |
| 265 | + total_tokens = permuted_data_tensor.shape[0] |
| 266 | + num_experts = expert_first_token_offset_tensor.shape[0] - 1 |
| 267 | + |
| 268 | + # create output tensors |
| 269 | + masked_m = torch.empty(num_experts, dtype=torch.int32, device='cuda') |
| 270 | + token_to_expert_map = torch.empty(total_tokens, |
| 271 | + dtype=torch.int32, |
| 272 | + device='cuda') |
| 273 | + |
| 274 | + # calculate the grid size |
| 275 | + DEFAULT_BLOCK_SIZE_M = 256 |
| 276 | + grid_m_size = triton.cdiv(total_tokens, DEFAULT_BLOCK_SIZE_M) |
| 277 | + if grid_m_size >= num_experts: |
| 278 | + BLOCK_SIZE_M = DEFAULT_BLOCK_SIZE_M |
| 279 | + grid = (grid_m_size, 2) |
| 280 | + else: |
| 281 | + BLOCK_SIZE_M = triton.cdiv(total_tokens, num_experts) |
| 282 | + grid = (num_experts, 2) |
| 283 | + |
| 284 | + # launch the kernel |
| 285 | + _preprocess_after_permute_kernel[grid]( |
| 286 | + expert_first_token_offset_tensor, |
| 287 | + masked_m, |
| 288 | + token_to_expert_map, |
| 289 | + TOTAL_TOKENS=total_tokens, |
| 290 | + NUM_EXPERTS=num_experts, |
| 291 | + BLOCK_SIZE_M=BLOCK_SIZE_M, |
| 292 | + ) |
| 293 | + return masked_m, token_to_expert_map |
243 | 294 |
|
244 | 295 |
|
245 | 296 | @nvtx_range("[DG]") |
|
0 commit comments