Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Tool Implementation improvements™️ #193

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ templates/*
templates/tool_calls/*
!templates/tool_calls
!templates/tool_calls/chatml_with_headers.jinja
!templates/tool_calls/hermes3_jitsys.jinja

# Sampler overrides folder
sampler_overrides/*
Expand Down
3 changes: 2 additions & 1 deletion backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
build_token_enforcer_tokenizer_data,
)
from loguru import logger
from typing import List
from typing import List, Type
from functools import lru_cache
from pydantic import BaseModel


class OutlinesTokenizerWrapper:
Expand Down
2 changes: 1 addition & 1 deletion common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from loguru import logger
from pydantic import AliasChoices, BaseModel, Field
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Type

from common.utils import unwrap, prune_dict

Expand Down
20 changes: 9 additions & 11 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
from endpoints.OAI.types.tools import ToolCall

from endpoints.OAI.utils.tools import(
postprocess_tool_call,
generate_strict_schemas
)


def _create_response(
Expand Down Expand Up @@ -433,8 +437,11 @@ async def generate_tool_calls(

# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
if data.tools:
strict_schema = generate_strict_schemas(data)
tool_data = deepcopy(data)
tool_data.json_schema = tool_data.tool_call_schema
#tool_data.json_schema = tool_data.tool_call_schema
tool_data.json_schema = strict_schema # needs strict flag
gen_params = tool_data.to_gen_params()

for idx, gen in enumerate(generations):
Expand Down Expand Up @@ -464,12 +471,3 @@ async def generate_tool_calls(
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]

return generations


def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_calls = json.loads(call_str)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]
100 changes: 100 additions & 0 deletions endpoints/OAI/utils/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Support functions to enable tool calling"""

from typing import List, Dict
import json

from endpoints.OAI.types.tools import ToolCall
from endpoints.OAI.types.chat_completion import ChatCompletionRequest

def postprocess_tool_call(call_str: str) -> List[ToolCall]:
print(call_str)
tool_calls = json.loads(call_str)
print(tool_calls)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]

def generate_strict_schemas(data: ChatCompletionRequest) -> Dict:
# Generate the $defs section
defs = generate_defs(data.tools)

# Generate the root structure (now an array)
root_structure = {
"type": "array",
"items": {"$ref": "#/$defs/ModelItem"}
}

# Combine the $defs and root structure
full_schema = {
"$defs": defs,
**root_structure
}

return full_schema

def generate_defs(tools: List) -> Dict:
defs = {}

for i, tool in enumerate(tools):
function_name = f"Function{i + 1}" if i > 0 else "Function"
arguments_name = f"Arguments{i + 1}" if i > 0 else "Arguments"
name_const = f"Name{i + 1}" if i > 0 else "Name"

# Generate Arguments schema
defs[arguments_name] = generate_arguments_schema(tool.function.parameters)

# Generate Name schema
defs[name_const] = {
"const": tool.function.name,
"title": name_const,
"type": "string"
}

# Generate Function schema
defs[function_name] = {
"type": "object",
"properties": {
"name": {"$ref": f"#/$defs/{name_const}"},
"arguments": {"$ref": f"#/$defs/{arguments_name}"}
},
"required": ["name", "arguments"]
}

# Add ModelItem and Type schemas
defs["ModelItem"] = generate_model_item_schema(len(tools))
defs["Type"] = {
"const": "function",
"type": "string"
}

return defs

def generate_arguments_schema(parameters: Dict) -> Dict:
properties = {}
required = parameters.get('required', [])

for name, details in parameters.get('properties', {}).items():
properties[name] = {"type": details['type']}

return {
"type": "object",
"properties": properties,
"required": required
}

def generate_model_item_schema(num_functions: int) -> Dict:
function_refs = [{"$ref": f"#/$defs/Function{i + 1}" if i > 0 else "#/$defs/Function"} for i in range(num_functions)]

return {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"anyOf": function_refs
},
"type": {"$ref": "#/$defs/Type"}
},
"required": ["id", "function", "type"]
}
4 changes: 3 additions & 1 deletion templates/tool_calls/chatml_with_headers.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ Argument Types: Use correct data types for arguments (e.g., strings in quotes, n
{%- set content = message['content'] | default('', true) | trim -%}
{%- if loop.first -%}
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
{{ inital_system_prompt }}
{%- if 'tools_json' in message and message['tools_json'] -%}
{{ inital_system_prompt }}
{%- endif -%}

{{ content }}{{ eos_token }}
{%- endif -%}
Expand Down
68 changes: 68 additions & 0 deletions templates/tool_calls/hermes3_jitsys.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{# Metadata #}
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
{% set tool_start = "<tool_call>" %}
{% set tool_end = "</tool_call>" %}
{% set response_start = "<tool_response>" %}
{% set response_end = "</tool_response>" %}
{% set sys_prompt = "" %}


{%- set inital_system_prompt -%}
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {{ tools_json }} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
{{ tool_start }}
{"name": <function-name>, "arguments": <args-dict>}
{{ tool_end }}{{ eos_token }}
{%- endset -%}

{%- set tool_reminder -%}Available Tools:
{{ tools_json }}

Tool Call Format Example:
{{ tool_start }}{{ example_tool_call }}

Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
{%- endset -%}

{# Template #}

{%- for message in messages -%}
{%- set role = message['role'] | lower -%}
{%- if role not in message_roles -%}
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
{%- endif -%}
{%- set content = message['content'] | default('', true) | trim -%}
{%- if loop.first and role == "system" -%}
{%- set sys_prompt = content -%}
{%- endif -%}
{%- if not loop.first -%}
{%- if role == 'tool' -%}
{{ role }}
{{ response_start}}
{{ content }}
{{ response_end}}
{%- else -%}
{{ role }}
{{ content }}
{%- endif -%}
{%- endif -%}
{%- if 'tool_calls_json' in message and message['tool_calls_json'] -%}
{{ tool_start }}
{{ message['tool_calls_json']}}
{{ tool_end }}
{%- endif -%}
{{ eos_token }}
{%- endfor -%}

{{ role }}
{{ inital_system_prompt }}

{{ sys_prompt }}{{ eos_token }}

{%- if tool_precursor -%}
assistant
{{ tool_precursor }}{{ tool_start }}
{%- else -%}
assistant
{%- endif -%}
Loading