Skip to content

Commit 653f456

Browse files
fix: Enable num_return_sequences (n) support in PyTorch backend
This PR enables the `n` parameter (num_return_sequences) in the PyTorch backend, which is the default path for LLM API. While this feature was already implemented in the TRT backend via C++ Executor, it was missing in the PyExecutor. This PR fixes the gap by adding necessary APIs to the pybind of the `LlmRequest` class. Changes: - Added `create_child_request` method to `pyexecutor.LlmRequest` that wraps C++'s createChildRequest method. This allows requests to properly handle their child requests and states. - Updated C++ LlmRequest and related Python bindings to expose additional properties required in the PyTorch backend. - Enhanced `PyExecutor` to create child requests, ensuring proper handling of requests when `num_return_sequences > 1`. Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 5bc3a15 commit 653f456

File tree

5 files changed

+170
-35
lines changed

5 files changed

+170
-35
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,11 @@ class GenericLlmRequest
472472
mExcludeInputFromOutput = exclude;
473473
}
474474

475+
bool getExcludeInputFromOutput()
476+
{
477+
return mExcludeInputFromOutput;
478+
}
479+
475480
/// @brief Get the params of the context
476481
/// @return The params of the context
477482
[[nodiscard]] std::optional<executor::ContextPhaseParams> const& getContextPhaseParams() const noexcept
@@ -769,6 +774,11 @@ class GenericLlmRequest
769774
return mParentRequestId;
770775
}
771776

777+
[[nodiscard]] SizeType32 getSequenceIndex() const
778+
{
779+
return mSequenceIndex;
780+
}
781+
772782
/// @brief Return a vector of the last-generated tokens of shape [num_beams]
773783
[[nodiscard]] VecTokens const& getLastTokens()
774784
{
@@ -1856,6 +1866,11 @@ class GenericLlmRequest
18561866
// current position of the prompt tuning table (only used in chunked prefill mode)
18571867
SizeType32 mPtableCurrentPosition{0};
18581868

1869+
[[nodiscard]] std::shared_ptr<std::vector<bool>> getSequenceFinalVec() const
1870+
{
1871+
return mSequenceFinalVec;
1872+
}
1873+
18591874
protected:
18601875
bool mIsStreaming;
18611876

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ void initBindings(pybind11::module_& m)
113113
.def("set_generated_tokens", &GenLlmReq::setGeneratedTokens, py::arg("generated_beam_tokens"))
114114
.def("pause", &GenLlmReq::pause, py::arg("max_input_len"))
115115
.def_property("max_sent_token_len", &GenLlmReq::getMaxSentTokenLen, &GenLlmReq::setMaxSentTokenLen)
116+
.def_property(
117+
"exclude_input_from_output", &GenLlmReq::getExcludeInputFromOutput, &GenLlmReq::setExcludeInputFromOutput)
116118
.def_property_readonly("prompt_embedding_table", &GenLlmReq::getPromptEmbeddingTable)
117119
.def_property_readonly("multimodal_embedding", &GenLlmReq::getMultimodalEmbedding)
118120
.def_property_readonly("mrope_rotary_cos_sin", &GenLlmReq::getMropeRotaryCosSin)
@@ -377,6 +379,7 @@ void initBindings(pybind11::module_& m)
377379
.def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager"))
378380
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
379381
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
382+
.def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("request_id"))
380383
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
381384
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"));
382385

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,15 @@ class LlmResult:
229229
py_result_properties = frozenset(
230230
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs'))
231231

232-
def __init__(self,
233-
result: Union[bytes, tensorrt_llm.bindings.executor.Result],
234-
py_result: PyResult,
235-
is_final: bool = False):
232+
def __init__(self, result: Union[bytes,
233+
tensorrt_llm.bindings.executor.Result],
234+
py_result: PyResult):
236235
self._result = result
237236
self._py_result = py_result
238-
self.is_final = is_final
239237

240238
def __getattr__(self, item):
241239
if item in self.py_result_properties:
242240
return getattr(self._py_result, item)
243-
if item == 'is_final':
244-
return object.__getattribute__(self, 'is_final')
245241
result = object.__getattribute__(self, '_result')
246242
return getattr(result, item)
247243

@@ -316,6 +312,7 @@ def __init__(
316312
self.py_return_logits_device_memory = return_logits_device_memory
317313
self.py_is_draft = is_draft
318314
self.py_seq_slot = None
315+
self.py_exclude_last_generation_logits = exclude_last_generation_logits
319316

320317
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
321318
# currently, keep py_stop_words_list as python list, rather than tensor.
@@ -327,19 +324,78 @@ def __init__(
327324
return_generation_logits,
328325
exclude_last_generation_logits)
329326

330-
def is_generation_only_request(self):
331-
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
327+
def create_child_request(self, request_id: int):
328+
# Create a child request by C++'s API to track the states each other.
329+
child_request = super().create_child_request(request_id)
330+
331+
# Copy Python-specific attributes from parent to child
332+
child_request.py_client_id = self.py_client_id
333+
child_request.py_parent_request_id = self.py_request_id
334+
child_request.py_request_id = child_request.request_id
335+
child_request.py_llm_request_type = child_request.llm_request_type
336+
child_request.py_end_id = child_request.end_id
337+
child_request.py_prompt_len = child_request.prompt_len
338+
child_request.py_orig_prompt_len = child_request.orig_prompt_len
339+
child_request.py_max_new_tokens = child_request.max_new_tokens
340+
341+
# input_toknes are already cloned in create_child_request.
342+
child_request.py_tokens = child_request.get_tokens()
343+
344+
# Copy Python-specific configuration from parent
345+
child_request.py_return_log_probs = self.py_return_log_probs
346+
child_request.py_return_context_logits = self.py_return_context_logits
347+
child_request.py_return_generation_logits = self.py_return_generation_logits
348+
child_request.py_return_logits_device_memory = self.py_return_logits_device_memory
349+
child_request.py_exclude_last_generation_logits = self.py_exclude_last_generation_logits
350+
child_request.py_stop_words_list = self.py_stop_words_list
351+
child_request.py_logits_post_processors = self.py_logits_post_processors
352+
child_request.py_rewind_len = self.py_rewind_len
353+
child_request.py_decoding_iter = self.py_decoding_iter
354+
child_request.py_draft_tokens = self.py_draft_tokens.copy(
355+
) if self.py_draft_tokens else []
356+
child_request.py_last_draft_tokens = self.py_last_draft_tokens.copy(
357+
) if self.py_last_draft_tokens else None
358+
child_request.py_num_accepted_draft_tokens = self.py_num_accepted_draft_tokens
359+
child_request.py_lora_task_layer_module_configs = self.py_lora_task_layer_module_configs
360+
361+
# Initialize Python-specific runtime state
362+
child_request.py_batch_idx = None
363+
child_request.is_attention_dp_dummy = self.is_attention_dp_dummy
364+
child_request.is_cuda_graph_dummy = self.is_cuda_graph_dummy
365+
366+
# Create PyResult for child
367+
child_request.py_result = PyResult(
368+
prompt_len=child_request.py_prompt_len,
369+
max_new_tokens=child_request.py_max_new_tokens,
370+
use_device_memory=self.py_return_logits_device_memory,
371+
streaming=child_request.streaming,
372+
return_log_probs=self.py_return_log_probs,
373+
return_context_logits=self.py_return_context_logits,
374+
return_generation_logits=self.py_return_generation_logits,
375+
exclude_last_generation_logits=self.
376+
py_exclude_last_generation_logits)
377+
378+
# Note: This mimics the behavior of the original LlmRequest.
379+
# We need to ensure the child request behaves like the parent
380+
# LlmRequest by copying any additional Python-specific attributes that
381+
# might be needed for proper request handling and response generation.
382+
child_request.is_dummy = self.is_dummy
383+
384+
return child_request
332385

333386
def create_response(
334387
self,
335388
use_fast_logits=False,
336389
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
337-
result, is_final = super().create_serialized_result(
338-
use_fast_logits, mpi_world_rank)
339-
return LlmResponse(
340-
request_id=self.py_request_id,
341-
result=LlmResult(result, self.py_result, is_final),
342-
client_id=self.py_client_id) if len(result) > 0 else None
390+
391+
result = super().create_result(use_fast_logits, mpi_world_rank)
392+
393+
if result is None:
394+
return None
395+
else:
396+
return LlmResponse(request_id=self.py_request_id,
397+
result=LlmResult(result, self.py_result),
398+
client_id=self.py_client_id)
343399

344400
@property
345401
def is_dummy(self):

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class RequestQueueItem:
5858
id: int
5959
request: Optional[ExecutorRequest] = None
6060
query: Optional[list] = None # only used in `StarAttention`
61+
child_req_ids: Optional[List[int]] = None # for num_return_sequences > 1
6162

6263
def is_shutdown_request(self):
6364
return self.id == SHUTDOWN_REQUEST_ID
@@ -319,6 +320,11 @@ def __enter__(self):
319320
def __exit__(self):
320321
self.shutdown()
321322

323+
def _get_request_id(self):
324+
req_id = self.next_req_id
325+
self.next_req_id += 1
326+
return req_id
327+
322328
def enqueue_requests(self, requests: List[ExecutorRequest]):
323329
"""
324330
Enqueue new requests
@@ -327,13 +333,33 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
327333
try:
328334
self.enqueue_lock.acquire()
329335
assert self.active, "PyExecutor has already been shutdown."
330-
start_time = time.time()
331336
for request in requests:
332-
self.start_times[self.next_req_id] = start_time
337+
req_id = self._get_request_id()
338+
339+
if self.enable_iter_perf_stats:
340+
self.start_times[req_id] = time.time()
341+
342+
# Generate child request IDs if needed
343+
child_req_ids = None
344+
sampling_config = request.sampling_config
345+
beam_width = sampling_config.beam_width
346+
num_return_sequences = sampling_config.num_return_sequences
347+
348+
if beam_width == 1 and num_return_sequences > 1:
349+
# Reserve request ids for child requests.
350+
child_req_ids = []
351+
for _ in range(num_return_sequences - 1):
352+
child_req_id = self._get_request_id()
353+
if self.enable_iter_perf_stats:
354+
self.start_times[child_req_id] = time.time()
355+
child_req_ids.append(child_req_id)
356+
333357
self.request_queue.put(
334-
RequestQueueItem(self.next_req_id, request))
335-
req_ids.append(self.next_req_id)
336-
self.next_req_id += 1
358+
RequestQueueItem(req_id,
359+
request,
360+
query=None,
361+
child_req_ids=child_req_ids))
362+
req_ids.append(req_id)
337363
finally:
338364
self.enqueue_lock.release()
339365
return req_ids
@@ -370,6 +396,12 @@ def cancel_request(self, id: int):
370396
"""
371397
self.canceled_req_ids.insert(id)
372398

399+
# Also cancel all child requests if this is a parent request
400+
# Look through active requests to find child requests
401+
for request in self.active_requests:
402+
if request.py_parent_request_id == id:
403+
self.canceled_req_ids.insert(request.py_request_id)
404+
373405
def shutdown(self):
374406
"""
375407
Signals the server to shutdown.
@@ -438,15 +470,34 @@ def enqueue_request(self,
438470
try:
439471
self.enqueue_lock.acquire()
440472
assert self.active, "PyExecutor has already been shutdown."
441-
req_id = self.next_req_id
473+
# Allocate the main request ID first
474+
req_id = self._get_request_id()
475+
442476
if self.enable_iter_perf_stats:
443477
self.start_times[req_id] = time.time()
444478

445-
if query is not None:
446-
self.request_queue.put(RequestQueueItem(req_id, request, query))
447-
else:
448-
self.request_queue.put(RequestQueueItem(req_id, request))
449-
self.next_req_id += 1
479+
# Generate child request IDs if needed
480+
child_req_ids = None
481+
sampling_config = request.sampling_config
482+
beam_width = (sampling_config.beam_width
483+
if sampling_config.beam_width else 1)
484+
num_return_sequences = (sampling_config.num_return_sequences if
485+
sampling_config.num_return_sequences else 1)
486+
487+
# Only create child requests if beam_width == 1 and num_return_sequences > 1
488+
if beam_width == 1 and num_return_sequences > 1:
489+
child_req_ids = []
490+
for i in range(num_return_sequences - 1):
491+
child_req_id = self._get_request_id()
492+
if self.enable_iter_perf_stats:
493+
self.start_times[child_req_id] = time.time()
494+
child_req_ids.append(child_req_id)
495+
496+
self.request_queue.put(
497+
RequestQueueItem(req_id,
498+
request,
499+
query=query,
500+
child_req_ids=child_req_ids))
450501
finally:
451502
self.enqueue_lock.release()
452503
return req_id
@@ -1396,6 +1447,9 @@ def _merge_star_attention_requests(self,
13961447
new_requests: list[RequestQueueItem]):
13971448
result = []
13981449
for req_item in new_requests:
1450+
assert req_item.child_req_ids is None, (
1451+
"Star attention does not yet support sampling_config.num_return_sequences > 1"
1452+
)
13991453
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
14001454
ctx_len0 = len(exe_req.input_token_ids)
14011455
ctx_blocks, position_blocks, last_block_padding_num = [
@@ -1461,12 +1515,18 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
14611515
else:
14621516
raise NotImplementedError(f'unsupport cp type {cp_type}')
14631517
else:
1464-
return [
1465-
executor_request_to_llm_request(
1518+
llm_reqs = []
1519+
for req_item in new_requests:
1520+
llm_req = executor_request_to_llm_request(
14661521
req_item.id, req_item.request,
14671522
self._should_exclude_last_generation_logits())
1468-
for req_item in new_requests
1469-
]
1523+
if req_item.child_req_ids:
1524+
# Create subrequests for n-returns using pre-generated child request ids.
1525+
for child_req_id in req_item.child_req_ids:
1526+
child_req = llm_req.create_child_request(child_req_id)
1527+
llm_reqs.append(child_req)
1528+
llm_reqs.append(llm_req)
1529+
return llm_reqs
14701530

14711531
@nvtx_range("_schedule")
14721532
def _schedule(self):
@@ -1982,7 +2042,7 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
19822042
if req_id in self.responses.keys():
19832043
self.responses[req_id].append(resp)
19842044
else:
1985-
self.responses.update({req_id: [resp]})
2045+
self.responses[req_id] = [resp]
19862046
self.response_cv.notify_all()
19872047

19882048
@nvtx_range("_handle_first_token_response")
@@ -2013,7 +2073,7 @@ def _handle_responses(self):
20132073
requests_to_terminate.append(request)
20142074
continue
20152075

2016-
if request.is_generation_only_request():
2076+
if request.is_generation_only_request:
20172077
# If request is in transmission, so we don't need to emit a response
20182078
# Also, for the first iteration with overlap, we should skip since first
20192079
# token has already been emitted previously
@@ -2033,7 +2093,7 @@ def _handle_responses(self):
20332093
if self.model_engine.iter_counter % self.stream_interval == 0 or request.is_finished:
20342094
response = request.create_response(False, self.dist.rank)
20352095
if response:
2036-
request_done = response.result.is_final
2096+
request_done = response.result.is_sequence_final
20372097
new_responses.update({req_id: response})
20382098

20392099
if request_done:

tensorrt_llm/executor/result.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ def _handle_response(self,
299299
handler(response.error_msg)
300300

301301
response_result = response.result
302-
if hasattr(response_result, "_result"):
303-
response_result.deserialize()
302+
# TODO(jaedeokk): Need to check. Why do we need to deserialize the result? Is it for disaggregated serving?
303+
# if hasattr(response_result, "_result"):
304+
# response_result.deserialize()
304305

305306
self._done = response_result.is_final
306307
context_phase_params = response_result.context_phase_params

0 commit comments

Comments
 (0)