@@ -181,7 +181,6 @@ def dispatch(
181181 num_max_dispatch_tokens_per_rank : int = 128 ,
182182 ) -> Tuple [torch .Tensor , torch .Tensor ]:
183183 self .hidden_shape = hidden_states .shape
184- topk_idx = topk_idx .to (torch .int64 )
185184 (
186185 hidden_states ,
187186 topk_idx ,
@@ -205,6 +204,39 @@ def dispatch(
205204 hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
206205 return hidden_states , topk_idx , topk_weights , tokens_per_expert
207206
207+ def dispatch_yield (
208+ self ,
209+ hidden_states : torch .Tensor ,
210+ topk_idx : torch .Tensor ,
211+ topk_weights : torch .Tensor ,
212+ num_experts : int ,
213+ previous_event = None ,
214+ num_max_dispatch_tokens_per_rank : int = 128 ,
215+ ):
216+ self .hidden_shape = hidden_states .shape
217+ (
218+ hidden_states ,
219+ topk_idx ,
220+ topk_weights ,
221+ num_recv_tokens_per_expert_list ,
222+ handle ,
223+ event ,
224+ ) = yield from self .dispatch_normal_yield (
225+ hidden_states , topk_idx , topk_weights , num_experts , previous_event
226+ )
227+ self .tokens_per_expert = torch .tensor (
228+ num_recv_tokens_per_expert_list ,
229+ device = hidden_states .device ,
230+ dtype = torch .int64 ,
231+ )
232+ tokens_per_expert = self .get_number_of_tokens_per_expert ()
233+ self .handle = handle
234+ self .topk_idx = topk_idx
235+ self .topk_weights = topk_weights
236+ if hidden_states .shape [0 ] > 0 :
237+ hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
238+ return hidden_states , topk_idx , topk_weights , tokens_per_expert
239+
208240 def dispatch_normal (
209241 self ,
210242 x : torch .Tensor ,
@@ -256,6 +288,62 @@ def dispatch_normal(
256288 event ,
257289 )
258290
291+ def dispatch_normal_yield (
292+ self ,
293+ x : torch .Tensor ,
294+ topk_idx : torch .Tensor ,
295+ topk_weights : torch .Tensor ,
296+ num_experts : int ,
297+ previous_event = None ,
298+ async_finish = True
299+ ):
300+ yield
301+ previous_event = self .buffer_normal .capture () if async_finish else None
302+ (
303+ num_tokens_per_rank ,
304+ num_tokens_per_rdma_rank ,
305+ num_tokens_per_expert ,
306+ is_token_in_rank ,
307+ previous_event ,
308+ ) = self .buffer_normal .get_dispatch_layout (
309+ topk_idx ,
310+ num_experts ,
311+ previous_event = previous_event ,
312+ async_finish = async_finish ,
313+ allocate_on_comm_stream = previous_event is not None and async_finish ,
314+ )
315+
316+ (
317+ recv_x ,
318+ recv_topk_idx ,
319+ recv_topk_weights ,
320+ num_recv_tokens_per_expert_list ,
321+ handle ,
322+ event ,
323+ ) = self .buffer_normal .dispatch (
324+ x ,
325+ topk_idx = topk_idx ,
326+ topk_weights = topk_weights ,
327+ num_tokens_per_rank = num_tokens_per_rank ,
328+ num_tokens_per_rdma_rank = num_tokens_per_rdma_rank ,
329+ is_token_in_rank = is_token_in_rank ,
330+ num_tokens_per_expert = num_tokens_per_expert ,
331+ previous_event = previous_event ,
332+ async_finish = async_finish ,
333+ allocate_on_comm_stream = previous_event is not None and async_finish ,
334+ )
335+
336+ yield
337+ if async_finish :
338+ event .current_stream_wait ()
339+ return (
340+ recv_x ,
341+ recv_topk_idx ,
342+ recv_topk_weights ,
343+ num_recv_tokens_per_expert_list ,
344+ handle ,
345+ event ,
346+ )
259347
260348 def combine (
261349 self , hidden_states : torch .Tensor
@@ -268,6 +356,17 @@ def combine(
268356 self .handle = None
269357 return hidden_states .view (self .hidden_shape )
270358
359+ def combine_yield (
360+ self , hidden_states : torch .Tensor
361+ ):
362+ if hidden_states .shape [0 ] > 0 :
363+ hidden_states = self .get_restored_hidden_states_by_experts (
364+ hidden_states
365+ )
366+ hidden_states , event = yield from self .combine_normal_yield (hidden_states , self .handle )
367+ self .handle = None
368+ return hidden_states .view (self .hidden_shape )
369+
271370 def combine_normal (self , x : torch .Tensor , handle : Tuple , previous_event = None ):
272371 combined_x , _ , event = self .buffer_normal .combine (
273372 x ,
@@ -278,6 +377,22 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
278377 )
279378 return combined_x , event
280379
380+ def combine_normal_yield (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
381+ yield
382+ previous_event = self .buffer_normal .capture () if async_finish else None
383+ combined_x , _ , event = self .buffer_normal .combine (
384+ x ,
385+ handle ,
386+ async_finish = async_finish ,
387+ previous_event = previous_event ,
388+ allocate_on_comm_stream = previous_event is not None and async_finish ,
389+ )
390+
391+ yield
392+ if async_finish :
393+ event .current_stream_wait ()
394+ return combined_x , event
395+
281396 def _indices_to_multihot (self , indices , probs ):
282397 batch_size = indices .shape [0 ]
283398 multihot_routing_map = torch .zeros (
0 commit comments