Skip to content

Commit cf24bf7

Browse files
committed
all2all check from both chunking and threshold
Signed-off-by: Vincent Huang <[email protected]>
1 parent 6c7ad69 commit cf24bf7

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,16 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
512512
# max-throughput
513513
use_dp_padding = False
514514
if self.use_dp and self.mapping.tp_size > 1:
515+
# MoE use static heuristic to check alltoall enabled or not, however, for wide_ep, the alltoall could also be dynamically disabled when chunking is used or TRTLLM_DEEP_EP_TOKEN_LIMIT is hit.
516+
is_wide_ep_alltoall_disabled = isinstance(
517+
self.experts, WideEPMoE) and not self.experts.can_use_alltoall(
518+
hidden_states, all_rank_num_tokens)
519+
alltoall_enabled = self.experts.enable_alltoall and not is_wide_ep_alltoall_disabled
520+
515521
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
516522
# to reduce allreduce BW
517-
if (disable_fp4_allgather()
518-
and not self.experts.can_use_alltoall(hidden_states_fp4 or hidden_states)) or isinstance(
519-
self.experts, TRTLLMGenFusedMoE):
523+
if (disable_fp4_allgather() and not alltoall_enabled) or isinstance(
524+
self.experts, TRTLLMGenFusedMoE):
520525
hidden_states = allgather(hidden_states,
521526
self.mapping,
522527
dim=0,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,16 @@ def enable_alltoall(self):
277277
"""
278278
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
279279

280-
def can_use_alltoall(self, input):
280+
def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
281+
num_rows = sum(all_rank_num_tokens)
282+
return (num_rows + self.moe_max_num_tokens -
283+
1) // self.moe_max_num_tokens
284+
285+
def can_use_alltoall(self, input, all_rank_num_tokens):
286+
# Disable alltoall when chunking is used
287+
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
288+
return False
289+
281290
num_tokens = input.shape[0]
282291

283292
# For DeepEPLowLatency, check if tokens exceed the threshold
@@ -507,7 +516,7 @@ def forward_chunk(
507516
f"unsupported quantization mode: {self.quant_config.quant_mode}"
508517
)
509518

510-
if use_allgather and not use_all_to_all:
519+
if use_allgather:
511520
# using allgather case.
512521
if self.enable_dummy_allreduce:
513522
self.dummy_allreduce()
@@ -752,20 +761,17 @@ def forward(
752761
) -> torch.Tensor:
753762
assert all_rank_num_tokens is not None
754763
assert use_dp_padding is not None
755-
num_rows = sum(all_rank_num_tokens)
756764

757765
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
758-
num_chunks = (num_rows + self.moe_max_num_tokens -
759-
1) // self.moe_max_num_tokens
766+
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
767+
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)
760768

761769
if use_dp_padding:
762770
all_rank_num_tokens_padded = [all_rank_max_num_tokens
763771
] * len(all_rank_num_tokens)
764772
else:
765773
all_rank_num_tokens_padded = all_rank_num_tokens
766774
if num_chunks == 1:
767-
use_all_to_all = self.can_use_alltoall(x)
768-
769775
is_first_call = self.repeat_idx == 0
770776
is_last_call = self.repeat_idx == self.repeat_count - 1
771777
outputs = self.forward_chunk(
@@ -784,8 +790,6 @@ def forward(
784790
use_dp_padding=use_dp_padding)
785791
else:
786792

787-
use_all_to_all = False
788-
789793
def split_chunk(split_token_num: int, split_num_chunks: int):
790794
val_div = split_token_num // split_num_chunks
791795
val_mod = split_token_num % split_num_chunks

0 commit comments

Comments
 (0)