Skip to content

Commit 54ec2c1

Browse files
authored
[None][opt] Add batch wait timeout in fetching requests (#6923)
Signed-off-by: Shunkang <[email protected]>
1 parent 636c622 commit 54ec2c1

File tree

7 files changed

+120
-2
lines changed

7 files changed

+120
-2
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
self.pytorch_backend_config.attention_dp_enable_balance = False
137137
self.pytorch_backend_config.attention_dp_time_out_iters = 50
138138
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
139+
self.pytorch_backend_config.batch_wait_timeout_ms = 0
139140
self.iter_counter = 0
140141

141142
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class PyTorchConfig:
5050
attention_dp_time_out_iters: int = 50
5151
attention_dp_batching_wait_iters: int = 10
5252

53+
batch_wait_timeout_ms: float = 0
54+
5355
attn_backend: str = 'TRTLLM'
5456
moe_backend: str = 'CUTLASS'
5557

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ class ExecutorRequestQueue:
4545
def __init__(self, dist: Distributed, enable_attention_dp: bool,
4646
max_batch_size: int, max_beam_width: int,
4747
max_num_active_requests: int, enable_iter_perf_stats: bool,
48-
is_disaggregated: bool):
48+
batch_wait_timeout_ms: float, is_disaggregated: bool):
4949
self.dist = dist
5050
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5151
self.waiting_queue: deque[RequestQueueItem] = deque()
5252
self.canceled_req_ids = []
5353
self.enable_attention_dp = enable_attention_dp
54+
self.max_batch_size = max_batch_size
5455
self.max_beam_width = max_beam_width
5556
self.max_num_active_requests = max_num_active_requests
5657
self.is_disaggregated = is_disaggregated
@@ -59,6 +60,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
5960
self.enable_iter_perf_stats = enable_iter_perf_stats
6061
self.start_times = {}
6162
self.active = True
63+
self.batch_wait_timeout_ms = batch_wait_timeout_ms
6264

6365
# State tracking
6466
self.num_fetch_requests = 0
@@ -74,6 +76,7 @@ def _get_from_request_queue(
7476

7577
items = []
7678
timeout_secs = timeout.total_seconds() if timeout is not None else None
79+
7780
try:
7881
if self.request_queue.empty() and (timeout_secs is None
7982
or timeout_secs > 0):
@@ -86,6 +89,26 @@ def _get_from_request_queue(
8689
items.append(queue_item)
8790
except queue.Empty:
8891
pass
92+
93+
if self.batch_wait_timeout_ms == 0:
94+
return items
95+
96+
if len(items) >= self.max_batch_size:
97+
return items
98+
99+
deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0
100+
while len(items) < self.max_batch_size:
101+
remaining_timeout = deadline - time.monotonic()
102+
103+
if remaining_timeout <= 0:
104+
break
105+
106+
try:
107+
item = self.request_queue.get(timeout=remaining_timeout)
108+
items.append(item)
109+
except queue.Empty:
110+
break
111+
89112
return items
90113

91114
@staticmethod

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(self,
186186
self.attention_dp_enable_balance = model_engine.pytorch_backend_config.attention_dp_enable_balance
187187
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
188188
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
189+
self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms
189190
self.num_fetch_requests_cur_rank = 0
190191
self.num_fetch_requests = 0
191192
self.shutdown_event = threading.Event()
@@ -239,6 +240,7 @@ def __init__(self,
239240
max_beam_width=self.max_beam_width,
240241
max_num_active_requests=self.max_num_active_requests,
241242
enable_iter_perf_stats=self.enable_iter_perf_stats,
243+
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
242244
is_disaggregated=kv_cache_transceiver is not None,
243245
)
244246
self.executor_request_queue.set_exclude_last_generation_logits(

tensorrt_llm/llmapi/llm_args.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,12 @@ class TorchLlmArgs(BaseLlmArgs):
20982098
description="Print iteration logs.",
20992099
status="beta")
21002100

2101+
batch_wait_timeout_ms: float = Field(
2102+
default=0,
2103+
description=
2104+
"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.",
2105+
status="prototype")
2106+
21012107
torch_compile_config: Optional[TorchCompileConfig] = Field(
21022108
default=None, description="Torch compile config.", status="prototype")
21032109

@@ -2344,6 +2350,13 @@ def validate_attention_dp_config(self) -> 'TorchLlmArgs':
23442350
)
23452351
return self
23462352

2353+
@model_validator(mode='after')
2354+
def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
2355+
"""Validate batch wait timeout."""
2356+
if self.batch_wait_timeout_ms < 0:
2357+
raise ValueError("batch_wait_timeout_ms must be greater than 0")
2358+
return self
2359+
23472360
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
23482361
def get_pytorch_backend_config(self) -> "PyTorchConfig":
23492362
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2409,7 +2422,8 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
24092422
AttentionDpConfig.model_fields['timeout_iters'].default,
24102423
attention_dp_batching_wait_iters=self.attention_dp_config.
24112424
batching_wait_iters if self.attention_dp_config is not None else
2412-
AttentionDpConfig.model_fields['batching_wait_iters'].default)
2425+
AttentionDpConfig.model_fields['batching_wait_iters'].default,
2426+
batch_wait_timeout_ms=self.batch_wait_timeout_ms)
24132427

24142428

24152429
def update_llm_args_with_extra_dict(

tests/unittest/_torch/test_executor_request_queue.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def executor_queue(mock_dist):
4040
max_beam_width=1,
4141
max_num_active_requests=16,
4242
enable_iter_perf_stats=True,
43+
batch_wait_timeout_ms=0.0,
4344
is_disaggregated=False)
4445

4546

@@ -52,6 +53,7 @@ def integration_queue(mock_dist):
5253
max_beam_width=2,
5354
max_num_active_requests=8,
5455
enable_iter_perf_stats=True,
56+
batch_wait_timeout_ms=0.0,
5557
is_disaggregated=False)
5658

5759

@@ -215,6 +217,75 @@ def test_get_from_request_queue_with_timeout(executor_queue):
215217
assert elapsed < 0.2 # Should finish within timeout
216218

217219

220+
def test_get_from_request_queue_async_behavior(executor_queue):
221+
"""Test asynchronous behavior where requests arrive over time."""
222+
import threading
223+
224+
def add_requests_after_delay(delay, num_requests):
225+
"""Helper function to add requests after a delay."""
226+
time.sleep(delay)
227+
for i in range(num_requests):
228+
item = RequestQueueItem(i + 10, Mock())
229+
executor_queue.request_queue.put(item)
230+
231+
# Test 1: Without batch_wait_timeout_ms (should only get initial requests)
232+
executor_queue.batch_wait_timeout_ms = 0.0
233+
234+
initial_requests = 3
235+
for i in range(initial_requests):
236+
item = RequestQueueItem(i, Mock())
237+
executor_queue.request_queue.put(item)
238+
239+
thread = threading.Thread(target=add_requests_after_delay, args=(0.05, 2))
240+
thread.start()
241+
242+
# Get requests immediately - should only get the initial ones
243+
start_time = time.time()
244+
items = executor_queue._get_from_request_queue(None)
245+
elapsed = time.time() - start_time
246+
247+
assert len(items) == initial_requests
248+
assert elapsed < 0.1
249+
assert all(item.id < 10 for item in items)
250+
251+
thread.join()
252+
253+
# Test 2: With batch_wait_timeout_ms (should wait and get all requests)
254+
executor_queue.batch_wait_timeout_ms = 200.0
255+
256+
# Clear the queue and add initial requests again
257+
while not executor_queue.request_queue.empty():
258+
try:
259+
executor_queue.request_queue.get_nowait()
260+
except queue.Empty:
261+
break
262+
263+
initial_requests = 2
264+
for i in range(initial_requests):
265+
item = RequestQueueItem(i + 20, Mock())
266+
executor_queue.request_queue.put(item)
267+
268+
thread = threading.Thread(target=add_requests_after_delay, args=(0.05, 3))
269+
thread.start()
270+
271+
# Get requests with batch_wait_timeout_ms - should wait and get all
272+
start_time = time.time()
273+
items = executor_queue._get_from_request_queue(None)
274+
elapsed = time.time() - start_time
275+
276+
# Should wait and return all requests
277+
assert len(items) == initial_requests + 3
278+
assert elapsed >= 0.05
279+
assert elapsed < 0.3
280+
281+
initial_ids = {item.id for item in items if 20 <= item.id < 30}
282+
delayed_ids = {item.id for item in items if 10 <= item.id < 20}
283+
assert len(initial_ids) == initial_requests
284+
assert len(delayed_ids) == 3
285+
286+
thread.join()
287+
288+
218289
def test_get_from_waiting_queue(executor_queue):
219290
"""Test getting items from waiting queue."""
220291
# Add items to waiting queue
@@ -371,6 +442,7 @@ def attention_dp_queue(mock_dist_attention_dp):
371442
max_beam_width=2,
372443
max_num_active_requests=8,
373444
enable_iter_perf_stats=True,
445+
batch_wait_timeout_ms=0.0,
374446
is_disaggregated=False)
375447
# Initialize all_ranks_num_active_requests
376448
return queue

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ methods:
123123
annotation: bool
124124
default: False
125125
status: prototype
126+
batch_wait_timeout_ms:
127+
annotation: float
128+
default: 0
129+
status: prototype
126130
print_iter_log:
127131
annotation: bool
128132
default: False

0 commit comments

Comments
 (0)