@@ -193,8 +193,12 @@ def __init__(
193193 model_config .mapping )
194194 self .deep_ep_buffer .reserve (hidden_size , dtype )
195195 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
196- self .deep_ep_max_num_tokens = min (model_config .max_num_tokens ,
197- self .moe_max_num_tokens )
196+ self .deep_ep_max_num_tokens = int (
197+ os .environ .get (
198+ "TRTLLM_DEEP_EP_TOKEN_LIMIT" ,
199+ str (
200+ min (model_config .max_num_tokens ,
201+ self .moe_max_num_tokens ))))
198202 self .deep_ep_buffer = buffer_pool .get_low_latency_buffer (
199203 model_config .mapping )
200204 self .deep_ep_buffer .reserve (self .deep_ep_max_num_tokens ,
@@ -273,6 +277,10 @@ def enable_alltoall(self):
273277 """
274278 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
275279
280+ def can_use_alltoall (self , input ):
281+ num_tokens = input .shape [0 ]
282+ return self .enable_alltoall and num_tokens <= self .deep_ep_max_num_tokens
283+
276284 def _get_quant_method (self ):
277285 if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
278286 exclude_kv_cache = True ):
@@ -305,11 +313,12 @@ def create_weights(self):
305313 def reducescatter_or_allreduce (
306314 self ,
307315 inputs ,
316+ use_all_to_all : bool ,
308317 all_rank_num_tokens : Optional [List [int ]] = None ,
309318 use_dp_padding : Optional [bool ] = None ,
310319 ):
311320 outputs = inputs
312- if self .parallel_size > 1 and not self . enable_alltoall :
321+ if self .parallel_size > 1 and not use_all_to_all :
313322 if self .use_dp :
314323 outputs = reducescatter (
315324 inputs ,
@@ -324,6 +333,7 @@ def forward_chunk(
324333 self ,
325334 x : Union [torch .Tensor , Fp4QuantizedTensor ],
326335 router_logits : torch .Tensor ,
336+ use_all_to_all : bool ,
327337 cutlass_min_latency_mode : bool = False ,
328338 output_dtype : Optional [torch .dtype ] = None ,
329339 all_rank_num_tokens : Optional [List [int ]] = None ,
@@ -396,7 +406,7 @@ def forward_chunk(
396406 ExpertStatistic .set_layer (self .layer_idx )
397407 ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
398408
399- if self . enable_alltoall :
409+ if use_all_to_all :
400410 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
401411 token_count = x .shape [0 ]
402412 alltoall_info = None
@@ -483,7 +493,7 @@ def forward_chunk(
483493 )
484494
485495 if self .use_dp and self .parallel_size > 1 and not disable_fp4_allgather (
486- ) and not self . enable_alltoall :
496+ ) and not use_all_to_all :
487497 x , x_sf , token_selected_slots , token_final_scales , gathered_token_selected_experts_for_statistic = allgather (
488498 [
489499 x ,
@@ -570,7 +580,7 @@ def forward_chunk(
570580 f"Not available alltoall method type: { self .alltoall_method_type !r} "
571581 )
572582
573- if self . enable_alltoall :
583+ if use_all_to_all :
574584 # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
575585 # TODO: remove the adapter by changing APIs
576586 if self .alltoall_method_type == AlltoallMethodType .DeepEP :
@@ -610,7 +620,7 @@ def forward_chunk(
610620 ep_rank = ep_rank ,
611621 cluster_size = cluster_size ,
612622 cluster_rank = cluster_rank ,
613- enable_alltoall = self . enable_alltoall ,
623+ enable_alltoall = use_all_to_all ,
614624 use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
615625 use_w4a8_group_scaling = use_w4a8_group_scaling ,
616626 min_latency_mode = cutlass_min_latency_mode ,
@@ -630,7 +640,7 @@ def forward_chunk(
630640 # Otherwise, the output should be unpacked as a single tensor.
631641 final_hidden_states = final_hidden_states [0 ]
632642
633- if self . enable_alltoall :
643+ if use_all_to_all :
634644 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
635645 final_hidden_states = self .alltoall_combine (
636646 final_hidden_states , alltoall_info , token_count )
@@ -691,11 +701,14 @@ def forward(
691701 else :
692702 all_rank_num_tokens_padded = all_rank_num_tokens
693703 if num_chunks == 1 :
704+ use_all_to_all = self .can_use_alltoall (x )
705+
694706 is_first_call = self .repeat_idx == 0
695707 is_last_call = self .repeat_idx == self .repeat_count - 1
696708 outputs = self .forward_chunk (
697709 x ,
698710 router_logits ,
711+ use_all_to_all ,
699712 cutlass_min_latency_mode ,
700713 output_dtype ,
701714 all_rank_num_tokens = all_rank_num_tokens_padded ,
@@ -704,10 +717,13 @@ def forward(
704717 repeating_info = (is_first_call , is_last_call ))
705718 outputs = self .reducescatter_or_allreduce (
706719 outputs ,
720+ use_all_to_all ,
707721 all_rank_num_tokens = all_rank_num_tokens_padded ,
708722 use_dp_padding = use_dp_padding )
709723 else :
710724
725+ use_all_to_all = False
726+
711727 def split_chunk (split_token_num : int , split_num_chunks : int ):
712728 val_div = split_token_num // split_num_chunks
713729 val_mod = split_token_num % split_num_chunks
@@ -726,7 +742,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
726742 all_rank_max_num_tokens_list = split_chunk (
727743 all_rank_max_num_tokens , num_chunks )
728744 chunk_size_list = all_rank_chunk_size_list [self .rank ]
729- if self . enable_alltoall :
745+ if use_all_to_all :
730746 all_rank_num_tokens_list = [[
731747 1 if val == 0 else val for val in val_list
732748 ] for val_list in all_rank_num_tokens_list ]
@@ -742,7 +758,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
742758 x_list = x .split (chunk_size_list )
743759 router_logits_list = router_logits .split (chunk_size_list )
744760
745- if not self . enable_alltoall :
761+ if not use_all_to_all :
746762 self .event_dict [EventType .Main ].record ()
747763 with torch .cuda .stream (self .aux_stream ):
748764 self .event_dict [EventType .Main ].wait ()
@@ -753,12 +769,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
753769 zip (x_list , router_logits_list )):
754770 is_first_call = idx_chunk == 0 and self .repeat_idx == 0
755771 is_last_call = idx_chunk == num_chunks - 1 and self .repeat_idx == self .repeat_count - 1
756- if not self . enable_alltoall :
772+ if not use_all_to_all :
757773 if idx_chunk % 2 == 0 :
758774 with torch .cuda .stream (self .aux_stream ):
759775 outputs = self .forward_chunk (
760776 x ,
761777 router_logits ,
778+ use_all_to_all ,
762779 all_rank_num_tokens = all_rank_num_tokens_list [
763780 idx_chunk ] if self .use_dp else None ,
764781 all_rank_max_num_tokens =
@@ -769,13 +786,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
769786 if idx_chunk > 0 :
770787 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
771788 outputs_list [- 1 ],
789+ use_all_to_all ,
772790 all_rank_num_tokens = all_rank_num_tokens_list [
773791 idx_chunk - 1 ],
774792 use_dp_padding = use_dp_padding )
775793 else :
776794 outputs = self .forward_chunk (
777795 x ,
778796 router_logits ,
797+ use_all_to_all ,
779798 all_rank_num_tokens = all_rank_num_tokens_list [
780799 idx_chunk ] if self .use_dp else None ,
781800 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
@@ -785,30 +804,34 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
785804 with torch .cuda .stream (self .aux_stream ):
786805 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
787806 outputs_list [- 1 ],
807+ use_all_to_all ,
788808 all_rank_num_tokens = all_rank_num_tokens_list [
789809 idx_chunk - 1 ],
790810 use_dp_padding = use_dp_padding )
791811 else :
792812 outputs = self .forward_chunk (
793813 x ,
794814 router_logits ,
815+ use_all_to_all ,
795816 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ]
796817 if self .use_dp else None ,
797818 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
798819 idx_chunk ] if self .use_dp else None ,
799820 repeating_info = (is_first_call , is_last_call ))
800821
801822 outputs_list .append (outputs )
802- if not self . enable_alltoall :
823+ if not use_all_to_all :
803824 if num_chunks % 2 == 0 :
804825 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
805826 outputs_list [- 1 ],
827+ use_all_to_all ,
806828 all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
807829 use_dp_padding = use_dp_padding )
808830 else :
809831 with torch .cuda .stream (self .aux_stream ):
810832 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
811833 outputs_list [- 1 ],
834+ use_all_to_all ,
812835 all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
813836 use_dp_padding = use_dp_padding )
814837 with torch .cuda .stream (self .aux_stream ):
0 commit comments