@@ -419,14 +419,15 @@ def reducescatter_or_allreduce(
419419 return outputs
420420
421421 def forward_chunk (
422- self ,
423- x : Union [torch .Tensor , Fp4QuantizedTensor ],
424- router_logits : torch .Tensor ,
425- use_all_to_all : bool ,
426- output_dtype : Optional [torch .dtype ] = None ,
427- all_rank_num_tokens : Optional [List [int ]] = None ,
428- use_dp_padding : Optional [bool ] = None ,
429- repeating_info : Tuple = (True , True ),
422+ self ,
423+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
424+ router_logits : torch .Tensor ,
425+ use_all_to_all : bool ,
426+ output_dtype : Optional [torch .dtype ] = None ,
427+ all_rank_num_tokens : Optional [List [int ]] = None ,
428+ use_dp_padding : Optional [bool ] = None ,
429+ repeating_info : Tuple = (True , True ),
430+ alltoall_result_do_sum : bool = True ,
430431 ) -> torch .Tensor :
431432 all_rank_max_num_tokens = max (all_rank_num_tokens )
432433 if isinstance (x , Fp4QuantizedTensor ):
@@ -441,7 +442,7 @@ def forward_chunk(
441442 self .layer_load_balancer .start_wait_gpu_stage ()
442443
443444 if not use_all_to_all or self .alltoall_method_type != AlltoallMethodType .MNNVL :
444- pass
445+ alltoall_result_do_sum = True
445446
446447 weight_dtype = self .w3_w1_weight .dtype
447448
@@ -706,7 +707,8 @@ def forward_chunk(
706707 if self .enable_dummy_allreduce :
707708 self .dummy_allreduce ()
708709 final_hidden_states = self .alltoall_combine (
709- final_hidden_states , alltoall_info , token_count )
710+ final_hidden_states , alltoall_info , token_count ,
711+ alltoall_result_do_sum )
710712 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
711713 final_hidden_states = self .unpad_tensors (
712714 padded , final_hidden_states )
@@ -751,6 +753,7 @@ def forward_impl(
751753 output_dtype : Optional [torch .dtype ] = None ,
752754 all_rank_num_tokens : Optional [List [int ]] = None ,
753755 use_dp_padding : Optional [bool ] = None ,
756+ alltoall_result_do_sum : bool = True ,
754757 ** kwargs ,
755758 ) -> torch .Tensor :
756759 assert all_rank_num_tokens is not None
@@ -778,7 +781,8 @@ def forward_impl(
778781 output_dtype ,
779782 all_rank_num_tokens = all_rank_num_tokens_padded ,
780783 use_dp_padding = use_dp_padding ,
781- repeating_info = (is_first_call , is_last_call ))
784+ repeating_info = (is_first_call , is_last_call ),
785+ alltoall_result_do_sum = alltoall_result_do_sum )
782786 outputs = self .reducescatter_or_allreduce (
783787 outputs ,
784788 use_all_to_all ,
@@ -836,7 +840,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
836840 all_rank_num_tokens = all_rank_num_tokens_list [
837841 idx_chunk ],
838842 use_dp_padding = use_dp_padding ,
839- repeating_info = (is_first_call , is_last_call ))
843+ repeating_info = (is_first_call , is_last_call ),
844+ alltoall_result_do_sum = alltoall_result_do_sum )
840845 if idx_chunk > 0 :
841846 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
842847 outputs_list [- 1 ],
@@ -852,7 +857,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
852857 all_rank_num_tokens = all_rank_num_tokens_list [
853858 idx_chunk ],
854859 use_dp_padding = use_dp_padding ,
855- repeating_info = (is_first_call , is_last_call ))
860+ repeating_info = (is_first_call , is_last_call ),
861+ alltoall_result_do_sum = alltoall_result_do_sum )
856862 with torch .cuda .stream (self .aux_stream ):
857863 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
858864 outputs_list [- 1 ],
@@ -866,7 +872,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
866872 router_logits ,
867873 use_all_to_all ,
868874 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
869- repeating_info = (is_first_call , is_last_call ))
875+ repeating_info = (is_first_call , is_last_call ),
876+ alltoall_result_do_sum = alltoall_result_do_sum )
870877
871878 outputs_list .append (outputs )
872879 if not use_all_to_all :
@@ -922,7 +929,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
922929 return x , x_sf , token_selected_slots , token_final_scales
923930
924931 def alltoall_combine (self , final_hidden_states : torch .Tensor ,
925- alltoall_info : MoEAlltoallInfo , token_count : int ):
932+ alltoall_info : MoEAlltoallInfo , token_count : int ,
933+ alltoall_result_do_sum : bool ):
926934 top_k = self .routing_method .experts_per_token
927935 if isinstance (final_hidden_states , list ):
928936 final_hidden_states = final_hidden_states [0 ]
@@ -935,7 +943,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
935943 top_k = top_k ,
936944 token_count = token_count ,
937945 use_low_precision_combine = self .use_low_precision_combine ,
938- do_reduce = False )
946+ do_reduce = alltoall_result_do_sum )
939947
940948 return final_hidden_states
941949
0 commit comments