30
30
import torch
31
31
import torch .distributed as dist
32
32
import torch_npu
33
- import vllm .envs as envs
33
+ import vllm_ascend .envs as envs_ascend
34
34
from torch import nn
35
35
from transformers import PretrainedConfig
36
+ from vllm_ascend .multistream .base import MSEventKey
37
+ from vllm_ascend .multistream .context import (
38
+ advance_step_multistream_layer_context , get_multistream_comm_context ,
39
+ get_multistream_layer_context , set_multistream_context )
40
+ from vllm_ascend .multistream .layers import (MultiStreamPostTransformerLayer ,
41
+ MultiStreamPreTransformerLayer )
42
+ from vllm_ascend .multistream .metadata import (MultiStreamConfig ,
43
+ MultiStreamStepMetadata ,
44
+ make_multistream_metadata_ds )
45
+ from vllm_ascend .multistream .ms_split import compute_split_seq_index
46
+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
47
+ from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
48
+
49
+ import vllm .envs as envs
36
50
from vllm .attention import Attention , AttentionMetadata
37
51
from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
38
52
get_current_vllm_config )
65
79
maybe_prefix )
66
80
from vllm .sequence import IntermediateTensors
67
81
68
- import vllm_ascend .envs as envs_ascend
69
- from vllm_ascend .ops .fused_moe import AscendFusedMoE
70
- from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
71
-
72
- from vllm_ascend .multistream .context import (set_multistream_context ,get_multistream_layer_context ,
73
- advance_step_multistream_layer_context , get_multistream_comm_context )
74
- from vllm_ascend .multistream .layers import (MultiStreamPreTransformerLayer , MultiStreamPostTransformerLayer )
75
- from vllm_ascend .multistream .metadata import make_multistream_metadata_ds , MultiStreamStepMetadata , MultiStreamConfig
76
- from vllm_ascend .multistream .base import MSEventKey
77
- from vllm_ascend .multistream .ms_split import compute_split_seq_index
78
-
79
82
VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
80
- VLLM_ENABLE_MS : bool = envs_ascend .VLLM_ENABLE_MS
83
+ VLLM_ENABLE_DBO : bool = envs_ascend .VLLM_ENABLE_DBO
81
84
82
85
83
86
class CustomDeepseekV2MLP (nn .Module ):
@@ -149,6 +152,50 @@ def forward(self, x):
149
152
x , _ = self .down_proj (x )
150
153
return x
151
154
155
+ def _forward_ms_mlp (self , x ):
156
+ current_ms_metadata = get_multistream_comm_context ()
157
+ assert current_ms_metadata is not None
158
+ if self .is_dynamic_quant :
159
+ x , dynamic_scale = torch_npu .npu_dynamic_quant (x )
160
+ x = torch_npu .npu_quant_matmul (
161
+ x ,
162
+ self .gate_up_proj .weight ,
163
+ self .gate_up_proj .weight_scale ,
164
+ output_dtype = torch .int32 ,
165
+ )
166
+ x , dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
167
+ x = x ,
168
+ weight_scale = self .gate_up_proj .weight_scale_fp32 ,
169
+ activation_scale = dynamic_scale ,
170
+ bias = None ,
171
+ quant_scale = None ,
172
+ quant_offset = None ,
173
+ group_index = None ,
174
+ activate_left = True ,
175
+ quant_mode = 1 )
176
+ x = torch_npu .npu_quant_matmul (
177
+ x ,
178
+ self .down_proj .weight ,
179
+ self .down_proj .weight_scale ,
180
+ pertoken_scale = dynamic_scale ,
181
+ output_dtype = torch .bfloat16 ,
182
+ )
183
+ if self .down_proj .reduce_results and self .down_proj .tp_size > 1 :
184
+ current_ms_metadata .before_comm_event .record ()
185
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
186
+ current_ms_metadata .before_comm_event .wait ()
187
+ x = tensor_model_parallel_all_reduce (x )
188
+ current_ms_metadata .after_comm_event .record ()
189
+ return x
190
+ gate_up , _ = self .gate_up_proj (x )
191
+ x = self .act_fn (gate_up )
192
+ current_ms_metadata .before_comm_event .record ()
193
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
194
+ current_ms_metadata .before_comm_event .wait ()
195
+ x , _ = self .down_proj (x )
196
+ current_ms_metadata .after_comm_event .record ()
197
+ return x
198
+
152
199
153
200
class CustomDeepseekV2MoE (nn .Module ):
154
201
@@ -282,7 +329,7 @@ def _forward_ms_op_shared_expert(
282
329
self ,
283
330
hidden_states : torch .Tensor ,
284
331
):
285
- shared_output = self .shared_experts (hidden_states )
332
+ shared_output = self .shared_experts . _forward_ms_mlp (hidden_states )
286
333
return shared_output
287
334
288
335
def _forward_ms_op_gate (
@@ -293,7 +340,7 @@ def _forward_ms_op_gate(
293
340
router_logits , _ = self .gate (hidden_states )
294
341
return router_logits
295
342
296
- def _forward_ms_op_tp_allreduce (
343
+ def _forward_ms_op_tp_allgather (
297
344
self ,
298
345
hidden_states : torch .Tensor ,
299
346
shared_output : torch .Tensor ,
@@ -303,13 +350,26 @@ def _forward_ms_op_tp_allreduce(
303
350
):
304
351
305
352
if self .tp_size > 1 :
306
- dist .all_gather (list (chunk_hidden_states ), hidden_states ,
307
- self .tp_group )
308
- final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
309
- #if num_tokens < self.tp_size:
310
- # final_hidden_states = final_hidden_states[:num_tokens]
311
- if num_tokens > 0 :
312
- final_hidden_states = final_hidden_states [:- num_tokens ]
353
+ current_ms_metadata = get_multistream_comm_context ()
354
+ if current_ms_metadata is None :
355
+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
356
+ self .tp_group )
357
+ final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
358
+ #if num_tokens < self.tp_size:
359
+ # final_hidden_states = final_hidden_states[:num_tokens]
360
+ if num_tokens > 0 :
361
+ final_hidden_states = final_hidden_states [:- num_tokens ]
362
+ else :
363
+ current_ms_metadata .before_comm_event .record ()
364
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
365
+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
366
+ self .tp_group )
367
+ final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
368
+ #if num_tokens < self.tp_size:
369
+ # final_hidden_states = final_hidden_states[:num_tokens]
370
+ if num_tokens > 0 :
371
+ final_hidden_states = final_hidden_states [:- num_tokens ]
372
+
313
373
else :
314
374
final_hidden_states = hidden_states
315
375
@@ -650,25 +710,24 @@ def _forward_ms_layer(
650
710
651
711
# input layernorm
652
712
hidden_states [i ], residual [i ] = self ._forward_ms_op_input_layernorm (hidden_states [i ], residual [i ])
653
- # attention and tp allreducea
713
+ # attention and tp allreduce
654
714
hidden_states [i ], residual [i ] = self ._forward_ms_op_attn (positions [i ], hidden_states [i ], residual [i ], kv_cache , attn_metadata [i ])
655
715
656
716
''' block 3 : shared experts
657
717
if there is an allreduce ops in shared expert, we can overlap it with the computation of the
658
718
shared expert for next microbatch or moe gating
659
719
'''
660
720
for i in range (num_micro_batchs ):
721
+ ms_metadata .try_wait_event (layer_index , i , MSEventKey .ATTN_AR_FINISH )
661
722
context = MultiStreamStepMetadata (
662
723
comm_stream = ms_metadata .communicate_stream ,
663
724
before_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMP_FINISH ],
664
725
after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMM_FINISH ],
665
726
)
666
727
with set_multistream_context (context , i ):
667
728
# compute shared expert after finishing ATTN AR
668
- ms_metadata .try_wait_event (layer_index , i , MSEventKey .ATTN_AR_FINISH )
669
729
hidden_states [i ], residual [i ] = self ._forward_ms_op_post_attn_layernorm (hidden_states [i ], residual [i ])
670
730
671
-
672
731
num_token , hidden_dim = hidden_states [i ].shape
673
732
hidden_states [i ] = hidden_states [i ].view (- 1 , hidden_dim )
674
733
#num_tokens.append(num_token)
@@ -740,10 +799,14 @@ def _forward_ms_layer(
740
799
before_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .FFN_COM_FINISH ],
741
800
after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_AFTER_COMM ],
742
801
)
743
- with set_multistream_context (context , i ):
802
+ context .before_comm_event .record ()
803
+ with torch .npu .stream (ms_metadata .communicate_stream ):
804
+ #with set_multistream_context(context, i):
805
+ context .before_comm_event .wait ()
744
806
if self .mlp .experts .reduce_results and (self .mlp .experts .tp_size > 1 or self .mlp .experts .ep_size > 1 ):
745
807
hidden_states [i ] = tensor_model_parallel_all_reduce (
746
808
hidden_states [i ])
809
+ context .after_comm_event .record ()
747
810
# check here
748
811
hidden_states [i ] = hidden_states [i ] * self .mlp .routed_scaling_factor
749
812
context = MultiStreamStepMetadata (
@@ -752,7 +815,7 @@ def _forward_ms_layer(
752
815
after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .FFN_AR_FINISH ],
753
816
)
754
817
with set_multistream_context (context , i ):
755
- hidden_states [i ] = self .mlp ._forward_ms_op_tp_allreduce (hidden_states [i ], shared_outputs [i ], chunk_hidden_states [i ], num_tokens [i ], hidden_dims [i ])
818
+ hidden_states [i ] = self .mlp ._forward_ms_op_tp_allgather (hidden_states [i ], shared_outputs [i ], chunk_hidden_states [i ], num_tokens [i ], hidden_dims [i ])
756
819
with torch .npu .stream (ms_metadata .communicate_stream ):
757
820
# last
758
821
if isinstance (
@@ -764,6 +827,7 @@ def _forward_ms_layer(
764
827
# The scaling of DeepseekV2MOE output would be done in the forward
765
828
# of DeepseekV2MOE
766
829
hidden_states [i ] *= 1. / self .routed_scaling_factor
830
+ context .after_comm_event .record ()
767
831
return hidden_states , residual
768
832
# should split ops in Decoder Layer
769
833
def _forward_ms_op_input_layernorm (
@@ -861,7 +925,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
861
925
["hidden_states" , "residual" ], config .hidden_size ))
862
926
863
927
# tbo related members
864
- if VLLM_ENABLE_MS :
928
+ if VLLM_ENABLE_DBO :
865
929
self .multistream_config = MultiStreamConfig ()
866
930
else :
867
931
self .multistream_config = None
@@ -934,7 +998,7 @@ def forward(
934
998
def can_run_ms (self ):
935
999
# currently we only enable prefill overlap
936
1000
attn_metadata = get_forward_context ().attn_metadata
937
- dp_metadata = get_forward_context ().dp_metadata
1001
+ # dp_metadata = get_forward_context().dp_metadata
938
1002
# profile run
939
1003
if self .multistream_config is None or attn_metadata is None :
940
1004
return False
@@ -944,16 +1008,17 @@ def can_run_ms(self):
944
1008
# disable decode dbo
945
1009
if attn_metadata .num_prefills == 0 :
946
1010
return False
947
- num_microbatchs = self .multistream_config .num_micro_batches
948
1011
# check whether there is a dp rank that not use dual batch
949
- '''if dp_metadata is not None:
1012
+ '''
1013
+ num_microbatchs = self.multistream_config.num_micro_batches
1014
+ if dp_metadata is not None:
950
1015
for i in range(num_microbatchs):
951
1016
cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
952
1017
if torch.any(cu_tokens == 0).item():
953
1018
return False
1019
+ '''
954
1020
[token_index , seq_index ] = compute_split_seq_index (attn_metadata .query_lens ,
955
- attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956
- '''
1021
+ attn_metadata .attn_state , attn_metadata .num_decode_tokens )
957
1022
if token_index == 0 or seq_index == 0 or seq_index == len (attn_metadata .query_lens ):
958
1023
return False
959
1024
# check whether the total tokens exceed the threshold
0 commit comments