@@ -277,7 +277,16 @@ def enable_alltoall(self):
277277 """
278278 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
279279
280- def can_use_alltoall (self , input ):
280+ def calculate_num_chunks (self , all_rank_num_tokens : List [int ]) -> int :
281+ num_rows = sum (all_rank_num_tokens )
282+ return (num_rows + self .moe_max_num_tokens -
283+ 1 ) // self .moe_max_num_tokens
284+
285+ def can_use_alltoall (self , input , all_rank_num_tokens ):
286+ # Disable alltoall when chunking is used
287+ if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
288+ return False
289+
281290 num_tokens = input .shape [0 ]
282291
283292 # For DeepEPLowLatency, check if tokens exceed the threshold
@@ -507,7 +516,7 @@ def forward_chunk(
507516 f"unsupported quantization mode: { self .quant_config .quant_mode } "
508517 )
509518
510- if use_allgather and not use_all_to_all :
519+ if use_allgather :
511520 # using allgather case.
512521 if self .enable_dummy_allreduce :
513522 self .dummy_allreduce ()
@@ -752,20 +761,17 @@ def forward(
752761 ) -> torch .Tensor :
753762 assert all_rank_num_tokens is not None
754763 assert use_dp_padding is not None
755- num_rows = sum (all_rank_num_tokens )
756764
757765 # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
758- num_chunks = ( num_rows + self .moe_max_num_tokens -
759- 1 ) // self .moe_max_num_tokens
766+ num_chunks = self .calculate_num_chunks ( all_rank_num_tokens )
767+ use_all_to_all = self .can_use_alltoall ( x , all_rank_num_tokens )
760768
761769 if use_dp_padding :
762770 all_rank_num_tokens_padded = [all_rank_max_num_tokens
763771 ] * len (all_rank_num_tokens )
764772 else :
765773 all_rank_num_tokens_padded = all_rank_num_tokens
766774 if num_chunks == 1 :
767- use_all_to_all = self .can_use_alltoall (x )
768-
769775 is_first_call = self .repeat_idx == 0
770776 is_last_call = self .repeat_idx == self .repeat_count - 1
771777 outputs = self .forward_chunk (
@@ -784,8 +790,6 @@ def forward(
784790 use_dp_padding = use_dp_padding )
785791 else :
786792
787- use_all_to_all = False
788-
789793 def split_chunk (split_token_num : int , split_num_chunks : int ):
790794 val_div = split_token_num // split_num_chunks
791795 val_mod = split_token_num % split_num_chunks
0 commit comments