Skip to content

Improve Prefill Performance by Removing Redundant Padding and Optimizing Alltoall Communication #948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,18 +521,9 @@
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),

Check failure on line 524 in vllm_ascend/attention/mla_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm_ascend/attention/mla_v1.py:524:13: F841 Local variable `key` is assigned to but never used
dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)

attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
Expand Down
148 changes: 123 additions & 25 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,59 @@

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2

def process_topk_ids(
topk_ids: torch.Tensor,
expert_num: int,
ep_size: int,
max_row_per_ep_rank: int,
num_tokens: int,
top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
original_total_elements = num_tokens * top_k
device = topk_ids.device
original_dtype = topk_ids.dtype

if original_total_elements == 0:
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len,), expert_num, dtype=original_dtype, device=device)
unpad_indices = torch.full((original_total_elements,), -1, dtype=torch.long, device=device)
return topk_ids_pad, unpad_indices

experts_per_ep_rank_val = expert_num // ep_size
if experts_per_ep_rank_val == 0:
raise ValueError("expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
"Ensure expert_num >= ep_size.")

assigned_ep_rank = (topk_ids.float() / experts_per_ep_rank_val).to(original_dtype)
indices_arange = torch.arange(topk_ids.shape[0], device=device)

is_new_segment = torch.cat((
torch.tensor([True], device=device),
assigned_ep_rank[1:] != assigned_ep_rank[:-1]
))
temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype)
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
start_offset_for_each_token = torch.cummax(temp_start_markers.float(), dim=0)[0].to(temp_start_markers.dtype)
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
indices_in_rec_cond_list_for_all = cumsum_kept - 1
unpad_indices = torch.where(
is_kept_mask,
indices_in_rec_cond_list_for_all,
torch.tensor(-1, device=device, dtype=torch.long)
)
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len,), expert_num, dtype=original_dtype, device=device)
if topk_ids.shape[0] > 0:
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
temp_pad_buffer = torch.full((output_len + 1,), expert_num, dtype=original_dtype, device=device)
output_len_tensor = torch.tensor(output_len, dtype=torch.long, device=device)
scatter_indices = torch.where(is_kept_mask, all_destination_indices, output_len_tensor)
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
topk_ids_pad = temp_pad_buffer[:output_len]
return topk_ids_pad, unpad_indices


def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
Expand Down Expand Up @@ -236,28 +289,52 @@ def fused_experts_with_all2all(
expert_idx=topk_ids,
active_num=num_tokens)

global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)

gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
local_buffer_rows = (num_tokens // ep_group.world_size + 1) * ep_group.world_size * top_k * 2
max_row_per_ep_rank = local_buffer_rows // ep_group.world_size
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
expanded_expert_idx,
global_num_experts,
ep_group.world_size,
max_row_per_ep_rank,
num_tokens,
top_k
)
hidden_states_pad_idx = torch.zeros(
expert_idx_buffer_scatter.shape,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device
)
non_pad_len = torch.sum((expert_idx_buffer_scatter != global_num_experts).to(torch.int32))
hidden_states_pad_idx[expert_idx_buffer_scatter != global_num_experts] = torch.arange(
non_pad_len,
dtype=expert_idx_buffer_scatter.dtype,
device=hidden_states.device
)
expert_idx_buffer_gather = torch.empty_like(
expert_idx_buffer_scatter,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device
)
dist.all_to_all_single(expert_idx_buffer_gather,
expert_idx_buffer_scatter,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()

expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)

sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]

hidden_states_buffer_gather = torch.empty_like(
hidden_states_buffer_scatter,
dtype=hidden_states_buffer_scatter.dtype,
device=hidden_states_buffer_scatter.device
)

dist.all_to_all_single(hidden_states_buffer_gather,
hidden_states_buffer_scatter,
group=ep_group.device_group)
mask = expert_idx_buffer_gather != global_num_experts
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (global_num_experts // ep_group.world_size)
hidden_states = hidden_states_buffer_gather[mask]
idx_type = local_expert_idx.dtype
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)

Expand Down Expand Up @@ -293,12 +370,33 @@ def fused_experts_with_all2all(
group_list_type=group_list_type)

if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
idx_type = sorted_idx.dtype
resorted_idx = torch.argsort(sorted_idx.float()).to(idx_type)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)

hidden_states_scatter = torch.zeros(
(mask.shape[0], hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device
)
hidden_states_scatter[mask] = hidden_states
hidden_states_gatter = torch.empty_like(
hidden_states_scatter,
dtype=hidden_states_scatter.dtype,
device=hidden_states_scatter.device
)
dist.all_to_all_single(hidden_states_gatter,
hidden_states_scatter,
group=ep_group.device_group)
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != global_num_experts]
if hidden_states_gatter.shape[0] != row_idx_len:
hidden_states = torch.zeros(
(row_idx_len, hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device
)
hidden_states[unpad_indices != -1] = hidden_states_gatter
else:
hidden_states = hidden_states_gatter
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
Expand Down
Loading