Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Dec 25, 2024
1 parent 90889fd commit 3603399
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_chat_completion_with_timings_per_token():
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
Expand All @@ -197,7 +197,7 @@ def test_logprobs():
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
Expand Down
35 changes: 35 additions & 0 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import time
from openai import OpenAI
from utils import *

server = ServerPreset.tinyllama2()
Expand Down Expand Up @@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream():
assert content_stream == res_non_stream.body["content"]


def test_completion_stream_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].text is not None
assert match_regex("(going|bed)+", res.choices[0].text)


def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
stream=True,
)
output_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("(going|bed)+", output_text)


@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
Expand Down
7 changes: 5 additions & 2 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,11 @@ static json oaicompat_completion_params_parse(const json & body) {
}

// Params supported by OAI but unsupported by llama.cpp
if (body.contains("best_of")) {
throw std::runtime_error("Unsupported param: best_of");
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
}
}

// Copy remaining properties to llama_params
Expand Down

0 comments on commit 3603399

Please sign in to comment.