Skip to content

Commit 4a26bd6

Browse files
authored
Fix: pad DeepEP fp4 recv tensors if empty (#6048)
Signed-off-by: Tailing Yuan <[email protected]>
1 parent 9ebc3ab commit 4a26bd6

File tree

1 file changed

+40
-16
lines changed

1 file changed

+40
-16
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ def forward_chunk(
428428
if not self.use_postquant_alltoall:
429429
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
430430
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
431+
padded, x, _, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
432+
x, None, recv_topk_idx, token_final_scales)
431433
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
432434
if not self.use_postquant_alltoall:
433435
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
@@ -559,6 +561,8 @@ def forward_chunk(
559561
x_sf = x_sf.view(torch.float32)
560562
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
561563
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
564+
padded, x, x_sf, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
565+
x, x_sf, recv_topk_idx, token_final_scales)
562566
if x_sf is not None:
563567
x_sf = x_sf.view(x_sf_dtype)
564568
if self.has_nvfp4:
@@ -644,20 +648,6 @@ def forward_chunk(
644648
mask = token_selected_slots == -1
645649
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
646650
token_selected_slots[mask] = self.num_slots
647-
num_recv_token_is_zero = x.shape[0] == 0
648-
if x.shape[0] == 0:
649-
x = torch.zeros((1, x.shape[1]),
650-
dtype=x.dtype,
651-
device=x.device)
652-
token_selected_slots = torch.full(
653-
(1, token_selected_slots.shape[1]),
654-
self.num_slots,
655-
dtype=token_selected_slots.dtype,
656-
device=token_selected_slots.device)
657-
token_final_scales = torch.ones(
658-
(1, token_final_scales.shape[1]),
659-
dtype=token_final_scales.dtype,
660-
device=token_final_scales.device)
661651

662652
final_hidden_states = torch.ops.trtllm.fused_moe(
663653
x,
@@ -698,8 +688,8 @@ def forward_chunk(
698688
final_hidden_states = self.alltoall_combine(
699689
final_hidden_states, alltoall_info, token_count)
700690
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
701-
if num_recv_token_is_zero:
702-
final_hidden_states = final_hidden_states[:0]
691+
final_hidden_states = self.unpad_tensors(
692+
padded, final_hidden_states)
703693
final_hidden_states = self.deep_ep_buffer.combine(
704694
final_hidden_states, deep_ep_handle)
705695
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
@@ -972,6 +962,40 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
972962

973963
return final_hidden_states
974964

965+
def pad_empty_recv_tensors(
966+
self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
967+
recv_topk_idx: torch.Tensor, token_final_scales: torch.Tensor
968+
) -> Tuple[bool, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
969+
torch.Tensor]:
970+
"""
971+
Pad the output of DeepEP `dispatch` if the output length is zero.
972+
We can remove the adapter if both `fused_moe` op and `swizzle_sf`
973+
accept zero-length inputs.
974+
"""
975+
if x.shape[0] == 0:
976+
padded = True
977+
x = torch.zeros((1, x.shape[1]), dtype=x.dtype, device=x.device)
978+
if x_sf is not None:
979+
x_sf = torch.zeros((1, x_sf.shape[1]),
980+
dtype=x_sf.dtype,
981+
device=x_sf.device)
982+
recv_topk_idx = torch.full((1, recv_topk_idx.shape[1]),
983+
self.num_slots,
984+
dtype=recv_topk_idx.dtype,
985+
device=recv_topk_idx.device)
986+
token_final_scales = torch.ones((1, token_final_scales.shape[1]),
987+
dtype=token_final_scales.dtype,
988+
device=token_final_scales.device)
989+
else:
990+
padded = False
991+
return padded, x, x_sf, recv_topk_idx, token_final_scales
992+
993+
def unpad_tensors(self, padded: bool,
994+
final_hidden_states: torch.Tensor) -> torch.Tensor:
995+
if padded:
996+
final_hidden_states = final_hidden_states[:0]
997+
return final_hidden_states
998+
975999
def register_parameter_weight_slot_fn(self, weight_name: str,
9761000
local_slot_id: int):
9771001
assert hasattr(

0 commit comments

Comments
 (0)