Skip to content

Commit 943d296

Browse files
author
zhuohuan
committed
[feat]: initially support multistream overlap(dbo) for deepseek
1 parent 7aa4f85 commit 943d296

File tree

11 files changed

+898
-5
lines changed

11 files changed

+898
-5
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
1919
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2020

21+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
22+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23+
2124
if TYPE_CHECKING:
2225
from vllm.v1.core.sched.output import SchedulerOutput
2326
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -100,6 +103,8 @@ class AscendMLAMetadata:
100103
# For logging.
101104
num_input_tokens: int = 0 # Number of tokens including padding.
102105

106+
query_lens: list[int] = None
107+
seq_lens: torch.Tensor = None
103108
# The dimension of the attention heads
104109
head_dim: Optional[int] = None
105110
attn_mask: torch.Tensor = None
@@ -118,6 +123,16 @@ def __post_init__(self):
118123
# f"Only {supported_head_sizes} are supported for head_dim,",
119124
# f"received {self.head_dim}.")
120125

126+
def split_metadata_for_multistream(
127+
self,
128+
ms_split_config: MSAttentionMetadataSplitConfig,
129+
) -> list["AscendMLAMetadata"]:
130+
"""Split metadata for multi-stream with AscendMLAMetadata"""
131+
return model_input_split_v1_mla_attn(
132+
ms_split_config=ms_split_config,
133+
attn_metadata=self,
134+
_metadata_cls=AscendMLAMetadata,
135+
)
121136

122137
M = TypeVar("M", bound=AscendMLAMetadata)
123138

@@ -315,6 +330,8 @@ def build(self,
315330

316331
return self.metadata_cls( # type: ignore
317332
num_actual_tokens=num_actual_tokens,
333+
query_lens=query_lens.tolist(),
334+
seq_lens=seq_lens,
318335
slot_mapping=slot_mapping,
319336
head_dim=self.runner.model_config.get_head_size(),
320337
num_decodes=self._num_decodes,
@@ -783,16 +800,34 @@ def forward(
783800
key_cache=kv_cache,
784801
slot_indices=attn_metadata.slot_mapping.flatten())
785802
if has_prefill:
786-
output[num_decode_tokens:] = self._forward_prefill(
803+
# FIX: aicore move/copy should be also placed on the comm stream in dbo,
804+
# otherwise it may affect the accuracy or disturb the overlap of next stage
805+
# TODO: use an elegant way here to avoid it
806+
output_prefill = self._forward_prefill(
787807
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
788808
attn_metadata)
809+
from vllm.multistream.context import get_multistream_comm_context
810+
current_ms_metadata = get_multistream_comm_context()
811+
if current_ms_metadata is not None:
812+
with torch.npu.stream(current_ms_metadata.comm_stream):
813+
output[num_decode_tokens:] = output_prefill
814+
current_ms_metadata.after_comm_event.record()
815+
else:
816+
output[num_decode_tokens:] = output_prefill
789817
if has_decode:
790818
if self.running_in_graph:
791819
return self._forward_decode(decode_ql_nope, decode_q_pe,
792820
decode_k_nope, decode_k_pe,
793821
kv_cache, attn_metadata)
794822
else:
795-
output[:num_decode_tokens] = self._forward_decode(
796-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
797-
kv_cache, attn_metadata)
823+
from vllm.multistream.context import get_multistream_comm_context
824+
current_ms_metadata = get_multistream_comm_context()
825+
output_decode = self._forward_decode(
826+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
827+
kv_cache, attn_metadata)
828+
if current_ms_metadata is not None:
829+
with torch.npu.stream(current_ms_metadata.comm_stream):
830+
output[:num_decode_tokens] = output_decode
831+
else:
832+
output[:num_decode_tokens] = output_decode
798833
return output_padded

0 commit comments

Comments
 (0)