@@ -226,7 +226,6 @@ def _preprocess_after_permute_kernel(
226226):
227227 pid_x = tl .program_id (0 )
228228 pid_y = tl .program_id (1 )
229-
230229 if pid_y == 0 :
231230 token_offsets = pid_x * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
232231 token_mask = token_offsets < TOTAL_TOKENS
@@ -238,8 +237,9 @@ def _preprocess_after_permute_kernel(
238237 cond = (token_offsets < boundary ) & ~ found_mask
239238 expert_ids = tl .where (cond , i , expert_ids )
240239 found_mask = found_mask | cond
241- tl .store (token_map_ptr + token_offsets , expert_ids , mask = token_mask )
242-
240+ tl .store (token_map_ptr + token_offsets ,
241+ expert_ids .to (tl .int64 ),
242+ mask = token_mask )
243243 elif pid_y == 1 :
244244 # get num_tokens for each expert
245245 expert_mask = pid_x < NUM_EXPERTS
@@ -268,7 +268,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
268268 # create output tensors
269269 masked_m = torch .empty (num_experts , dtype = torch .int32 , device = 'cuda' )
270270 token_to_expert_map = torch .empty (total_tokens ,
271- dtype = torch .int32 ,
271+ dtype = torch .int64 ,
272272 device = 'cuda' )
273273
274274 # calculate the grid size
@@ -278,7 +278,8 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
278278 BLOCK_SIZE_M = DEFAULT_BLOCK_SIZE_M
279279 grid = (grid_m_size , 2 )
280280 else :
281- BLOCK_SIZE_M = triton .cdiv (total_tokens , num_experts )
281+ block_size_m = triton .cdiv (total_tokens , num_experts )
282+ BLOCK_SIZE_M = triton .next_power_of_2 (block_size_m )
282283 grid = (num_experts , 2 )
283284
284285 # launch the kernel
0 commit comments