diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d987eaba9..aa2d8f805 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -523,16 +523,7 @@ def _forward_prefill( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) + attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) else: raise RuntimeError( diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f54b012f..19556a54c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -28,6 +28,59 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +def process_topk_ids( + topk_ids: torch.Tensor, + expert_num: int, + ep_size: int, + max_row_per_ep_rank: int, + num_tokens: int, + top_k: int +) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len,), expert_num, dtype=original_dtype, device=device) + unpad_indices = torch.full((original_total_elements,), -1, dtype=torch.long, device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError("expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat(( + torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1] + )) + temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers.float(), dim=0)[0].to(temp_start_markers.dtype) + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, + indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long) + ) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len,), expert_num, dtype=original_dtype, device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1,), expert_num, dtype=original_dtype, device=device) + output_len_tensor = torch.tensor(output_len, dtype=torch.long, device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices + def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w1: torch.Tensor, @@ -236,28 +289,52 @@ def fused_experts_with_all2all( expert_idx=topk_ids, active_num=num_tokens) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, + local_buffer_rows = (num_tokens // ep_group.world_size + 1) * ep_group.world_size * top_k * 2 + max_row_per_ep_rank = local_buffer_rows // ep_group.world_size + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, + global_num_experts, + ep_group.world_size, + max_row_per_ep_rank, + num_tokens, + top_k + ) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device + ) + non_pad_len = torch.sum((expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device + ) + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device + ) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) - - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device + ) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( sorted_local_expert_idx, local_num_experts).to(torch.int64) @@ -293,12 +370,33 @@ def fused_experts_with_all2all( group_list_type=group_list_type) if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) + idx_type = sorted_idx.dtype + resorted_idx = torch.argsort(sorted_idx.float()).to(idx_type) hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device + ) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros( + (row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + hidden_states = hidden_states_gatter final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None,