Skip to content

Commit ac839cb

Browse files
committed
support TRTLLM_DEEP_EP_TOKEN_LIMIT to allow run deep-ep on memory-constrained GPUs.
DeepEP requires additional RDMA memory for communication, and on memory-constrained GPUs, we may not have enough memory to enable DeepEP for both the context and decoding phases. In disaggregated serving scenarios, it's straightforward to enable DeepEP only on the decoding server. However, for inflight batching, we need to apply a token limit so that DeepEP is only used during decoding. Signed-off-by: Vincent Huang <[email protected]>
1 parent aa72d39 commit ac839cb

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

examples/pytorch/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def add_llm_args(parser):
5151
parser.add_argument('--moe_backend',
5252
type=str,
5353
default='CUTLASS',
54-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
54+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5555
parser.add_argument('--enable_attention_dp',
5656
default=False,
5757
action='store_true')

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,12 @@ def __init__(
193193
model_config.mapping)
194194
self.deep_ep_buffer.reserve(hidden_size, dtype)
195195
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
196-
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
197-
self.moe_max_num_tokens)
196+
self.deep_ep_max_num_tokens = int(
197+
os.environ.get(
198+
"TRTLLM_DEEP_EP_TOKEN_LIMIT",
199+
str(
200+
min(model_config.max_num_tokens,
201+
self.moe_max_num_tokens))))
198202
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
199203
model_config.mapping)
200204
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
@@ -273,6 +277,10 @@ def enable_alltoall(self):
273277
"""
274278
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
275279

280+
def can_use_alltoall(self, input):
281+
num_tokens = input.shape[0]
282+
return self.enable_alltoall and num_tokens <= self.deep_ep_max_num_tokens
283+
276284
def _get_quant_method(self):
277285
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
278286
exclude_kv_cache=True):
@@ -305,11 +313,12 @@ def create_weights(self):
305313
def reducescatter_or_allreduce(
306314
self,
307315
inputs,
316+
use_all_to_all: bool,
308317
all_rank_num_tokens: Optional[List[int]] = None,
309318
use_dp_padding: Optional[bool] = None,
310319
):
311320
outputs = inputs
312-
if self.parallel_size > 1 and not self.enable_alltoall:
321+
if self.parallel_size > 1 and not use_all_to_all:
313322
if self.use_dp:
314323
outputs = reducescatter(
315324
inputs,
@@ -324,6 +333,7 @@ def forward_chunk(
324333
self,
325334
x: Union[torch.Tensor, Fp4QuantizedTensor],
326335
router_logits: torch.Tensor,
336+
use_all_to_all: bool,
327337
cutlass_min_latency_mode: bool = False,
328338
output_dtype: Optional[torch.dtype] = None,
329339
all_rank_num_tokens: Optional[List[int]] = None,
@@ -396,7 +406,7 @@ def forward_chunk(
396406
ExpertStatistic.set_layer(self.layer_idx)
397407
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
398408

399-
if self.enable_alltoall:
409+
if use_all_to_all:
400410
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
401411
token_count = x.shape[0]
402412
alltoall_info = None
@@ -483,7 +493,7 @@ def forward_chunk(
483493
)
484494

485495
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
486-
) and not self.enable_alltoall:
496+
) and not use_all_to_all:
487497
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
488498
[
489499
x,
@@ -570,7 +580,7 @@ def forward_chunk(
570580
f"Not available alltoall method type: {self.alltoall_method_type!r}"
571581
)
572582

573-
if self.enable_alltoall:
583+
if use_all_to_all:
574584
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
575585
# TODO: remove the adapter by changing APIs
576586
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
@@ -610,7 +620,7 @@ def forward_chunk(
610620
ep_rank=ep_rank,
611621
cluster_size=cluster_size,
612622
cluster_rank=cluster_rank,
613-
enable_alltoall=self.enable_alltoall,
623+
enable_alltoall=use_all_to_all,
614624
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
615625
use_w4a8_group_scaling=use_w4a8_group_scaling,
616626
min_latency_mode=cutlass_min_latency_mode,
@@ -630,7 +640,7 @@ def forward_chunk(
630640
# Otherwise, the output should be unpacked as a single tensor.
631641
final_hidden_states = final_hidden_states[0]
632642

633-
if self.enable_alltoall:
643+
if use_all_to_all:
634644
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
635645
final_hidden_states = self.alltoall_combine(
636646
final_hidden_states, alltoall_info, token_count)
@@ -691,11 +701,14 @@ def forward(
691701
else:
692702
all_rank_num_tokens_padded = all_rank_num_tokens
693703
if num_chunks == 1:
704+
use_all_to_all = self.can_use_alltoall(x)
705+
694706
is_first_call = self.repeat_idx == 0
695707
is_last_call = self.repeat_idx == self.repeat_count - 1
696708
outputs = self.forward_chunk(
697709
x,
698710
router_logits,
711+
use_all_to_all,
699712
cutlass_min_latency_mode,
700713
output_dtype,
701714
all_rank_num_tokens=all_rank_num_tokens_padded,
@@ -704,10 +717,13 @@ def forward(
704717
repeating_info=(is_first_call, is_last_call))
705718
outputs = self.reducescatter_or_allreduce(
706719
outputs,
720+
use_all_to_all,
707721
all_rank_num_tokens=all_rank_num_tokens_padded,
708722
use_dp_padding=use_dp_padding)
709723
else:
710724

725+
use_all_to_all = False
726+
711727
def split_chunk(split_token_num: int, split_num_chunks: int):
712728
val_div = split_token_num // split_num_chunks
713729
val_mod = split_token_num % split_num_chunks
@@ -726,7 +742,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
726742
all_rank_max_num_tokens_list = split_chunk(
727743
all_rank_max_num_tokens, num_chunks)
728744
chunk_size_list = all_rank_chunk_size_list[self.rank]
729-
if self.enable_alltoall:
745+
if use_all_to_all:
730746
all_rank_num_tokens_list = [[
731747
1 if val == 0 else val for val in val_list
732748
] for val_list in all_rank_num_tokens_list]
@@ -742,7 +758,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
742758
x_list = x.split(chunk_size_list)
743759
router_logits_list = router_logits.split(chunk_size_list)
744760

745-
if not self.enable_alltoall:
761+
if not use_all_to_all:
746762
self.event_dict[EventType.Main].record()
747763
with torch.cuda.stream(self.aux_stream):
748764
self.event_dict[EventType.Main].wait()
@@ -753,12 +769,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
753769
zip(x_list, router_logits_list)):
754770
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
755771
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
756-
if not self.enable_alltoall:
772+
if not use_all_to_all:
757773
if idx_chunk % 2 == 0:
758774
with torch.cuda.stream(self.aux_stream):
759775
outputs = self.forward_chunk(
760776
x,
761777
router_logits,
778+
use_all_to_all,
762779
all_rank_num_tokens=all_rank_num_tokens_list[
763780
idx_chunk] if self.use_dp else None,
764781
all_rank_max_num_tokens=
@@ -769,13 +786,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
769786
if idx_chunk > 0:
770787
outputs_list[-1] = self.reducescatter_or_allreduce(
771788
outputs_list[-1],
789+
use_all_to_all,
772790
all_rank_num_tokens=all_rank_num_tokens_list[
773791
idx_chunk - 1],
774792
use_dp_padding=use_dp_padding)
775793
else:
776794
outputs = self.forward_chunk(
777795
x,
778796
router_logits,
797+
use_all_to_all,
779798
all_rank_num_tokens=all_rank_num_tokens_list[
780799
idx_chunk] if self.use_dp else None,
781800
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
@@ -785,30 +804,34 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
785804
with torch.cuda.stream(self.aux_stream):
786805
outputs_list[-1] = self.reducescatter_or_allreduce(
787806
outputs_list[-1],
807+
use_all_to_all,
788808
all_rank_num_tokens=all_rank_num_tokens_list[
789809
idx_chunk - 1],
790810
use_dp_padding=use_dp_padding)
791811
else:
792812
outputs = self.forward_chunk(
793813
x,
794814
router_logits,
815+
use_all_to_all,
795816
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk]
796817
if self.use_dp else None,
797818
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
798819
idx_chunk] if self.use_dp else None,
799820
repeating_info=(is_first_call, is_last_call))
800821

801822
outputs_list.append(outputs)
802-
if not self.enable_alltoall:
823+
if not use_all_to_all:
803824
if num_chunks % 2 == 0:
804825
outputs_list[-1] = self.reducescatter_or_allreduce(
805826
outputs_list[-1],
827+
use_all_to_all,
806828
all_rank_num_tokens=all_rank_num_tokens_list[-1],
807829
use_dp_padding=use_dp_padding)
808830
else:
809831
with torch.cuda.stream(self.aux_stream):
810832
outputs_list[-1] = self.reducescatter_or_allreduce(
811833
outputs_list[-1],
834+
use_all_to_all,
812835
all_rank_num_tokens=all_rank_num_tokens_list[-1],
813836
use_dp_padding=use_dp_padding)
814837
with torch.cuda.stream(self.aux_stream):

0 commit comments

Comments
 (0)