Skip to content

Commit 321f081

Browse files
committed
support TRTLLM_DEEP_EP_TOKEN_LIMIT to allow run deep-ep on memory-constrained GPUs.
DeepEP requires additional RDMA memory for communication, and on memory-constrained GPUs, we may not have enough memory to enable DeepEP for both the context and decoding phases. In disaggregated serving scenarios, it's straightforward to enable DeepEP only on the decoding server. However, for inflight batching, we need to apply a token limit so that DeepEP is only used during decoding. Signed-off-by: Vincent Huang <[email protected]>
1 parent 6d4b045 commit 321f081

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_llm_args(parser):
5050
parser.add_argument('--moe_backend',
5151
type=str,
5252
default='CUTLASS',
53-
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
53+
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
5454
parser.add_argument('--enable_attention_dp',
5555
default=False,
5656
action='store_true')

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,12 @@ def __init__(
192192
model_config.mapping)
193193
self.deep_ep_buffer.reserve(hidden_size, dtype)
194194
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
195-
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
196-
self.moe_max_num_tokens)
195+
self.deep_ep_max_num_tokens = int(
196+
os.environ.get(
197+
"TRTLLM_DEEP_EP_TOKEN_LIMIT",
198+
str(
199+
min(model_config.max_num_tokens,
200+
self.moe_max_num_tokens))))
197201
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
198202
model_config.mapping)
199203
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
@@ -274,6 +278,16 @@ def enable_alltoall(self):
274278
"""
275279
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
276280

281+
def can_use_alltoall(self, input):
282+
num_tokens = input.shape[0]
283+
284+
# For DeepEPLowLatency, check if tokens exceed the threshold
285+
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
286+
and num_tokens > self.deep_ep_max_num_tokens):
287+
return False
288+
289+
return self.enable_alltoall
290+
277291
def _get_quant_method(self):
278292
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
279293
exclude_kv_cache=True):
@@ -316,11 +330,12 @@ def dummy_allreduce(self):
316330
def reducescatter_or_allreduce(
317331
self,
318332
inputs,
333+
use_all_to_all: bool,
319334
all_rank_num_tokens: Optional[List[int]] = None,
320335
use_dp_padding: Optional[bool] = None,
321336
):
322337
outputs = inputs
323-
if not self.enable_alltoall:
338+
if not use_all_toall:
324339
if self.enable_dummy_allreduce:
325340
self.dummy_allreduce()
326341
outputs = reducescatter(
@@ -334,6 +349,7 @@ def forward_chunk(
334349
self,
335350
x: Union[torch.Tensor, Fp4QuantizedTensor],
336351
router_logits: torch.Tensor,
352+
use_all_to_all: bool,
337353
output_dtype: Optional[torch.dtype] = None,
338354
all_rank_num_tokens: Optional[List[int]] = None,
339355
all_rank_max_num_tokens: Optional[int] = None,
@@ -412,7 +428,7 @@ def forward_chunk(
412428
ExpertStatistic.set_layer(self.layer_idx)
413429
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
414430

415-
if self.enable_alltoall:
431+
if use_all_to_all:
416432
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
417433
if self.enable_dummy_allreduce:
418434
self.dummy_allreduce()
@@ -502,7 +518,7 @@ def forward_chunk(
502518
f"unsupported quantization mode: {self.quant_config.quant_mode}"
503519
)
504520

505-
if use_allgather:
521+
if use_allgather and not use_all_to_all::
506522
# using allgather case.
507523
if self.enable_dummy_allreduce:
508524
self.dummy_allreduce()
@@ -636,7 +652,7 @@ def forward_chunk(
636652
f"Not available alltoall method type: {self.alltoall_method_type!r}"
637653
)
638654

639-
if self.enable_alltoall:
655+
if use_all_to_all:
640656
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
641657
# TODO: remove the adapter by changing APIs
642658
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
@@ -676,7 +692,7 @@ def forward_chunk(
676692
ep_rank=ep_rank,
677693
cluster_size=cluster_size,
678694
cluster_rank=cluster_rank,
679-
enable_alltoall=self.enable_alltoall,
695+
enable_alltoall=use_all_to_all,
680696
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
681697
use_w4a8_group_scaling=use_w4a8_group_scaling,
682698
min_latency_mode=False,
@@ -691,7 +707,7 @@ def forward_chunk(
691707
# Otherwise, the output should be unpacked as a single tensor.
692708
final_hidden_states = final_hidden_states[0]
693709

694-
if self.enable_alltoall:
710+
if use_all_to_all:
695711
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
696712
if self.enable_dummy_allreduce:
697713
self.dummy_allreduce()
@@ -759,22 +775,28 @@ def forward(
759775
else:
760776
all_rank_num_tokens_padded = all_rank_num_tokens
761777
if num_chunks == 1:
778+
use_all_to_all = self.can_use_alltoall(x)
779+
762780
is_first_call = self.repeat_idx == 0
763781
is_last_call = self.repeat_idx == self.repeat_count - 1
764782
outputs = self.forward_chunk(
765783
x,
766784
router_logits,
785+
use_all_to_all,
767786
output_dtype,
768787
all_rank_num_tokens=all_rank_num_tokens_padded,
769788
all_rank_max_num_tokens=all_rank_max_num_tokens,
770789
use_dp_padding=use_dp_padding,
771790
repeating_info=(is_first_call, is_last_call))
772791
outputs = self.reducescatter_or_allreduce(
773792
outputs,
793+
use_all_to_all,
774794
all_rank_num_tokens=all_rank_num_tokens_padded,
775795
use_dp_padding=use_dp_padding)
776796
else:
777797

798+
use_all_to_all = False
799+
778800
def split_chunk(split_token_num: int, split_num_chunks: int):
779801
val_div = split_token_num // split_num_chunks
780802
val_mod = split_token_num % split_num_chunks
@@ -804,7 +826,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
804826
x_list = x.split(chunk_size_list)
805827
router_logits_list = router_logits.split(chunk_size_list)
806828

807-
if not self.enable_alltoall:
829+
if not use_all_to_all:
808830
self.event_dict[EventType.Main].record()
809831
with torch.cuda.stream(self.aux_stream):
810832
self.event_dict[EventType.Main].wait()
@@ -815,12 +837,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
815837
zip(x_list, router_logits_list)):
816838
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
817839
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
818-
if not self.enable_alltoall:
840+
if not use_all_to_all:
819841
if idx_chunk % 2 == 0:
820842
with torch.cuda.stream(self.aux_stream):
821843
outputs = self.forward_chunk(
822844
x,
823845
router_logits,
846+
use_all_to_all,
824847
all_rank_num_tokens=all_rank_num_tokens_list[
825848
idx_chunk],
826849
all_rank_max_num_tokens=
@@ -830,13 +853,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
830853
if idx_chunk > 0:
831854
outputs_list[-1] = self.reducescatter_or_allreduce(
832855
outputs_list[-1],
856+
use_all_to_all,
833857
all_rank_num_tokens=all_rank_num_tokens_list[
834858
idx_chunk - 1],
835859
use_dp_padding=use_dp_padding)
836860
else:
837861
outputs = self.forward_chunk(
838862
x,
839863
router_logits,
864+
use_all_to_all,
840865
all_rank_num_tokens=all_rank_num_tokens_list[
841866
idx_chunk],
842867
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
@@ -846,29 +871,33 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
846871
with torch.cuda.stream(self.aux_stream):
847872
outputs_list[-1] = self.reducescatter_or_allreduce(
848873
outputs_list[-1],
874+
use_all_to_all,
849875
all_rank_num_tokens=all_rank_num_tokens_list[
850876
idx_chunk - 1],
851877
use_dp_padding=use_dp_padding)
852878
else:
853879
outputs = self.forward_chunk(
854880
x,
855881
router_logits,
882+
use_all_to_all,
856883
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
857884
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
858885
idx_chunk],
859886
repeating_info=(is_first_call, is_last_call))
860887

861888
outputs_list.append(outputs)
862-
if not self.enable_alltoall:
889+
if not use_all_to_all:
863890
if num_chunks % 2 == 0:
864891
outputs_list[-1] = self.reducescatter_or_allreduce(
865892
outputs_list[-1],
893+
use_all_to_all,
866894
all_rank_num_tokens=all_rank_num_tokens_list[-1],
867895
use_dp_padding=use_dp_padding)
868896
else:
869897
with torch.cuda.stream(self.aux_stream):
870898
outputs_list[-1] = self.reducescatter_or_allreduce(
871899
outputs_list[-1],
900+
use_all_to_all,
872901
all_rank_num_tokens=all_rank_num_tokens_list[-1],
873902
use_dp_padding=use_dp_padding)
874903
with torch.cuda.stream(self.aux_stream):

0 commit comments

Comments
 (0)