From f9989d2fe1174b6532b67aee6821d5ac8ad435d6 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 30 Jul 2024 21:28:59 +0000 Subject: [PATCH] Add test --- tests/entrypoints/openai/test_serving_chat.py | 39 +++++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 2 + 2 files changed, 41 insertions(+) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 464465494b714..168ba7ba888ef 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,7 +1,12 @@ import asyncio +from contextlib import suppress from dataclasses import dataclass +from unittest.mock import MagicMock +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -42,3 +47,37 @@ async def _async_serving_chat_init(): def test_async_serving_chat_init(): serving_completion = asyncio.run(_async_serving_chat_init()) assert serving_completion.chat_template == CHAT_TEMPLATE + + +def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + # AsyncLLMEngine.generate(inputs, sampling_params, ...) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + req.max_tokens = 10 + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1dee798858a16..a6fe538dda74f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -135,6 +135,8 @@ async def create_chat_completion( guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) + print(request) + print("prompt", prompt) prompt_inputs = self._tokenize_prompt_input( request, tokenizer,