Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Jul 30, 2024
1 parent ff3e004 commit f9989d2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from unittest.mock import MagicMock

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
Expand Down Expand Up @@ -42,3 +47,37 @@ async def _async_serving_chat_init():
def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.chat_template == CHAT_TEMPLATE


def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLMEngine)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)

serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert mock_engine.generate.call_args.args[1].max_tokens == 93

req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ async def create_chat_completion(
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))

print(request)
print("prompt", prompt)
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
Expand Down

0 comments on commit f9989d2

Please sign in to comment.