Skip to content

Commit 7fbd931

Browse files
add llm request unittest
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 5c0d5ba commit 7fbd931

File tree

2 files changed

+248
-3
lines changed

2 files changed

+248
-3
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,6 @@ def create_child_request(self, request_id: int):
356356
child_request.py_orig_prompt_len = child_request.orig_prompt_len
357357
child_request.py_max_new_tokens = child_request.max_new_tokens
358358

359-
# input_toknes are already cloned in create_child_request.
360-
child_request.py_tokens = child_request.get_tokens()
361-
362359
# Copy Python-specific configuration from parent
363360
child_request.py_return_log_probs = self.py_return_log_probs
364361
child_request.py_return_context_logits = self.py_return_context_logits
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from unittest.mock import MagicMock
17+
18+
import pytest
19+
20+
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, LlmResponse,
21+
SamplingConfig)
22+
23+
24+
def create_sampling_config():
25+
"""Create test sampling configuration."""
26+
config = SamplingConfig()
27+
# Used setattr method due to incompatible binding.
28+
setattr(config, 'top_p', [0.9])
29+
setattr(config, 'num_return_sequences', 2)
30+
return config
31+
32+
33+
def test_create_request():
34+
"""Test basic LlmRequest creation and attribute initialization."""
35+
sampling_config = create_sampling_config()
36+
request = LlmRequest(
37+
request_id=1,
38+
max_new_tokens=10,
39+
input_tokens=[1, 2, 3],
40+
sampling_config=sampling_config,
41+
is_streaming=False,
42+
)
43+
44+
# Verify basic attributes
45+
assert request.py_request_id == 1
46+
assert request.py_max_new_tokens == 10
47+
assert request.py_prompt_len == 3
48+
assert request.py_orig_prompt_len == 3
49+
assert request.py_client_id is None
50+
51+
# Verify default values
52+
assert not request.py_return_log_probs
53+
assert not request.py_return_context_logits
54+
assert not request.py_return_generation_logits
55+
assert request.py_return_logits_device_memory
56+
assert not request.py_is_draft
57+
assert not request.py_exclude_last_generation_logits
58+
59+
# Verify PyResult is initialized
60+
assert request.py_result is not None
61+
62+
63+
def test_create_request_with_optional_params():
64+
"""Test LlmRequest creation with optional parameters."""
65+
sampling_config = create_sampling_config()
66+
request = LlmRequest(
67+
request_id=2,
68+
max_new_tokens=20,
69+
input_tokens=[10, 20, 30, 40],
70+
sampling_config=sampling_config,
71+
is_streaming=False,
72+
client_id=100,
73+
return_log_probs=True,
74+
return_context_logits=True,
75+
return_generation_logits=True,
76+
return_logits_device_memory=False,
77+
exclude_last_generation_logits=True,
78+
is_draft=True,
79+
)
80+
81+
# Verify optional parameters
82+
assert request.py_client_id == 100
83+
assert request.py_return_log_probs
84+
assert request.py_return_context_logits
85+
assert request.py_return_generation_logits
86+
assert not request.py_return_logits_device_memory
87+
assert request.py_exclude_last_generation_logits
88+
assert request.py_is_draft
89+
90+
91+
def test_create_child_request():
92+
"""Test create_child_request method."""
93+
sampling_config = create_sampling_config()
94+
# Create parent request with various attributes
95+
parent_request = LlmRequest(
96+
request_id=1,
97+
max_new_tokens=10,
98+
input_tokens=[1, 2, 3],
99+
sampling_config=sampling_config,
100+
is_streaming=False,
101+
client_id=50,
102+
return_log_probs=True,
103+
return_context_logits=True,
104+
)
105+
106+
# Create child request
107+
child_request = parent_request.create_child_request(2)
108+
109+
# Verify child request attributes
110+
assert child_request.request_id == 2
111+
assert child_request.py_request_id == 2
112+
assert child_request.py_parent_request_id == 1
113+
114+
# Verify copied configuration
115+
assert child_request.py_client_id == 50
116+
assert child_request.py_max_new_tokens == 10
117+
assert child_request.get_tokens() == [[1, 2, 3]]
118+
assert child_request.py_return_log_probs
119+
assert child_request.py_return_context_logits
120+
121+
# Verify runtime state
122+
assert child_request.py_batch_idx is None # Reset to None
123+
124+
# Verify PyResult is new instance
125+
assert child_request.py_result is not None
126+
assert child_request.py_result is not parent_request.py_result
127+
128+
# Cannot create child request more than num_return_sequences.
129+
with pytest.raises(RuntimeError):
130+
child_request.create_child_request(3)
131+
132+
133+
def test_child_inherits_parent_attributes():
134+
"""Test that child requests properly inherit parent attributes"""
135+
sampling_config = create_sampling_config()
136+
# Set up parent with various attributes
137+
parent_request = LlmRequest(request_id=100,
138+
max_new_tokens=20,
139+
input_tokens=[1, 2, 3, 4],
140+
sampling_config=sampling_config,
141+
is_streaming=True,
142+
client_id=2000,
143+
return_log_probs=True,
144+
return_context_logits=True,
145+
return_generation_logits=True)
146+
147+
child = parent_request.create_child_request(2)
148+
149+
# Verify inheritance
150+
assert child.py_client_id == parent_request.py_client_id
151+
assert child.py_max_new_tokens == parent_request.py_max_new_tokens
152+
assert child.py_return_log_probs == parent_request.py_return_log_probs
153+
assert (child.py_return_context_logits ==
154+
parent_request.py_return_context_logits)
155+
assert (child.py_return_generation_logits ==
156+
parent_request.py_return_generation_logits)
157+
158+
159+
def test_parent_child_independence():
160+
"""Test that parent and child requests are independent"""
161+
sampling_config = create_sampling_config()
162+
input_tokens = [1, 2, 3]
163+
input_len = len(input_tokens)
164+
parent_request = LlmRequest(request_id=100,
165+
max_new_tokens=10,
166+
input_tokens=input_tokens,
167+
sampling_config=sampling_config,
168+
is_streaming=False,
169+
client_id=1000)
170+
171+
# Create child requests
172+
child_request = parent_request.create_child_request(2)
173+
174+
# Verify initial tokens are the same content but different objects
175+
assert parent_request.get_tokens() == child_request.get_tokens()
176+
assert (parent_request.get_tokens() is not
177+
child_request.get_tokens()), \
178+
"Parent and child should have independent token lists"
179+
180+
# Test token generation independence
181+
# Add new tokens to each request
182+
parent_request.add_new_token(10, beam=0)
183+
child_request.add_new_token(20, beam=0)
184+
185+
# Verify tokens are updated independently in the first beam.
186+
assert 10 in parent_request.get_tokens()[0]
187+
assert 20 not in parent_request.get_tokens()[0]
188+
189+
assert 20 in child_request.get_tokens()[0]
190+
assert 10 not in child_request.get_tokens()[0][input_len:]
191+
assert 30 not in child_request.get_tokens()[0]
192+
193+
# Test that each has independent PyResult
194+
assert parent_request.py_result is not child_request.py_result
195+
196+
197+
def test_create_response():
198+
"""Test create_response method of parent and child requests."""
199+
sampling_config = create_sampling_config()
200+
request = LlmRequest(
201+
request_id=1,
202+
max_new_tokens=10,
203+
input_tokens=[1, 2, 3],
204+
sampling_config=sampling_config,
205+
is_streaming=False,
206+
client_id=100,
207+
)
208+
209+
child_request = request.create_child_request(2)
210+
child_response = child_request.create_response()
211+
212+
# Test when result is not None
213+
response = request.create_response(use_fast_logits=True, mpi_world_rank=1)
214+
assert response is not None
215+
assert isinstance(response, LlmResponse)
216+
assert response.request_id == request.py_request_id
217+
assert response.client_id == request.py_client_id
218+
assert response.error_msg is None
219+
assert response.result is not None
220+
assert response.result.sequence_index == 0
221+
222+
assert child_response is not None
223+
assert child_response.request_id == request.py_request_id
224+
assert child_response.client_id == child_request.py_client_id
225+
assert child_response.error_msg is None
226+
assert child_response.result is not None
227+
assert child_response.result.sequence_index == 1
228+
229+
230+
def test_creates_none_response_when_result_is_none():
231+
"""None response should be returned when request result is None."""
232+
sampling_config = create_sampling_config()
233+
request = LlmRequest(
234+
request_id=1,
235+
max_new_tokens=10,
236+
input_tokens=[1, 2, 3],
237+
sampling_config=sampling_config,
238+
is_streaming=False,
239+
client_id=100,
240+
)
241+
242+
# Mock create_result to return None
243+
request.create_result = MagicMock(return_value=None)
244+
245+
# Test when result is None
246+
response = request.create_response()
247+
248+
assert response is None

0 commit comments

Comments
 (0)