@@ -278,7 +278,16 @@ def enable_alltoall(self):
278278 """
279279 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
280280
281- def can_use_alltoall (self , input ):
281+ def calculate_num_chunks (self , all_rank_num_tokens : List [int ]) -> int :
282+ num_rows = sum (all_rank_num_tokens )
283+ return (num_rows + self .moe_max_num_tokens -
284+ 1 ) // self .moe_max_num_tokens
285+
286+ def can_use_alltoall (self , input , all_rank_num_tokens ):
287+ # Disable alltoall when chunking is used
288+ if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
289+ return False
290+
282291 num_tokens = input .shape [0 ]
283292
284293 # For DeepEPLowLatency, check if tokens exceed the threshold
@@ -521,7 +530,7 @@ def forward_chunk(
521530 f"unsupported quantization mode: { self .quant_config .quant_mode } "
522531 )
523532
524- if use_allgather and not use_all_to_all :
533+ if use_allgather :
525534 # using allgather case.
526535 if self .enable_dummy_allreduce :
527536 self .dummy_allreduce ()
@@ -766,20 +775,17 @@ def forward(
766775 ) -> torch .Tensor :
767776 assert all_rank_num_tokens is not None
768777 assert use_dp_padding is not None
769- num_rows = sum (all_rank_num_tokens )
770778
771779 # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
772- num_chunks = ( num_rows + self .moe_max_num_tokens -
773- 1 ) // self .moe_max_num_tokens
780+ num_chunks = self .calculate_num_chunks ( all_rank_num_tokens )
781+ use_all_to_all = self .can_use_alltoall ( x , all_rank_num_tokens )
774782
775783 if use_dp_padding :
776784 all_rank_num_tokens_padded = [all_rank_max_num_tokens
777785 ] * len (all_rank_num_tokens )
778786 else :
779787 all_rank_num_tokens_padded = all_rank_num_tokens
780788 if num_chunks == 1 :
781- use_all_to_all = self .can_use_alltoall (x )
782-
783789 is_first_call = self .repeat_idx == 0
784790 is_last_call = self .repeat_idx == self .repeat_count - 1
785791 outputs = self .forward_chunk (
@@ -798,8 +804,6 @@ def forward(
798804 use_dp_padding = use_dp_padding )
799805 else :
800806
801- use_all_to_all = False
802-
803807 def split_chunk (split_token_num : int , split_num_chunks : int ):
804808 val_div = split_token_num // split_num_chunks
805809 val_mod = split_token_num % split_num_chunks
0 commit comments