@@ -193,8 +193,12 @@ def __init__(
193
193
model_config .mapping )
194
194
self .deep_ep_buffer .reserve (hidden_size , dtype )
195
195
elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
196
- self .deep_ep_max_num_tokens = min (model_config .max_num_tokens ,
197
- self .moe_max_num_tokens )
196
+ self .deep_ep_max_num_tokens = int (
197
+ os .environ .get (
198
+ "TRTLLM_DEEP_EP_TOKEN_LIMIT" ,
199
+ str (
200
+ min (model_config .max_num_tokens ,
201
+ self .moe_max_num_tokens ))))
198
202
self .deep_ep_buffer = buffer_pool .get_low_latency_buffer (
199
203
model_config .mapping )
200
204
self .deep_ep_buffer .reserve (self .deep_ep_max_num_tokens ,
@@ -277,6 +281,16 @@ def enable_alltoall(self):
277
281
"""
278
282
return self .alltoall_method_type != AlltoallMethodType .NotEnabled
279
283
284
+ def can_use_alltoall (self , input ):
285
+ num_tokens = input .shape [0 ]
286
+
287
+ # For DeepEPLowLatency, check if tokens exceed the threshold
288
+ if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
289
+ and num_tokens > self .deep_ep_max_num_tokens ):
290
+ return False
291
+
292
+ return self .enable_alltoall
293
+
280
294
def _get_quant_method (self ):
281
295
if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
282
296
exclude_kv_cache = True ):
@@ -319,11 +333,12 @@ def dummy_allreduce(self):
319
333
def reducescatter_or_allreduce (
320
334
self ,
321
335
inputs ,
336
+ use_all_to_all : bool ,
322
337
all_rank_num_tokens : Optional [List [int ]] = None ,
323
338
use_dp_padding : Optional [bool ] = None ,
324
339
):
325
340
outputs = inputs
326
- if self .parallel_size > 1 and not self . enable_alltoall :
341
+ if self .parallel_size > 1 and not use_all_to_all :
327
342
if self .use_dp :
328
343
if self .enable_dummy_allreduce :
329
344
self .dummy_allreduce ()
@@ -340,6 +355,7 @@ def forward_chunk(
340
355
self ,
341
356
x : Union [torch .Tensor , Fp4QuantizedTensor ],
342
357
router_logits : torch .Tensor ,
358
+ use_all_to_all : bool ,
343
359
cutlass_min_latency_mode : bool = False ,
344
360
output_dtype : Optional [torch .dtype ] = None ,
345
361
all_rank_num_tokens : Optional [List [int ]] = None ,
@@ -412,7 +428,7 @@ def forward_chunk(
412
428
ExpertStatistic .set_layer (self .layer_idx )
413
429
ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
414
430
415
- if self . enable_alltoall :
431
+ if use_all_to_all :
416
432
if self .alltoall_method_type == AlltoallMethodType .MNNVL :
417
433
if self .enable_dummy_allreduce :
418
434
self .dummy_allreduce ()
@@ -499,7 +515,7 @@ def forward_chunk(
499
515
)
500
516
501
517
if self .use_dp and self .parallel_size > 1 and not disable_fp4_allgather (
502
- ) and not self . enable_alltoall :
518
+ ) and not use_all_to_all :
503
519
if self .enable_dummy_allreduce :
504
520
self .dummy_allreduce ()
505
521
x , x_sf , token_selected_slots , token_final_scales , gathered_token_selected_experts_for_statistic = allgather (
@@ -588,7 +604,7 @@ def forward_chunk(
588
604
f"Not available alltoall method type: { self .alltoall_method_type !r} "
589
605
)
590
606
591
- if self . enable_alltoall :
607
+ if use_all_to_all :
592
608
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
593
609
# TODO: remove the adapter by changing APIs
594
610
if self .alltoall_method_type == AlltoallMethodType .DeepEP :
@@ -628,7 +644,7 @@ def forward_chunk(
628
644
ep_rank = ep_rank ,
629
645
cluster_size = cluster_size ,
630
646
cluster_rank = cluster_rank ,
631
- enable_alltoall = self . enable_alltoall ,
647
+ enable_alltoall = use_all_to_all ,
632
648
use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
633
649
use_w4a8_group_scaling = use_w4a8_group_scaling ,
634
650
min_latency_mode = cutlass_min_latency_mode ,
@@ -648,7 +664,7 @@ def forward_chunk(
648
664
# Otherwise, the output should be unpacked as a single tensor.
649
665
final_hidden_states = final_hidden_states [0 ]
650
666
651
- if self . enable_alltoall :
667
+ if use_all_to_all :
652
668
if self .alltoall_method_type == AlltoallMethodType .MNNVL :
653
669
if self .enable_dummy_allreduce :
654
670
self .dummy_allreduce ()
@@ -726,11 +742,14 @@ def forward(
726
742
else :
727
743
all_rank_num_tokens_padded = all_rank_num_tokens
728
744
if num_chunks == 1 :
745
+ use_all_to_all = self .can_use_alltoall (x )
746
+
729
747
is_first_call = self .repeat_idx == 0
730
748
is_last_call = self .repeat_idx == self .repeat_count - 1
731
749
outputs = self .forward_chunk (
732
750
x ,
733
751
router_logits ,
752
+ use_all_to_all ,
734
753
cutlass_min_latency_mode ,
735
754
output_dtype ,
736
755
all_rank_num_tokens = all_rank_num_tokens_padded ,
@@ -739,10 +758,13 @@ def forward(
739
758
repeating_info = (is_first_call , is_last_call ))
740
759
outputs = self .reducescatter_or_allreduce (
741
760
outputs ,
761
+ use_all_to_all ,
742
762
all_rank_num_tokens = all_rank_num_tokens_padded ,
743
763
use_dp_padding = use_dp_padding )
744
764
else :
745
765
766
+ use_all_to_all = False
767
+
746
768
def split_chunk (split_token_num : int , split_num_chunks : int ):
747
769
val_div = split_token_num // split_num_chunks
748
770
val_mod = split_token_num % split_num_chunks
@@ -761,7 +783,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
761
783
all_rank_max_num_tokens_list = split_chunk (
762
784
all_rank_max_num_tokens , num_chunks )
763
785
chunk_size_list = all_rank_chunk_size_list [self .rank ]
764
- if self . enable_alltoall :
786
+ if use_all_to_all :
765
787
all_rank_num_tokens_list = [[
766
788
1 if val == 0 else val for val in val_list
767
789
] for val_list in all_rank_num_tokens_list ]
@@ -777,7 +799,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
777
799
x_list = x .split (chunk_size_list )
778
800
router_logits_list = router_logits .split (chunk_size_list )
779
801
780
- if not self . enable_alltoall :
802
+ if not use_all_to_all :
781
803
self .event_dict [EventType .Main ].record ()
782
804
with torch .cuda .stream (self .aux_stream ):
783
805
self .event_dict [EventType .Main ].wait ()
@@ -788,12 +810,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
788
810
zip (x_list , router_logits_list )):
789
811
is_first_call = idx_chunk == 0 and self .repeat_idx == 0
790
812
is_last_call = idx_chunk == num_chunks - 1 and self .repeat_idx == self .repeat_count - 1
791
- if not self . enable_alltoall :
813
+ if not use_all_to_all :
792
814
if idx_chunk % 2 == 0 :
793
815
with torch .cuda .stream (self .aux_stream ):
794
816
outputs = self .forward_chunk (
795
817
x ,
796
818
router_logits ,
819
+ use_all_to_all ,
797
820
all_rank_num_tokens = all_rank_num_tokens_list [
798
821
idx_chunk ] if self .use_dp else None ,
799
822
all_rank_max_num_tokens =
@@ -804,13 +827,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
804
827
if idx_chunk > 0 :
805
828
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
806
829
outputs_list [- 1 ],
830
+ use_all_to_all ,
807
831
all_rank_num_tokens = all_rank_num_tokens_list [
808
832
idx_chunk - 1 ],
809
833
use_dp_padding = use_dp_padding )
810
834
else :
811
835
outputs = self .forward_chunk (
812
836
x ,
813
837
router_logits ,
838
+ use_all_to_all ,
814
839
all_rank_num_tokens = all_rank_num_tokens_list [
815
840
idx_chunk ] if self .use_dp else None ,
816
841
all_rank_max_num_tokens = all_rank_max_num_tokens_list [
@@ -820,30 +845,34 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
820
845
with torch .cuda .stream (self .aux_stream ):
821
846
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
822
847
outputs_list [- 1 ],
848
+ use_all_to_all ,
823
849
all_rank_num_tokens = all_rank_num_tokens_list [
824
850
idx_chunk - 1 ],
825
851
use_dp_padding = use_dp_padding )
826
852
else :
827
853
outputs = self .forward_chunk (
828
854
x ,
829
855
router_logits ,
856
+ use_all_to_all ,
830
857
all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ]
831
858
if self .use_dp else None ,
832
859
all_rank_max_num_tokens = all_rank_max_num_tokens_list [
833
860
idx_chunk ] if self .use_dp else None ,
834
861
repeating_info = (is_first_call , is_last_call ))
835
862
836
863
outputs_list .append (outputs )
837
- if not self . enable_alltoall :
864
+ if not use_all_to_all :
838
865
if num_chunks % 2 == 0 :
839
866
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
840
867
outputs_list [- 1 ],
868
+ use_all_to_all ,
841
869
all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
842
870
use_dp_padding = use_dp_padding )
843
871
else :
844
872
with torch .cuda .stream (self .aux_stream ):
845
873
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
846
874
outputs_list [- 1 ],
875
+ use_all_to_all ,
847
876
all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
848
877
use_dp_padding = use_dp_padding )
849
878
with torch .cuda .stream (self .aux_stream ):
0 commit comments