Skip to content

Commit 95eac2c

Browse files
authored
[https://nvbugs/5537738][fix] Add fp8 post-quant allgather support (#8008)
Signed-off-by: Christina Zhang <[email protected]>
1 parent 77b68d9 commit 95eac2c

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,7 @@ def forward_impl(
199199
topk_group = None
200200
routed_scaling_factor = None
201201

202-
# Don't support post-quant allgather for fp8 block scale and has_w4a16_mxfp4 for now.
203-
is_post_quant_allgather_supported = self.has_nvfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8
204-
run_post_quant_allgather = self.use_dp and self.parallel_size > 1 and is_post_quant_allgather_supported
202+
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
205203

206204
x_sf = None
207205
token_selected_experts = None
@@ -239,6 +237,11 @@ def forward_impl(
239237
x, False, alignment=self.quant_method.weight_alignment)
240238
# Update x_row and x_col to the padded shape
241239
x_row, x_col = x.shape[0], x.shape[1]
240+
elif self.has_deepseek_fp8_block_scales:
241+
pass
242+
elif self.has_w4a16_mxfp4:
243+
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
244+
x = torch.nn.functional.pad(x, (0, pad_size))
242245
else:
243246
raise ValueError(
244247
f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}"
@@ -266,8 +269,8 @@ def forward_impl(
266269
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
267270

268271
final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
269-
router_logits,
270-
routing_bias,
272+
router_logits if not run_post_quant_allgather else None,
273+
routing_bias if not run_post_quant_allgather else None,
271274
x_val,
272275
x_scale,
273276
self.w3_w1_weight,
@@ -284,6 +287,8 @@ def forward_impl(
284287
self.expert_size_per_partition, # local_expert_size
285288
routed_scaling_factor,
286289
self.routing_method.routing_method_type,
290+
topk_weights=token_final_scales,
291+
topk_ids=token_selected_experts,
287292
)
288293
elif self.has_nvfp4:
289294
scale_factor_use_ue8m0 = False
@@ -324,8 +329,8 @@ def forward_impl(
324329
routed_scaling_factor,
325330
self.routing_method.routing_method_type,
326331
do_finalize=do_finalize,
327-
topk_ids=token_selected_experts,
328332
topk_weights=token_final_scales,
333+
topk_ids=token_selected_experts,
329334
)
330335

331336
if not do_finalize:
@@ -335,14 +340,17 @@ def forward_impl(
335340
final_hidden_states = outputs[0]
336341
elif self.has_w4a16_mxfp4:
337342
assert x.dtype == torch.bfloat16
343+
if not run_post_quant_allgather:
344+
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
345+
x = torch.nn.functional.pad(x, (0, pad_size))
346+
else:
347+
x = x
338348

339-
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
340-
x = torch.nn.functional.pad(x, (0, pad_size))
341349
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
342350
-2] // 2
343351
final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner(
344-
router_logits,
345-
routing_bias,
352+
router_logits if not run_post_quant_allgather else None,
353+
routing_bias if not run_post_quant_allgather else None,
346354
x,
347355
self.w3_w1_weight,
348356
self.w3_w1_weight_scale,

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plug
341341
accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights SKIP (https://nvbugs/5532023)
342342
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5532023)
343343
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_tp2cp2 SKIP (https://nvbugs/5532023)
344-
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5537738)
345344
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5503479)
346345
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5541494)
347346
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5541545)

0 commit comments

Comments
 (0)