@@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
6868 temp_extra_llm_api_options_file : str , num_postprocess_workers : int ):
6969 model_path = get_model_path (model_name )
7070 args = ["--backend" , f"{ backend } " ]
71+ args .extend (["--kv_cache_free_gpu_memory_fraction" ,
72+ "0.2" ]) # for co-existence with other servers
7173 if backend == "trt" :
7274 args .extend (["--max_beam_width" , "4" ])
7375 if extra_llm_api_options :
@@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
7880 yield remote_server
7981
8082
83+ @pytest .fixture (scope = "module" )
84+ def server_with_beam_search (model_name : str , backend : str ,
85+ extra_llm_api_options : bool ,
86+ temp_extra_llm_api_options_file : str ,
87+ num_postprocess_workers : int ):
88+ model_path = get_model_path (model_name )
89+ args = ["--backend" , f"{ backend } " ]
90+ args .extend (["--kv_cache_free_gpu_memory_fraction" ,
91+ "0.2" ]) # for co-existence with other servers
92+ args .extend (["--max_beam_width" , "2" ])
93+ if extra_llm_api_options :
94+ args .extend (
95+ ["--extra_llm_api_options" , temp_extra_llm_api_options_file ])
96+ args .extend (["--num_postprocess_workers" , f"{ num_postprocess_workers } " ])
97+ with RemoteOpenAIServer (model_path , args ) as remote_server :
98+ yield remote_server
99+
100+
81101@pytest .fixture (scope = "module" )
82102def client (server : RemoteOpenAIServer ):
83103 return server .get_client ()
84104
85105
106+ @pytest .fixture (scope = "module" )
107+ def client_with_beam_search (server_with_beam_search : RemoteOpenAIServer ):
108+ return server_with_beam_search .get_client ()
109+
110+
86111@pytest .fixture (scope = "module" )
87112def async_client (server : RemoteOpenAIServer ):
88113 return server .get_async_client ()
@@ -176,11 +201,32 @@ def test_multi_turn_dialogue(client: openai.OpenAI, model_name: str):
176201 assert message .content is not None and len (message .content ) >= 0
177202
178203
179- def test_multiple_responses (client : openai .OpenAI , model_name : str ,
180- backend : str ):
204+ def test_multiple_responses (client : openai .OpenAI , model_name : str ):
205+ messages = [{
206+ "role" : "system" ,
207+ "content" : "you are a helpful assistant"
208+ }, {
209+ "role" : "user" ,
210+ "content" : "what is 1+1?"
211+ }]
212+ # test n and best_of
213+ chat_completion = client .chat .completions .create (
214+ model = model_name ,
215+ messages = messages ,
216+ max_completion_tokens = 10 ,
217+ n = 2 ,
218+ temperature = 0.0 ,
219+ extra_body = dict (best_of = 4 ),
220+ )
221+ assert len (chat_completion .choices ) == 2
222+
223+
224+ def test_multiple_responses_and_beam_search (client : openai .OpenAI ,
225+ model_name : str , backend : str ):
181226 if backend == "pytorch" :
182227 pytest .skip (
183- "Multiple responses are not supported in PyTorch backend yet" )
228+ "Mixing beam search and regular requests is not supported in PyTorch backend"
229+ )
184230
185231 messages = [{
186232 "role" : "system" ,
@@ -202,6 +248,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
202248 assert chat_completion .choices [
203249 0 ].message .content != chat_completion .choices [
204250 1 ].message .content , "beam search should be different"
251+
205252 # test n and best_of
206253 chat_completion = client .chat .completions .create (
207254 model = model_name ,
@@ -214,6 +261,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
214261 assert len (chat_completion .choices ) == 2
215262
216263
264+ def test_multiple_responses_with_beam_search (
265+ client_with_beam_search : openai .OpenAI , model_name : str ):
266+ messages = [{
267+ "role" : "system" ,
268+ "content" : "you are a helpful assistant"
269+ }, {
270+ "role" : "user" ,
271+ "content" : "what is 1+1?"
272+ }]
273+ # test beam search
274+ chat_completion = client_with_beam_search .chat .completions .create (
275+ model = model_name ,
276+ messages = messages ,
277+ max_completion_tokens = 10 ,
278+ n = 2 ,
279+ temperature = 0.0 ,
280+ extra_body = dict (use_beam_search = True ),
281+ )
282+ assert len (chat_completion .choices ) == 2
283+ assert chat_completion .choices [
284+ 0 ].message .content != chat_completion .choices [
285+ 1 ].message .content , "beam search should be different"
286+
287+
217288@pytest .mark .asyncio (loop_scope = "module" )
218289async def test_chat_streaming (async_client : openai .AsyncOpenAI ,
219290 model_name : str ):
0 commit comments