@@ -428,6 +428,8 @@ def forward_chunk(
428428 if not self .use_postquant_alltoall :
429429 x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
430430 self .deep_ep_buffer .dispatch (x , token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
431+ padded , x , _ , recv_topk_idx , token_final_scales = self .pad_empty_recv_tensors (
432+ x , None , recv_topk_idx , token_final_scales )
431433 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
432434 if not self .use_postquant_alltoall :
433435 deep_ep_topk_idx = token_selected_slots .to (torch .int64 )
@@ -559,6 +561,8 @@ def forward_chunk(
559561 x_sf = x_sf .view (torch .float32 )
560562 (x , x_sf ), recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
561563 self .deep_ep_buffer .dispatch ((x , x_sf ), token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
564+ padded , x , x_sf , recv_topk_idx , token_final_scales = self .pad_empty_recv_tensors (
565+ x , x_sf , recv_topk_idx , token_final_scales )
562566 if x_sf is not None :
563567 x_sf = x_sf .view (x_sf_dtype )
564568 if self .has_nvfp4 :
@@ -644,20 +648,6 @@ def forward_chunk(
644648 mask = token_selected_slots == - 1
645649 token_selected_slots += self .expert_size_per_partition * self .mapping .moe_ep_rank
646650 token_selected_slots [mask ] = self .num_slots
647- num_recv_token_is_zero = x .shape [0 ] == 0
648- if x .shape [0 ] == 0 :
649- x = torch .zeros ((1 , x .shape [1 ]),
650- dtype = x .dtype ,
651- device = x .device )
652- token_selected_slots = torch .full (
653- (1 , token_selected_slots .shape [1 ]),
654- self .num_slots ,
655- dtype = token_selected_slots .dtype ,
656- device = token_selected_slots .device )
657- token_final_scales = torch .ones (
658- (1 , token_final_scales .shape [1 ]),
659- dtype = token_final_scales .dtype ,
660- device = token_final_scales .device )
661651
662652 final_hidden_states = torch .ops .trtllm .fused_moe (
663653 x ,
@@ -698,8 +688,8 @@ def forward_chunk(
698688 final_hidden_states = self .alltoall_combine (
699689 final_hidden_states , alltoall_info , token_count )
700690 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
701- if num_recv_token_is_zero :
702- final_hidden_states = final_hidden_states [: 0 ]
691+ final_hidden_states = self . unpad_tensors (
692+ padded , final_hidden_states )
703693 final_hidden_states = self .deep_ep_buffer .combine (
704694 final_hidden_states , deep_ep_handle )
705695 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
@@ -972,6 +962,40 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
972962
973963 return final_hidden_states
974964
965+ def pad_empty_recv_tensors (
966+ self , x : torch .Tensor , x_sf : Optional [torch .Tensor ],
967+ recv_topk_idx : torch .Tensor , token_final_scales : torch .Tensor
968+ ) -> Tuple [bool , torch .Tensor , Optional [torch .Tensor ], torch .Tensor ,
969+ torch .Tensor ]:
970+ """
971+ Pad the output of DeepEP `dispatch` if the output length is zero.
972+ We can remove the adapter if both `fused_moe` op and `swizzle_sf`
973+ accept zero-length inputs.
974+ """
975+ if x .shape [0 ] == 0 :
976+ padded = True
977+ x = torch .zeros ((1 , x .shape [1 ]), dtype = x .dtype , device = x .device )
978+ if x_sf is not None :
979+ x_sf = torch .zeros ((1 , x_sf .shape [1 ]),
980+ dtype = x_sf .dtype ,
981+ device = x_sf .device )
982+ recv_topk_idx = torch .full ((1 , recv_topk_idx .shape [1 ]),
983+ self .num_slots ,
984+ dtype = recv_topk_idx .dtype ,
985+ device = recv_topk_idx .device )
986+ token_final_scales = torch .ones ((1 , token_final_scales .shape [1 ]),
987+ dtype = token_final_scales .dtype ,
988+ device = token_final_scales .device )
989+ else :
990+ padded = False
991+ return padded , x , x_sf , recv_topk_idx , token_final_scales
992+
993+ def unpad_tensors (self , padded : bool ,
994+ final_hidden_states : torch .Tensor ) -> torch .Tensor :
995+ if padded :
996+ final_hidden_states = final_hidden_states [:0 ]
997+ return final_hidden_states
998+
975999 def register_parameter_weight_slot_fn (self , weight_name : str ,
9761000 local_slot_id : int ):
9771001 assert hasattr (
0 commit comments