@@ -278,7 +278,16 @@ def enable_alltoall(self):
278
278
"""
279
279
return self .alltoall_method_type != AlltoallMethodType .NotEnabled
280
280
281
- def can_use_alltoall (self , input ):
281
+ def calculate_num_chunks (self , all_rank_num_tokens : List [int ]) -> int :
282
+ num_rows = sum (all_rank_num_tokens )
283
+ return (num_rows + self .moe_max_num_tokens -
284
+ 1 ) // self .moe_max_num_tokens
285
+
286
+ def can_use_alltoall (self , input , all_rank_num_tokens ):
287
+ # Disable alltoall when chunking is used
288
+ if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
289
+ return False
290
+
282
291
num_tokens = input .shape [0 ]
283
292
284
293
# For DeepEPLowLatency, check if tokens exceed the threshold
@@ -521,7 +530,7 @@ def forward_chunk(
521
530
f"unsupported quantization mode: { self .quant_config .quant_mode } "
522
531
)
523
532
524
- if use_allgather and not use_all_to_all :
533
+ if use_allgather :
525
534
# using allgather case.
526
535
if self .enable_dummy_allreduce :
527
536
self .dummy_allreduce ()
@@ -766,20 +775,17 @@ def forward(
766
775
) -> torch .Tensor :
767
776
assert all_rank_num_tokens is not None
768
777
assert use_dp_padding is not None
769
- num_rows = sum (all_rank_num_tokens )
770
778
771
779
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
772
- num_chunks = ( num_rows + self .moe_max_num_tokens -
773
- 1 ) // self .moe_max_num_tokens
780
+ num_chunks = self .calculate_num_chunks ( all_rank_num_tokens )
781
+ use_all_to_all = self .can_use_alltoall ( x , all_rank_num_tokens )
774
782
775
783
if use_dp_padding :
776
784
all_rank_num_tokens_padded = [all_rank_max_num_tokens
777
785
] * len (all_rank_num_tokens )
778
786
else :
779
787
all_rank_num_tokens_padded = all_rank_num_tokens
780
788
if num_chunks == 1 :
781
- use_all_to_all = self .can_use_alltoall (x )
782
-
783
789
is_first_call = self .repeat_idx == 0
784
790
is_last_call = self .repeat_idx == self .repeat_count - 1
785
791
outputs = self .forward_chunk (
@@ -798,8 +804,6 @@ def forward(
798
804
use_dp_padding = use_dp_padding )
799
805
else :
800
806
801
- use_all_to_all = False
802
-
803
807
def split_chunk (split_token_num : int , split_num_chunks : int ):
804
808
val_div = split_token_num // split_num_chunks
805
809
val_mod = split_token_num % split_num_chunks
0 commit comments