From 46828872c31b25df16169cbbf5c2225fa9cb0675 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 09:55:09 +0100 Subject: [PATCH] server : (embeddings) using same format for "input" and "content" (#10872) * server : (embeddings) using same format for "input" and "content" * fix test case * handle empty input case * fix test --- examples/server/server.cpp | 20 +++++++---- examples/server/tests/unit/test_embedding.py | 35 ++++++++++++++++++-- examples/server/utils.hpp | 1 + 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 436170a034fde..71566b94e61bb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) { const json body = json::parse(req.body); bool oaicompat = false; - // an input prompt can be a string or a list of tokens (integer) + // for the shape of input/content, see tokenize_input_prompts() json prompt; - if (body.count("input") != 0) { + if (body.contains("input")) { oaicompat = true; prompt = body.at("input"); - } else if (body.count("content") != 0) { - // with "content", we only support single prompt - prompt = std::vector{body.at("content")}; + } else if (body.contains("content")) { + oaicompat = false; + prompt = body.at("content"); } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + // create and queue the task json responses = json::array(); bool error = false; { std::vector tasks; - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = ctx_server.queue_tasks.get_new_id(); diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index fea1d6510c89e..4f4e9dcf087fa 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -45,6 +45,35 @@ def test_embedding_multiple(): assert len(d['embedding']) > 1 +@pytest.mark.parametrize( + "content,is_multi_prompt", + [ + # single prompt + ("string", False), + ([12, 34, 56], False), + ([12, 34, "string", 56, 78], False), + # multiple prompts + (["string1", "string2"], True), + (["string1", [12, 34, 56]], True), + ([[12, 34, 56], [12, 34, 56]], True), + ([[12, 34, 56], [12, "string", 34, 56]], True), + ] +) +def test_embedding_mixed_input(content, is_multi_prompt: bool): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={"content": content}) + assert res.status_code == 200 + if is_multi_prompt: + assert len(res.body) == len(content) + for d in res.body: + assert 'embedding' in d + assert len(d['embedding']) > 1 + else: + assert 'embedding' in res.body + assert len(res.body['embedding']) > 1 + + def test_embedding_openai_library_single(): global server server.start() @@ -102,8 +131,8 @@ def test_same_prompt_give_same_result(): @pytest.mark.parametrize( "content,n_tokens", [ - ("I believe the meaning of life is", 7), - ("This is a test", 4), + ("I believe the meaning of life is", 9), + ("This is a test", 6), ] ) def test_embedding_usage_single(content, n_tokens): @@ -126,4 +155,4 @@ def test_embedding_usage_multiple(): }) assert res.status_code == 200 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] - assert res.body['usage']['prompt_tokens'] == 2 * 7 + assert res.body['usage']['prompt_tokens'] == 2 * 9 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8fffe484aec12..ffdffe904308f 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_ * and multiple prompts (multi-tasks): * - "prompt": ["string1", "string2"] * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] */ static std::vector tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {