18
18
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
19
19
from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
20
20
21
+ from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
22
+ from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23
+
21
24
if TYPE_CHECKING :
22
25
from vllm .v1 .core .sched .output import SchedulerOutput
23
26
from vllm .v1 .worker .gpu_input_batch import InputBatch
@@ -100,6 +103,8 @@ class AscendMLAMetadata:
100
103
# For logging.
101
104
num_input_tokens : int = 0 # Number of tokens including padding.
102
105
106
+ query_lens : list [int ] = None
107
+ seq_lens : torch .Tensor = None
103
108
# The dimension of the attention heads
104
109
head_dim : Optional [int ] = None
105
110
attn_mask : torch .Tensor = None
@@ -118,6 +123,16 @@ def __post_init__(self):
118
123
# f"Only {supported_head_sizes} are supported for head_dim,",
119
124
# f"received {self.head_dim}.")
120
125
126
+ def split_metadata_for_multistream (
127
+ self ,
128
+ ms_split_config : MSAttentionMetadataSplitConfig ,
129
+ ) -> list ["AscendMLAMetadata" ]:
130
+ """Split metadata for multi-stream with AscendMLAMetadata"""
131
+ return model_input_split_v1_mla_attn (
132
+ ms_split_config = ms_split_config ,
133
+ attn_metadata = self ,
134
+ _metadata_cls = AscendMLAMetadata ,
135
+ )
121
136
122
137
M = TypeVar ("M" , bound = AscendMLAMetadata )
123
138
@@ -315,6 +330,8 @@ def build(self,
315
330
316
331
return self .metadata_cls ( # type: ignore
317
332
num_actual_tokens = num_actual_tokens ,
333
+ query_lens = query_lens .tolist (),
334
+ seq_lens = seq_lens ,
318
335
slot_mapping = slot_mapping ,
319
336
head_dim = self .runner .model_config .get_head_size (),
320
337
num_decodes = self ._num_decodes ,
@@ -783,16 +800,34 @@ def forward(
783
800
key_cache = kv_cache ,
784
801
slot_indices = attn_metadata .slot_mapping .flatten ())
785
802
if has_prefill :
786
- output [num_decode_tokens :] = self ._forward_prefill (
803
+ # FIX: aicore move/copy should be also placed on the comm stream in dbo,
804
+ # otherwise it may affect the accuracy or disturb the overlap of next stage
805
+ # TODO: use an elegant way here to avoid it
806
+ output_prefill = self ._forward_prefill (
787
807
prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
788
808
attn_metadata )
809
+ from vllm .multistream .context import get_multistream_comm_context
810
+ current_ms_metadata = get_multistream_comm_context ()
811
+ if current_ms_metadata is not None :
812
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
813
+ output [num_decode_tokens :] = output_prefill
814
+ current_ms_metadata .after_comm_event .record ()
815
+ else :
816
+ output [num_decode_tokens :] = output_prefill
789
817
if has_decode :
790
818
if self .running_in_graph :
791
819
return self ._forward_decode (decode_ql_nope , decode_q_pe ,
792
820
decode_k_nope , decode_k_pe ,
793
821
kv_cache , attn_metadata )
794
822
else :
795
- output [:num_decode_tokens ] = self ._forward_decode (
796
- decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
797
- kv_cache , attn_metadata )
823
+ from vllm .multistream .context import get_multistream_comm_context
824
+ current_ms_metadata = get_multistream_comm_context ()
825
+ output_decode = self ._forward_decode (
826
+ decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
827
+ kv_cache , attn_metadata )
828
+ if current_ms_metadata is not None :
829
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
830
+ output [:num_decode_tokens ] = output_decode
831
+ else :
832
+ output [:num_decode_tokens ] = output_decode
798
833
return output_padded
0 commit comments