Skip to content

Commit 4e35808

Browse files
committed
[feat]: improve overlap performance
Signed-off-by: zhuohuan <[email protected]>
1 parent 5220270 commit 4e35808

File tree

9 files changed

+180
-89
lines changed

9 files changed

+180
-89
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import numpy as np
55
import torch
66
import torch_npu
7+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
8+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
9+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
10+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
11+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
12+
713
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
814
AttentionMetadata,
915
MLAAttentionImpl)
@@ -14,13 +20,6 @@
1420
UnquantizedLinearMethod)
1521
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
1622

17-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
18-
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19-
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
20-
21-
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
22-
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23-
2423
if TYPE_CHECKING:
2524
from vllm.v1.core.sched.output import SchedulerOutput
2625
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -575,7 +574,18 @@ def _forward_prefill(
575574
)
576575
attn_output = attn_output.reshape(
577576
[num_tokens, self.num_heads * self.v_head_dim])
578-
return self.o_proj(attn_output)[0]
577+
578+
# A better way is to modify the communication ops or RowParallel Layer in vllm;
579+
from vllm_ascend.multistream.context import \
580+
get_multistream_comm_context
581+
current_ms_metadata = get_multistream_comm_context()
582+
if current_ms_metadata is None:
583+
return self.o_proj(attn_output)[0]
584+
else:
585+
current_ms_metadata.before_comm_event.record()
586+
with torch.npu.stream(current_ms_metadata.comm_stream):
587+
current_ms_metadata.before_comm_event.wait()
588+
return self.o_proj(attn_output)[0]
579589

580590
def exec_kv(
581591
self,
@@ -675,7 +685,17 @@ def _forward_decode(
675685
context_lens=attn_metadata.decode.seq_lens, # type:ignore
676686
mla_vheadsize=self.kv_lora_rank,
677687
out=attn_output)
678-
return self._v_up_proj_and_o_proj(attn_output)
688+
from vllm_ascend.multistream.context import \
689+
get_multistream_comm_context
690+
current_ms_metadata = get_multistream_comm_context()
691+
if current_ms_metadata is None:
692+
return self._v_up_proj_and_o_proj(attn_output)
693+
else:
694+
current_ms_metadata.before_comm_event.record()
695+
with torch.npu.stream(current_ms_metadata.comm_stream):
696+
current_ms_metadata.before_comm_event.wait()
697+
return self._v_up_proj_and_o_proj(attn_output)
698+
679699

680700
def forward(
681701
self,
@@ -800,24 +820,21 @@ def forward(
800820
key_cache=kv_cache,
801821
slot_indices=attn_metadata.slot_mapping.flatten())
802822
if has_prefill:
803-
# FIX: aicore move/copy should be also placed on the comm stream in dbo,
804-
# otherwise it may affect the accuracy or disturb the overlap of next stage
805-
# TODO: use an elegant way here to avoid it
806-
from vllm_ascend.multistream.context import get_multistream_comm_context
823+
# FIX: aicore move should be also placed on the comm stream in dbo,
824+
# otherwise it may affect the accuracy
825+
# TODO: use an elegant way to overlap
826+
from vllm_ascend.multistream.context import \
827+
get_multistream_comm_context
828+
output_prefill = self._forward_prefill(
829+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
830+
attn_metadata)
807831
current_ms_metadata = get_multistream_comm_context()
808-
if current_ms_metadata is None:
809-
output[num_decode_tokens:] = self._forward_prefill(
810-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
811-
attn_metadata)
812-
else:
813-
current_ms_metadata.before_comm_event.record()
832+
if current_ms_metadata is not None:
814833
with torch.npu.stream(current_ms_metadata.comm_stream):
815-
current_ms_metadata.before_comm_event.wait()
816-
output[num_decode_tokens:] = self._forward_prefill(
817-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
818-
attn_metadata)
834+
output[num_decode_tokens:] = output_prefill
819835
current_ms_metadata.after_comm_event.record()
820-
836+
else:
837+
output[num_decode_tokens:] = output_prefill
821838

822839

823840
if has_decode:
@@ -826,21 +843,18 @@ def forward(
826843
decode_k_nope, decode_k_pe,
827844
kv_cache, attn_metadata)
828845
else:
829-
830-
from vllm_ascend.multistream.context import get_multistream_comm_context
831-
current_ms_metadata = get_multistream_comm_context()
832-
if current_ms_metadata is None:
833-
output[:num_decode_tokens] = self._forward_decode(
834-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
835-
kv_cache, attn_metadata)
836-
else:
837-
current_ms_metadata.before_comm_event.record()
838-
with torch.npu.stream(current_ms_metadata.comm_stream):
839-
current_ms_metadata.before_comm_event.wait()
840-
output[:num_decode_tokens] = self._forward_decode(
841-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
842-
kv_cache, attn_metadata)
843-
current_ms_metadata.after_comm_event.record()
846+
from vllm_ascend.multistream.context import \
847+
get_multistream_comm_context
848+
output_decode = self._forward_decode(
849+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
850+
kv_cache, attn_metadata)
851+
current_ms_metadata = get_multistream_comm_context()
852+
if current_ms_metadata is not None:
853+
with torch.npu.stream(current_ms_metadata.comm_stream):
854+
output[:num_decode_tokens] = output_decode
855+
current_ms_metadata.after_comm_event.record()
856+
else:
857+
output[:num_decode_tokens] = output_decode
844858

845859

846860
return output_padded

vllm_ascend/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
lambda: os.getenv("C_COMPILER", None),
6767
"VLLM_VERSION":
6868
lambda: os.getenv("VLLM_VERSION", None),
69-
"VLLM_ENABLE_MS":
70-
lambda: bool(int(os.getenv("VLLM_ENABLE_MS", '0'))),
69+
"VLLM_ENABLE_DBO":
70+
lambda: bool(int(os.getenv("VLLM_ENABLE_DBO", '0'))),
7171
}
7272

7373
# end-env-vars-definition

vllm_ascend/models/deepseek_v2.py

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,23 @@
3030
import torch
3131
import torch.distributed as dist
3232
import torch_npu
33-
import vllm.envs as envs
33+
import vllm_ascend.envs as envs_ascend
3434
from torch import nn
3535
from transformers import PretrainedConfig
36+
from vllm_ascend.multistream.base import MSEventKey
37+
from vllm_ascend.multistream.context import (
38+
advance_step_multistream_layer_context, get_multistream_comm_context,
39+
get_multistream_layer_context, set_multistream_context)
40+
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
41+
MultiStreamPreTransformerLayer)
42+
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
43+
MultiStreamStepMetadata,
44+
make_multistream_metadata_ds)
45+
from vllm_ascend.multistream.ms_split import compute_split_seq_index
46+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
47+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
48+
49+
import vllm.envs as envs
3650
from vllm.attention import Attention, AttentionMetadata
3751
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3852
get_current_vllm_config)
@@ -65,19 +79,8 @@
6579
maybe_prefix)
6680
from vllm.sequence import IntermediateTensors
6781

68-
import vllm_ascend.envs as envs_ascend
69-
from vllm_ascend.ops.fused_moe import AscendFusedMoE
70-
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
71-
72-
from vllm_ascend.multistream.context import (set_multistream_context,get_multistream_layer_context,
73-
advance_step_multistream_layer_context, get_multistream_comm_context)
74-
from vllm_ascend.multistream.layers import (MultiStreamPreTransformerLayer, MultiStreamPostTransformerLayer)
75-
from vllm_ascend.multistream.metadata import make_multistream_metadata_ds, MultiStreamStepMetadata, MultiStreamConfig
76-
from vllm_ascend.multistream.base import MSEventKey
77-
from vllm_ascend.multistream.ms_split import compute_split_seq_index
78-
7982
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
80-
VLLM_ENABLE_MS: bool = envs_ascend.VLLM_ENABLE_MS
83+
VLLM_ENABLE_DBO: bool = envs_ascend.VLLM_ENABLE_DBO
8184

8285

8386
class CustomDeepseekV2MLP(nn.Module):
@@ -149,6 +152,50 @@ def forward(self, x):
149152
x, _ = self.down_proj(x)
150153
return x
151154

155+
def _forward_ms_mlp(self, x):
156+
current_ms_metadata = get_multistream_comm_context()
157+
assert current_ms_metadata is not None
158+
if self.is_dynamic_quant:
159+
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
160+
x = torch_npu.npu_quant_matmul(
161+
x,
162+
self.gate_up_proj.weight,
163+
self.gate_up_proj.weight_scale,
164+
output_dtype=torch.int32,
165+
)
166+
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
167+
x=x,
168+
weight_scale=self.gate_up_proj.weight_scale_fp32,
169+
activation_scale=dynamic_scale,
170+
bias=None,
171+
quant_scale=None,
172+
quant_offset=None,
173+
group_index=None,
174+
activate_left=True,
175+
quant_mode=1)
176+
x = torch_npu.npu_quant_matmul(
177+
x,
178+
self.down_proj.weight,
179+
self.down_proj.weight_scale,
180+
pertoken_scale=dynamic_scale,
181+
output_dtype=torch.bfloat16,
182+
)
183+
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
184+
current_ms_metadata.before_comm_event.record()
185+
with torch.npu.stream(current_ms_metadata.comm_stream):
186+
current_ms_metadata.before_comm_event.wait()
187+
x = tensor_model_parallel_all_reduce(x)
188+
current_ms_metadata.after_comm_event.record()
189+
return x
190+
gate_up, _ = self.gate_up_proj(x)
191+
x = self.act_fn(gate_up)
192+
current_ms_metadata.before_comm_event.record()
193+
with torch.npu.stream(current_ms_metadata.comm_stream):
194+
current_ms_metadata.before_comm_event.wait()
195+
x, _ = self.down_proj(x)
196+
current_ms_metadata.after_comm_event.record()
197+
return x
198+
152199

153200
class CustomDeepseekV2MoE(nn.Module):
154201

@@ -282,7 +329,7 @@ def _forward_ms_op_shared_expert(
282329
self,
283330
hidden_states: torch.Tensor,
284331
):
285-
shared_output = self.shared_experts(hidden_states)
332+
shared_output = self.shared_experts._forward_ms_mlp(hidden_states)
286333
return shared_output
287334

288335
def _forward_ms_op_gate(
@@ -293,7 +340,7 @@ def _forward_ms_op_gate(
293340
router_logits, _ = self.gate(hidden_states)
294341
return router_logits
295342

296-
def _forward_ms_op_tp_allreduce(
343+
def _forward_ms_op_tp_allgather(
297344
self,
298345
hidden_states: torch.Tensor,
299346
shared_output: torch.Tensor,
@@ -303,13 +350,26 @@ def _forward_ms_op_tp_allreduce(
303350
):
304351

305352
if self.tp_size > 1:
306-
dist.all_gather(list(chunk_hidden_states), hidden_states,
307-
self.tp_group)
308-
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
309-
#if num_tokens < self.tp_size:
310-
# final_hidden_states = final_hidden_states[:num_tokens]
311-
if num_tokens > 0:
312-
final_hidden_states = final_hidden_states[:-num_tokens]
353+
current_ms_metadata = get_multistream_comm_context()
354+
if current_ms_metadata is None:
355+
dist.all_gather(list(chunk_hidden_states), hidden_states,
356+
self.tp_group)
357+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
358+
#if num_tokens < self.tp_size:
359+
# final_hidden_states = final_hidden_states[:num_tokens]
360+
if num_tokens > 0:
361+
final_hidden_states = final_hidden_states[:-num_tokens]
362+
else:
363+
current_ms_metadata.before_comm_event.record()
364+
with torch.npu.stream(current_ms_metadata.comm_stream):
365+
dist.all_gather(list(chunk_hidden_states), hidden_states,
366+
self.tp_group)
367+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
368+
#if num_tokens < self.tp_size:
369+
# final_hidden_states = final_hidden_states[:num_tokens]
370+
if num_tokens > 0:
371+
final_hidden_states = final_hidden_states[:-num_tokens]
372+
313373
else:
314374
final_hidden_states = hidden_states
315375

@@ -650,25 +710,24 @@ def _forward_ms_layer(
650710

651711
# input layernorm
652712
hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm(hidden_states[i], residual[i])
653-
# attention and tp allreducea
713+
# attention and tp allreduce
654714
hidden_states[i], residual[i] = self._forward_ms_op_attn(positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i])
655715

656716
''' block 3 : shared experts
657717
if there is an allreduce ops in shared expert, we can overlap it with the computation of the
658718
shared expert for next microbatch or moe gating
659719
'''
660720
for i in range(num_micro_batchs):
721+
ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH)
661722
context = MultiStreamStepMetadata(
662723
comm_stream=ms_metadata.communicate_stream,
663724
before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMP_FINISH],
664725
after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH],
665726
)
666727
with set_multistream_context(context, i):
667728
# compute shared expert after finishing ATTN AR
668-
ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH)
669729
hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm(hidden_states[i], residual[i])
670730

671-
672731
num_token, hidden_dim = hidden_states[i].shape
673732
hidden_states[i] = hidden_states[i].view(-1, hidden_dim)
674733
#num_tokens.append(num_token)
@@ -740,10 +799,14 @@ def _forward_ms_layer(
740799
before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_COM_FINISH],
741800
after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_AFTER_COMM],
742801
)
743-
with set_multistream_context(context, i):
802+
context.before_comm_event.record()
803+
with torch.npu.stream(ms_metadata.communicate_stream):
804+
#with set_multistream_context(context, i):
805+
context.before_comm_event.wait()
744806
if self.mlp.experts.reduce_results and (self.mlp.experts.tp_size > 1 or self.mlp.experts.ep_size > 1):
745807
hidden_states[i] = tensor_model_parallel_all_reduce(
746808
hidden_states[i])
809+
context.after_comm_event.record()
747810
# check here
748811
hidden_states[i] = hidden_states[i] * self.mlp.routed_scaling_factor
749812
context = MultiStreamStepMetadata(
@@ -752,7 +815,7 @@ def _forward_ms_layer(
752815
after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH],
753816
)
754817
with set_multistream_context(context, i):
755-
hidden_states[i] = self.mlp._forward_ms_op_tp_allreduce(hidden_states[i], shared_outputs[i], chunk_hidden_states[i], num_tokens[i], hidden_dims[i])
818+
hidden_states[i] = self.mlp._forward_ms_op_tp_allgather(hidden_states[i], shared_outputs[i], chunk_hidden_states[i], num_tokens[i], hidden_dims[i])
756819
with torch.npu.stream(ms_metadata.communicate_stream):
757820
# last
758821
if isinstance(
@@ -764,6 +827,7 @@ def _forward_ms_layer(
764827
# The scaling of DeepseekV2MOE output would be done in the forward
765828
# of DeepseekV2MOE
766829
hidden_states[i] *= 1. / self.routed_scaling_factor
830+
context.after_comm_event.record()
767831
return hidden_states, residual
768832
# should split ops in Decoder Layer
769833
def _forward_ms_op_input_layernorm(
@@ -861,7 +925,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
861925
["hidden_states", "residual"], config.hidden_size))
862926

863927
# tbo related members
864-
if VLLM_ENABLE_MS:
928+
if VLLM_ENABLE_DBO:
865929
self.multistream_config = MultiStreamConfig()
866930
else:
867931
self.multistream_config = None
@@ -934,7 +998,7 @@ def forward(
934998
def can_run_ms(self):
935999
# currently we only enable prefill overlap
9361000
attn_metadata = get_forward_context().attn_metadata
937-
dp_metadata = get_forward_context().dp_metadata
1001+
# dp_metadata = get_forward_context().dp_metadata
9381002
# profile run
9391003
if self.multistream_config is None or attn_metadata is None:
9401004
return False
@@ -944,16 +1008,17 @@ def can_run_ms(self):
9441008
# disable decode dbo
9451009
if attn_metadata.num_prefills == 0:
9461010
return False
947-
num_microbatchs = self.multistream_config.num_micro_batches
9481011
# check whether there is a dp rank that not use dual batch
949-
'''if dp_metadata is not None:
1012+
'''
1013+
num_microbatchs = self.multistream_config.num_micro_batches
1014+
if dp_metadata is not None:
9501015
for i in range(num_microbatchs):
9511016
cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
9521017
if torch.any(cu_tokens == 0).item():
9531018
return False
1019+
'''
9541020
[token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens,
955-
attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956-
'''
1021+
attn_metadata.attn_state, attn_metadata.num_decode_tokens)
9571022
if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens):
9581023
return False
9591024
# check whether the total tokens exceed the threshold

0 commit comments

Comments
 (0)