Skip to content

Commit 4cee276

Browse files
committed
fix.
Signed-off-by: Fanrong Li <[email protected]>
1 parent bb82e26 commit 4cee276

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def _preprocess_after_permute_kernel(
226226
):
227227
pid_x = tl.program_id(0)
228228
pid_y = tl.program_id(1)
229-
230229
if pid_y == 0:
231230
token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
232231
token_mask = token_offsets < TOTAL_TOKENS
@@ -238,8 +237,9 @@ def _preprocess_after_permute_kernel(
238237
cond = (token_offsets < boundary) & ~found_mask
239238
expert_ids = tl.where(cond, i, expert_ids)
240239
found_mask = found_mask | cond
241-
tl.store(token_map_ptr + token_offsets, expert_ids, mask=token_mask)
242-
240+
tl.store(token_map_ptr + token_offsets,
241+
expert_ids.to(tl.int64),
242+
mask=token_mask)
243243
elif pid_y == 1:
244244
# get num_tokens for each expert
245245
expert_mask = pid_x < NUM_EXPERTS
@@ -268,7 +268,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
268268
# create output tensors
269269
masked_m = torch.empty(num_experts, dtype=torch.int32, device='cuda')
270270
token_to_expert_map = torch.empty(total_tokens,
271-
dtype=torch.int32,
271+
dtype=torch.int64,
272272
device='cuda')
273273

274274
# calculate the grid size
@@ -278,7 +278,8 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
278278
BLOCK_SIZE_M = DEFAULT_BLOCK_SIZE_M
279279
grid = (grid_m_size, 2)
280280
else:
281-
BLOCK_SIZE_M = triton.cdiv(total_tokens, num_experts)
281+
block_size_m = triton.cdiv(total_tokens, num_experts)
282+
BLOCK_SIZE_M = triton.next_power_of_2(block_size_m)
282283
grid = (num_experts, 2)
283284

284285
# launch the kernel

0 commit comments

Comments
 (0)