Skip to content

Commit 28d16ac

Browse files
simplify create_child_request logic
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 7fbd931 commit 28d16ac

File tree

2 files changed

+32
-52
lines changed

2 files changed

+32
-52
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from dataclasses import dataclass
23
from functools import partial
34
from typing import List, Optional, Union
@@ -230,15 +231,19 @@ class LlmResult:
230231
py_result_properties = frozenset(
231232
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs'))
232233

233-
def __init__(self, result: Union[bytes,
234-
tensorrt_llm.bindings.executor.Result],
235-
py_result: PyResult):
234+
def __init__(self,
235+
result: Union[bytes, tensorrt_llm.bindings.executor.Result],
236+
py_result: PyResult,
237+
is_final: bool = False):
236238
self._result = result
237239
self._py_result = py_result
240+
self.is_final = is_final
238241

239242
def __getattr__(self, item):
240243
if item in self.py_result_properties:
241244
return getattr(self._py_result, item)
245+
if item == 'is_final':
246+
return object.__getattribute__(self, 'is_final')
242247
result = object.__getattribute__(self, '_result')
243248
return getattr(result, item)
244249

@@ -343,58 +348,33 @@ def __init__(
343348
exclude_last_generation_logits)
344349

345350
def create_child_request(self, request_id: int):
346-
# Create a child request by C++'s API to track the states each other.
351+
""" Create a child request.
352+
353+
NOTE: This function generate a child request by C++'s API to track the
354+
states each other and returns the object of type batch_manager.LlmRequest,
355+
which is not a llm_request.LlmRequest. As a workaround, to ensure the
356+
child request behaves like its parent, this function mimics the behavior
357+
of the original LlmRequest by dynamically adding the required attributes
358+
of the parent request to the child request. This function will be
359+
implemented when LlmRequest becomes pure-python.
360+
361+
See: https://github.com/NVIDIA/TensorRT-LLM/issues/3034
362+
"""
363+
347364
child_request = super().create_child_request(request_id)
348365

349-
# Copy Python-specific attributes from parent to child
350-
child_request.py_client_id = self.py_client_id
366+
# Copy all py_* attributes from parent to child
367+
for attr_name, attr_value in self.__dict__.items():
368+
if attr_name.startswith('py_'):
369+
attr_value = getattr(self, attr_name)
370+
setattr(child_request, attr_name, copy.deepcopy(attr_value))
371+
372+
# Override specific attributes that should use child_request values.
351373
child_request.py_parent_request_id = self.py_request_id
352374
child_request.py_request_id = child_request.request_id
353-
child_request.py_llm_request_type = child_request.llm_request_type
354-
child_request.py_end_id = child_request.end_id
355-
child_request.py_prompt_len = child_request.prompt_len
356-
child_request.py_orig_prompt_len = child_request.orig_prompt_len
357-
child_request.py_max_new_tokens = child_request.max_new_tokens
358-
359-
# Copy Python-specific configuration from parent
360-
child_request.py_return_log_probs = self.py_return_log_probs
361-
child_request.py_return_context_logits = self.py_return_context_logits
362-
child_request.py_return_generation_logits = self.py_return_generation_logits
363-
child_request.py_return_logits_device_memory = self.py_return_logits_device_memory
364-
child_request.py_exclude_last_generation_logits = self.py_exclude_last_generation_logits
365-
child_request.py_stop_words_list = self.py_stop_words_list
366-
child_request.py_logits_post_processors = self.py_logits_post_processors
367-
child_request.py_rewind_len = self.py_rewind_len
368-
child_request.py_decoding_iter = self.py_decoding_iter
369-
child_request.py_draft_tokens = self.py_draft_tokens.copy(
370-
) if self.py_draft_tokens else []
371-
child_request.py_last_draft_tokens = self.py_last_draft_tokens.copy(
372-
) if self.py_last_draft_tokens else None
373-
child_request.py_num_accepted_draft_tokens = self.py_num_accepted_draft_tokens
374-
child_request.py_lora_task_layer_module_configs = self.py_lora_task_layer_module_configs
375-
376-
# Initialize Python-specific runtime state
377375
child_request.py_batch_idx = None
378376
child_request.is_attention_dp_dummy = self.is_attention_dp_dummy
379377
child_request.is_cuda_graph_dummy = self.is_cuda_graph_dummy
380-
381-
# Create PyResult for child
382-
child_request.py_result = PyResult(
383-
prompt_len=child_request.py_prompt_len,
384-
max_new_tokens=child_request.py_max_new_tokens,
385-
use_device_memory=self.py_return_logits_device_memory,
386-
streaming=child_request.streaming,
387-
return_log_probs=self.py_return_log_probs,
388-
return_context_logits=self.py_return_context_logits,
389-
return_generation_logits=self.py_return_generation_logits,
390-
exclude_last_generation_logits=self.
391-
py_exclude_last_generation_logits)
392-
393-
# Note: The below mimics the behavior of the original LlmRequest.
394-
395-
# We need to ensure the child request behaves like the parent
396-
# LlmRequest by copying any additional Python-specific attributes that
397-
# might be needed for proper request handling and response generation.
398378
child_request.is_dummy = self.is_dummy
399379

400380
# Override create_response to return the child request

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __init__(self,
206206
# enqueue and _fetch_new_requests used data
207207
self.enqueue_lock = threading.Lock()
208208
self.active = True
209-
self.next_req_id = max_batch_size # The first max_batch_size request IDs are reserved for dummy requests
209+
self._next_req_id = max_batch_size # The first max_batch_size request IDs are reserved for dummy requests
210210
self.max_beam_width = max_beam_width
211211
self.max_draft_tokens = max_draft_tokens
212212
self.print_log = model_engine.pytorch_backend_config.print_iter_log
@@ -321,9 +321,9 @@ def __exit__(self):
321321
self.shutdown()
322322

323323
def _get_request_id(self):
324-
req_id = self.next_req_id
325-
self.next_req_id += 1
326-
return req_id
324+
# (next_req_id + 1) % UINT64_MAX
325+
self._next_req_id = (self._next_req_id + 1) & ((1 << 64) - 1)
326+
return self._next_req_id
327327

328328
def enqueue_requests(self, requests: List[ExecutorRequest]):
329329
"""

0 commit comments

Comments
 (0)