11# Adapted from
22# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py
3+ import json
34import os
5+ import re
46import tempfile
57
8+ import jsonschema
69import openai
710import pytest
811import yaml
912
10- from ..test_llm import get_model_path , similar
13+ from ..test_llm import get_model_path
1114from .openai_server import RemoteOpenAIServer
1215
1316pytestmark = pytest .mark .threadleak (enabled = False )
1417
1518
16- @pytest .fixture (scope = "module" , ids = [ "TinyLlama-1.1B-Chat" ] )
19+ @pytest .fixture (scope = "module" )
1720def model_name ():
1821 return "llama-3.1-model/Llama-3.1-8B-Instruct"
1922
2023
2124@pytest .fixture (scope = "module" )
22- def temp_extra_llm_api_options_file (request ):
25+ def temp_extra_llm_api_options_file ():
2326 temp_dir = tempfile .gettempdir ()
2427 temp_file_path = os .path .join (temp_dir , "extra_llm_api_options.yaml" )
2528 try :
@@ -37,7 +40,12 @@ def temp_extra_llm_api_options_file(request):
3740@pytest .fixture (scope = "module" )
3841def server (model_name : str , temp_extra_llm_api_options_file : str ):
3942 model_path = get_model_path (model_name )
40- args = ["--extra_llm_api_options" , temp_extra_llm_api_options_file ]
43+
44+ # Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
45+ args = [
46+ "--max_batch_size=8" , "--max_seq_len=1024" , "--max_num_tokens=1024" ,
47+ f"--extra_llm_api_options={ temp_extra_llm_api_options_file } "
48+ ]
4149 with RemoteOpenAIServer (model_path , args ) as remote_server :
4250 yield remote_server
4351
@@ -112,12 +120,7 @@ def tool_get_current_date():
112120
113121def test_chat_structural_tag (client : openai .OpenAI , model_name : str ,
114122 tool_get_current_weather , tool_get_current_date ):
115- messages = [
116- {
117- "role" :
118- "system" ,
119- "content" :
120- f"""
123+ system_prompt = f"""
121124# Tool Instructions
122125- Always execute python code in messages that you share.
123126- When looking for real time information use relevant functions if available else fallback to brave_search
@@ -140,20 +143,24 @@ def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
140143- Only call one function at a time
141144- Put the entire function call reply on one line
142145- Always add your sources when using search results to answer the user query
143- You are a helpful assistant.""" ,
146+ You are a helpful assistant."""
147+ user_prompt = "You are in New York. Please get the current date and time, and the weather."
148+
149+ messages = [
150+ {
151+ "role" : "system" ,
152+ "content" : system_prompt ,
144153 },
145154 {
146- "role" :
147- "user" ,
148- "content" :
149- "You are in New York. Please get the current date and time, and the weather." ,
155+ "role" : "user" ,
156+ "content" : user_prompt ,
150157 },
151158 ]
152159
153160 chat_completion = client .chat .completions .create (
154161 model = model_name ,
155162 messages = messages ,
156- max_completion_tokens = 100 ,
163+ max_completion_tokens = 256 ,
157164 response_format = {
158165 "type" :
159166 "structural_tag" ,
@@ -173,11 +180,18 @@ def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
173180 "triggers" : ["<function=" ],
174181 },
175182 )
176- assert chat_completion .id is not None
177- assert len (chat_completion .choices ) == 1
183+
178184 message = chat_completion .choices [0 ].message
179185 assert message .content is not None
180186 assert message .role == "assistant"
181187
182- reference = '<function=get_current_date>{"timezone": "America/New_York"}</function>\n <function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>\n \n Sources:\n - get_current_date function\n - get_current_weather function'
183- assert similar (chat_completion .choices [0 ].message .content , reference )
188+ match = re .search (r'<function=get_current_weather>([\S\s]+?)</function>' ,
189+ message .content )
190+ params = json .loads (match .group (1 ))
191+ jsonschema .validate (params ,
192+ tool_get_current_weather ["function" ]["parameters" ])
193+
194+ match = re .search (r'<function=get_current_date>([\S\s]+?)</function>' ,
195+ message .content )
196+ params = json .loads (match .group (1 ))
197+ jsonschema .validate (params , tool_get_current_date ["function" ]["parameters" ])
0 commit comments