Skip to content

Commit e8ce36b

Browse files
fix child request's create_response
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 653f456 commit e8ce36b

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from functools import partial
23
from typing import List, Optional, Union
34

45
import torch
@@ -257,6 +258,23 @@ def has_error(self):
257258
return self.error_msg is not None
258259

259260

261+
def create_response(
262+
request: Union['LlmRequest',
263+
tensorrt_llm.bindings.internal.batch_manager.LlmRequest],
264+
use_fast_logits=False,
265+
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
266+
""" Create a response for a given request. """
267+
268+
result = request.create_result(use_fast_logits, mpi_world_rank)
269+
270+
if result is None:
271+
return None
272+
else:
273+
return LlmResponse(request_id=request.py_request_id,
274+
result=LlmResult(result, request.py_result),
275+
client_id=request.py_client_id)
276+
277+
260278
class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
261279
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
262280
but detour some features to Python implementation"""
@@ -375,27 +393,23 @@ def create_child_request(self, request_id: int):
375393
exclude_last_generation_logits=self.
376394
py_exclude_last_generation_logits)
377395

378-
# Note: This mimics the behavior of the original LlmRequest.
396+
# Note: The below mimics the behavior of the original LlmRequest.
397+
379398
# We need to ensure the child request behaves like the parent
380399
# LlmRequest by copying any additional Python-specific attributes that
381400
# might be needed for proper request handling and response generation.
382401
child_request.is_dummy = self.is_dummy
383402

403+
# Override create_response to return the child request
404+
child_request.create_response = partial(create_response, child_request)
405+
384406
return child_request
385407

386408
def create_response(
387409
self,
388410
use_fast_logits=False,
389411
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
390-
391-
result = super().create_result(use_fast_logits, mpi_world_rank)
392-
393-
if result is None:
394-
return None
395-
else:
396-
return LlmResponse(request_id=self.py_request_id,
397-
result=LlmResult(result, self.py_result),
398-
client_id=self.py_client_id)
412+
return create_response(self, use_fast_logits, mpi_world_rank)
399413

400414
@property
401415
def is_dummy(self):

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
343343
child_req_ids = None
344344
sampling_config = request.sampling_config
345345
beam_width = sampling_config.beam_width
346-
num_return_sequences = sampling_config.num_return_sequences
346+
num_return_sequences = sampling_config.num_return_sequences or beam_width
347347

348348
if beam_width == 1 and num_return_sequences > 1:
349349
# Reserve request ids for child requests.

0 commit comments

Comments
 (0)