|
1 | 1 | from dataclasses import dataclass |
| 2 | +from functools import partial |
2 | 3 | from typing import List, Optional, Union |
3 | 4 |
|
4 | 5 | import torch |
@@ -257,6 +258,23 @@ def has_error(self): |
257 | 258 | return self.error_msg is not None |
258 | 259 |
|
259 | 260 |
|
| 261 | +def create_response( |
| 262 | + request: Union['LlmRequest', |
| 263 | + tensorrt_llm.bindings.internal.batch_manager.LlmRequest], |
| 264 | + use_fast_logits=False, |
| 265 | + mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: |
| 266 | + """ Create a response for a given request. """ |
| 267 | + |
| 268 | + result = request.create_result(use_fast_logits, mpi_world_rank) |
| 269 | + |
| 270 | + if result is None: |
| 271 | + return None |
| 272 | + else: |
| 273 | + return LlmResponse(request_id=request.py_request_id, |
| 274 | + result=LlmResult(result, request.py_result), |
| 275 | + client_id=request.py_client_id) |
| 276 | + |
| 277 | + |
260 | 278 | class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): |
261 | 279 | """LlmRequest wraps `bindings.internal.batch_manager.LlmRequest` |
262 | 280 | but detour some features to Python implementation""" |
@@ -375,27 +393,23 @@ def create_child_request(self, request_id: int): |
375 | 393 | exclude_last_generation_logits=self. |
376 | 394 | py_exclude_last_generation_logits) |
377 | 395 |
|
378 | | - # Note: This mimics the behavior of the original LlmRequest. |
| 396 | + # Note: The below mimics the behavior of the original LlmRequest. |
| 397 | + |
379 | 398 | # We need to ensure the child request behaves like the parent |
380 | 399 | # LlmRequest by copying any additional Python-specific attributes that |
381 | 400 | # might be needed for proper request handling and response generation. |
382 | 401 | child_request.is_dummy = self.is_dummy |
383 | 402 |
|
| 403 | + # Override create_response to return the child request |
| 404 | + child_request.create_response = partial(create_response, child_request) |
| 405 | + |
384 | 406 | return child_request |
385 | 407 |
|
386 | 408 | def create_response( |
387 | 409 | self, |
388 | 410 | use_fast_logits=False, |
389 | 411 | mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: |
390 | | - |
391 | | - result = super().create_result(use_fast_logits, mpi_world_rank) |
392 | | - |
393 | | - if result is None: |
394 | | - return None |
395 | | - else: |
396 | | - return LlmResponse(request_id=self.py_request_id, |
397 | | - result=LlmResult(result, self.py_result), |
398 | | - client_id=self.py_client_id) |
| 412 | + return create_response(self, use_fast_logits, mpi_world_rank) |
399 | 413 |
|
400 | 414 | @property |
401 | 415 | def is_dummy(self): |
|
0 commit comments