From ba27e98582c25a791fcf542adcb4b1199db21c0c Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 29 Jan 2025 18:29:18 +0000 Subject: [PATCH] Unify llama 3.x chat handling again (allow `{"type": "function", "name": ...` prefix) --- common/chat-handler.cpp | 111 +++++++------------ examples/server/server.cpp | 7 +- examples/server/tests/unit/test_tool_call.py | 38 +++---- src/llama-grammar.cpp | 3 + tests/test-chat-handler.cpp | 16 +-- 5 files changed, 73 insertions(+), 102 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 069db4bee785a..aaef05dfddaf9 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -344,7 +344,7 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } -static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; @@ -379,24 +379,31 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c return true; }; + auto has_function = false; foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime - if (handle_builtin_tool(name, parameters)) { + if (allow_python_tag_builtin_tools && handle_builtin_tool(name, parameters)) { return; } builder.resolve_refs(parameters); tool_rules.push_back( builder.add_rule( name + "-call", - "\"{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + "\"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + builder.add_schema(name + "-args", parameters) + " \"}\"")); data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); + has_function = true; }); + if (has_function) { + data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true}); + } if (!builtin_tools.empty()) { data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); } @@ -407,79 +414,44 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); - data.format = "llama 3.1 tool calls"; - data.parser = [params](const std::string & input) -> common_chat_msg { - static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": "); + data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : ""); + data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg { + static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); static std::regex close_regex("\\}"); static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); - std::smatch match; - if (std::regex_match(input, match, builtin_call_regex)) { - auto name = match[1].str(); - auto raw_args = match[2].str(); + if (allow_python_tag_builtin_tools && !builtin_tools.empty()) { + std::smatch match; + if (std::regex_match(input, match, builtin_call_regex)) { + auto name = match[1].str(); + auto raw_args = match[2].str(); - // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. - auto it_eq = raw_args.find('='); - auto arg_name = raw_args.substr(0, it_eq); - auto arg_value_str = raw_args.substr(it_eq + 1); - auto arg_value = json::parse(arg_value_str); + // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. + auto it_eq = raw_args.find('='); + auto arg_name = raw_args.substr(0, it_eq); + auto arg_value_str = raw_args.substr(it_eq + 1); + auto arg_value = json::parse(arg_value_str); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ match[1], - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ match[1], + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }, }, - }, - }; + }; + } } return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); }; return data; } -static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) { - common_chat_data data; - - data.grammar_lazy = params.tool_choice != "required"; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - - foreach_function(params.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - builder.resolve_refs(parameters); - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"{\" " - // " ( \"\\\"type\\\": \\\"function\\\", \" | space ) " - "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + - " \"}\"")); - data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); - }); - - builder.add_rule("root", string_join(tool_rules, " | ")); - }, grammar_options); - data.additional_stops.push_back("<|eom_id|>"); - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {}); - data.format = "llama 3.2 tool calls"; - data.parser = [params](const std::string & input) { - static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); - static std::regex close_regex("\\}"); - auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); - return res; - }; - return data; -} - static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { common_chat_data data; data.grammar_lazy = params.tool_choice != "required"; @@ -559,8 +531,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_data data; - data.grammar_lazy = params.tool_choice != "required"; if (!params.tools.is_null() && !params.tools.empty()) { + data.grammar_lazy = params.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; @@ -806,13 +778,8 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params); } if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { - auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos; - - if (uses_python_tag) { - return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params); - } else { - return common_chat_init_llama_3_2_tool_calls(tmpl, params); - } + auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; + return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); } if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { return common_chat_init_deepseek_r1_tool_call(tmpl, params); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0658cbdb6171f..c5ba7c2b2e033 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3800,6 +3800,8 @@ int main(int argc, char ** argv) { /* .grammar = */ json_value(data, "grammar", std::string("")), }); LOG_INF("Chat format: %s\n", chat_data.format.c_str()); + LOG_DBG("Prompt: %s\n", chat_data.prompt.get().c_str()); + LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str()); if (data.contains("grammar")) { if (!chat_data.grammar.empty()) { throw std::runtime_error("Cannot provide grammar and tools"); @@ -3841,11 +3843,11 @@ int main(int argc, char ** argv) { for (const auto & trigger : chat_data.grammar_triggers) { auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { - LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]); + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); task.params.sampling.grammar_trigger_tokens.push_back(ids[0]); continue; } - LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str()); + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); task.params.sampling.grammar_trigger_words.push_back(trigger); } task.params.antiprompt = chat_data.additional_stops; @@ -4021,6 +4023,7 @@ int main(int argc, char ** argv) { }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + LOG_DBG("request: %s\n", req.body.c_str()); if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 69d0b63bc9c43..b65255ea284ef 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -58,7 +58,7 @@ def create_server(): "required":["location"] } } -}# TODO: fix this crash +} def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): @@ -132,8 +132,8 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, @pytest.mark.slow @pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [ - (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), @@ -231,7 +231,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t @pytest.mark.slow @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ # TODO: fix this crash - # ("meetkai-functionary-medium-v3.2", 256, [], None), + ("meetkai-functionary-medium-v3.2", 256, [], None), ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'), ("meetkai-functionary-medium-v3.1", 256, [], None), @@ -247,9 +247,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t @pytest.mark.slow @pytest.mark.parametrize("hf_repo,hf_file,template_override", [ - # TODO: fix these - # ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), ("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), ("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), ("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), @@ -259,6 +257,8 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t ("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai/functionary-medium-v3.2", None)), ("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)), ("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # TODO: fix this (times out) + # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), ]) def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server @@ -276,7 +276,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[ res = server.make_request("POST", "/chat/completions", data={ "max_tokens": 256, "messages": [ - # {"role": "system", "content": "Use tools as appropriate."}, {"role": "user", "content": "What is the weather in Istanbul?"}, ], "tools": [WEATHER_TOOL], @@ -295,21 +294,21 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[ @pytest.mark.slow -@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [ - # TODO: fix these - # ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - # (None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), +@pytest.mark.parametrize("expected_arguments_override,hf_repo,hf_file,template_override", [ + (None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - (None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), + # TODO: fix this (times out) + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), ]) -def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): +def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server server.n_slots = 1 server.jinja = True @@ -319,7 +318,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_ server.model_hf_file = hf_file if template_override: (template_hf_repo, template_variant) = template_override - server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start(timeout_seconds=15*60) res = server.make_request("POST", "/chat/completions", data={ @@ -327,7 +326,6 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_ "messages": [ {"role": "system", "content": "You are a coding assistant."}, {"role": "user", "content": "say hello world with python"}, - # {"role": "user", "content": "Print a hello world message with python"}, ], "tools": [PYTHON_TOOL], # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test. @@ -342,8 +340,8 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_ tool_call = tool_calls[0] assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] actual_arguments = tool_call["function"]["arguments"] - if expected_arguments is not None: - assert actual_arguments == expected_arguments + if expected_arguments_override is not None: + assert actual_arguments == expected_arguments_override else: actual_arguments = json.loads(actual_arguments) assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index cd57987736b8f..6be5cbe0e76fd 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1170,6 +1170,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { // TODO: consider a smarter incremental substring search algorithm (store last position to search from). @@ -1181,9 +1182,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token auto constrained_str = grammar.trigger_buffer.substr(pos); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); return; } } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); return; } } diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 1beb2fa5c8faa..ecdd02bc80ca5 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -372,8 +372,8 @@ static void test_template_output_parsers() { const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params)); + assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, @@ -384,26 +384,26 @@ static void test_template_output_parsers() { "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params)); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, - "{\"arg1\": 1}"); + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; - assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); + assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); + "{\"arg1\": 1}"); } { const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", "");