Skip to content

Commit 6c7ad69

Browse files
committed
check use_alltoall in use_post_quant_alltoall
Signed-off-by: Vincent Huang <[email protected]>
1 parent 00d06e8 commit 6c7ad69

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
515515
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
516516
# to reduce allreduce BW
517517
if (disable_fp4_allgather()
518-
and not self.experts.enable_alltoall) or isinstance(
518+
and not self.experts.can_use_alltoall(hidden_states_fp4 or hidden_states)) or isinstance(
519519
self.experts, TRTLLMGenFusedMoE):
520520
hidden_states = allgather(hidden_states,
521521
self.mapping,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def reducescatter_or_allreduce(
334334
use_dp_padding: Optional[bool] = None,
335335
):
336336
outputs = inputs
337-
if not use_all_toall:
337+
if not use_all_to_all:
338338
if self.enable_dummy_allreduce:
339339
self.dummy_allreduce()
340340
outputs = reducescatter(
@@ -385,7 +385,7 @@ def forward_chunk(
385385
) and is_first_call:
386386
self.layer_load_balancer.maybe_cudagraph_done_wait()
387387

388-
use_allgather = not self.enable_alltoall
388+
use_allgather = not use_all_to_all
389389

390390
loadbalancer_local_statistic_info = None
391391
gathered_loadbalancer_local_statistic_info = None
@@ -394,7 +394,7 @@ def forward_chunk(
394394
token_selected_slots = token_selected_experts
395395
else:
396396
if not self.layer_load_balancer.is_static_routing(
397-
) and self.enable_alltoall:
397+
) and use_all_to_all:
398398
self.layer_load_balancer.local_statistic(
399399
token_selected_experts,
400400
is_first_stage=is_first_call,
@@ -403,7 +403,7 @@ def forward_chunk(
403403
token_selected_experts, self.use_dp)
404404
if not self.layer_load_balancer.is_static_routing():
405405
# split into two part to get possible overlap with load balancer routing
406-
if self.enable_alltoall:
406+
if use_all_to_all:
407407
if is_last_call:
408408
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
409409
)
@@ -415,6 +415,8 @@ def forward_chunk(
415415
ExpertStatistic.set_layer(self.layer_idx)
416416
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
417417

418+
# If alltoall is disabled, we need also disable use_postquant_alltoall
419+
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
418420
if use_all_to_all:
419421
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
420422
if self.enable_dummy_allreduce:
@@ -426,13 +428,14 @@ def forward_chunk(
426428
x,
427429
token_selected_slots,
428430
token_final_scales,
431+
use_postquant_alltoall,
429432
loadbalancer_local_statistic_info)
430433
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
431-
if not self.use_postquant_alltoall:
434+
if not use_postquant_alltoall:
432435
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
433436
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
434437
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
435-
if not self.use_postquant_alltoall:
438+
if not use_postquant_alltoall:
436439
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
437440
deep_ep_topk_weights = token_final_scales
438441
x, recv_expert_count, deep_ep_handle = \
@@ -471,7 +474,7 @@ def forward_chunk(
471474
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
472475
x, self.fc31_input_dequant)
473476
elif self.has_nvfp4:
474-
if use_allgather or self.use_postquant_alltoall:
477+
if use_allgather or use_postquant_alltoall:
475478
if isinstance(x, Fp4QuantizedTensor):
476479
if use_allgather:
477480
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -504,7 +507,7 @@ def forward_chunk(
504507
f"unsupported quantization mode: {self.quant_config.quant_mode}"
505508
)
506509

507-
if use_allgather and not use_all_to_all::
510+
if use_allgather and not use_all_to_all:
508511
# using allgather case.
509512
if self.enable_dummy_allreduce:
510513
self.dummy_allreduce()
@@ -527,7 +530,7 @@ def forward_chunk(
527530

528531
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
529532
):
530-
if self.enable_alltoall:
533+
if use_all_to_all:
531534
if is_last_call:
532535
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
533536
(self.mapping.moe_ep_size, self.num_experts))
@@ -547,7 +550,7 @@ def forward_chunk(
547550
cluster_rank = self.cluster_rank
548551
quant_scales = self.quant_scales
549552

550-
if self.use_postquant_alltoall:
553+
if use_postquant_alltoall:
551554
if x_sf is not None and self.has_nvfp4:
552555
assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall"
553556
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
@@ -800,7 +803,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
800803
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
801804
num_chunks)
802805
chunk_size_list = all_rank_chunk_size_list[self.rank]
803-
if self.enable_alltoall:
806+
if use_all_to_all:
804807
all_rank_num_tokens_list = [[
805808
1 if val == 0 else val for val in val_list
806809
] for val_list in all_rank_num_tokens_list]
@@ -898,7 +901,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
898901
def alltoall_prepare_maybe_dispatch(
899902
self, all_rank_max_num_tokens: int, x: torch.Tensor,
900903
token_selected_slots: torch.Tensor,
901-
token_final_scales: torch.Tensor,
904+
token_final_scales: torch.Tensor, use_postquant_alltoall: bool,
902905
local_statistic_tensor: Optional[torch.Tensor]):
903906
top_k = self.routing_method.experts_per_token
904907

@@ -942,7 +945,7 @@ def alltoall_prepare_maybe_dispatch(
942945
gathered_token_final_scales, all_rank_max_num_tokens,
943946
self.num_slots, top_k, self.ep_rank, self.ep_size)
944947

945-
if not self.use_postquant_alltoall:
948+
if not use_postquant_alltoall:
946949
assert not isinstance(
947950
x, Fp4QuantizedTensor
948951
), "pre-quant alltoall doesn't support fp4 tensor"

0 commit comments

Comments
 (0)