Skip to content

Commit

Permalink
Tool calling support. Part I
Browse files Browse the repository at this point in the history
Add tool calling support for below providers -
OpenAI, Groq, Anthropic, AWS, & Mistral.
OpenAI compatible SDKs need to changes for
tool calling support.

Adding utility ToolManager for users to easily
supply tools, and parse model's request for
tool usage.
  • Loading branch information
rohitcpbot committed Nov 28, 2024
1 parent 1b5da0e commit 9946507
Show file tree
Hide file tree
Showing 12 changed files with 865 additions and 92 deletions.
1 change: 1 addition & 0 deletions aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import Client
from .framework.message import Message
1 change: 1 addition & 0 deletions aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
from .message import Message
6 changes: 5 additions & 1 deletion aisuite/framework/choice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from aisuite.framework.message import Message
from typing import Literal, Optional


class Choice:
def __init__(self):
self.message = Message()
self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None
self.message = Message(
content=None, tool_calls=None, role="assistant", refusal=None
)
22 changes: 19 additions & 3 deletions aisuite/framework/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
"""Interface to hold contents of api responses when they do not conform to the OpenAI style response"""

from pydantic import BaseModel
from typing import Literal, Optional

class Message:
def __init__(self):
self.content = None

class Function(BaseModel):
arguments: str
name: str


class ChatCompletionMessageToolCall(BaseModel):
id: str
function: Function
type: Literal["function"]


class Message(BaseModel):
content: Optional[str]
tool_calls: Optional[list[ChatCompletionMessageToolCall]]
role: Optional[Literal["user", "assistant", "system"]]
refusal: Optional[str]
216 changes: 202 additions & 14 deletions aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import anthropic
import json
from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function

# Define a constant for the default max_tokens value
DEFAULT_MAX_TOKENS = 4096

# Links:
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use


class AnthropicProvider(Provider):
# Add these at the class level, after the class definition
FINISH_REASON_MAPPING = {
"end_turn": "stop",
"max_tokens": "length",
"tool_use": "tool_calls",
# Add more mappings as needed
}

# Role constants
ROLE_USER = "user"
ROLE_ASSISTANT = "assistant"
ROLE_TOOL = "tool"
ROLE_SYSTEM = "system"

def __init__(self, **config):
"""
Initialize the Anthropic provider with the given configuration.
Expand All @@ -15,26 +34,195 @@ def __init__(self, **config):

self.client = anthropic.Anthropic(**config)

def convert_request(self, messages):
"""Convert framework messages to Anthropic format."""
return [self._convert_single_message(msg) for msg in messages]

def _convert_single_message(self, msg):
"""Convert a single message to Anthropic format."""
if isinstance(msg, dict):
return self._convert_dict_message(msg)
return self._convert_message_object(msg)

def _convert_dict_message(self, msg):
"""Convert a dictionary message to Anthropic format."""
if msg["role"] == self.ROLE_TOOL:
return self._create_tool_result_message(msg["tool_call_id"], msg["content"])
elif msg["role"] == self.ROLE_ASSISTANT and "tool_calls" in msg:
return self._create_assistant_tool_message(
msg["content"], msg["tool_calls"]
)
return {"role": msg["role"], "content": msg["content"]}

def _convert_message_object(self, msg):
"""Convert a Message object to Anthropic format."""
if msg.role == self.ROLE_TOOL:
return self._create_tool_result_message(msg.tool_call_id, msg.content)
elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls:
return self._create_assistant_tool_message(msg.content, msg.tool_calls)
return {"role": msg.role, "content": msg.content}

def _create_tool_result_message(self, tool_call_id, content):
"""Create a tool result message in Anthropic format."""
return {
"role": self.ROLE_USER,
"content": [
{
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": content,
}
],
}

def _create_assistant_tool_message(self, content, tool_calls):
"""Create an assistant message with tool calls in Anthropic format."""
message_content = []
if content:
message_content.append({"type": "text", "text": content})

for tool_call in tool_calls:
tool_input = (
tool_call["function"]["arguments"]
if isinstance(tool_call, dict)
else tool_call.function.arguments
)
message_content.append(
{
"type": "tool_use",
"id": (
tool_call["id"] if isinstance(tool_call, dict) else tool_call.id
),
"name": (
tool_call["function"]["name"]
if isinstance(tool_call, dict)
else tool_call.function.name
),
"input": json.loads(tool_input),
}
)

return {"role": self.ROLE_ASSISTANT, "content": message_content}

def chat_completions_create(self, model, messages, **kwargs):
# Check if the fist message is a system message
if messages[0]["role"] == "system":
"""Create a chat completion using the Anthropic API."""
system_message = self._extract_system_message(messages)
kwargs = self._prepare_kwargs(kwargs)
converted_messages = self.convert_request(messages)

response = self.client.messages.create(
model=model, system=system_message, messages=converted_messages, **kwargs
)
return self.convert_response(response)

def _extract_system_message(self, messages):
"""Extract system message if present, otherwise return empty list."""
if messages and messages[0]["role"] == "system":
system_message = messages[0]["content"]
messages = messages[1:]
else:
system_message = []
messages.pop(0)
return system_message
return []

# kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS)
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
def _prepare_kwargs(self, kwargs):
"""Prepare kwargs for the API call."""
kwargs = kwargs.copy() # Create a copy to avoid modifying the original
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)

return self.normalize_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
if "tools" in kwargs:
kwargs["tools"] = self._convert_tool_spec(kwargs["tools"])

return kwargs

def convert_response_with_tool_use(self, response):
"""Convert Anthropic tool use response to the framework's format."""
# Find the tool_use content
tool_call = next(
(content for content in response.content if content.type == "tool_use"),
None,
)

def normalize_response(self, response):
if tool_call:
function = Function(
name=tool_call.name, arguments=json.dumps(tool_call.input)
)
tool_call_obj = ChatCompletionMessageToolCall(
id=tool_call.id, function=function, type="function"
)
# Get the text content if any
text_content = next(
(
content.text
for content in response.content
if content.type == "text"
),
"",
)

return Message(
content=text_content or None,
tool_calls=[tool_call_obj] if tool_call else None,
role="assistant",
refusal=None,
)
return None

def convert_response(self, response):
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.content[0].text

normalized_response.choices[0].finish_reason = self._get_finish_reason(response)
normalized_response.usage = self._get_usage_stats(response)
normalized_response.choices[0].message = self._get_message(response)

return normalized_response

def _get_finish_reason(self, response):
"""Get the normalized finish reason."""
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop")

def _get_usage_stats(self, response):
"""Get the usage statistics."""
return {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}

def _get_message(self, response):
"""Get the appropriate message based on response type."""
if response.stop_reason == "tool_use":
tool_message = self.convert_response_with_tool_use(response)
if tool_message:
return tool_message

return Message(
content=response.content[0].text,
role="assistant",
tool_calls=None,
refusal=None,
)

def _convert_tool_spec(self, openai_tools):
"""Convert OpenAI tool specification to Anthropic format."""
anthropic_tools = []

for tool in openai_tools:
# Only handle function-type tools from OpenAI
if tool.get("type") != "function":
continue

function = tool["function"]

anthropic_tool = {
"name": function["name"],
"description": function["description"],
"input_schema": {
"type": "object",
"properties": function["parameters"]["properties"],
"required": function["parameters"].get("required", []),
},
}

anthropic_tools.append(anthropic_tool)

return anthropic_tools
Loading

0 comments on commit 9946507

Please sign in to comment.