Skip to content

Commit bb5b16f

Browse files
authored
feat: Return context response immediately when stream_interval > 1 (#5836)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent 3079e8c commit bb5b16f

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2037,7 +2037,8 @@ def _handle_responses(self):
20372037
request.update_perf_metrics(self.model_engine.iter_counter)
20382038

20392039
request_done = False
2040-
if self.model_engine.iter_counter % self.stream_interval == 0 or request.is_finished:
2040+
if request.py_decoding_iter == 1 or request.is_finished or \
2041+
request.py_decoding_iter % self.stream_interval == 0:
20412042
response = request.create_response(False, self.dist.rank)
20422043
if response:
20432044
request_done = response.result.is_final

tests/unittest/llmapi/test_llm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,40 @@ async def main():
534534
test_non_streaming_usage_wait()
535535

536536

537+
@pytest.mark.parametrize("chunked", [True, False])
538+
@pytest.mark.part0
539+
def test_llm_generate_async_with_stream_interval(chunked):
540+
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
541+
max_num_tokens = 256
542+
with LLM_torch(model_path,
543+
max_num_tokens=max_num_tokens,
544+
stream_interval=4,
545+
enable_chunked_prefill=chunked) as llm:
546+
sampling_params = SamplingParams(max_tokens=13,
547+
ignore_eos=True,
548+
detokenize=False)
549+
step = 0
550+
last_step_len = 0
551+
prompt = "The capital of France is "
552+
if chunked:
553+
prompt = prompt * max_num_tokens
554+
for output in llm.generate_async(prompt,
555+
sampling_params=sampling_params,
556+
streaming=True):
557+
current_step_len = len(output.outputs[0].token_ids)
558+
# The output lens of each step need to be [1, 3, 4, 4, 1]
559+
if step == 0:
560+
assert current_step_len == 1
561+
elif step == 1:
562+
assert current_step_len - last_step_len == 3
563+
elif step == 2 or step == 3:
564+
assert current_step_len - last_step_len == 4
565+
else:
566+
assert current_step_len - last_step_len == 1
567+
step += 1
568+
last_step_len = current_step_len
569+
570+
537571
@pytest.fixture(scope="module")
538572
def llm_for_sampling_params():
539573
build_config = BuildConfig(max_beam_width=3)

0 commit comments

Comments
 (0)