Skip to content

Commit 0523f77

Browse files
authored
support TRTLLM_DEEP_EP_TOKEN_LIMIT to allow run deep-ep on memory-con… (#5684)
Signed-off-by: Vincent Huang <[email protected]>
1 parent e761231 commit 0523f77

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_llm_args(parser):
5050
parser.add_argument('--moe_backend',
5151
type=str,
5252
default='CUTLASS',
53-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
53+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5454
parser.add_argument('--enable_attention_dp',
5555
default=False,
5656
action='store_true')

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@
6262
from ..modules.rms_norm import RMSNorm
6363
from ..peft.lora.layer import LoraLayer
6464
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
65-
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
66-
disable_fp4_allgather)
65+
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
6766
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
6867
EagerFusionConfig, filter_weights,
6968
register_auto_model)
@@ -514,9 +513,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
514513
if self.use_dp and self.mapping.tp_size > 1:
515514
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
516515
# to reduce allreduce BW
517-
if (disable_fp4_allgather()
518-
and not self.experts.enable_alltoall) or isinstance(
519-
self.experts, TRTLLMGenFusedMoE):
516+
if isinstance(self.experts, TRTLLMGenFusedMoE):
520517
hidden_states = allgather(hidden_states,
521518
self.mapping,
522519
dim=0,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,12 @@ def __init__(
192192
model_config.mapping)
193193
self.deep_ep_buffer.reserve(hidden_size, dtype)
194194
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
195-
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
196-
self.moe_max_num_tokens)
195+
self.deep_ep_max_num_tokens = int(
196+
os.environ.get(
197+
"TRTLLM_DEEP_EP_TOKEN_LIMIT",
198+
str(
199+
min(model_config.max_num_tokens,
200+
self.moe_max_num_tokens))))
197201
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
198202
model_config.mapping)
199203
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
@@ -274,6 +278,25 @@ def enable_alltoall(self):
274278
"""
275279
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
276280

281+
def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
282+
num_rows = sum(all_rank_num_tokens)
283+
return (num_rows + self.moe_max_num_tokens -
284+
1) // self.moe_max_num_tokens
285+
286+
def can_use_alltoall(self, input, all_rank_num_tokens):
287+
# Disable alltoall when chunking is used
288+
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
289+
return False
290+
291+
num_tokens = input.shape[0]
292+
293+
# For DeepEPLowLatency, check if tokens exceed the threshold
294+
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
295+
and num_tokens > self.deep_ep_max_num_tokens):
296+
return False
297+
298+
return self.enable_alltoall
299+
277300
def _get_quant_method(self):
278301
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
279302
exclude_kv_cache=True):
@@ -316,11 +339,12 @@ def dummy_allreduce(self):
316339
def reducescatter_or_allreduce(
317340
self,
318341
inputs,
342+
use_all_to_all: bool,
319343
all_rank_num_tokens: Optional[List[int]] = None,
320344
use_dp_padding: Optional[bool] = None,
321345
):
322346
outputs = inputs
323-
if not self.enable_alltoall:
347+
if not use_all_to_all:
324348
if self.enable_dummy_allreduce:
325349
self.dummy_allreduce()
326350
outputs = reducescatter(
@@ -334,6 +358,7 @@ def forward_chunk(
334358
self,
335359
x: Union[torch.Tensor, Fp4QuantizedTensor],
336360
router_logits: torch.Tensor,
361+
use_all_to_all: bool,
337362
output_dtype: Optional[torch.dtype] = None,
338363
all_rank_num_tokens: Optional[List[int]] = None,
339364
all_rank_max_num_tokens: Optional[int] = None,
@@ -382,7 +407,7 @@ def forward_chunk(
382407
) and is_first_call:
383408
self.layer_load_balancer.maybe_cudagraph_done_wait()
384409

385-
use_allgather = not self.enable_alltoall
410+
use_allgather = not use_all_to_all
386411

387412
loadbalancer_local_statistic_info = None
388413
gathered_loadbalancer_local_statistic_info = None
@@ -391,7 +416,7 @@ def forward_chunk(
391416
token_selected_slots = token_selected_experts
392417
else:
393418
if not self.layer_load_balancer.is_static_routing(
394-
) and self.enable_alltoall:
419+
) and use_all_to_all:
395420
self.layer_load_balancer.local_statistic(
396421
token_selected_experts,
397422
is_first_stage=is_first_call,
@@ -400,7 +425,7 @@ def forward_chunk(
400425
token_selected_experts, self.use_dp)
401426
if not self.layer_load_balancer.is_static_routing():
402427
# split into two part to get possible overlap with load balancer routing
403-
if self.enable_alltoall:
428+
if use_all_to_all:
404429
if is_last_call:
405430
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
406431
)
@@ -412,7 +437,9 @@ def forward_chunk(
412437
ExpertStatistic.set_layer(self.layer_idx)
413438
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
414439

415-
if self.enable_alltoall:
440+
# If alltoall is disabled, we need also disable use_postquant_alltoall
441+
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
442+
if use_all_to_all:
416443
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
417444
if self.enable_dummy_allreduce:
418445
self.dummy_allreduce()
@@ -423,15 +450,16 @@ def forward_chunk(
423450
x,
424451
token_selected_slots,
425452
token_final_scales,
453+
use_postquant_alltoall,
426454
loadbalancer_local_statistic_info)
427455
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
428-
if not self.use_postquant_alltoall:
456+
if not use_postquant_alltoall:
429457
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
430458
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
431459
padded, x, _, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
432460
x, None, recv_topk_idx, token_final_scales)
433461
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
434-
if not self.use_postquant_alltoall:
462+
if not use_postquant_alltoall:
435463
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
436464
deep_ep_topk_weights = token_final_scales
437465
x, recv_expert_count, deep_ep_handle = \
@@ -471,7 +499,7 @@ def forward_chunk(
471499
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
472500
x, self.fc31_input_dequant)
473501
elif self.has_nvfp4:
474-
if use_allgather or self.use_postquant_alltoall:
502+
if use_allgather or use_postquant_alltoall:
475503
if isinstance(x, Fp4QuantizedTensor):
476504
if use_allgather:
477505
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -527,7 +555,7 @@ def forward_chunk(
527555

528556
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
529557
):
530-
if self.enable_alltoall:
558+
if use_all_to_all:
531559
if is_last_call:
532560
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
533561
(self.mapping.moe_ep_size, self.num_experts))
@@ -547,7 +575,7 @@ def forward_chunk(
547575
cluster_rank = self.cluster_rank
548576
quant_scales = self.quant_scales
549577

550-
if self.use_postquant_alltoall:
578+
if use_postquant_alltoall:
551579
if x_sf is not None and self.has_nvfp4:
552580
assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall"
553581
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
@@ -640,7 +668,7 @@ def forward_chunk(
640668
f"Not available alltoall method type: {self.alltoall_method_type!r}"
641669
)
642670

643-
if self.enable_alltoall:
671+
if use_all_to_all:
644672
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
645673
# TODO: remove the adapter by changing APIs
646674
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
@@ -666,7 +694,7 @@ def forward_chunk(
666694
ep_rank=ep_rank,
667695
cluster_size=cluster_size,
668696
cluster_rank=cluster_rank,
669-
enable_alltoall=self.enable_alltoall,
697+
enable_alltoall=use_all_to_all,
670698
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
671699
use_w4a8_group_scaling=use_w4a8_group_scaling,
672700
min_latency_mode=False,
@@ -681,7 +709,7 @@ def forward_chunk(
681709
# Otherwise, the output should be unpacked as a single tensor.
682710
final_hidden_states = final_hidden_states[0]
683711

684-
if self.enable_alltoall:
712+
if use_all_to_all:
685713
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
686714
if self.enable_dummy_allreduce:
687715
self.dummy_allreduce()
@@ -737,11 +765,10 @@ def forward(
737765
) -> torch.Tensor:
738766
assert all_rank_num_tokens is not None
739767
assert use_dp_padding is not None
740-
num_rows = sum(all_rank_num_tokens)
741768

742769
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
743-
num_chunks = (num_rows + self.moe_max_num_tokens -
744-
1) // self.moe_max_num_tokens
770+
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
771+
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)
745772

746773
if use_dp_padding:
747774
all_rank_num_tokens_padded = [all_rank_max_num_tokens
@@ -754,13 +781,15 @@ def forward(
754781
outputs = self.forward_chunk(
755782
x,
756783
router_logits,
784+
use_all_to_all,
757785
output_dtype,
758786
all_rank_num_tokens=all_rank_num_tokens_padded,
759787
all_rank_max_num_tokens=all_rank_max_num_tokens,
760788
use_dp_padding=use_dp_padding,
761789
repeating_info=(is_first_call, is_last_call))
762790
outputs = self.reducescatter_or_allreduce(
763791
outputs,
792+
use_all_to_all,
764793
all_rank_num_tokens=all_rank_num_tokens_padded,
765794
use_dp_padding=use_dp_padding)
766795
else:
@@ -782,7 +811,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
782811
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
783812
num_chunks)
784813
chunk_size_list = all_rank_chunk_size_list[self.rank]
785-
if self.enable_alltoall:
814+
if use_all_to_all:
786815
all_rank_num_tokens_list = [[
787816
1 if val == 0 else val for val in val_list
788817
] for val_list in all_rank_num_tokens_list]
@@ -794,7 +823,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
794823
x_list = x.split(chunk_size_list)
795824
router_logits_list = router_logits.split(chunk_size_list)
796825

797-
if not self.enable_alltoall:
826+
if not use_all_to_all:
798827
self.event_dict[EventType.Main].record()
799828
with torch.cuda.stream(self.aux_stream):
800829
self.event_dict[EventType.Main].wait()
@@ -805,12 +834,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
805834
zip(x_list, router_logits_list)):
806835
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
807836
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
808-
if not self.enable_alltoall:
837+
if not use_all_to_all:
809838
if idx_chunk % 2 == 0:
810839
with torch.cuda.stream(self.aux_stream):
811840
outputs = self.forward_chunk(
812841
x,
813842
router_logits,
843+
use_all_to_all,
814844
all_rank_num_tokens=all_rank_num_tokens_list[
815845
idx_chunk],
816846
all_rank_max_num_tokens=
@@ -820,13 +850,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
820850
if idx_chunk > 0:
821851
outputs_list[-1] = self.reducescatter_or_allreduce(
822852
outputs_list[-1],
853+
use_all_to_all,
823854
all_rank_num_tokens=all_rank_num_tokens_list[
824855
idx_chunk - 1],
825856
use_dp_padding=use_dp_padding)
826857
else:
827858
outputs = self.forward_chunk(
828859
x,
829860
router_logits,
861+
use_all_to_all,
830862
all_rank_num_tokens=all_rank_num_tokens_list[
831863
idx_chunk],
832864
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
@@ -836,29 +868,33 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
836868
with torch.cuda.stream(self.aux_stream):
837869
outputs_list[-1] = self.reducescatter_or_allreduce(
838870
outputs_list[-1],
871+
use_all_to_all,
839872
all_rank_num_tokens=all_rank_num_tokens_list[
840873
idx_chunk - 1],
841874
use_dp_padding=use_dp_padding)
842875
else:
843876
outputs = self.forward_chunk(
844877
x,
845878
router_logits,
879+
use_all_to_all,
846880
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
847881
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
848882
idx_chunk],
849883
repeating_info=(is_first_call, is_last_call))
850884

851885
outputs_list.append(outputs)
852-
if not self.enable_alltoall:
886+
if not use_all_to_all:
853887
if num_chunks % 2 == 0:
854888
outputs_list[-1] = self.reducescatter_or_allreduce(
855889
outputs_list[-1],
890+
use_all_to_all,
856891
all_rank_num_tokens=all_rank_num_tokens_list[-1],
857892
use_dp_padding=use_dp_padding)
858893
else:
859894
with torch.cuda.stream(self.aux_stream):
860895
outputs_list[-1] = self.reducescatter_or_allreduce(
861896
outputs_list[-1],
897+
use_all_to_all,
862898
all_rank_num_tokens=all_rank_num_tokens_list[-1],
863899
use_dp_padding=use_dp_padding)
864900
with torch.cuda.stream(self.aux_stream):
@@ -873,7 +909,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
873909
def alltoall_prepare_maybe_dispatch(
874910
self, all_rank_max_num_tokens: int, x: torch.Tensor,
875911
token_selected_slots: torch.Tensor,
876-
token_final_scales: torch.Tensor,
912+
token_final_scales: torch.Tensor, use_postquant_alltoall: bool,
877913
local_statistic_tensor: Optional[torch.Tensor]):
878914
top_k = self.routing_method.experts_per_token
879915

@@ -919,7 +955,7 @@ def alltoall_prepare_maybe_dispatch(
919955
gathered_token_final_scales, all_rank_max_num_tokens,
920956
self.num_slots, top_k, self.ep_rank, self.ep_size)
921957

922-
if not self.use_postquant_alltoall:
958+
if not use_postquant_alltoall:
923959
assert not isinstance(
924960
x, Fp4QuantizedTensor
925961
), "pre-quant alltoall doesn't support fp4 tensor"

0 commit comments

Comments
 (0)