@@ -199,9 +199,7 @@ def forward_impl(
199199 topk_group = None
200200 routed_scaling_factor = None
201201
202- # Don't support post-quant allgather for fp8 block scale and has_w4a16_mxfp4 for now.
203- is_post_quant_allgather_supported = self .has_nvfp4 or self .has_w4a8_mxfp4_fp8 or self .has_w4a8_mxfp4_mxfp8
204- run_post_quant_allgather = self .use_dp and self .parallel_size > 1 and is_post_quant_allgather_supported
202+ run_post_quant_allgather = self .use_dp and self .parallel_size > 1
205203
206204 x_sf = None
207205 token_selected_experts = None
@@ -239,6 +237,11 @@ def forward_impl(
239237 x , False , alignment = self .quant_method .weight_alignment )
240238 # Update x_row and x_col to the padded shape
241239 x_row , x_col = x .shape [0 ], x .shape [1 ]
240+ elif self .has_deepseek_fp8_block_scales :
241+ pass
242+ elif self .has_w4a16_mxfp4 :
243+ pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
244+ x = torch .nn .functional .pad (x , (0 , pad_size ))
242245 else :
243246 raise ValueError (
244247 f"unsupported quantization mode with run_post_quant_allgather: { self .quant_config .quant_mode } "
@@ -266,8 +269,8 @@ def forward_impl(
266269 x_val , x_scale = torch .ops .trtllm .fp8_quantize_1x128 (x )
267270
268271 final_hidden_states = torch .ops .trtllm .fp8_block_scale_moe_runner (
269- router_logits ,
270- routing_bias ,
272+ router_logits if not run_post_quant_allgather else None ,
273+ routing_bias if not run_post_quant_allgather else None ,
271274 x_val ,
272275 x_scale ,
273276 self .w3_w1_weight ,
@@ -284,6 +287,8 @@ def forward_impl(
284287 self .expert_size_per_partition , # local_expert_size
285288 routed_scaling_factor ,
286289 self .routing_method .routing_method_type ,
290+ topk_weights = token_final_scales ,
291+ topk_ids = token_selected_experts ,
287292 )
288293 elif self .has_nvfp4 :
289294 scale_factor_use_ue8m0 = False
@@ -324,8 +329,8 @@ def forward_impl(
324329 routed_scaling_factor ,
325330 self .routing_method .routing_method_type ,
326331 do_finalize = do_finalize ,
327- topk_ids = token_selected_experts ,
328332 topk_weights = token_final_scales ,
333+ topk_ids = token_selected_experts ,
329334 )
330335
331336 if not do_finalize :
@@ -335,14 +340,17 @@ def forward_impl(
335340 final_hidden_states = outputs [0 ]
336341 elif self .has_w4a16_mxfp4 :
337342 assert x .dtype == torch .bfloat16
343+ if not run_post_quant_allgather :
344+ pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
345+ x = torch .nn .functional .pad (x , (0 , pad_size ))
346+ else :
347+ x = x
338348
339- pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
340- x = torch .nn .functional .pad (x , (0 , pad_size ))
341349 intermediate_size_per_partition_padded = self .w3_w1_weight .shape [
342350 - 2 ] // 2
343351 final_hidden_states = torch .ops .trtllm .bf16_mxe2m1_block_scale_moe_runner (
344- router_logits ,
345- routing_bias ,
352+ router_logits if not run_post_quant_allgather else None ,
353+ routing_bias if not run_post_quant_allgather else None ,
346354 x ,
347355 self .w3_w1_weight ,
348356 self .w3_w1_weight_scale ,
0 commit comments