diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index acda26b511c..7534a0d22ad 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -217,12 +217,6 @@ def parse_chat_messages_coroutines( ), mm_placeholder_counts -def check_multiple_response(n: int, backend: Optional[str]): - if n > 1 and backend == "pytorch": - raise ValueError( - "Multiple response is not supported in PyTorch workflow") - - def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): if id_type == "kimi_k2": return f"functions.{func_name}:{idx}" diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index aad50e9c7d8..05facda203a 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -34,8 +34,7 @@ from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger from tensorrt_llm.metrics.collector import MetricsCollector -from tensorrt_llm.serve.chat_utils import (check_multiple_response, - parse_chat_messages_coroutines) +from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker from tensorrt_llm.serve.metadata_server import create_metadata_server @@ -484,7 +483,6 @@ async def create_chat_response( return chat_response try: - check_multiple_response(request.n, self.llm.args.backend) conversation: List[ConversationMessage] = [] tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools @@ -595,7 +593,6 @@ async def create_mm_embedding_response(promise: RequestOutput): ) try: - check_multiple_response(request.n, self.llm.args.backend) conversation: List[ConversationMessage] = [] tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools @@ -730,7 +727,6 @@ async def generator_wrapper(generator: AsyncIterator[Any]): yield "data: [DONE]\n\n" try: - check_multiple_response(request.n, self.llm.args.backend) if isinstance(request.prompt, str) or \ (isinstance(request.prompt, list) and isinstance(request.prompt[0], int)): prompts = [request.prompt] diff --git a/tests/unittest/llmapi/apps/_test_openai_chat.py b/tests/unittest/llmapi/apps/_test_openai_chat.py index 532f11c0c4e..d35935b1e03 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat.py @@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool, temp_extra_llm_api_options_file: str, num_postprocess_workers: int): model_path = get_model_path(model_name) args = ["--backend", f"{backend}"] + args.extend(["--kv_cache_free_gpu_memory_fraction", + "0.2"]) # for co-existence with other servers if backend == "trt": args.extend(["--max_beam_width", "4"]) if extra_llm_api_options: @@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool, yield remote_server +@pytest.fixture(scope="module") +def server_with_beam_search(model_name: str, backend: str, + extra_llm_api_options: bool, + temp_extra_llm_api_options_file: str, + num_postprocess_workers: int): + model_path = get_model_path(model_name) + args = ["--backend", f"{backend}"] + args.extend(["--kv_cache_free_gpu_memory_fraction", + "0.2"]) # for co-existence with other servers + args.extend(["--max_beam_width", "2"]) + if extra_llm_api_options: + args.extend( + ["--extra_llm_api_options", temp_extra_llm_api_options_file]) + args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"]) + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + @pytest.fixture(scope="module") def client(server: RemoteOpenAIServer): return server.get_client() +@pytest.fixture(scope="module") +def client_with_beam_search(server_with_beam_search: RemoteOpenAIServer): + return server_with_beam_search.get_client() + + @pytest.fixture(scope="module") def async_client(server: RemoteOpenAIServer): return server.get_async_client() @@ -180,7 +205,33 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str, backend: str): if backend == "pytorch": pytest.skip( - "Multiple responses are not supported in PyTorch backend yet") + "'n' not allowed with temperature=0 unless TLLM_ALLOW_N_GREEDY_DECODING=1" + ) + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + # test n and best_of + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + n=2, + temperature=0.0, + extra_body=dict(best_of=4), + ) + assert len(chat_completion.choices) == 2 + + +def test_multiple_responses_and_beam_search(client: openai.OpenAI, + model_name: str, backend: str): + if backend == "pytorch": + pytest.skip( + "Mixing beam search and regular requests is not supported in PyTorch backend" + ) messages = [{ "role": "system", @@ -202,6 +253,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str, assert chat_completion.choices[ 0].message.content != chat_completion.choices[ 1].message.content, "beam search should be different" + # test n and best_of chat_completion = client.chat.completions.create( model=model_name, @@ -214,6 +266,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str, assert len(chat_completion.choices) == 2 +def test_multiple_responses_with_beam_search( + client_with_beam_search: openai.OpenAI, model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + # test beam search + chat_completion = client_with_beam_search.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + n=2, + temperature=0.0, + extra_body=dict(use_beam_search=True), + ) + assert len(chat_completion.choices) == 2 + assert chat_completion.choices[ + 0].message.content != chat_completion.choices[ + 1].message.content, "beam search should be different" + + @pytest.mark.asyncio(loop_scope="module") async def test_chat_streaming(async_client: openai.AsyncOpenAI, model_name: str): diff --git a/tests/unittest/llmapi/apps/_test_openai_completions.py b/tests/unittest/llmapi/apps/_test_openai_completions.py index 098b93c5cb5..e3e374ec1c8 100644 --- a/tests/unittest/llmapi/apps/_test_openai_completions.py +++ b/tests/unittest/llmapi/apps/_test_openai_completions.py @@ -33,8 +33,21 @@ def num_postprocess_workers(request): def server(model_name: str, backend: str, num_postprocess_workers: int): model_path = get_model_path(model_name) args = ["--backend", f"{backend}"] - if backend == "trt": - args.extend(["--max_beam_width", "4"]) + args.extend(["--kv_cache_free_gpu_memory_fraction", + "0.2"]) # for co-existence with other servers + args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"]) + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def server_with_beam_search(model_name: str, backend: str, + num_postprocess_workers: int): + model_path = get_model_path(model_name) + args = ["--backend", f"{backend}"] + args.extend(["--kv_cache_free_gpu_memory_fraction", + "0.2"]) # for co-existence with other servers + args.extend(["--max_beam_width", "2"]) args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"]) with RemoteOpenAIServer(model_path, args) as remote_server: yield remote_server @@ -50,6 +63,11 @@ def async_client(server: RemoteOpenAIServer): return server.get_async_client() +@pytest.fixture(scope="module") +def async_client_with_beam_search(server_with_beam_search: RemoteOpenAIServer): + return server_with_beam_search.get_async_client() + + def test_single_completion(client: openai.OpenAI, model_name): completion = client.completions.create( model=model_name, @@ -145,12 +163,10 @@ async def test_batch_completions(async_client: openai.AsyncOpenAI, model_name, @pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("prompts", [["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2]) -async def test_batch_completions_beam_search(async_client: openai.AsyncOpenAI, - model_name, prompts, backend): +async def test_batch_completions_beam_search( + async_client_with_beam_search: openai.AsyncOpenAI, model_name, prompts): # test beam search - if backend == 'pytorch': - pytest.skip("Beam search is not supported in PyTorch backend yet") - batch = await async_client.completions.create( + batch = await async_client_with_beam_search.completions.create( model=model_name, prompt=prompts, n=2,