diff --git a/aisuite/__init__.py b/aisuite/__init__.py index 3ff722bf..92e57f18 100644 --- a/aisuite/__init__.py +++ b/aisuite/__init__.py @@ -1 +1,2 @@ from .client import Client +from .framework.message import Message diff --git a/aisuite/framework/__init__.py b/aisuite/framework/__init__.py index aad7ebd2..bc7d71c4 100644 --- a/aisuite/framework/__init__.py +++ b/aisuite/framework/__init__.py @@ -1,2 +1,3 @@ from .provider_interface import ProviderInterface from .chat_completion_response import ChatCompletionResponse +from .message import Message diff --git a/aisuite/framework/choice.py b/aisuite/framework/choice.py index 3542da57..4557f46b 100644 --- a/aisuite/framework/choice.py +++ b/aisuite/framework/choice.py @@ -1,6 +1,10 @@ from aisuite.framework.message import Message +from typing import Literal, Optional class Choice: def __init__(self): - self.message = Message() + self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None + self.message = Message( + content=None, tool_calls=None, role="assistant", refusal=None + ) diff --git a/aisuite/framework/message.py b/aisuite/framework/message.py index 5aa7f822..26be291f 100644 --- a/aisuite/framework/message.py +++ b/aisuite/framework/message.py @@ -1,6 +1,22 @@ """Interface to hold contents of api responses when they do not conform to the OpenAI style response""" +from pydantic import BaseModel +from typing import Literal, Optional -class Message: - def __init__(self): - self.content = None + +class Function(BaseModel): + arguments: str + name: str + + +class ChatCompletionMessageToolCall(BaseModel): + id: str + function: Function + type: Literal["function"] + + +class Message(BaseModel): + content: Optional[str] + tool_calls: Optional[list[ChatCompletionMessageToolCall]] + role: Optional[Literal["user", "assistant", "system"]] + refusal: Optional[str] diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index f63c054c..3e1776c4 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -1,12 +1,31 @@ import anthropic +import json from aisuite.provider import Provider from aisuite.framework import ChatCompletionResponse +from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function # Define a constant for the default max_tokens value DEFAULT_MAX_TOKENS = 4096 +# Links: +# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use + class AnthropicProvider(Provider): + # Add these at the class level, after the class definition + FINISH_REASON_MAPPING = { + "end_turn": "stop", + "max_tokens": "length", + "tool_use": "tool_calls", + # Add more mappings as needed + } + + # Role constants + ROLE_USER = "user" + ROLE_ASSISTANT = "assistant" + ROLE_TOOL = "tool" + ROLE_SYSTEM = "system" + def __init__(self, **config): """ Initialize the Anthropic provider with the given configuration. @@ -15,26 +34,195 @@ def __init__(self, **config): self.client = anthropic.Anthropic(**config) + def convert_request(self, messages): + """Convert framework messages to Anthropic format.""" + return [self._convert_single_message(msg) for msg in messages] + + def _convert_single_message(self, msg): + """Convert a single message to Anthropic format.""" + if isinstance(msg, dict): + return self._convert_dict_message(msg) + return self._convert_message_object(msg) + + def _convert_dict_message(self, msg): + """Convert a dictionary message to Anthropic format.""" + if msg["role"] == self.ROLE_TOOL: + return self._create_tool_result_message(msg["tool_call_id"], msg["content"]) + elif msg["role"] == self.ROLE_ASSISTANT and "tool_calls" in msg: + return self._create_assistant_tool_message( + msg["content"], msg["tool_calls"] + ) + return {"role": msg["role"], "content": msg["content"]} + + def _convert_message_object(self, msg): + """Convert a Message object to Anthropic format.""" + if msg.role == self.ROLE_TOOL: + return self._create_tool_result_message(msg.tool_call_id, msg.content) + elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls: + return self._create_assistant_tool_message(msg.content, msg.tool_calls) + return {"role": msg.role, "content": msg.content} + + def _create_tool_result_message(self, tool_call_id, content): + """Create a tool result message in Anthropic format.""" + return { + "role": self.ROLE_USER, + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": content, + } + ], + } + + def _create_assistant_tool_message(self, content, tool_calls): + """Create an assistant message with tool calls in Anthropic format.""" + message_content = [] + if content: + message_content.append({"type": "text", "text": content}) + + for tool_call in tool_calls: + tool_input = ( + tool_call["function"]["arguments"] + if isinstance(tool_call, dict) + else tool_call.function.arguments + ) + message_content.append( + { + "type": "tool_use", + "id": ( + tool_call["id"] if isinstance(tool_call, dict) else tool_call.id + ), + "name": ( + tool_call["function"]["name"] + if isinstance(tool_call, dict) + else tool_call.function.name + ), + "input": json.loads(tool_input), + } + ) + + return {"role": self.ROLE_ASSISTANT, "content": message_content} + def chat_completions_create(self, model, messages, **kwargs): - # Check if the fist message is a system message - if messages[0]["role"] == "system": + """Create a chat completion using the Anthropic API.""" + system_message = self._extract_system_message(messages) + kwargs = self._prepare_kwargs(kwargs) + converted_messages = self.convert_request(messages) + + response = self.client.messages.create( + model=model, system=system_message, messages=converted_messages, **kwargs + ) + return self.convert_response(response) + + def _extract_system_message(self, messages): + """Extract system message if present, otherwise return empty list.""" + if messages and messages[0]["role"] == "system": system_message = messages[0]["content"] - messages = messages[1:] - else: - system_message = [] + messages.pop(0) + return system_message + return [] - # kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS) - if "max_tokens" not in kwargs: - kwargs["max_tokens"] = DEFAULT_MAX_TOKENS + def _prepare_kwargs(self, kwargs): + """Prepare kwargs for the API call.""" + kwargs = kwargs.copy() # Create a copy to avoid modifying the original + kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) - return self.normalize_response( - self.client.messages.create( - model=model, system=system_message, messages=messages, **kwargs - ) + if "tools" in kwargs: + kwargs["tools"] = self._convert_tool_spec(kwargs["tools"]) + + return kwargs + + def convert_response_with_tool_use(self, response): + """Convert Anthropic tool use response to the framework's format.""" + # Find the tool_use content + tool_call = next( + (content for content in response.content if content.type == "tool_use"), + None, ) - def normalize_response(self, response): + if tool_call: + function = Function( + name=tool_call.name, arguments=json.dumps(tool_call.input) + ) + tool_call_obj = ChatCompletionMessageToolCall( + id=tool_call.id, function=function, type="function" + ) + # Get the text content if any + text_content = next( + ( + content.text + for content in response.content + if content.type == "text" + ), + "", + ) + + return Message( + content=text_content or None, + tool_calls=[tool_call_obj] if tool_call else None, + role="assistant", + refusal=None, + ) + return None + + def convert_response(self, response): """Normalize the response from the Anthropic API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() - normalized_response.choices[0].message.content = response.content[0].text + + normalized_response.choices[0].finish_reason = self._get_finish_reason(response) + normalized_response.usage = self._get_usage_stats(response) + normalized_response.choices[0].message = self._get_message(response) + return normalized_response + + def _get_finish_reason(self, response): + """Get the normalized finish reason.""" + return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop") + + def _get_usage_stats(self, response): + """Get the usage statistics.""" + return { + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + } + + def _get_message(self, response): + """Get the appropriate message based on response type.""" + if response.stop_reason == "tool_use": + tool_message = self.convert_response_with_tool_use(response) + if tool_message: + return tool_message + + return Message( + content=response.content[0].text, + role="assistant", + tool_calls=None, + refusal=None, + ) + + def _convert_tool_spec(self, openai_tools): + """Convert OpenAI tool specification to Anthropic format.""" + anthropic_tools = [] + + for tool in openai_tools: + # Only handle function-type tools from OpenAI + if tool.get("type") != "function": + continue + + function = tool["function"] + + anthropic_tool = { + "name": function["name"], + "description": function["description"], + "input_schema": { + "type": "object", + "properties": function["parameters"]["properties"], + "required": function["parameters"].get("required", []), + }, + } + + anthropic_tools.append(anthropic_tool) + + return anthropic_tools diff --git a/aisuite/providers/aws_provider.py b/aisuite/providers/aws_provider.py index 10f48afe..70e52e94 100644 --- a/aisuite/providers/aws_provider.py +++ b/aisuite/providers/aws_provider.py @@ -1,93 +1,243 @@ import os +import json +from typing import List, Dict, Any, Tuple, Optional import boto3 from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse +from aisuite.framework.message import Message +import botocore -class AwsProvider(Provider): +class BedrockConfig: + INFERENCE_PARAMETERS = ["maxTokens", "temperature", "topP", "stopSequences"] + def __init__(self, **config): - """ - Initialize the AWS Bedrock provider with the given configuration. + self.region_name = config.get( + "region_name", os.getenv("AWS_REGION_NAME", "us-west-2") + ) - This class uses the AWS Bedrock converse API, which provides a consistent interface - for all Amazon Bedrock models that support messages. Examples include: - - anthropic.claude-v2 - - meta.llama3-70b-instruct-v1:0 - - mistral.mixtral-8x7b-instruct-v0:1 + def create_client(self): + return boto3.client("bedrock-runtime", region_name=self.region_name) - The model value can be a baseModelId for on-demand throughput or a provisionedModelArn - for higher throughput. To obtain a provisionedModelArn, use the CreateProvisionedModelThroughput API. - For more information on model IDs, see: - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +class BedrockMessageConverter: + @staticmethod + def convert_request( + messages: List[Dict[str, Any]] + ) -> Tuple[List[Dict], List[Dict]]: + """Convert messages to AWS Bedrock format.""" + # Convert all messages to dicts if they're Message objects + messages = [ + message.model_dump() if hasattr(message, "model_dump") else message + for message in messages + ] - Note: - - The Anthropic Bedrock client uses default AWS credential providers, such as - ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. - - If the region is not set, it defaults to us-west-1, which may lead to a - "Could not connect to the endpoint URL" error. - - The client constructor does not accept additional parameters. + # Handle system message + system_message = [] + if messages and messages[0]["role"] == "system": + system_message = [{"text": messages[0]["content"]}] + messages = messages[1:] - Args: - **config: Configuration options for the provider. + formatted_messages = [] + for message in messages: + # Skip any additional system messages + if message["role"] == "system": + continue - """ - self.region_name = config.get( - "region_name", os.getenv("AWS_REGION_NAME", "us-west-2") - ) - self.client = boto3.client("bedrock-runtime", region_name=self.region_name) - self.inference_parameters = [ - "maxTokens", - "temperature", - "topP", - "stopSequences", - ] + if message["role"] == "tool": + bedrock_message = BedrockMessageConverter.convert_tool_result(message) + if bedrock_message: + formatted_messages.append(bedrock_message) + elif message["role"] == "assistant": + bedrock_message = BedrockMessageConverter.convert_assistant(message) + if bedrock_message: + formatted_messages.append(bedrock_message) + else: # user messages + formatted_messages.append( + { + "role": message["role"], + "content": [{"text": message["content"]}], + } + ) + + return system_message, formatted_messages + + @staticmethod + def convert_response_tool_call( + response: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """Convert AWS Bedrock tool call response to OpenAI format.""" + if response.get("stopReason") != "tool_use": + return None + + tool_calls = [] + for content in response["output"]["message"]["content"]: + if "toolUse" in content: + tool = content["toolUse"] + tool_calls.append( + { + "type": "function", + "id": tool["toolUseId"], + "function": { + "name": tool["name"], + "arguments": json.dumps(tool["input"]), + }, + } + ) + + if not tool_calls: + return None + + return { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + "refusal": None, + } + + @staticmethod + def convert_tool_result(message: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Convert OpenAI tool result format to AWS Bedrock format.""" + if message["role"] != "tool" or "content" not in message: + return None + + tool_call_id = message.get("tool_call_id") + if not tool_call_id: + raise LLMError("Tool result message must include tool_call_id") + + try: + content_json = json.loads(message["content"]) + content = [{"json": content_json}] + except json.JSONDecodeError: + content = [{"text": message["content"]}] - def normalize_response(self, response): + return { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": tool_call_id, "content": content}} + ], + } + + @staticmethod + def convert_assistant(message: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Convert OpenAI assistant format to AWS Bedrock format.""" + if message["role"] != "assistant": + return None + + content = [] + + if message.get("content"): + content.append({"text": message["content"]}) + + if message.get("tool_calls"): + for tool_call in message["tool_calls"]: + if tool_call["type"] == "function": + try: + input_json = json.loads(tool_call["function"]["arguments"]) + except json.JSONDecodeError: + input_json = tool_call["function"]["arguments"] + + content.append( + { + "toolUse": { + "toolUseId": tool_call["id"], + "name": tool_call["function"]["name"], + "input": input_json, + } + } + ) + + return {"role": "assistant", "content": content} if content else None + + +class AwsProvider(Provider): + def __init__(self, **config): + """Initialize the AWS Bedrock provider with the given configuration.""" + self.config = BedrockConfig(**config) + self.client = self.config.create_client() + self.transformer = BedrockMessageConverter() + + def convert_response(self, response: Dict[str, Any]) -> ChatCompletionResponse: """Normalize the response from the Bedrock API to match OpenAI's response format.""" norm_response = ChatCompletionResponse() + + # Check if the model is requesting tool use + if response.get("stopReason") == "tool_use": + tool_message = self.transformer.convert_response_tool_call(response) + if tool_message: + norm_response.choices[0].message = Message(**tool_message) + norm_response.choices[0].finish_reason = "tool_calls" + return norm_response + + # Handle regular text response norm_response.choices[0].message.content = response["output"]["message"][ "content" ][0]["text"] return norm_response - def chat_completions_create(self, model, messages, **kwargs): - # Any exception raised by Anthropic will be returned to the caller. - # Maybe we should catch them and raise a custom LLMError. - # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html - system_message = [] - if messages[0]["role"] == "system": - system_message = [{"text": messages[0]["content"]}] - messages = messages[1:] + def _convert_tool_spec(self, kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Convert tool specifications to Bedrock format.""" + if "tools" not in kwargs: + return None - formatted_messages = [] - for message in messages: - # QUIETLY Ignore any "system" messages except the first system message. - if message["role"] != "system": - formatted_messages.append( - {"role": message["role"], "content": [{"text": message["content"]}]} - ) + tool_config = { + "tools": [ + { + "toolSpec": { + "name": tool["function"]["name"], + "description": tool["function"].get("description", " "), + "inputSchema": {"json": tool["function"]["parameters"]}, + } + } + for tool in kwargs["tools"] + ] + } + return tool_config - # Maintain a list of Inference Parameters which Bedrock supports. - # These fields need to be passed using inferenceConfig. - # Rest all other fields are passed as additionalModelRequestFields. - inference_config = {} - additional_model_request_fields = {} + def _prepare_request_config(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Prepare the configuration for the Bedrock API request.""" + # Convert tools and remove from kwargs + tool_config = self._convert_tool_spec(kwargs) + kwargs.pop("tools", None) # Remove tools from kwargs if present - # Iterate over the kwargs and separate the inference parameters and additional model request fields. - for key, value in kwargs.items(): - if key in self.inference_parameters: - inference_config[key] = value + inference_config = { + key: kwargs[key] + for key in BedrockConfig.INFERENCE_PARAMETERS + if key in kwargs + } + + additional_fields = { + key: value + for key, value in kwargs.items() + if key not in BedrockConfig.INFERENCE_PARAMETERS + } + + return { + "inferenceConfig": inference_config, + "additionalModelRequestFields": additional_fields, + "toolConfig": tool_config, + } + + def chat_completions_create( + self, model: str, messages: List[Dict[str, Any]], **kwargs + ) -> ChatCompletionResponse: + """Create a chat completion request to AWS Bedrock.""" + system_message, formatted_messages = self.transformer.convert_request(messages) + request_config = self._prepare_request_config(kwargs) + + try: + response = self.client.converse( + modelId=model, + messages=formatted_messages, + system=system_message, + **request_config + ) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + error_message = e.response["Error"]["Message"] + raise LLMError(error_message) else: - additional_model_request_fields[key] = value - - # Call the Bedrock Converse API. - response = self.client.converse( - modelId=model, # baseModelId or provisionedModelArn - messages=formatted_messages, - system=system_message, - inferenceConfig=inference_config, - additionalModelRequestFields=additional_model_request_fields, - ) - return self.normalize_response(response) + raise + + return self.convert_response(response) diff --git a/aisuite/providers/groq_provider.py b/aisuite/providers/groq_provider.py index 73a4bb4f..a25fcc51 100644 --- a/aisuite/providers/groq_provider.py +++ b/aisuite/providers/groq_provider.py @@ -2,6 +2,23 @@ import groq from aisuite.provider import Provider +from aisuite.framework.message import Message + +# Implementation of Groq provider. +# Groq's message format is same as OpenAI's. +# Tool calling specification is also exactly the same as OpenAI's. +# Links: +# https://console.groq.com/docs/tool-use +# Groq supports tool calling for the following models, as of 16th Nov 2024: +# llama3-groq-70b-8192-tool-use-preview +# llama3-groq-8b-8192-tool-use-preview +# llama-3.1-70b-versatile +# llama-3.1-8b-instant +# llama3-70b-8192 +# llama3-8b-8192 +# mixtral-8x7b-32768 (parallel tool use not supported) +# gemma-7b-it (parallel tool use not supported) +# gemma2-9b-it (parallel tool use not supported) class GroqProvider(Provider): @@ -19,8 +36,22 @@ def __init__(self, **config): self.client = groq.Groq(**config) def chat_completions_create(self, model, messages, **kwargs): + transformed_messages = [] + for message in messages: + if isinstance(message, Message): + transformed_messages.append(self.transform_from_messages(message)) + else: + transformed_messages.append(message) return self.client.chat.completions.create( model=model, - messages=messages, + messages=transformed_messages, **kwargs # Pass any additional arguments to the Groq API ) + + # Transform framework Message to a format that Groq understands. + def transform_from_messages(self, message: Message): + return message.model_dump(mode="json") + + # Transform Groq message (dict) to a format that the framework Message understands. + def transform_to_message(self, message_dict: dict): + return Message(**message_dict) diff --git a/aisuite/providers/mistral_provider.py b/aisuite/providers/mistral_provider.py index 6e96f1cf..08c1804a 100644 --- a/aisuite/providers/mistral_provider.py +++ b/aisuite/providers/mistral_provider.py @@ -1,10 +1,15 @@ import os from mistralai import Mistral +from aisuite.framework.message import Message from aisuite.provider import Provider +# Implementation of Mistral provider. +# Mistral's message format is same as OpenAI's. Just different class names, but fully cross-compatible. +# Links: +# https://docs.mistral.ai/capabilities/function_calling/ class MistralProvider(Provider): def __init__(self, **config): """ @@ -20,4 +25,22 @@ def __init__(self, **config): self.client = Mistral(**config) def chat_completions_create(self, model, messages, **kwargs): - return self.client.chat.complete(model=model, messages=messages, **kwargs) + # If message is of type Message, transform it to a format that Mistral understands. + transformed_messages = [] + for message in messages: + if isinstance(message, Message): + transformed_messages.append(self.transform_from_messages(message)) + else: + transformed_messages.append(message) + + return self.client.chat.complete( + model=model, messages=transformed_messages, **kwargs + ) + + # Transform framework Message to a format that Mistral understands. + def transform_from_messages(self, message: Message): + return message.model_dump(mode="json") + + # Transform Mistral message (dict) to a format that the framework Message understands. + def transform_to_message(self, message_dict: dict): + return Message(**message_dict) diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index b8fe2fb8..fdd32d49 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -26,8 +26,10 @@ def __init__(self, **config): def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by OpenAI will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. - return self.client.chat.completions.create( + response = self.client.chat.completions.create( model=model, messages=messages, **kwargs # Pass any additional arguments to the OpenAI API ) + + return response diff --git a/aisuite/utils/tool_manager.py b/aisuite/utils/tool_manager.py new file mode 100644 index 00000000..f12ca343 --- /dev/null +++ b/aisuite/utils/tool_manager.py @@ -0,0 +1,169 @@ +from typing import Callable, Dict, Any, Type, Optional +from pydantic import BaseModel, create_model, Field, ValidationError +import inspect +import json + + +class ToolManager: + def __init__(self): + self._tools = {} + + # Add a tool function with or without a Pydantic model. + def add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = None): + """Register a tool function with metadata. If no param_model is provided, infer from function signature.""" + if param_model: + tool_spec = self._convert_to_tool_spec(func, param_model) + else: + tool_spec, param_model = self._infer_from_signature(func) + + self._tools[func.__name__] = { + "function": func, + "param_model": param_model, + "spec": tool_spec, + } + + def tools(self, format="openai") -> list: + """Return tools in the specified format (default OpenAI).""" + if format == "openai": + return self._convert_to_openai_format() + return [tool["spec"] for tool in self._tools.values()] + + def _convert_to_tool_spec( + self, func: Callable, param_model: Type[BaseModel] + ) -> Dict[str, Any]: + """Convert the function and its Pydantic model to a unified tool specification.""" + type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean"} + + properties = {} + for field_name, field in param_model.model_fields.items(): + field_type = field.annotation + + # Handle enum types + if hasattr(field_type, "__members__"): # Check if it's an enum + enum_values = [ + member.value if hasattr(member, "value") else member.name + for member in field_type + ] + properties[field_name] = { + "type": "string", + "enum": enum_values, + "description": field.description or "", + } + # Convert enum default value to string if it exists + if str(field.default) != "PydanticUndefined": + properties[field_name]["default"] = ( + field.default.value + if hasattr(field.default, "value") + else field.default + ) + else: + properties[field_name] = { + "type": type_mapping.get(field_type, str(field_type)), + "description": field.description or "", + } + # Add default if it exists and isn't PydanticUndefined + if str(field.default) != "PydanticUndefined": + properties[field_name]["default"] = field.default + + return { + "name": func.__name__, + "description": func.__doc__ or "", + "parameters": { + "type": "object", + "properties": properties, + "required": [ + name + for name, field in param_model.model_fields.items() + if field.is_required and str(field.default) == "PydanticUndefined" + ], + }, + } + + def _infer_from_signature( + self, func: Callable + ) -> tuple[Dict[str, Any], Type[BaseModel]]: + """Infer parameters(required and optional) and requirements directly from the function signature.""" + signature = inspect.signature(func) + fields = {} + required_fields = [] + + # Get function's docstring + docstring = inspect.getdoc(func) or " " + + for param_name, param in signature.parameters.items(): + # Check if a type annotation is missing + if param.annotation == inspect._empty: + raise TypeError( + f"Parameter '{param_name}' in function '{func.__name__}' must have a type annotation." + ) + + # Determine field type and optionality + param_type = param.annotation + if param.default == inspect._empty: + fields[param_name] = (param_type, ...) + required_fields.append(param_name) + else: + fields[param_name] = (param_type, Field(default=param.default)) + + # Dynamically create a Pydantic model based on inferred fields + param_model = create_model(f"{func.__name__.capitalize()}Params", **fields) + + # Convert inferred model to a tool spec format + tool_spec = self._convert_to_tool_spec(func, param_model) + + # Update the tool spec with the docstring + tool_spec["description"] = docstring + + return tool_spec, param_model + + def _convert_to_openai_format(self) -> list: + """Convert tools to OpenAI's format.""" + return [ + {"type": "function", "function": tool["spec"]} + for tool in self._tools.values() + ] + + def execute_tool(self, tool_calls) -> tuple[list, list]: + """Executes registered tools based on the tool calls from the model. + + Args: + tool_calls: List of tool calls from the model + + Returns: + List of tuples containing (result, result_message) for each tool call + """ + results = [] + messages = [] + + # Handle single tool call or list of tool calls + if not isinstance(tool_calls, list): + tool_calls = [tool_calls] + + for tool_call in tool_calls: + tool_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + + if tool_name not in self._tools: + raise ValueError(f"Tool '{tool_name}' not registered.") + + tool = self._tools[tool_name] + tool_func = tool["function"] + param_model = tool["param_model"] + + # Validate and parse the arguments with Pydantic if a model exists + try: + validated_args = param_model(**arguments) + result = tool_func(**validated_args.model_dump()) + results.append(result) + messages.append( + { + "role": "tool", + "name": tool_name, + "content": json.dumps(result), + "tool_call_id": tool_call.id, # Include the tool call ID in the response + } + ) + except ValidationError as e: + raise ValueError(f"Error in tool '{tool_name}' parameters: {e}") + + return results, messages diff --git a/examples/client.ipynb b/examples/client.ipynb index e99f2f50..ef789de3 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "f75736ee", "metadata": {}, "outputs": [], @@ -69,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "4de3a24f", "metadata": { "ExecuteTime": { @@ -231,6 +231,71 @@ "response = client.chat.completions.create(model=togetherai_model, messages=messages, temperature=0.75, top_p=0.7, top_k=50)\n", "print(response.choices[0].message.content)" ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "aadad67e-a098-4aed-b940-6ebd7174edbe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sending messages to Mistral: [{'role': 'system', 'content': 'Respond in Pirate English. Always try to include the phrase - No rum No fun.'}, {'role': 'user', 'content': 'Tell me a joke about Captain Jack Sparrow'}]\n", + "Arr matey, did ye hear the one about Cap'n Jack Sparrow? He walked into a tavern and ordered a rum, but the barkeep said they were out. Jack just smiled and said, \"No rum? No fun. But ye know what? If I be havin' no fun, then neither be ye!\" And with a wink and a smile, he walked out, leavin' the tavern in a sudden, mysterious shortage of ale too! Savvy?\n" + ] + } + ], + "source": [ + "mistral_large = \"mistral:mistral-large-latest\"\n", + "response = client.chat.completions.create(model=mistral_large, messages=messages)\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "546b5991-2cca-4b13-b8ff-19ee03088783", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "id='67dc3e5f92e84207b755998c5ce8c654' object='chat.completion' model='mistral-large-latest' usage=UsageInfo(prompt_tokens=34, completion_tokens=106, total_tokens=140) created=1731797496 choices=[ChatCompletionChoice(index=0, message=AssistantMessage(content='Arr matey, did ye hear the one about Cap\\'n Jack Sparrow? He walked into a tavern and ordered a rum, but the barkeep said they were out. Jack just smiled and said, \"No rum? No fun. But ye know what? If I be havin\\' no fun, then neither be ye!\" And with a wink and a smile, he walked out, leavin\\' the tavern in a sudden, mysterious shortage of ale too! Savvy?', tool_calls=None, prefix=False, role='assistant'), finish_reason='stop')]\n" + ] + } + ], + "source": [ + "print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2e214e97-0d81-43d0-b168-722c2a06c438", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "print(type(response.choices[0].message))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5394294e-5958-438b-a860-26a0af365a6b", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -249,7 +314,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/examples/simple_tool_calling.ipynb b/examples/simple_tool_calling.ipynb new file mode 100644 index 00000000..579bab66 --- /dev/null +++ b/examples/simple_tool_calling.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import sys\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "sys.path.append('../../aisuite')\n", + "\n", + "# Load from .env file if available\n", + "load_dotenv(find_dotenv())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import aisuite as ai\n", + "from aisuite.utils.tool_manager import ToolManager # Import your ToolManager class\n", + "\n", + "client = ai.Client()\n", + "tool_manager = ToolManager()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mock tool functions.\n", + "def get_current_temperature(location: str, unit: str):\n", + " # Simulate fetching temperature from an API\n", + " return {\"location\": location, \"unit\": unit, \"temperature\": 72}\n", + "\n", + "tool_manager.add_tool(get_current_temperature)\n", + "messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model = \"anthropic:claude-3-5-sonnet-20240620\"\n", + "# model = openai:gpt-4o\n", + "# model = mistral:mistral-large-latest\n", + "model = \"aws:anthropic.claude-3-haiku-20240307-v1:0\"\n", + "# model = \"groq:llama-3.1-70b-versatile\"\n", + "response = client.chat.completions.create(\n", + " model=model, messages=messages, tools=tool_manager.tools())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "pprint(vars(response.choices[0].message))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if response.choices[0].message.tool_calls:\n", + " tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n", + " messages.append(response.choices[0].message) # Model's function call message\n", + " messages.append(result_as_message[0])\n", + "\n", + " final_response = client.chat.completions.create(\n", + " model=model, messages=messages, tools=tool_manager.tools())\n", + " print(final_response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now, test without tool calling.\n", + "messages = [{\"role\": \"user\", \"content\": \"What is the capital of California?\"}]\n", + "response = client.chat.completions.create(\n", + " model=model, messages=messages)\n", + "print(response.choices[0].message.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}