@@ -308,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
308308    def  can_use_alltoall (self , all_rank_num_tokens , all_rank_max_num_tokens ):
309309        # Disable alltoall when chunking is used 
310310        if  self .calculate_num_chunks (all_rank_num_tokens ) >  1 :
311+             print (
312+                 f"can not use alltoall due to chunking { self .calculate_num_chunks (all_rank_num_tokens )}  " 
313+             )
311314            return  False 
312315
313316        # For DeepEPLowLatency, check if tokens exceed the threshold 
314317        if  (self .alltoall_method_type  ==  AlltoallMethodType .DeepEPLowLatency 
315318                and  all_rank_max_num_tokens  >  self .deep_ep_max_num_tokens ):
319+             print (
320+                 f"can not use alltoall due to deep_ep_max_num_tokens { all_rank_max_num_tokens }   > { self .deep_ep_max_num_tokens }  " 
321+             )
316322            return  False 
317323
324+         print (f"all to all type { self .alltoall_method_type }  " )
318325        return  self .enable_alltoall 
319326
320327    def  _get_quant_method (self ):
@@ -323,9 +330,18 @@ def _get_quant_method(self):
323330            if  self .quant_config .layer_quant_mode .has_fp8_qdq ():
324331                return  FP8QDQFusedMoEMethod ()
325332            elif  self .quant_config .layer_quant_mode .has_fp8_block_scales ():
333+                 print (
334+                     f"wide_ep _get_quant_method: get_sm_version()={ get_sm_version ()}  " 
335+                 )
326336                if  get_sm_version () ==  100 :
337+                     print (
338+                         f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm" 
339+                     )
327340                    return  DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ()
328341                else :
342+                     print (
343+                         f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod" 
344+                     )
329345                    return  DeepSeekFP8BlockScalesFusedMoEMethod ()
330346            elif  self .quant_config .layer_quant_mode .has_nvfp4 ():
331347                return  NVFP4CutlassFusedMoEMethod ()
@@ -399,6 +415,10 @@ def forward_chunk(
399415
400416        is_first_call , is_last_call  =  repeating_info 
401417
418+         # print( 
419+         #     f"xxi shape 1: enter wide_ep forward_chunk: layer_load_balancer={self.layer_load_balancer}, is_first_call={is_first_call}, is_last_call={is_last_call}, x shape: {getattr(x, 'shape', None)}, router_logits shape: {getattr(router_logits, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, repeating_info: {repeating_info}" 
420+         # ) 
421+ 
402422        if  self .layer_load_balancer  and  is_first_call :
403423            self .layer_load_balancer .start_wait_gpu_stage ()
404424
@@ -475,7 +495,7 @@ def forward_chunk(
475495                    self .dummy_allreduce ()
476496                token_count  =  x .shape [0 ]
477497                alltoall_info  =  None 
478-                 if  is_last_call :
498+                 if  self . layer_load_balancer   and   is_last_call :
479499                    loadbalancer_local_statistic_info  =  self .layer_load_balancer .get_local_statistic_tensor (
480500                    )
481501                else :
@@ -650,7 +670,35 @@ def forward_chunk(
650670                )
651671
652672        # Original fused_moe call (preserved as reference) 
653-         final_hidden_states  =  torch .ops .trtllm .fused_moe (
673+         # final_hidden_states = torch.ops.trtllm.fused_moe( 
674+         #     x, 
675+         #     token_selected_slots, 
676+         #     token_final_scales, 
677+         #     w3_w1_weight.view(weight_dtype), 
678+         #     None,  # w3_w1_bias 
679+         #     w2_weight.view(weight_dtype), 
680+         #     None,  # w2_bias 
681+         #     output_dtype, 
682+         #     quant_scales=quant_scales, 
683+         #     input_sf=x_sf, 
684+         #     swizzled_input_sf=False, 
685+         #     tp_size=self.tp_size, 
686+         #     tp_rank=self.tp_rank, 
687+         #     ep_size=ep_size, 
688+         #     ep_rank=ep_rank, 
689+         #     cluster_size=cluster_size, 
690+         #     cluster_rank=cluster_rank, 
691+         #     enable_alltoall=use_all_to_all, 
692+         #     use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, 
693+         #     use_w4_group_scaling=use_w4_group_scaling, 
694+         #     min_latency_mode=False, 
695+         #     tune_max_num_tokens=self.tune_max_num_tokens, 
696+         #     tuner_num_tokens=tuner_num_tokens, 
697+         #     tuner_top_k=tuner_top_k, 
698+         # ) 
699+ 
700+         # Use the selected backend to compute MoE with the same parameters as fused_moe 
701+         final_hidden_states  =  self .moe_backend .run_moe (
654702            x ,
655703            token_selected_slots ,
656704            token_final_scales ,
@@ -675,35 +723,12 @@ def forward_chunk(
675723            tune_max_num_tokens = self .tune_max_num_tokens ,
676724            tuner_num_tokens = tuner_num_tokens ,
677725            tuner_top_k = tuner_top_k ,
726+             module = 
727+             self ,  # Additional parameter for backend to access module properties 
678728        )
679729
680-         # Use the selected backend to compute MoE with the same parameters as fused_moe 
681-         # final_hidden_states = self.moe_backend.run_moe( 
682-         #     x, 
683-         #     token_selected_slots, 
684-         #     token_final_scales, 
685-         #     w3_w1_weight.view(weight_dtype), 
686-         #     None,  # w3_w1_bias 
687-         #     w2_weight.view(weight_dtype), 
688-         #     None,  # w2_bias 
689-         #     output_dtype, 
690-         #     quant_scales=quant_scales, 
691-         #     input_sf=x_sf, 
692-         #     swizzled_input_sf=False, 
693-         #     tp_size=self.tp_size, 
694-         #     tp_rank=self.tp_rank, 
695-         #     ep_size=ep_size, 
696-         #     ep_rank=ep_rank, 
697-         #     cluster_size=cluster_size, 
698-         #     cluster_rank=cluster_rank, 
699-         #     enable_alltoall=use_all_to_all, 
700-         #     use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, 
701-         #     use_w4_group_scaling=use_w4_group_scaling, 
702-         #     min_latency_mode=False, 
703-         #     tune_max_num_tokens=self.tune_max_num_tokens, 
704-         #     tuner_num_tokens=tuner_num_tokens, 
705-         #     tuner_top_k=tuner_top_k, 
706-         #     module=self,  # Additional parameter for backend to access module properties 
730+         # print( 
731+         #     f"xxi shape 4 after moe backend : {getattr(x, 'shape', None)}, final_hidden_states shape: {getattr(final_hidden_states, 'shape', None)}, token_selected_slots shape: {getattr(token_selected_slots, 'shape', None)}, token_final_scales shape: {getattr(token_final_scales, 'shape', None)}, w3_w1_weight shape: {getattr(w3_w1_weight, 'shape', None)}, w2_weight shape: {getattr(w2_weight, 'shape', None)}, quant_scales: {getattr(quant_scales, 'shape', None)}, input_sf: {getattr(x_sf, 'shape', None)}, swizzled_input_sf: False, tp_size: {self.tp_size}, tp_rank: {self.tp_rank}, ep_size: {ep_size}, ep_rank: {ep_rank}, cluster_size: {cluster_size}, cluster_rank: {cluster_rank}, enable_alltoall: {use_all_to_all}, use_deepseek_fp8_block_scale: {use_deepseek_fp8_block_scale}, use_w4_group_scaling: {use_w4_group_scaling}, min_latency_mode: False, tune_max_num_tokens: {self.tune_max_num_tokens}, tuner_num_tokens: {tuner_num_tokens}, tuner_top_k: {tuner_top_k}" 
707732        # ) 
708733
709734        if  self .layer_load_balancer  and  is_last_call :
@@ -784,6 +809,10 @@ def forward(
784809                all_rank_max_num_tokens = all_rank_max_num_tokens ,
785810                use_dp_padding = use_dp_padding ,
786811                repeating_info = (is_first_call , is_last_call ))
812+             # 一行打印所有信息 
813+             # print( 
814+             #     f"xxi x.shape: {getattr(x, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_num_tokens_padded: {all_rank_num_tokens_padded}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, outputs.shape: {getattr(outputs, 'shape', None)}, use_dp_padding(again): {use_dp_padding}" 
815+             # ) 
787816            outputs  =  self .reducescatter_or_allreduce (
788817                outputs ,
789818                use_all_to_all ,
0 commit comments