@@ -192,8 +192,12 @@ def __init__(
192192 model_config .mapping )
193193 self .deep_ep_buffer .reserve (hidden_size , dtype )
194194 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
195- self .deep_ep_max_num_tokens = min (model_config .max_num_tokens ,
196- self .moe_max_num_tokens )
195+ self .deep_ep_max_num_tokens = int (
196+ os .environ .get (
197+ "TRTLLM_DEEP_EP_TOKEN_LIMIT" ,
198+ str (
199+ min (model_config .max_num_tokens ,
200+ self .moe_max_num_tokens ))))
197201 self .deep_ep_buffer = buffer_pool .get_low_latency_buffer (
198202 model_config .mapping )
199203 self .deep_ep_buffer .reserve (self .deep_ep_max_num_tokens ,
@@ -274,6 +278,16 @@ def enable_alltoall(self):
274278 """
275279 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
276280
281+ def can_use_alltoall (self , input ):
282+ num_tokens = input .shape [0 ]
283+
284+ # For DeepEPLowLatency, check if tokens exceed the threshold
285+ if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
286+ and num_tokens > self .deep_ep_max_num_tokens ):
287+ return False
288+
289+ return self .enable_alltoall
290+
277291 def _get_quant_method (self ):
278292 if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
279293 exclude_kv_cache = True ):
@@ -316,11 +330,12 @@ def dummy_allreduce(self):
316330 def reducescatter_or_allreduce (
317331 self ,
318332 inputs ,
333+ use_all_to_all : bool ,
319334 all_rank_num_tokens : Optional [List [int ]] = None ,
320335 use_dp_padding : Optional [bool ] = None ,
321336 ):
322337 outputs = inputs
323- if not self . enable_alltoall :
338+ if not use_all_toall :
324339 if self .enable_dummy_allreduce :
325340 self .dummy_allreduce ()
326341 outputs = reducescatter (
@@ -334,6 +349,7 @@ def forward_chunk(
334349 self ,
335350 x : Union [torch .Tensor , Fp4QuantizedTensor ],
336351 router_logits : torch .Tensor ,
352+ use_all_to_all : bool ,
337353 output_dtype : Optional [torch .dtype ] = None ,
338354 all_rank_num_tokens : Optional [List [int ]] = None ,
339355 all_rank_max_num_tokens : Optional [int ] = None ,
@@ -412,7 +428,7 @@ def forward_chunk(
412428 ExpertStatistic .set_layer (self .layer_idx )
413429 ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
414430
415- if self . enable_alltoall :
431+ if use_all_to_all :
416432 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
417433 if self .enable_dummy_allreduce :
418434 self .dummy_allreduce ()
@@ -502,7 +518,7 @@ def forward_chunk(
502518 f"unsupported quantization mode: { self .quant_config .quant_mode } "
503519 )
504520
505- if use_allgather :
521+ if use_allgather and not use_all_to_all : :
506522 # using allgather case.
507523 if self .enable_dummy_allreduce :
508524 self .dummy_allreduce ()
@@ -636,7 +652,7 @@ def forward_chunk(
636652 f"Not available alltoall method type: { self .alltoall_method_type !r} "
637653 )
638654
639- if self . enable_alltoall :
655+ if use_all_to_all :
640656 # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
641657 # TODO: remove the adapter by changing APIs
642658 if self .alltoall_method_type == AlltoallMethodType .DeepEP :
@@ -676,7 +692,7 @@ def forward_chunk(
676692 ep_rank = ep_rank ,
677693 cluster_size = cluster_size ,
678694 cluster_rank = cluster_rank ,
679- enable_alltoall = self . enable_alltoall ,
695+ enable_alltoall = use_all_to_all ,
680696 use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
681697 use_w4a8_group_scaling = use_w4a8_group_scaling ,
682698 min_latency_mode = False ,
@@ -691,7 +707,7 @@ def forward_chunk(
691707 # Otherwise, the output should be unpacked as a single tensor.
692708 final_hidden_states = final_hidden_states [0 ]
693709
694- if self . enable_alltoall :
710+ if use_all_to_all :
695711 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
696712 if self .enable_dummy_allreduce :
697713 self .dummy_allreduce ()
@@ -759,22 +775,28 @@ def forward(
759775 else :
760776 all_rank_num_tokens_padded = all_rank_num_tokens
761777 if num_chunks == 1 :
778+ use_all_to_all = self .can_use_alltoall (x )
779+
762780 is_first_call = self .repeat_idx == 0
763781 is_last_call = self .repeat_idx == self .repeat_count - 1
764782 outputs = self .forward_chunk (
765783 x ,
766784 router_logits ,
785+ use_all_to_all ,
767786 output_dtype ,
768787 all_rank_num_tokens = all_rank_num_tokens_padded ,
769788 all_rank_max_num_tokens = all_rank_max_num_tokens ,
770789 use_dp_padding = use_dp_padding ,
771790 repeating_info = (is_first_call , is_last_call ))
772791 outputs = self .reducescatter_or_allreduce (
773792 outputs ,
793+ use_all_to_all ,
774794 all_rank_num_tokens = all_rank_num_tokens_padded ,
775795 use_dp_padding = use_dp_padding )
776796 else :
777797
798+ use_all_to_all = False
799+
778800 def split_chunk (split_token_num : int , split_num_chunks : int ):
779801 val_div = split_token_num // split_num_chunks
780802 val_mod = split_token_num % split_num_chunks
@@ -804,7 +826,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
804826 x_list = x .split (chunk_size_list )
805827 router_logits_list = router_logits .split (chunk_size_list )
806828
807- if not self . enable_alltoall :
829+ if not use_all_to_all :
808830 self .event_dict [EventType .Main ].record ()
809831 with torch .cuda .stream (self .aux_stream ):
810832 self .event_dict [EventType .Main ].wait ()
@@ -815,12 +837,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
815837 zip (x_list , router_logits_list )):
816838 is_first_call = idx_chunk == 0 and self .repeat_idx == 0
817839 is_last_call = idx_chunk == num_chunks - 1 and self .repeat_idx == self .repeat_count - 1
818- if not self . enable_alltoall :
840+ if not use_all_to_all :
819841 if idx_chunk % 2 == 0 :
820842 with torch .cuda .stream (self .aux_stream ):
821843 outputs = self .forward_chunk (
822844 x ,
823845 router_logits ,
846+ use_all_to_all ,
824847 all_rank_num_tokens = all_rank_num_tokens_list [
825848 idx_chunk ],
826849 all_rank_max_num_tokens =
@@ -830,13 +853,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
830853 if idx_chunk > 0 :
831854 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
832855 outputs_list [- 1 ],
856+ use_all_to_all ,
833857 all_rank_num_tokens = all_rank_num_tokens_list [
834858 idx_chunk - 1 ],
835859 use_dp_padding = use_dp_padding )
836860 else :
837861 outputs = self .forward_chunk (
838862 x ,
839863 router_logits ,
864+ use_all_to_all ,
840865 all_rank_num_tokens = all_rank_num_tokens_list [
841866 idx_chunk ],
842867 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
@@ -846,29 +871,33 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
846871 with torch .cuda .stream (self .aux_stream ):
847872 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
848873 outputs_list [- 1 ],
874+ use_all_to_all ,
849875 all_rank_num_tokens = all_rank_num_tokens_list [
850876 idx_chunk - 1 ],
851877 use_dp_padding = use_dp_padding )
852878 else :
853879 outputs = self .forward_chunk (
854880 x ,
855881 router_logits ,
882+ use_all_to_all ,
856883 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
857884 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
858885 idx_chunk ],
859886 repeating_info = (is_first_call , is_last_call ))
860887
861888 outputs_list .append (outputs )
862- if not self . enable_alltoall :
889+ if not use_all_to_all :
863890 if num_chunks % 2 == 0 :
864891 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
865892 outputs_list [- 1 ],
893+ use_all_to_all ,
866894 all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
867895 use_dp_padding = use_dp_padding )
868896 else :
869897 with torch .cuda .stream (self .aux_stream ):
870898 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
871899 outputs_list [- 1 ],
900+ use_all_to_all ,
872901 all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
873902 use_dp_padding = use_dp_padding )
874903 with torch .cuda .stream (self .aux_stream ):
0 commit comments