Skip to content

Commit 35ba966

Browse files
committed
all2all check from both chunking and threshold
Signed-off-by: Vincent Huang <[email protected]>
1 parent ea2c89c commit 35ba966

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,16 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
512512
# max-throughput
513513
use_dp_padding = False
514514
if self.use_dp and self.mapping.tp_size > 1:
515+
# MoE use static heuristic to check alltoall enabled or not, however, for wide_ep, the alltoall could also be dynamically disabled when chunking is used or TRTLLM_DEEP_EP_TOKEN_LIMIT is hit.
516+
is_wide_ep_alltoall_disabled = isinstance(
517+
self.experts, WideEPMoE) and not self.experts.can_use_alltoall(
518+
hidden_states, all_rank_num_tokens)
519+
alltoall_enabled = self.experts.enable_alltoall and not is_wide_ep_alltoall_disabled
520+
515521
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
516522
# to reduce allreduce BW
517-
if (disable_fp4_allgather()
518-
and not self.experts.can_use_alltoall(hidden_states_fp4 or hidden_states)) or isinstance(
519-
self.experts, TRTLLMGenFusedMoE):
523+
if (disable_fp4_allgather() and not alltoall_enabled) or isinstance(
524+
self.experts, TRTLLMGenFusedMoE):
520525
hidden_states = allgather(hidden_states,
521526
self.mapping,
522527
dim=0,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,16 @@ def enable_alltoall(self):
278278
"""
279279
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
280280

281-
def can_use_alltoall(self, input):
281+
def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
282+
num_rows = sum(all_rank_num_tokens)
283+
return (num_rows + self.moe_max_num_tokens -
284+
1) // self.moe_max_num_tokens
285+
286+
def can_use_alltoall(self, input, all_rank_num_tokens):
287+
# Disable alltoall when chunking is used
288+
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
289+
return False
290+
282291
num_tokens = input.shape[0]
283292

284293
# For DeepEPLowLatency, check if tokens exceed the threshold
@@ -521,7 +530,7 @@ def forward_chunk(
521530
f"unsupported quantization mode: {self.quant_config.quant_mode}"
522531
)
523532

524-
if use_allgather and not use_all_to_all:
533+
if use_allgather:
525534
# using allgather case.
526535
if self.enable_dummy_allreduce:
527536
self.dummy_allreduce()
@@ -766,20 +775,17 @@ def forward(
766775
) -> torch.Tensor:
767776
assert all_rank_num_tokens is not None
768777
assert use_dp_padding is not None
769-
num_rows = sum(all_rank_num_tokens)
770778

771779
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
772-
num_chunks = (num_rows + self.moe_max_num_tokens -
773-
1) // self.moe_max_num_tokens
780+
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
781+
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)
774782

775783
if use_dp_padding:
776784
all_rank_num_tokens_padded = [all_rank_max_num_tokens
777785
] * len(all_rank_num_tokens)
778786
else:
779787
all_rank_num_tokens_padded = all_rank_num_tokens
780788
if num_chunks == 1:
781-
use_all_to_all = self.can_use_alltoall(x)
782-
783789
is_first_call = self.repeat_idx == 0
784790
is_last_call = self.repeat_idx == self.repeat_count - 1
785791
outputs = self.forward_chunk(
@@ -798,8 +804,6 @@ def forward(
798804
use_dp_padding=use_dp_padding)
799805
else:
800806

801-
use_all_to_all = False
802-
803807
def split_chunk(split_token_num: int, split_num_chunks: int):
804808
val_div = split_token_num // split_num_chunks
805809
val_mod = split_token_num % split_num_chunks

0 commit comments

Comments
 (0)