Skip to content

Commit ea2c89c

Browse files
committed
check use_alltoall in use_post_quant_alltoall
Signed-off-by: Vincent Huang <[email protected]>
1 parent 321f081 commit ea2c89c

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
515515
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
516516
# to reduce allreduce BW
517517
if (disable_fp4_allgather()
518-
and not self.experts.enable_alltoall) or isinstance(
518+
and not self.experts.can_use_alltoall(hidden_states_fp4 or hidden_states)) or isinstance(
519519
self.experts, TRTLLMGenFusedMoE):
520520
hidden_states = allgather(hidden_states,
521521
self.mapping,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def reducescatter_or_allreduce(
335335
use_dp_padding: Optional[bool] = None,
336336
):
337337
outputs = inputs
338-
if not use_all_toall:
338+
if not use_all_to_all:
339339
if self.enable_dummy_allreduce:
340340
self.dummy_allreduce()
341341
outputs = reducescatter(
@@ -398,7 +398,7 @@ def forward_chunk(
398398
) and is_first_call:
399399
self.layer_load_balancer.maybe_cudagraph_done_wait()
400400

401-
use_allgather = not self.enable_alltoall
401+
use_allgather = not use_all_to_all
402402

403403
loadbalancer_local_statistic_info = None
404404
gathered_loadbalancer_local_statistic_info = None
@@ -407,7 +407,7 @@ def forward_chunk(
407407
token_selected_slots = token_selected_experts
408408
else:
409409
if not self.layer_load_balancer.is_static_routing(
410-
) and self.enable_alltoall:
410+
) and use_all_to_all:
411411
self.layer_load_balancer.local_statistic(
412412
token_selected_experts,
413413
is_first_stage=is_first_call,
@@ -416,7 +416,7 @@ def forward_chunk(
416416
token_selected_experts, self.use_dp)
417417
if not self.layer_load_balancer.is_static_routing():
418418
# split into two part to get possible overlap with load balancer routing
419-
if self.enable_alltoall:
419+
if use_all_to_all:
420420
if is_last_call:
421421
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
422422
)
@@ -428,6 +428,8 @@ def forward_chunk(
428428
ExpertStatistic.set_layer(self.layer_idx)
429429
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
430430

431+
# If alltoall is disabled, we need also disable use_postquant_alltoall
432+
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
431433
if use_all_to_all:
432434
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
433435
if self.enable_dummy_allreduce:
@@ -439,13 +441,14 @@ def forward_chunk(
439441
x,
440442
token_selected_slots,
441443
token_final_scales,
444+
use_postquant_alltoall,
442445
loadbalancer_local_statistic_info)
443446
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
444-
if not self.use_postquant_alltoall:
447+
if not use_postquant_alltoall:
445448
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
446449
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
447450
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
448-
if not self.use_postquant_alltoall:
451+
if not use_postquant_alltoall:
449452
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
450453
deep_ep_topk_weights = token_final_scales
451454
x, recv_expert_count, deep_ep_handle = \
@@ -485,7 +488,7 @@ def forward_chunk(
485488
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
486489
x, self.fc31_input_dequant)
487490
elif self.has_nvfp4:
488-
if use_allgather or self.use_postquant_alltoall:
491+
if use_allgather or use_postquant_alltoall:
489492
if isinstance(x, Fp4QuantizedTensor):
490493
if use_allgather:
491494
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -518,7 +521,7 @@ def forward_chunk(
518521
f"unsupported quantization mode: {self.quant_config.quant_mode}"
519522
)
520523

521-
if use_allgather and not use_all_to_all::
524+
if use_allgather and not use_all_to_all:
522525
# using allgather case.
523526
if self.enable_dummy_allreduce:
524527
self.dummy_allreduce()
@@ -541,7 +544,7 @@ def forward_chunk(
541544

542545
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
543546
):
544-
if self.enable_alltoall:
547+
if use_all_to_all:
545548
if is_last_call:
546549
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
547550
(self.mapping.moe_ep_size, self.num_experts))
@@ -561,7 +564,7 @@ def forward_chunk(
561564
cluster_rank = self.cluster_rank
562565
quant_scales = self.quant_scales
563566

564-
if self.use_postquant_alltoall:
567+
if use_postquant_alltoall:
565568
if x_sf is not None and self.has_nvfp4:
566569
assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall"
567570
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
@@ -814,7 +817,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
814817
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
815818
num_chunks)
816819
chunk_size_list = all_rank_chunk_size_list[self.rank]
817-
if self.enable_alltoall:
820+
if use_all_to_all:
818821
all_rank_num_tokens_list = [[
819822
1 if val == 0 else val for val in val_list
820823
] for val_list in all_rank_num_tokens_list]
@@ -912,7 +915,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
912915
def alltoall_prepare_maybe_dispatch(
913916
self, all_rank_max_num_tokens: int, x: torch.Tensor,
914917
token_selected_slots: torch.Tensor,
915-
token_final_scales: torch.Tensor,
918+
token_final_scales: torch.Tensor, use_postquant_alltoall: bool,
916919
local_statistic_tensor: Optional[torch.Tensor]):
917920
top_k = self.routing_method.experts_per_token
918921

@@ -958,7 +961,7 @@ def alltoall_prepare_maybe_dispatch(
958961
gathered_token_final_scales, all_rank_max_num_tokens,
959962
self.num_slots, top_k, self.ep_rank, self.ep_size)
960963

961-
if not self.use_postquant_alltoall:
964+
if not use_postquant_alltoall:
962965
assert not isinstance(
963966
x, Fp4QuantizedTensor
964967
), "pre-quant alltoall doesn't support fp4 tensor"

0 commit comments

Comments
 (0)