Skip to content

Commit 7024f73

Browse files
committed
support TRTLLM_DEEP_EP_TOKEN_LIMIT to allow run deep-ep on memory-constrained GPUs.
DeepEP requires additional RDMA memory for communication, and on memory-constrained GPUs, we may not have enough memory to enable DeepEP for both the context and decoding phases. In disaggregated serving scenarios, it's straightforward to enable DeepEP only on the decoding server. However, for inflight batching, we need to apply a token limit so that DeepEP is only used during decoding. Signed-off-by: Vincent Huang <[email protected]>
1 parent e3ccca0 commit 7024f73

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

examples/pytorch/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def add_llm_args(parser):
5151
parser.add_argument('--moe_backend',
5252
type=str,
5353
default='CUTLASS',
54-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
54+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5555
parser.add_argument('--enable_attention_dp',
5656
default=False,
5757
action='store_true')

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,12 @@ def __init__(
193193
model_config.mapping)
194194
self.deep_ep_buffer.reserve(hidden_size, dtype)
195195
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
196-
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
197-
self.moe_max_num_tokens)
196+
self.deep_ep_max_num_tokens = int(
197+
os.environ.get(
198+
"TRTLLM_DEEP_EP_TOKEN_LIMIT",
199+
str(
200+
min(model_config.max_num_tokens,
201+
self.moe_max_num_tokens))))
198202
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
199203
model_config.mapping)
200204
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
@@ -277,6 +281,16 @@ def enable_alltoall(self):
277281
"""
278282
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
279283

284+
def can_use_alltoall(self, input):
285+
num_tokens = input.shape[0]
286+
287+
# For DeepEPLowLatency, check if tokens exceed the threshold
288+
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
289+
and num_tokens > self.deep_ep_max_num_tokens):
290+
return False
291+
292+
return self.enable_alltoall
293+
280294
def _get_quant_method(self):
281295
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
282296
exclude_kv_cache=True):
@@ -319,11 +333,12 @@ def dummy_allreduce(self):
319333
def reducescatter_or_allreduce(
320334
self,
321335
inputs,
336+
use_all_to_all: bool,
322337
all_rank_num_tokens: Optional[List[int]] = None,
323338
use_dp_padding: Optional[bool] = None,
324339
):
325340
outputs = inputs
326-
if self.parallel_size > 1 and not self.enable_alltoall:
341+
if self.parallel_size > 1 and not use_all_to_all:
327342
if self.use_dp:
328343
if self.enable_dummy_allreduce:
329344
self.dummy_allreduce()
@@ -340,6 +355,7 @@ def forward_chunk(
340355
self,
341356
x: Union[torch.Tensor, Fp4QuantizedTensor],
342357
router_logits: torch.Tensor,
358+
use_all_to_all: bool,
343359
cutlass_min_latency_mode: bool = False,
344360
output_dtype: Optional[torch.dtype] = None,
345361
all_rank_num_tokens: Optional[List[int]] = None,
@@ -412,7 +428,7 @@ def forward_chunk(
412428
ExpertStatistic.set_layer(self.layer_idx)
413429
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
414430

415-
if self.enable_alltoall:
431+
if use_all_to_all:
416432
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
417433
if self.enable_dummy_allreduce:
418434
self.dummy_allreduce()
@@ -499,7 +515,7 @@ def forward_chunk(
499515
)
500516

501517
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
502-
) and not self.enable_alltoall:
518+
) and not use_all_to_all:
503519
if self.enable_dummy_allreduce:
504520
self.dummy_allreduce()
505521
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
@@ -588,7 +604,7 @@ def forward_chunk(
588604
f"Not available alltoall method type: {self.alltoall_method_type!r}"
589605
)
590606

591-
if self.enable_alltoall:
607+
if use_all_to_all:
592608
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
593609
# TODO: remove the adapter by changing APIs
594610
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
@@ -628,7 +644,7 @@ def forward_chunk(
628644
ep_rank=ep_rank,
629645
cluster_size=cluster_size,
630646
cluster_rank=cluster_rank,
631-
enable_alltoall=self.enable_alltoall,
647+
enable_alltoall=use_all_to_all,
632648
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
633649
use_w4a8_group_scaling=use_w4a8_group_scaling,
634650
min_latency_mode=cutlass_min_latency_mode,
@@ -648,7 +664,7 @@ def forward_chunk(
648664
# Otherwise, the output should be unpacked as a single tensor.
649665
final_hidden_states = final_hidden_states[0]
650666

651-
if self.enable_alltoall:
667+
if use_all_to_all:
652668
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
653669
if self.enable_dummy_allreduce:
654670
self.dummy_allreduce()
@@ -726,11 +742,14 @@ def forward(
726742
else:
727743
all_rank_num_tokens_padded = all_rank_num_tokens
728744
if num_chunks == 1:
745+
use_all_to_all = self.can_use_alltoall(x)
746+
729747
is_first_call = self.repeat_idx == 0
730748
is_last_call = self.repeat_idx == self.repeat_count - 1
731749
outputs = self.forward_chunk(
732750
x,
733751
router_logits,
752+
use_all_to_all,
734753
cutlass_min_latency_mode,
735754
output_dtype,
736755
all_rank_num_tokens=all_rank_num_tokens_padded,
@@ -739,10 +758,13 @@ def forward(
739758
repeating_info=(is_first_call, is_last_call))
740759
outputs = self.reducescatter_or_allreduce(
741760
outputs,
761+
use_all_to_all,
742762
all_rank_num_tokens=all_rank_num_tokens_padded,
743763
use_dp_padding=use_dp_padding)
744764
else:
745765

766+
use_all_to_all = False
767+
746768
def split_chunk(split_token_num: int, split_num_chunks: int):
747769
val_div = split_token_num // split_num_chunks
748770
val_mod = split_token_num % split_num_chunks
@@ -761,7 +783,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
761783
all_rank_max_num_tokens_list = split_chunk(
762784
all_rank_max_num_tokens, num_chunks)
763785
chunk_size_list = all_rank_chunk_size_list[self.rank]
764-
if self.enable_alltoall:
786+
if use_all_to_all:
765787
all_rank_num_tokens_list = [[
766788
1 if val == 0 else val for val in val_list
767789
] for val_list in all_rank_num_tokens_list]
@@ -777,7 +799,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
777799
x_list = x.split(chunk_size_list)
778800
router_logits_list = router_logits.split(chunk_size_list)
779801

780-
if not self.enable_alltoall:
802+
if not use_all_to_all:
781803
self.event_dict[EventType.Main].record()
782804
with torch.cuda.stream(self.aux_stream):
783805
self.event_dict[EventType.Main].wait()
@@ -788,12 +810,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
788810
zip(x_list, router_logits_list)):
789811
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
790812
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
791-
if not self.enable_alltoall:
813+
if not use_all_to_all:
792814
if idx_chunk % 2 == 0:
793815
with torch.cuda.stream(self.aux_stream):
794816
outputs = self.forward_chunk(
795817
x,
796818
router_logits,
819+
use_all_to_all,
797820
all_rank_num_tokens=all_rank_num_tokens_list[
798821
idx_chunk] if self.use_dp else None,
799822
all_rank_max_num_tokens=
@@ -804,13 +827,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
804827
if idx_chunk > 0:
805828
outputs_list[-1] = self.reducescatter_or_allreduce(
806829
outputs_list[-1],
830+
use_all_to_all,
807831
all_rank_num_tokens=all_rank_num_tokens_list[
808832
idx_chunk - 1],
809833
use_dp_padding=use_dp_padding)
810834
else:
811835
outputs = self.forward_chunk(
812836
x,
813837
router_logits,
838+
use_all_to_all,
814839
all_rank_num_tokens=all_rank_num_tokens_list[
815840
idx_chunk] if self.use_dp else None,
816841
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
@@ -820,30 +845,34 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
820845
with torch.cuda.stream(self.aux_stream):
821846
outputs_list[-1] = self.reducescatter_or_allreduce(
822847
outputs_list[-1],
848+
use_all_to_all,
823849
all_rank_num_tokens=all_rank_num_tokens_list[
824850
idx_chunk - 1],
825851
use_dp_padding=use_dp_padding)
826852
else:
827853
outputs = self.forward_chunk(
828854
x,
829855
router_logits,
856+
use_all_to_all,
830857
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk]
831858
if self.use_dp else None,
832859
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
833860
idx_chunk] if self.use_dp else None,
834861
repeating_info=(is_first_call, is_last_call))
835862

836863
outputs_list.append(outputs)
837-
if not self.enable_alltoall:
864+
if not use_all_to_all:
838865
if num_chunks % 2 == 0:
839866
outputs_list[-1] = self.reducescatter_or_allreduce(
840867
outputs_list[-1],
868+
use_all_to_all,
841869
all_rank_num_tokens=all_rank_num_tokens_list[-1],
842870
use_dp_padding=use_dp_padding)
843871
else:
844872
with torch.cuda.stream(self.aux_stream):
845873
outputs_list[-1] = self.reducescatter_or_allreduce(
846874
outputs_list[-1],
875+
use_all_to_all,
847876
all_rank_num_tokens=all_rank_num_tokens_list[-1],
848877
use_dp_padding=use_dp_padding)
849878
with torch.cuda.stream(self.aux_stream):

0 commit comments

Comments
 (0)