@@ -277,7 +277,16 @@ def enable_alltoall(self):
277
277
"""
278
278
return self .alltoall_method_type != AlltoallMethodType .NotEnabled
279
279
280
- def can_use_alltoall (self , input ):
280
+ def calculate_num_chunks (self , all_rank_num_tokens : List [int ]) -> int :
281
+ num_rows = sum (all_rank_num_tokens )
282
+ return (num_rows + self .moe_max_num_tokens -
283
+ 1 ) // self .moe_max_num_tokens
284
+
285
+ def can_use_alltoall (self , input , all_rank_num_tokens ):
286
+ # Disable alltoall when chunking is used
287
+ if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
288
+ return False
289
+
281
290
num_tokens = input .shape [0 ]
282
291
283
292
# For DeepEPLowLatency, check if tokens exceed the threshold
@@ -507,7 +516,7 @@ def forward_chunk(
507
516
f"unsupported quantization mode: { self .quant_config .quant_mode } "
508
517
)
509
518
510
- if use_allgather and not use_all_to_all :
519
+ if use_allgather :
511
520
# using allgather case.
512
521
if self .enable_dummy_allreduce :
513
522
self .dummy_allreduce ()
@@ -752,20 +761,17 @@ def forward(
752
761
) -> torch .Tensor :
753
762
assert all_rank_num_tokens is not None
754
763
assert use_dp_padding is not None
755
- num_rows = sum (all_rank_num_tokens )
756
764
757
765
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
758
- num_chunks = ( num_rows + self .moe_max_num_tokens -
759
- 1 ) // self .moe_max_num_tokens
766
+ num_chunks = self .calculate_num_chunks ( all_rank_num_tokens )
767
+ use_all_to_all = self .can_use_alltoall ( x , all_rank_num_tokens )
760
768
761
769
if use_dp_padding :
762
770
all_rank_num_tokens_padded = [all_rank_max_num_tokens
763
771
] * len (all_rank_num_tokens )
764
772
else :
765
773
all_rank_num_tokens_padded = all_rank_num_tokens
766
774
if num_chunks == 1 :
767
- use_all_to_all = self .can_use_alltoall (x )
768
-
769
775
is_first_call = self .repeat_idx == 0
770
776
is_last_call = self .repeat_idx == self .repeat_count - 1
771
777
outputs = self .forward_chunk (
@@ -784,8 +790,6 @@ def forward(
784
790
use_dp_padding = use_dp_padding )
785
791
else :
786
792
787
- use_all_to_all = False
788
-
789
793
def split_chunk (split_token_num : int , split_num_chunks : int ):
790
794
val_div = split_token_num // split_num_chunks
791
795
val_mod = split_token_num % split_num_chunks
0 commit comments