@@ -452,14 +452,15 @@ def is_post_quant_all2all_supported(self):
452452 return False
453453
454454 def forward_chunk (
455- self ,
456- x : Union [torch .Tensor , Fp4QuantizedTensor ],
457- router_logits : torch .Tensor ,
458- use_all_to_all : bool ,
459- output_dtype : Optional [torch .dtype ] = None ,
460- all_rank_num_tokens : Optional [List [int ]] = None ,
461- use_dp_padding : Optional [bool ] = None ,
462- repeating_info : Tuple = (True , True ),
455+ self ,
456+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
457+ router_logits : torch .Tensor ,
458+ use_all_to_all : bool ,
459+ output_dtype : Optional [torch .dtype ] = None ,
460+ all_rank_num_tokens : Optional [List [int ]] = None ,
461+ use_dp_padding : Optional [bool ] = None ,
462+ repeating_info : Tuple = (True , True ),
463+ alltoall_result_do_sum : bool = True ,
463464 ) -> torch .Tensor :
464465 all_rank_max_num_tokens = max (all_rank_num_tokens )
465466 if isinstance (x , Fp4QuantizedTensor ):
@@ -474,7 +475,7 @@ def forward_chunk(
474475 self .layer_load_balancer .start_wait_gpu_stage ()
475476
476477 if not use_all_to_all or self .alltoall_method_type != AlltoallMethodType .MNNVL :
477- pass
478+ alltoall_result_do_sum = True
478479
479480 weight_dtype = self .w3_w1_weight .dtype
480481
@@ -741,7 +742,8 @@ def forward_chunk(
741742 if self .enable_dummy_allreduce :
742743 self .dummy_allreduce ()
743744 final_hidden_states = self .alltoall_combine (
744- final_hidden_states , alltoall_info , token_count )
745+ final_hidden_states , alltoall_info , token_count ,
746+ alltoall_result_do_sum )
745747 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
746748 final_hidden_states = self .unpad_tensors (
747749 padded , final_hidden_states )
@@ -786,6 +788,7 @@ def forward_impl(
786788 output_dtype : Optional [torch .dtype ] = None ,
787789 all_rank_num_tokens : Optional [List [int ]] = None ,
788790 use_dp_padding : Optional [bool ] = None ,
791+ alltoall_result_do_sum : bool = True ,
789792 ** kwargs ,
790793 ) -> torch .Tensor :
791794 assert all_rank_num_tokens is not None
@@ -813,7 +816,8 @@ def forward_impl(
813816 output_dtype ,
814817 all_rank_num_tokens = all_rank_num_tokens_padded ,
815818 use_dp_padding = use_dp_padding ,
816- repeating_info = (is_first_call , is_last_call ))
819+ repeating_info = (is_first_call , is_last_call ),
820+ alltoall_result_do_sum = alltoall_result_do_sum )
817821 outputs = self .reducescatter_or_allreduce (
818822 outputs ,
819823 use_all_to_all ,
@@ -871,7 +875,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
871875 all_rank_num_tokens = all_rank_num_tokens_list [
872876 idx_chunk ],
873877 use_dp_padding = use_dp_padding ,
874- repeating_info = (is_first_call , is_last_call ))
878+ repeating_info = (is_first_call , is_last_call ),
879+ alltoall_result_do_sum = alltoall_result_do_sum )
875880 if idx_chunk > 0 :
876881 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
877882 outputs_list [- 1 ],
@@ -887,7 +892,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
887892 all_rank_num_tokens = all_rank_num_tokens_list [
888893 idx_chunk ],
889894 use_dp_padding = use_dp_padding ,
890- repeating_info = (is_first_call , is_last_call ))
895+ repeating_info = (is_first_call , is_last_call ),
896+ alltoall_result_do_sum = alltoall_result_do_sum )
891897 with torch .cuda .stream (self .aux_stream ):
892898 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
893899 outputs_list [- 1 ],
@@ -901,7 +907,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
901907 router_logits ,
902908 use_all_to_all ,
903909 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
904- repeating_info = (is_first_call , is_last_call ))
910+ repeating_info = (is_first_call , is_last_call ),
911+ alltoall_result_do_sum = alltoall_result_do_sum )
905912
906913 outputs_list .append (outputs )
907914 if not use_all_to_all :
@@ -957,7 +964,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
957964 return x , x_sf , token_selected_slots , token_final_scales
958965
959966 def alltoall_combine (self , final_hidden_states : torch .Tensor ,
960- alltoall_info : MoEAlltoallInfo , token_count : int ):
967+ alltoall_info : MoEAlltoallInfo , token_count : int ,
968+ alltoall_result_do_sum : bool ):
961969 top_k = self .routing_method .experts_per_token
962970 if isinstance (final_hidden_states , list ):
963971 final_hidden_states = final_hidden_states [0 ]
@@ -970,7 +978,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
970978 top_k = top_k ,
971979 token_count = token_count ,
972980 use_low_precision_combine = self .use_low_precision_combine ,
973- do_reduce = False )
981+ do_reduce = alltoall_result_do_sum )
974982
975983 return final_hidden_states
976984
0 commit comments