@@ -149,6 +149,50 @@ def forward(self, x):
149
149
x , _ = self .down_proj (x )
150
150
return x
151
151
152
+ def _forward_ms_mlp (self , x ):
153
+ current_ms_metadata = get_multistream_comm_context ()
154
+ assert current_ms_metadata is not None
155
+ if self .is_dynamic_quant :
156
+ x , dynamic_scale = torch_npu .npu_dynamic_quant (x )
157
+ x = torch_npu .npu_quant_matmul (
158
+ x ,
159
+ self .gate_up_proj .weight ,
160
+ self .gate_up_proj .weight_scale ,
161
+ output_dtype = torch .int32 ,
162
+ )
163
+ x , dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
164
+ x = x ,
165
+ weight_scale = self .gate_up_proj .weight_scale_fp32 ,
166
+ activation_scale = dynamic_scale ,
167
+ bias = None ,
168
+ quant_scale = None ,
169
+ quant_offset = None ,
170
+ group_index = None ,
171
+ activate_left = True ,
172
+ quant_mode = 1 )
173
+ x = torch_npu .npu_quant_matmul (
174
+ x ,
175
+ self .down_proj .weight ,
176
+ self .down_proj .weight_scale ,
177
+ pertoken_scale = dynamic_scale ,
178
+ output_dtype = torch .bfloat16 ,
179
+ )
180
+ if self .down_proj .reduce_results and self .down_proj .tp_size > 1 :
181
+ current_ms_metadata .before_comm_event .record ()
182
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
183
+ current_ms_metadata .before_comm_event .wait ()
184
+ x = tensor_model_parallel_all_reduce (x )
185
+ current_ms_metadata .after_comm_event .record ()
186
+ return x
187
+ gate_up , _ = self .gate_up_proj (x )
188
+ x = self .act_fn (gate_up )
189
+ current_ms_metadata .before_comm_event .record ()
190
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
191
+ current_ms_metadata .before_comm_event .wait ()
192
+ x , _ = self .down_proj (x )
193
+ current_ms_metadata .after_comm_event .record ()
194
+ return x
195
+
152
196
153
197
class CustomDeepseekV2MoE (nn .Module ):
154
198
@@ -282,7 +326,7 @@ def _forward_ms_op_shared_expert(
282
326
self ,
283
327
hidden_states : torch .Tensor ,
284
328
):
285
- shared_output = self .shared_experts (hidden_states )
329
+ shared_output = self .shared_experts . _forward_ms_mlp (hidden_states )
286
330
return shared_output
287
331
288
332
def _forward_ms_op_gate (
@@ -293,7 +337,7 @@ def _forward_ms_op_gate(
293
337
router_logits , _ = self .gate (hidden_states )
294
338
return router_logits
295
339
296
- def _forward_ms_op_tp_allreduce (
340
+ def _forward_ms_op_tp_allgather (
297
341
self ,
298
342
hidden_states : torch .Tensor ,
299
343
shared_output : torch .Tensor ,
@@ -303,13 +347,26 @@ def _forward_ms_op_tp_allreduce(
303
347
):
304
348
305
349
if self .tp_size > 1 :
306
- dist .all_gather (list (chunk_hidden_states ), hidden_states ,
307
- self .tp_group )
308
- final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
309
- #if num_tokens < self.tp_size:
310
- # final_hidden_states = final_hidden_states[:num_tokens]
311
- if num_tokens > 0 :
312
- final_hidden_states = final_hidden_states [:- num_tokens ]
350
+ current_ms_metadata = get_multistream_comm_context ()
351
+ if current_ms_metadata is None :
352
+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
353
+ self .tp_group )
354
+ final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
355
+ #if num_tokens < self.tp_size:
356
+ # final_hidden_states = final_hidden_states[:num_tokens]
357
+ if num_tokens > 0 :
358
+ final_hidden_states = final_hidden_states [:- num_tokens ]
359
+ else :
360
+ current_ms_metadata .before_comm_event .record ()
361
+ with torch .npu .stream (current_ms_metadata .comm_stream ):
362
+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
363
+ self .tp_group )
364
+ final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
365
+ #if num_tokens < self.tp_size:
366
+ # final_hidden_states = final_hidden_states[:num_tokens]
367
+ if num_tokens > 0 :
368
+ final_hidden_states = final_hidden_states [:- num_tokens ]
369
+
313
370
else :
314
371
final_hidden_states = hidden_states
315
372
@@ -650,25 +707,24 @@ def _forward_ms_layer(
650
707
651
708
# input layernorm
652
709
hidden_states [i ], residual [i ] = self ._forward_ms_op_input_layernorm (hidden_states [i ], residual [i ])
653
- # attention and tp allreducea
710
+ # attention and tp allreduce
654
711
hidden_states [i ], residual [i ] = self ._forward_ms_op_attn (positions [i ], hidden_states [i ], residual [i ], kv_cache , attn_metadata [i ])
655
712
656
713
''' block 3 : shared experts
657
714
if there is an allreduce ops in shared expert, we can overlap it with the computation of the
658
715
shared expert for next microbatch or moe gating
659
716
'''
660
717
for i in range (num_micro_batchs ):
718
+ ms_metadata .try_wait_event (layer_index , i , MSEventKey .ATTN_AR_FINISH )
661
719
context = MultiStreamStepMetadata (
662
720
comm_stream = ms_metadata .communicate_stream ,
663
721
before_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMP_FINISH ],
664
722
after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMM_FINISH ],
665
723
)
666
724
with set_multistream_context (context , i ):
667
725
# compute shared expert after finishing ATTN AR
668
- ms_metadata .try_wait_event (layer_index , i , MSEventKey .ATTN_AR_FINISH )
669
726
hidden_states [i ], residual [i ] = self ._forward_ms_op_post_attn_layernorm (hidden_states [i ], residual [i ])
670
727
671
-
672
728
num_token , hidden_dim = hidden_states [i ].shape
673
729
hidden_states [i ] = hidden_states [i ].view (- 1 , hidden_dim )
674
730
#num_tokens.append(num_token)
@@ -740,10 +796,14 @@ def _forward_ms_layer(
740
796
before_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .FFN_COM_FINISH ],
741
797
after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_AFTER_COMM ],
742
798
)
743
- with set_multistream_context (context , i ):
799
+ context .before_comm_event .record ()
800
+ with torch .npu .stream (ms_metadata .communicate_stream ):
801
+ #with set_multistream_context(context, i):
802
+ context .before_comm_event .wait ()
744
803
if self .mlp .experts .reduce_results and (self .mlp .experts .tp_size > 1 or self .mlp .experts .ep_size > 1 ):
745
804
hidden_states [i ] = tensor_model_parallel_all_reduce (
746
805
hidden_states [i ])
806
+ context .after_comm_event .record ()
747
807
# check here
748
808
hidden_states [i ] = hidden_states [i ] * self .mlp .routed_scaling_factor
749
809
context = MultiStreamStepMetadata (
@@ -752,7 +812,7 @@ def _forward_ms_layer(
752
812
after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .FFN_AR_FINISH ],
753
813
)
754
814
with set_multistream_context (context , i ):
755
- hidden_states [i ] = self .mlp ._forward_ms_op_tp_allreduce (hidden_states [i ], shared_outputs [i ], chunk_hidden_states [i ], num_tokens [i ], hidden_dims [i ])
815
+ hidden_states [i ] = self .mlp ._forward_ms_op_tp_allgather (hidden_states [i ], shared_outputs [i ], chunk_hidden_states [i ], num_tokens [i ], hidden_dims [i ])
756
816
with torch .npu .stream (ms_metadata .communicate_stream ):
757
817
# last
758
818
if isinstance (
@@ -764,6 +824,7 @@ def _forward_ms_layer(
764
824
# The scaling of DeepseekV2MOE output would be done in the forward
765
825
# of DeepseekV2MOE
766
826
hidden_states [i ] *= 1. / self .routed_scaling_factor
827
+ context .after_comm_event .record ()
767
828
return hidden_states , residual
768
829
# should split ops in Decoder Layer
769
830
def _forward_ms_op_input_layernorm (
@@ -934,7 +995,7 @@ def forward(
934
995
def can_run_ms (self ):
935
996
# currently we only enable prefill overlap
936
997
attn_metadata = get_forward_context ().attn_metadata
937
- dp_metadata = get_forward_context ().dp_metadata
998
+ # dp_metadata = get_forward_context().dp_metadata
938
999
# profile run
939
1000
if self .multistream_config is None or attn_metadata is None :
940
1001
return False
@@ -944,16 +1005,17 @@ def can_run_ms(self):
944
1005
# disable decode dbo
945
1006
if attn_metadata .num_prefills == 0 :
946
1007
return False
947
- num_microbatchs = self .multistream_config .num_micro_batches
948
1008
# check whether there is a dp rank that not use dual batch
949
- '''if dp_metadata is not None:
1009
+ '''
1010
+ num_microbatchs = self.multistream_config.num_micro_batches
1011
+ if dp_metadata is not None:
950
1012
for i in range(num_microbatchs):
951
1013
cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
952
1014
if torch.any(cu_tokens == 0).item():
953
1015
return False
1016
+ '''
954
1017
[token_index , seq_index ] = compute_split_seq_index (attn_metadata .query_lens ,
955
- attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956
- '''
1018
+ attn_metadata .attn_state , attn_metadata .num_decode_tokens )
957
1019
if token_index == 0 or seq_index == 0 or seq_index == len (attn_metadata .query_lens ):
958
1020
return False
959
1021
# check whether the total tokens exceed the threshold
0 commit comments