66 use_deepep = False
77
88import os
9- from typing import Optional , Tuple
9+ from typing import Optional , Tuple , Any
1010
1111import torch
1212import torch .distributed as dist
@@ -156,6 +156,8 @@ def __init__(
156156 self .token_probs = None
157157 # Handle used for combine operation
158158 self .handle = None
159+ # shared experts
160+ self .shared_experts = None
159161
160162 # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
161163 # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
@@ -212,30 +214,44 @@ def dispatch_yield(
212214 num_experts : int ,
213215 previous_event = None ,
214216 num_max_dispatch_tokens_per_rank : int = 128 ,
217+ is_prefill : bool = False ,
218+ is_decoding : bool = False
215219 ):
216220 self .hidden_shape = hidden_states .shape
221+ # yield for attn1, dis (+share)
222+ yield
223+ previous_event = self .buffer_normal .capture ()
217224 (
218- hidden_states ,
219- topk_idx ,
220- topk_weights ,
225+ recv_hidden_states ,
226+ recv_topk_idx ,
227+ recv_topk_weights ,
221228 num_recv_tokens_per_expert_list ,
222229 handle ,
223230 event ,
224- ) = yield from self .dispatch_normal_yield (
225- hidden_states , topk_idx , topk_weights , num_experts , previous_event
231+ ) = self .dispatch_normal_async (
232+ hidden_states , topk_idx , topk_weights , num_experts , previous_event , True
226233 )
234+ if is_decoding and self .shared_experts is not None :
235+ shared_states = self .shared_experts (hidden_states )
236+ else :
237+ shared_states = None
238+ # yield for dis (+share), dis_wait
239+ yield
240+ event .current_stream_wait ()
241+ # yield for dis_wait, moe
242+ yield
227243 self .tokens_per_expert = torch .tensor (
228244 num_recv_tokens_per_expert_list ,
229245 device = hidden_states .device ,
230246 dtype = torch .int64 ,
231247 )
232248 tokens_per_expert = self .get_number_of_tokens_per_expert ()
233249 self .handle = handle
234- self .topk_idx = topk_idx
235- self .topk_weights = topk_weights
236- if hidden_states .shape [0 ] > 0 :
237- hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
238- return hidden_states , topk_idx , topk_weights , tokens_per_expert
250+ self .topk_idx = recv_topk_idx
251+ self .topk_weights = recv_topk_weights
252+ if recv_hidden_states .shape [0 ] > 0 :
253+ recv_hidden_states = self .get_permuted_hidden_states_by_experts (recv_hidden_states )
254+ return recv_hidden_states , recv_topk_idx , recv_topk_weights , tokens_per_expert , shared_states
239255
240256 def dispatch_normal (
241257 self ,
@@ -288,7 +304,7 @@ def dispatch_normal(
288304 event ,
289305 )
290306
291- def dispatch_normal_yield (
307+ def dispatch_normal_async (
292308 self ,
293309 x : torch .Tensor ,
294310 topk_idx : torch .Tensor ,
@@ -297,8 +313,6 @@ def dispatch_normal_yield(
297313 previous_event = None ,
298314 async_finish = True
299315 ):
300- yield
301- previous_event = self .buffer_normal .capture () if async_finish else None
302316 (
303317 num_tokens_per_rank ,
304318 num_tokens_per_rdma_rank ,
@@ -333,9 +347,6 @@ def dispatch_normal_yield(
333347 allocate_on_comm_stream = previous_event is not None and async_finish ,
334348 )
335349
336- yield
337- if async_finish :
338- event .current_stream_wait ()
339350 return (
340351 recv_x ,
341352 recv_topk_idx ,
@@ -357,15 +368,34 @@ def combine(
357368 return hidden_states .view (self .hidden_shape )
358369
359370 def combine_yield (
360- self , hidden_states : torch .Tensor
371+ self ,
372+ out_states : torch .Tensor ,
373+ hidden_states : torch .Tensor ,
374+ is_prefill : bool = False ,
375+ is_decoding : bool = False
361376 ):
362- if hidden_states .shape [0 ] > 0 :
363- hidden_states = self .get_restored_hidden_states_by_experts (
364- hidden_states
377+ if out_states .shape [0 ] > 0 :
378+ out_states = self .get_restored_hidden_states_by_experts (
379+ out_states
365380 )
366- hidden_states , event = yield from self .combine_normal_yield (hidden_states , self .handle )
381+ # yield for moe, comb
382+ yield
383+ previous_event = self .buffer_normal .capture ()
384+ out_states , event = self .combine_normal_async (out_states ,
385+ self .handle ,
386+ previous_event = previous_event ,
387+ async_finish = True )
388+ # yield for comb, (+share) comb_wait,
389+ yield
390+ if is_prefill and self .shared_experts is not None :
391+ shared_states = self .shared_experts (hidden_states )
392+ else :
393+ shared_states = None
394+ event .current_stream_wait ()
395+ # yield for (+share) comb_wait, (+share) attn0
396+ yield
367397 self .handle = None
368- return hidden_states .view (self .hidden_shape )
398+ return out_states .view (self .hidden_shape ), shared_states
369399
370400 def combine_normal (self , x : torch .Tensor , handle : Tuple , previous_event = None ):
371401 combined_x , _ , event = self .buffer_normal .combine (
@@ -377,20 +407,14 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
377407 )
378408 return combined_x , event
379409
380- def combine_normal_yield (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
381- yield
382- previous_event = self .buffer_normal .capture () if async_finish else None
410+ def combine_normal_async (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
383411 combined_x , _ , event = self .buffer_normal .combine (
384412 x ,
385413 handle ,
386414 async_finish = async_finish ,
387415 previous_event = previous_event ,
388416 allocate_on_comm_stream = previous_event is not None and async_finish ,
389417 )
390-
391- yield
392- if async_finish :
393- event .current_stream_wait ()
394418 return combined_x , event
395419
396420 def _indices_to_multihot (self , indices , probs ):
@@ -456,3 +480,11 @@ def get_restored_hidden_states_by_experts(
456480 fused = self .permute_fusion ,
457481 )
458482 return hidden_states .to (input_dtype )
483+
484+ def set_shared_experts (self , shared_experts : Any = None ):
485+ if self .shared_experts is not None :
486+ self .shared_experts = shared_experts
487+ return self .shared_experts
488+
489+ def get_shared_experts (self ):
490+ return self .shared_experts
0 commit comments