Skip to content

Commit 9053dd1

Browse files
committed
[feat]: improve overlap performance
1 parent 43f5388 commit 9053dd1

File tree

3 files changed

+122
-51
lines changed

3 files changed

+122
-51
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,17 @@ def _forward_prefill(
575575
)
576576
attn_output = attn_output.reshape(
577577
[num_tokens, self.num_heads * self.v_head_dim])
578-
return self.o_proj(attn_output)[0]
578+
579+
# A better way is to modify the communication ops or RowParallel Layer in vllm;
580+
from vllm_ascend.multistream.context import get_multistream_comm_context
581+
current_ms_metadata = get_multistream_comm_context()
582+
if current_ms_metadata is None:
583+
return self.o_proj(attn_output)[0]
584+
else:
585+
current_ms_metadata.before_comm_event.record()
586+
with torch.npu.stream(current_ms_metadata.comm_stream):
587+
current_ms_metadata.before_comm_event.wait()
588+
return self.o_proj(attn_output)[0]
579589

580590
def exec_kv(
581591
self,
@@ -675,7 +685,16 @@ def _forward_decode(
675685
context_lens=attn_metadata.decode.seq_lens, # type:ignore
676686
mla_vheadsize=self.kv_lora_rank,
677687
out=attn_output)
678-
return self._v_up_proj_and_o_proj(attn_output)
688+
from vllm_ascend.multistream.context import get_multistream_comm_context
689+
current_ms_metadata = get_multistream_comm_context()
690+
if current_ms_metadata is None:
691+
return self._v_up_proj_and_o_proj(attn_output)
692+
else:
693+
current_ms_metadata.before_comm_event.record()
694+
with torch.npu.stream(current_ms_metadata.comm_stream):
695+
current_ms_metadata.before_comm_event.wait()
696+
return self._v_up_proj_and_o_proj(attn_output)
697+
679698

680699
def forward(
681700
self,
@@ -800,24 +819,20 @@ def forward(
800819
key_cache=kv_cache,
801820
slot_indices=attn_metadata.slot_mapping.flatten())
802821
if has_prefill:
803-
# FIX: aicore move/copy should be also placed on the comm stream in dbo,
804-
# otherwise it may affect the accuracy or disturb the overlap of next stage
805-
# TODO: use an elegant way here to avoid it
822+
# FIX: aicore move should be also placed on the comm stream in dbo,
823+
# otherwise it may affect the accuracy
824+
# TODO: use an elegant way to overlap
806825
from vllm_ascend.multistream.context import get_multistream_comm_context
826+
output_prefill = self._forward_prefill(
827+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
828+
attn_metadata)
807829
current_ms_metadata = get_multistream_comm_context()
808-
if current_ms_metadata is None:
809-
output[num_decode_tokens:] = self._forward_prefill(
810-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
811-
attn_metadata)
812-
else:
813-
current_ms_metadata.before_comm_event.record()
830+
if current_ms_metadata is not None:
814831
with torch.npu.stream(current_ms_metadata.comm_stream):
815-
current_ms_metadata.before_comm_event.wait()
816-
output[num_decode_tokens:] = self._forward_prefill(
817-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
818-
attn_metadata)
832+
output[num_decode_tokens:] = output_prefill
819833
current_ms_metadata.after_comm_event.record()
820-
834+
else:
835+
output[num_decode_tokens:] = output_prefill
821836

822837

823838
if has_decode:
@@ -826,21 +841,17 @@ def forward(
826841
decode_k_nope, decode_k_pe,
827842
kv_cache, attn_metadata)
828843
else:
829-
830844
from vllm_ascend.multistream.context import get_multistream_comm_context
831-
current_ms_metadata = get_multistream_comm_context()
832-
if current_ms_metadata is None:
833-
output[:num_decode_tokens] = self._forward_decode(
834-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
835-
kv_cache, attn_metadata)
836-
else:
837-
current_ms_metadata.before_comm_event.record()
838-
with torch.npu.stream(current_ms_metadata.comm_stream):
839-
current_ms_metadata.before_comm_event.wait()
840-
output[:num_decode_tokens] = self._forward_decode(
841-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
842-
kv_cache, attn_metadata)
843-
current_ms_metadata.after_comm_event.record()
845+
output_decode = self._forward_decode(
846+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
847+
kv_cache, attn_metadata)
848+
current_ms_metadata = get_multistream_comm_context()
849+
if current_ms_metadata is not None:
850+
with torch.npu.stream(current_ms_metadata.comm_stream):
851+
output[:num_decode_tokens] = output_decode
852+
current_ms_metadata.after_comm_event.record()
853+
else:
854+
output[:num_decode_tokens] = output_decode
844855

845856

846857
return output_padded

vllm_ascend/models/deepseek_v2.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,50 @@ def forward(self, x):
149149
x, _ = self.down_proj(x)
150150
return x
151151

152+
def _forward_ms_mlp(self, x):
153+
current_ms_metadata = get_multistream_comm_context()
154+
assert current_ms_metadata is not None
155+
if self.is_dynamic_quant:
156+
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
157+
x = torch_npu.npu_quant_matmul(
158+
x,
159+
self.gate_up_proj.weight,
160+
self.gate_up_proj.weight_scale,
161+
output_dtype=torch.int32,
162+
)
163+
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
164+
x=x,
165+
weight_scale=self.gate_up_proj.weight_scale_fp32,
166+
activation_scale=dynamic_scale,
167+
bias=None,
168+
quant_scale=None,
169+
quant_offset=None,
170+
group_index=None,
171+
activate_left=True,
172+
quant_mode=1)
173+
x = torch_npu.npu_quant_matmul(
174+
x,
175+
self.down_proj.weight,
176+
self.down_proj.weight_scale,
177+
pertoken_scale=dynamic_scale,
178+
output_dtype=torch.bfloat16,
179+
)
180+
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
181+
current_ms_metadata.before_comm_event.record()
182+
with torch.npu.stream(current_ms_metadata.comm_stream):
183+
current_ms_metadata.before_comm_event.wait()
184+
x = tensor_model_parallel_all_reduce(x)
185+
current_ms_metadata.after_comm_event.record()
186+
return x
187+
gate_up, _ = self.gate_up_proj(x)
188+
x = self.act_fn(gate_up)
189+
current_ms_metadata.before_comm_event.record()
190+
with torch.npu.stream(current_ms_metadata.comm_stream):
191+
current_ms_metadata.before_comm_event.wait()
192+
x, _ = self.down_proj(x)
193+
current_ms_metadata.after_comm_event.record()
194+
return x
195+
152196

153197
class CustomDeepseekV2MoE(nn.Module):
154198

@@ -282,7 +326,7 @@ def _forward_ms_op_shared_expert(
282326
self,
283327
hidden_states: torch.Tensor,
284328
):
285-
shared_output = self.shared_experts(hidden_states)
329+
shared_output = self.shared_experts._forward_ms_mlp(hidden_states)
286330
return shared_output
287331

288332
def _forward_ms_op_gate(
@@ -293,7 +337,7 @@ def _forward_ms_op_gate(
293337
router_logits, _ = self.gate(hidden_states)
294338
return router_logits
295339

296-
def _forward_ms_op_tp_allreduce(
340+
def _forward_ms_op_tp_allgather(
297341
self,
298342
hidden_states: torch.Tensor,
299343
shared_output: torch.Tensor,
@@ -303,13 +347,26 @@ def _forward_ms_op_tp_allreduce(
303347
):
304348

305349
if self.tp_size > 1:
306-
dist.all_gather(list(chunk_hidden_states), hidden_states,
307-
self.tp_group)
308-
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
309-
#if num_tokens < self.tp_size:
310-
# final_hidden_states = final_hidden_states[:num_tokens]
311-
if num_tokens > 0:
312-
final_hidden_states = final_hidden_states[:-num_tokens]
350+
current_ms_metadata = get_multistream_comm_context()
351+
if current_ms_metadata is None:
352+
dist.all_gather(list(chunk_hidden_states), hidden_states,
353+
self.tp_group)
354+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
355+
#if num_tokens < self.tp_size:
356+
# final_hidden_states = final_hidden_states[:num_tokens]
357+
if num_tokens > 0:
358+
final_hidden_states = final_hidden_states[:-num_tokens]
359+
else:
360+
current_ms_metadata.before_comm_event.record()
361+
with torch.npu.stream(current_ms_metadata.comm_stream):
362+
dist.all_gather(list(chunk_hidden_states), hidden_states,
363+
self.tp_group)
364+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
365+
#if num_tokens < self.tp_size:
366+
# final_hidden_states = final_hidden_states[:num_tokens]
367+
if num_tokens > 0:
368+
final_hidden_states = final_hidden_states[:-num_tokens]
369+
313370
else:
314371
final_hidden_states = hidden_states
315372

@@ -650,25 +707,24 @@ def _forward_ms_layer(
650707

651708
# input layernorm
652709
hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm(hidden_states[i], residual[i])
653-
# attention and tp allreducea
710+
# attention and tp allreduce
654711
hidden_states[i], residual[i] = self._forward_ms_op_attn(positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i])
655712

656713
''' block 3 : shared experts
657714
if there is an allreduce ops in shared expert, we can overlap it with the computation of the
658715
shared expert for next microbatch or moe gating
659716
'''
660717
for i in range(num_micro_batchs):
718+
ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH)
661719
context = MultiStreamStepMetadata(
662720
comm_stream=ms_metadata.communicate_stream,
663721
before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMP_FINISH],
664722
after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH],
665723
)
666724
with set_multistream_context(context, i):
667725
# compute shared expert after finishing ATTN AR
668-
ms_metadata.try_wait_event(layer_index, i, MSEventKey.ATTN_AR_FINISH)
669726
hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm(hidden_states[i], residual[i])
670727

671-
672728
num_token, hidden_dim = hidden_states[i].shape
673729
hidden_states[i] = hidden_states[i].view(-1, hidden_dim)
674730
#num_tokens.append(num_token)
@@ -740,10 +796,14 @@ def _forward_ms_layer(
740796
before_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_COM_FINISH],
741797
after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_AFTER_COMM],
742798
)
743-
with set_multistream_context(context, i):
799+
context.before_comm_event.record()
800+
with torch.npu.stream(ms_metadata.communicate_stream):
801+
#with set_multistream_context(context, i):
802+
context.before_comm_event.wait()
744803
if self.mlp.experts.reduce_results and (self.mlp.experts.tp_size > 1 or self.mlp.experts.ep_size > 1):
745804
hidden_states[i] = tensor_model_parallel_all_reduce(
746805
hidden_states[i])
806+
context.after_comm_event.record()
747807
# check here
748808
hidden_states[i] = hidden_states[i] * self.mlp.routed_scaling_factor
749809
context = MultiStreamStepMetadata(
@@ -752,7 +812,7 @@ def _forward_ms_layer(
752812
after_comm_event=ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH],
753813
)
754814
with set_multistream_context(context, i):
755-
hidden_states[i] = self.mlp._forward_ms_op_tp_allreduce(hidden_states[i], shared_outputs[i], chunk_hidden_states[i], num_tokens[i], hidden_dims[i])
815+
hidden_states[i] = self.mlp._forward_ms_op_tp_allgather(hidden_states[i], shared_outputs[i], chunk_hidden_states[i], num_tokens[i], hidden_dims[i])
756816
with torch.npu.stream(ms_metadata.communicate_stream):
757817
# last
758818
if isinstance(
@@ -764,6 +824,7 @@ def _forward_ms_layer(
764824
# The scaling of DeepseekV2MOE output would be done in the forward
765825
# of DeepseekV2MOE
766826
hidden_states[i] *= 1. / self.routed_scaling_factor
827+
context.after_comm_event.record()
767828
return hidden_states, residual
768829
# should split ops in Decoder Layer
769830
def _forward_ms_op_input_layernorm(
@@ -934,7 +995,7 @@ def forward(
934995
def can_run_ms(self):
935996
# currently we only enable prefill overlap
936997
attn_metadata = get_forward_context().attn_metadata
937-
dp_metadata = get_forward_context().dp_metadata
998+
# dp_metadata = get_forward_context().dp_metadata
938999
# profile run
9391000
if self.multistream_config is None or attn_metadata is None:
9401001
return False
@@ -944,16 +1005,17 @@ def can_run_ms(self):
9441005
# disable decode dbo
9451006
if attn_metadata.num_prefills == 0:
9461007
return False
947-
num_microbatchs = self.multistream_config.num_micro_batches
9481008
# check whether there is a dp rank that not use dual batch
949-
'''if dp_metadata is not None:
1009+
'''
1010+
num_microbatchs = self.multistream_config.num_micro_batches
1011+
if dp_metadata is not None:
9501012
for i in range(num_microbatchs):
9511013
cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
9521014
if torch.any(cu_tokens == 0).item():
9531015
return False
1016+
'''
9541017
[token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens,
955-
attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956-
'''
1018+
attn_metadata.attn_state, attn_metadata.num_decode_tokens)
9571019
if token_index == 0 or seq_index == 0 or seq_index == len(attn_metadata.query_lens):
9581020
return False
9591021
# check whether the total tokens exceed the threshold

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
from vllm.model_executor.layers.quantization.base_config import (
3838
QuantizationConfig, QuantizeMethodBase)
3939

40-
from vllm_ascend.multistream.base import MSEventKey
41-
from vllm_ascend.multistream.metadata import MultiStreamStepMetadata, MultiStreamMetadata
4240
import vllm_ascend.envs as envs_ascend
4341
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
4442

0 commit comments

Comments
 (0)