@@ -58,6 +58,7 @@ class RequestQueueItem:
5858 id : int
5959 request : Optional [ExecutorRequest ] = None
6060 query : Optional [list ] = None # only used in `StarAttention`
61+ child_req_ids : Optional [List [int ]] = None # for num_return_sequences > 1
6162
6263 def is_shutdown_request (self ):
6364 return self .id == SHUTDOWN_REQUEST_ID
@@ -319,6 +320,11 @@ def __enter__(self):
319320 def __exit__ (self ):
320321 self .shutdown ()
321322
323+ def _get_request_id (self ):
324+ req_id = self .next_req_id
325+ self .next_req_id += 1
326+ return req_id
327+
322328 def enqueue_requests (self , requests : List [ExecutorRequest ]):
323329 """
324330 Enqueue new requests
@@ -327,13 +333,33 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
327333 try :
328334 self .enqueue_lock .acquire ()
329335 assert self .active , "PyExecutor has already been shutdown."
330- start_time = time .time ()
331336 for request in requests :
332- self .start_times [self .next_req_id ] = start_time
337+ req_id = self ._get_request_id ()
338+
339+ if self .enable_iter_perf_stats :
340+ self .start_times [req_id ] = time .time ()
341+
342+ # Generate child request IDs if needed
343+ child_req_ids = None
344+ sampling_config = request .sampling_config
345+ beam_width = sampling_config .beam_width
346+ num_return_sequences = sampling_config .num_return_sequences
347+
348+ if beam_width == 1 and num_return_sequences > 1 :
349+ # Reserve request ids for child requests.
350+ child_req_ids = []
351+ for _ in range (num_return_sequences - 1 ):
352+ child_req_id = self ._get_request_id ()
353+ if self .enable_iter_perf_stats :
354+ self .start_times [child_req_id ] = time .time ()
355+ child_req_ids .append (child_req_id )
356+
333357 self .request_queue .put (
334- RequestQueueItem (self .next_req_id , request ))
335- req_ids .append (self .next_req_id )
336- self .next_req_id += 1
358+ RequestQueueItem (req_id ,
359+ request ,
360+ query = None ,
361+ child_req_ids = child_req_ids ))
362+ req_ids .append (req_id )
337363 finally :
338364 self .enqueue_lock .release ()
339365 return req_ids
@@ -370,6 +396,12 @@ def cancel_request(self, id: int):
370396 """
371397 self .canceled_req_ids .insert (id )
372398
399+ # Also cancel all child requests if this is a parent request
400+ # Look through active requests to find child requests
401+ for request in self .active_requests :
402+ if request .py_parent_request_id == id :
403+ self .canceled_req_ids .insert (request .py_request_id )
404+
373405 def shutdown (self ):
374406 """
375407 Signals the server to shutdown.
@@ -438,15 +470,34 @@ def enqueue_request(self,
438470 try :
439471 self .enqueue_lock .acquire ()
440472 assert self .active , "PyExecutor has already been shutdown."
441- req_id = self .next_req_id
473+ # Allocate the main request ID first
474+ req_id = self ._get_request_id ()
475+
442476 if self .enable_iter_perf_stats :
443477 self .start_times [req_id ] = time .time ()
444478
445- if query is not None :
446- self .request_queue .put (RequestQueueItem (req_id , request , query ))
447- else :
448- self .request_queue .put (RequestQueueItem (req_id , request ))
449- self .next_req_id += 1
479+ # Generate child request IDs if needed
480+ child_req_ids = None
481+ sampling_config = request .sampling_config
482+ beam_width = (sampling_config .beam_width
483+ if sampling_config .beam_width else 1 )
484+ num_return_sequences = (sampling_config .num_return_sequences if
485+ sampling_config .num_return_sequences else 1 )
486+
487+ # Only create child requests if beam_width == 1 and num_return_sequences > 1
488+ if beam_width == 1 and num_return_sequences > 1 :
489+ child_req_ids = []
490+ for i in range (num_return_sequences - 1 ):
491+ child_req_id = self ._get_request_id ()
492+ if self .enable_iter_perf_stats :
493+ self .start_times [child_req_id ] = time .time ()
494+ child_req_ids .append (child_req_id )
495+
496+ self .request_queue .put (
497+ RequestQueueItem (req_id ,
498+ request ,
499+ query = query ,
500+ child_req_ids = child_req_ids ))
450501 finally :
451502 self .enqueue_lock .release ()
452503 return req_id
@@ -1396,6 +1447,9 @@ def _merge_star_attention_requests(self,
13961447 new_requests : list [RequestQueueItem ]):
13971448 result = []
13981449 for req_item in new_requests :
1450+ assert req_item .child_req_ids is None , (
1451+ "Star attention does not yet support sampling_config.num_return_sequences > 1"
1452+ )
13991453 req_id , exe_req , query_token_ids = req_item .id , req_item .request , req_item .query
14001454 ctx_len0 = len (exe_req .input_token_ids )
14011455 ctx_blocks , position_blocks , last_block_padding_num = [
@@ -1461,12 +1515,18 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
14611515 else :
14621516 raise NotImplementedError (f'unsupport cp type { cp_type } ' )
14631517 else :
1464- return [
1465- executor_request_to_llm_request (
1518+ llm_reqs = []
1519+ for req_item in new_requests :
1520+ llm_req = executor_request_to_llm_request (
14661521 req_item .id , req_item .request ,
14671522 self ._should_exclude_last_generation_logits ())
1468- for req_item in new_requests
1469- ]
1523+ if req_item .child_req_ids :
1524+ # Create subrequests for n-returns using pre-generated child request ids.
1525+ for child_req_id in req_item .child_req_ids :
1526+ child_req = llm_req .create_child_request (child_req_id )
1527+ llm_reqs .append (child_req )
1528+ llm_reqs .append (llm_req )
1529+ return llm_reqs
14701530
14711531 @nvtx_range ("_schedule" )
14721532 def _schedule (self ):
@@ -1982,7 +2042,7 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
19822042 if req_id in self .responses .keys ():
19832043 self .responses [req_id ].append (resp )
19842044 else :
1985- self .responses . update ({ req_id : [resp ]})
2045+ self .responses [ req_id ] = [resp ]
19862046 self .response_cv .notify_all ()
19872047
19882048 @nvtx_range ("_handle_first_token_response" )
@@ -2013,7 +2073,7 @@ def _handle_responses(self):
20132073 requests_to_terminate .append (request )
20142074 continue
20152075
2016- if request .is_generation_only_request () :
2076+ if request .is_generation_only_request :
20172077 # If request is in transmission, so we don't need to emit a response
20182078 # Also, for the first iteration with overlap, we should skip since first
20192079 # token has already been emitted previously
@@ -2033,7 +2093,7 @@ def _handle_responses(self):
20332093 if self .model_engine .iter_counter % self .stream_interval == 0 or request .is_finished :
20342094 response = request .create_response (False , self .dist .rank )
20352095 if response :
2036- request_done = response .result .is_final
2096+ request_done = response .result .is_sequence_final
20372097 new_responses .update ({req_id : response })
20382098
20392099 if request_done :
0 commit comments