@@ -181,7 +181,6 @@ def dispatch(
181181 num_max_dispatch_tokens_per_rank : int = 128 ,
182182 ) -> Tuple [torch .Tensor , torch .Tensor ]:
183183 self .hidden_shape = hidden_states .shape
184- topk_idx = topk_idx .to (torch .int64 )
185184 (
186185 hidden_states ,
187186 topk_idx ,
@@ -205,6 +204,39 @@ def dispatch(
205204 hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
206205 return hidden_states , topk_idx , topk_weights , tokens_per_expert
207206
207+ def dispatch_yield (
208+ self ,
209+ hidden_states : torch .Tensor ,
210+ topk_idx : torch .Tensor ,
211+ topk_weights : torch .Tensor ,
212+ num_experts : int ,
213+ previous_event = None ,
214+ num_max_dispatch_tokens_per_rank : int = 128 ,
215+ ):
216+ self .hidden_shape = hidden_states .shape
217+ (
218+ hidden_states ,
219+ topk_idx ,
220+ topk_weights ,
221+ num_recv_tokens_per_expert_list ,
222+ handle ,
223+ event ,
224+ ) = yield from self .dispatch_normal_yield (
225+ hidden_states , topk_idx , topk_weights , num_experts , previous_event
226+ )
227+ self .tokens_per_expert = torch .tensor (
228+ num_recv_tokens_per_expert_list ,
229+ device = hidden_states .device ,
230+ dtype = torch .int64 ,
231+ )
232+ tokens_per_expert = self .get_number_of_tokens_per_expert ()
233+ self .handle = handle
234+ self .topk_idx = topk_idx
235+ self .topk_weights = topk_weights
236+ if hidden_states .shape [0 ] > 0 :
237+ hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
238+ return hidden_states , topk_idx , topk_weights , tokens_per_expert
239+
208240 def dispatch_normal (
209241 self ,
210242 x : torch .Tensor ,
@@ -256,6 +288,61 @@ def dispatch_normal(
256288 event ,
257289 )
258290
291+ def dispatch_normal_yield (
292+ self ,
293+ x : torch .Tensor ,
294+ topk_idx : torch .Tensor ,
295+ topk_weights : torch .Tensor ,
296+ num_experts : int ,
297+ previous_event = None ,
298+ async_finish = True
299+ ):
300+ previous_event = self .buffer_normal .capture () if async_finish else None
301+ (
302+ num_tokens_per_rank ,
303+ num_tokens_per_rdma_rank ,
304+ num_tokens_per_expert ,
305+ is_token_in_rank ,
306+ previous_event ,
307+ ) = self .buffer_normal .get_dispatch_layout (
308+ topk_idx ,
309+ num_experts ,
310+ previous_event = previous_event ,
311+ async_finish = async_finish ,
312+ allocate_on_comm_stream = previous_event is not None and async_finish ,
313+ )
314+
315+ (
316+ recv_x ,
317+ recv_topk_idx ,
318+ recv_topk_weights ,
319+ num_recv_tokens_per_expert_list ,
320+ handle ,
321+ event ,
322+ ) = self .buffer_normal .dispatch (
323+ x ,
324+ topk_idx = topk_idx ,
325+ topk_weights = topk_weights ,
326+ num_tokens_per_rank = num_tokens_per_rank ,
327+ num_tokens_per_rdma_rank = num_tokens_per_rdma_rank ,
328+ is_token_in_rank = is_token_in_rank ,
329+ num_tokens_per_expert = num_tokens_per_expert ,
330+ previous_event = previous_event ,
331+ async_finish = async_finish ,
332+ allocate_on_comm_stream = previous_event is not None and async_finish ,
333+ )
334+
335+ yield
336+ if async_finish :
337+ event .current_stream_wait ()
338+ return (
339+ recv_x ,
340+ recv_topk_idx ,
341+ recv_topk_weights ,
342+ num_recv_tokens_per_expert_list ,
343+ handle ,
344+ event ,
345+ )
259346
260347 def combine (
261348 self , hidden_states : torch .Tensor
@@ -268,6 +355,17 @@ def combine(
268355 self .handle = None
269356 return hidden_states .view (self .hidden_shape )
270357
358+ def combine_yield (
359+ self , hidden_states : torch .Tensor
360+ ):
361+ if hidden_states .shape [0 ] > 0 :
362+ hidden_states = self .get_restored_hidden_states_by_experts (
363+ hidden_states
364+ )
365+ hidden_states , event = yield from self .combine_normal_yield (hidden_states , self .handle )
366+ self .handle = None
367+ return hidden_states .view (self .hidden_shape )
368+
271369 def combine_normal (self , x : torch .Tensor , handle : Tuple , previous_event = None ):
272370 combined_x , _ , event = self .buffer_normal .combine (
273371 x ,
@@ -278,6 +376,22 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
278376 )
279377 return combined_x , event
280378
379+ def combine_normal_yield (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
380+ yield
381+ previous_event = self .buffer_normal .capture () if async_finish else None
382+ combined_x , _ , event = self .buffer_normal .combine (
383+ x ,
384+ handle ,
385+ async_finish = async_finish ,
386+ previous_event = previous_event ,
387+ allocate_on_comm_stream = previous_event is not None and async_finish ,
388+ )
389+
390+ yield
391+ if async_finish :
392+ event .current_stream_wait ()
393+ return combined_x , event
394+
281395 def _indices_to_multihot (self , indices , probs ):
282396 batch_size = indices .shape [0 ]
283397 multihot_routing_map = torch .zeros (
0 commit comments