Skip to content

Commit 3c7ad49

Browse files
committed
feat: enable n > 1 in OpenAI API with PyTorch backend
Signed-off-by: ixlmar <[email protected]>
1 parent 5154556 commit 3c7ad49

File tree

4 files changed

+98
-21
lines changed

4 files changed

+98
-21
lines changed

tensorrt_llm/serve/chat_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,6 @@ def parse_chat_messages_coroutines(
217217
), mm_placeholder_counts
218218

219219

220-
def check_multiple_response(n: int, backend: Optional[str]):
221-
if n > 1 and backend == "pytorch":
222-
raise ValueError(
223-
"Multiple response is not supported in PyTorch workflow")
224-
225-
226220
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
227221
if id_type == "kimi_k2":
228222
return f"functions.{func_name}:{idx}"

tensorrt_llm/serve/openai_server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
from tensorrt_llm.llmapi.llm import RequestOutput
3535
from tensorrt_llm.logger import logger
3636
from tensorrt_llm.metrics.collector import MetricsCollector
37-
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
38-
parse_chat_messages_coroutines)
37+
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
3938
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
4039
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
4140
from tensorrt_llm.serve.metadata_server import create_metadata_server
@@ -484,7 +483,6 @@ async def create_chat_response(
484483
return chat_response
485484

486485
try:
487-
check_multiple_response(request.n, self.llm.args.backend)
488486
conversation: List[ConversationMessage] = []
489487
tool_dicts = None if request.tools is None else [
490488
tool.model_dump() for tool in request.tools
@@ -595,7 +593,6 @@ async def create_mm_embedding_response(promise: RequestOutput):
595593
)
596594

597595
try:
598-
check_multiple_response(request.n, self.llm.args.backend)
599596
conversation: List[ConversationMessage] = []
600597
tool_dicts = None if request.tools is None else [
601598
tool.model_dump() for tool in request.tools
@@ -730,7 +727,6 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
730727
yield "data: [DONE]\n\n"
731728

732729
try:
733-
check_multiple_response(request.n, self.llm.args.backend)
734730
if isinstance(request.prompt, str) or \
735731
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
736732
prompts = [request.prompt]

tests/unittest/llmapi/apps/_test_openai_chat.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
6868
temp_extra_llm_api_options_file: str, num_postprocess_workers: int):
6969
model_path = get_model_path(model_name)
7070
args = ["--backend", f"{backend}"]
71+
args.extend(["--kv_cache_free_gpu_memory_fraction",
72+
"0.2"]) # for co-existence with other servers
7173
if backend == "trt":
7274
args.extend(["--max_beam_width", "4"])
7375
if extra_llm_api_options:
@@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
7880
yield remote_server
7981

8082

83+
@pytest.fixture(scope="module")
84+
def server_with_beam_search(model_name: str, backend: str,
85+
extra_llm_api_options: bool,
86+
temp_extra_llm_api_options_file: str,
87+
num_postprocess_workers: int):
88+
model_path = get_model_path(model_name)
89+
args = ["--backend", f"{backend}"]
90+
args.extend(["--kv_cache_free_gpu_memory_fraction",
91+
"0.2"]) # for co-existence with other servers
92+
args.extend(["--max_beam_width", "2"])
93+
if extra_llm_api_options:
94+
args.extend(
95+
["--extra_llm_api_options", temp_extra_llm_api_options_file])
96+
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
97+
with RemoteOpenAIServer(model_path, args) as remote_server:
98+
yield remote_server
99+
100+
81101
@pytest.fixture(scope="module")
82102
def client(server: RemoteOpenAIServer):
83103
return server.get_client()
84104

85105

106+
@pytest.fixture(scope="module")
107+
def client_with_beam_search(server_with_beam_search: RemoteOpenAIServer):
108+
return server_with_beam_search.get_client()
109+
110+
86111
@pytest.fixture(scope="module")
87112
def async_client(server: RemoteOpenAIServer):
88113
return server.get_async_client()
@@ -176,11 +201,32 @@ def test_multi_turn_dialogue(client: openai.OpenAI, model_name: str):
176201
assert message.content is not None and len(message.content) >= 0
177202

178203

179-
def test_multiple_responses(client: openai.OpenAI, model_name: str,
180-
backend: str):
204+
def test_multiple_responses(client: openai.OpenAI, model_name: str):
205+
messages = [{
206+
"role": "system",
207+
"content": "you are a helpful assistant"
208+
}, {
209+
"role": "user",
210+
"content": "what is 1+1?"
211+
}]
212+
# test n and best_of
213+
chat_completion = client.chat.completions.create(
214+
model=model_name,
215+
messages=messages,
216+
max_completion_tokens=10,
217+
n=2,
218+
temperature=0.0,
219+
extra_body=dict(best_of=4),
220+
)
221+
assert len(chat_completion.choices) == 2
222+
223+
224+
def test_multiple_responses_and_beam_search(client: openai.OpenAI,
225+
model_name: str, backend: str):
181226
if backend == "pytorch":
182227
pytest.skip(
183-
"Multiple responses are not supported in PyTorch backend yet")
228+
"Mixing beam search and regular requests is not supported in PyTorch backend"
229+
)
184230

185231
messages = [{
186232
"role": "system",
@@ -202,6 +248,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
202248
assert chat_completion.choices[
203249
0].message.content != chat_completion.choices[
204250
1].message.content, "beam search should be different"
251+
205252
# test n and best_of
206253
chat_completion = client.chat.completions.create(
207254
model=model_name,
@@ -214,6 +261,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
214261
assert len(chat_completion.choices) == 2
215262

216263

264+
def test_multiple_responses_with_beam_search(
265+
client_with_beam_search: openai.OpenAI, model_name: str):
266+
messages = [{
267+
"role": "system",
268+
"content": "you are a helpful assistant"
269+
}, {
270+
"role": "user",
271+
"content": "what is 1+1?"
272+
}]
273+
# test beam search
274+
chat_completion = client_with_beam_search.chat.completions.create(
275+
model=model_name,
276+
messages=messages,
277+
max_completion_tokens=10,
278+
n=2,
279+
temperature=0.0,
280+
extra_body=dict(use_beam_search=True),
281+
)
282+
assert len(chat_completion.choices) == 2
283+
assert chat_completion.choices[
284+
0].message.content != chat_completion.choices[
285+
1].message.content, "beam search should be different"
286+
287+
217288
@pytest.mark.asyncio(loop_scope="module")
218289
async def test_chat_streaming(async_client: openai.AsyncOpenAI,
219290
model_name: str):

tests/unittest/llmapi/apps/_test_openai_completions.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,21 @@ def num_postprocess_workers(request):
3333
def server(model_name: str, backend: str, num_postprocess_workers: int):
3434
model_path = get_model_path(model_name)
3535
args = ["--backend", f"{backend}"]
36-
if backend == "trt":
37-
args.extend(["--max_beam_width", "4"])
36+
args.extend(["--kv_cache_free_gpu_memory_fraction",
37+
"0.2"]) # for co-existence with other servers
38+
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
39+
with RemoteOpenAIServer(model_path, args) as remote_server:
40+
yield remote_server
41+
42+
43+
@pytest.fixture(scope="module")
44+
def server_with_beam_search(model_name: str, backend: str,
45+
num_postprocess_workers: int):
46+
model_path = get_model_path(model_name)
47+
args = ["--backend", f"{backend}"]
48+
args.extend(["--kv_cache_free_gpu_memory_fraction",
49+
"0.2"]) # for co-existence with other servers
50+
args.extend(["--max_beam_width", "2"])
3851
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
3952
with RemoteOpenAIServer(model_path, args) as remote_server:
4053
yield remote_server
@@ -50,6 +63,11 @@ def async_client(server: RemoteOpenAIServer):
5063
return server.get_async_client()
5164

5265

66+
@pytest.fixture(scope="module")
67+
def async_client_with_beam_search(server_with_beam_search: RemoteOpenAIServer):
68+
return server_with_beam_search.get_async_client()
69+
70+
5371
def test_single_completion(client: openai.OpenAI, model_name):
5472
completion = client.completions.create(
5573
model=model_name,
@@ -145,12 +163,10 @@ async def test_batch_completions(async_client: openai.AsyncOpenAI, model_name,
145163
@pytest.mark.asyncio(loop_scope="module")
146164
@pytest.mark.parametrize("prompts",
147165
[["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2])
148-
async def test_batch_completions_beam_search(async_client: openai.AsyncOpenAI,
149-
model_name, prompts, backend):
166+
async def test_batch_completions_beam_search(
167+
async_client_with_beam_search: openai.AsyncOpenAI, model_name, prompts):
150168
# test beam search
151-
if backend == 'pytorch':
152-
pytest.skip("Beam search is not supported in PyTorch backend yet")
153-
batch = await async_client.completions.create(
169+
batch = await async_client_with_beam_search.completions.create(
154170
model=model_name,
155171
prompt=prompts,
156172
n=2,

0 commit comments

Comments
 (0)