Skip to content

[MLA][Graph] Improve assertion on Graph mode with MLA #933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/faqs.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,13 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam
- **Adjust `--gpu-memory-utilization`**: If unspecified, will use the default value of `0.9`. You can decrease this param to reserve more memory to reduce fragmentation risks. See more note in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig).

- **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html).

### 15. Failed to enable NPU graph mode when running DeepSeek?
You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future.

And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}.

```bash
[rank0]: RuntimeError: EZ9999: Inner Error!
[rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218]
```
10 changes: 10 additions & 0 deletions vllm_ascend/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)

_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]


def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
# Construct lower triangle matrix.
Expand Down Expand Up @@ -1005,6 +1007,14 @@ def __init__(
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
if self.enable_graph_mode:
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
("The allowed number of queries per kv when enabling both MLA and Graph mode"
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
"{32, 64, 128}.")

def exec_kv(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.cuda())
execute_model_req.last_sampled_token_ids.npu())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
Expand Down
Loading