Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tensorrt_llm/serve/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,6 @@ def parse_chat_messages_coroutines(
), mm_placeholder_counts


def check_multiple_response(n: int, backend: Optional[str]):
if n > 1 and backend == "pytorch":
raise ValueError(
"Multiple response is not supported in PyTorch workflow")


def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
Expand Down
6 changes: 1 addition & 5 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.metrics.collector import MetricsCollector
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
from tensorrt_llm.serve.metadata_server import create_metadata_server
Expand Down Expand Up @@ -484,7 +483,6 @@ async def create_chat_response(
return chat_response

try:
check_multiple_response(request.n, self.llm.args.backend)
conversation: List[ConversationMessage] = []
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
Expand Down Expand Up @@ -595,7 +593,6 @@ async def create_mm_embedding_response(promise: RequestOutput):
)

try:
check_multiple_response(request.n, self.llm.args.backend)
conversation: List[ConversationMessage] = []
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
Expand Down Expand Up @@ -730,7 +727,6 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
yield "data: [DONE]\n\n"

try:
check_multiple_response(request.n, self.llm.args.backend)
if isinstance(request.prompt, str) or \
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
prompts = [request.prompt]
Expand Down
78 changes: 77 additions & 1 deletion tests/unittest/llmapi/apps/_test_openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
temp_extra_llm_api_options_file: str, num_postprocess_workers: int):
model_path = get_model_path(model_name)
args = ["--backend", f"{backend}"]
args.extend(["--kv_cache_free_gpu_memory_fraction",
"0.2"]) # for co-existence with other servers
if backend == "trt":
args.extend(["--max_beam_width", "4"])
if extra_llm_api_options:
Expand All @@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
yield remote_server


@pytest.fixture(scope="module")
def server_with_beam_search(model_name: str, backend: str,
extra_llm_api_options: bool,
temp_extra_llm_api_options_file: str,
num_postprocess_workers: int):
model_path = get_model_path(model_name)
args = ["--backend", f"{backend}"]
args.extend(["--kv_cache_free_gpu_memory_fraction",
"0.2"]) # for co-existence with other servers
args.extend(["--max_beam_width", "2"])
if extra_llm_api_options:
args.extend(
["--extra_llm_api_options", temp_extra_llm_api_options_file])
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer):
return server.get_client()


@pytest.fixture(scope="module")
def client_with_beam_search(server_with_beam_search: RemoteOpenAIServer):
return server_with_beam_search.get_client()


@pytest.fixture(scope="module")
def async_client(server: RemoteOpenAIServer):
return server.get_async_client()
Expand Down Expand Up @@ -180,7 +205,33 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
backend: str):
if backend == "pytorch":
pytest.skip(
"Multiple responses are not supported in PyTorch backend yet")
"'n' not allowed with temperature=0 unless TLLM_ALLOW_N_GREEDY_DECODING=1"
)
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# test n and best_of
chat_completion = client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
n=2,
temperature=0.0,
extra_body=dict(best_of=4),
)
assert len(chat_completion.choices) == 2


def test_multiple_responses_and_beam_search(client: openai.OpenAI,
model_name: str, backend: str):
if backend == "pytorch":
pytest.skip(
"Mixing beam search and regular requests is not supported in PyTorch backend"
)

messages = [{
"role": "system",
Expand All @@ -202,6 +253,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
assert chat_completion.choices[
0].message.content != chat_completion.choices[
1].message.content, "beam search should be different"

# test n and best_of
chat_completion = client.chat.completions.create(
model=model_name,
Expand All @@ -214,6 +266,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
assert len(chat_completion.choices) == 2


def test_multiple_responses_with_beam_search(
client_with_beam_search: openai.OpenAI, model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# test beam search
chat_completion = client_with_beam_search.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
n=2,
temperature=0.0,
extra_body=dict(use_beam_search=True),
)
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[
1].message.content, "beam search should be different"


@pytest.mark.asyncio(loop_scope="module")
async def test_chat_streaming(async_client: openai.AsyncOpenAI,
model_name: str):
Expand Down
30 changes: 23 additions & 7 deletions tests/unittest/llmapi/apps/_test_openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,21 @@ def num_postprocess_workers(request):
def server(model_name: str, backend: str, num_postprocess_workers: int):
model_path = get_model_path(model_name)
args = ["--backend", f"{backend}"]
if backend == "trt":
args.extend(["--max_beam_width", "4"])
args.extend(["--kv_cache_free_gpu_memory_fraction",
"0.2"]) # for co-existence with other servers
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def server_with_beam_search(model_name: str, backend: str,
num_postprocess_workers: int):
model_path = get_model_path(model_name)
args = ["--backend", f"{backend}"]
args.extend(["--kv_cache_free_gpu_memory_fraction",
"0.2"]) # for co-existence with other servers
args.extend(["--max_beam_width", "2"])
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
Expand All @@ -50,6 +63,11 @@ def async_client(server: RemoteOpenAIServer):
return server.get_async_client()


@pytest.fixture(scope="module")
def async_client_with_beam_search(server_with_beam_search: RemoteOpenAIServer):
return server_with_beam_search.get_async_client()


def test_single_completion(client: openai.OpenAI, model_name):
completion = client.completions.create(
model=model_name,
Expand Down Expand Up @@ -145,12 +163,10 @@ async def test_batch_completions(async_client: openai.AsyncOpenAI, model_name,
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("prompts",
[["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2])
async def test_batch_completions_beam_search(async_client: openai.AsyncOpenAI,
model_name, prompts, backend):
async def test_batch_completions_beam_search(
async_client_with_beam_search: openai.AsyncOpenAI, model_name, prompts):
# test beam search
if backend == 'pytorch':
pytest.skip("Beam search is not supported in PyTorch backend yet")
batch = await async_client.completions.create(
batch = await async_client_with_beam_search.completions.create(
model=model_name,
prompt=prompts,
n=2,
Expand Down