Skip to content

Commit

Permalink
Tool calling support in xAI, Mistral, Together, Cohere, Groq, Sambanova.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitprasad15 committed Jan 24, 2025
1 parent 1a7be38 commit ca67c72
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 194 deletions.
158 changes: 142 additions & 16 deletions aisuite/providers/cohere_provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,133 @@
import os
import cohere

import json
from aisuite.framework import ChatCompletionResponse
from aisuite.provider import Provider
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function
from aisuite.provider import Provider, LLMError


class CohereMessageConverter:
"""
Cohere-specific message converter
"""

def convert_request(self, messages):
"""Convert framework messages to Cohere format."""
converted_messages = []

for message in messages:
if isinstance(message, dict):
role = message.get("role")
content = message.get("content")
tool_calls = message.get("tool_calls")
tool_plan = message.get("tool_plan")
else:
role = message.role
content = message.content
tool_calls = message.tool_calls
tool_plan = getattr(message, "tool_plan", None)

# Convert to Cohere's format
if role == "tool":
# Handle tool response messages
converted_message = {
"role": role,
"tool_call_id": (
message.get("tool_call_id")
if isinstance(message, dict)
else message.tool_call_id
),
"content": self._convert_tool_content(content),
}
elif role == "assistant" and tool_calls:
# Handle assistant messages with tool calls
converted_message = {
"role": role,
"tool_calls": [
{
"id": tc.id if not isinstance(tc, dict) else tc["id"],
"function": {
"name": (
tc.function.name
if not isinstance(tc, dict)
else tc["function"]["name"]
),
"arguments": (
tc.function.arguments
if not isinstance(tc, dict)
else tc["function"]["arguments"]
),
},
"type": "function",
}
for tc in tool_calls
],
"tool_plan": tool_plan,
}
if content:
converted_message["content"] = content
else:
# Handle regular messages
converted_message = {"role": role, "content": content}

converted_messages.append(converted_message)

return converted_messages

def _convert_tool_content(self, content):
"""Convert tool response content to Cohere's expected format."""
if isinstance(content, str):
try:
# Try to parse as JSON first
data = json.loads(content)
return [{"type": "document", "document": {"data": json.dumps(data)}}]
except json.JSONDecodeError:
# If not JSON, return as plain text
return content
elif isinstance(content, list):
# If content is already in Cohere's format, return as is
return content
else:
# For other types, convert to string
return str(content)

@staticmethod
def convert_response(response_data) -> ChatCompletionResponse:
"""Convert Cohere's response to our standard format."""
normalized_response = ChatCompletionResponse()

# Set usage information
normalized_response.usage = {
"prompt_tokens": response_data.usage.tokens.input_tokens,
"completion_tokens": response_data.usage.tokens.output_tokens,
"total_tokens": response_data.usage.tokens.input_tokens
+ response_data.usage.tokens.output_tokens,
}

# Handle tool calls
if response_data.finish_reason == "TOOL_CALL":
tool_call = response_data.message.tool_calls[0]
function = Function(
name=tool_call.function.name, arguments=tool_call.function.arguments
)
tool_call_obj = ChatCompletionMessageToolCall(
id=tool_call.id, function=function, type="function"
)
normalized_response.choices[0].message = Message(
content=response_data.message.tool_plan, # Use tool_plan as content
tool_calls=[tool_call_obj],
role="assistant",
refusal=None,
)
normalized_response.choices[0].finish_reason = "tool_calls"
else:
# Handle regular text response
normalized_response.choices[0].message.content = (
response_data.message.content[0].text
)
normalized_response.choices[0].finish_reason = "stop"

return normalized_response


class CohereProvider(Provider):
Expand All @@ -15,23 +140,24 @@ def __init__(self, **config):
config.setdefault("api_key", os.getenv("CO_API_KEY"))
if not config["api_key"]:
raise ValueError(
" API key is missing. Please provide it in the config or set the CO_API_KEY environment variable."
"Cohere API key is missing. Please provide it in the config or set the CO_API_KEY environment variable."
)
self.client = cohere.ClientV2(**config)
self.transformer = CohereMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
response = self.client.chat(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the Cohere API
)
"""
Makes a request to Cohere using the official client.
"""
try:
# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)

return self.normalize_response(response)
# Make the request to Cohere
response = self.client.chat(
model=model, messages=transformed_messages, **kwargs
)

def normalize_response(self, response):
"""Normalize the reponse from Cohere API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.message.content[
0
].text
return normalized_response
return self.transformer.convert_response(response)
except Exception as e:
raise LLMError(f"An error occurred: {e}")
55 changes: 30 additions & 25 deletions aisuite/providers/groq_provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os

import groq
from aisuite.provider import Provider
from aisuite.framework.message import Message
from aisuite.provider import Provider, LLMError
from aisuite.providers.message_converter import OpenAICompliantMessageConverter

# Implementation of Groq provider.
# Groq's message format is same as OpenAI's.
Expand All @@ -21,37 +20,43 @@
# gemma2-9b-it (parallel tool use not supported)


class GroqMessageConverter(OpenAICompliantMessageConverter):
"""
Groq-specific message converter if needed
"""

pass


class GroqProvider(Provider):
def __init__(self, **config):
"""
Initialize the Groq provider with the given configuration.
Pass the entire configuration dictionary to the Groq client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("GROQ_API_KEY"))
if not config["api_key"]:
self.api_key = config.get("api_key", os.getenv("GROQ_API_KEY"))
if not self.api_key:
raise ValueError(
" API key is missing. Please provide it in the config or set the GROQ_API_KEY environment variable."
"Groq API key is missing. Please provide it in the config or set the GROQ_API_KEY environment variable."
)
config["api_key"] = self.api_key
self.client = groq.Groq(**config)
self.transformer = GroqMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
transformed_messages = []
for message in messages:
if isinstance(message, Message):
transformed_messages.append(self.transform_from_messages(message))
else:
transformed_messages.append(message)
return self.client.chat.completions.create(
model=model,
messages=transformed_messages,
**kwargs # Pass any additional arguments to the Groq API
)

# Transform framework Message to a format that Groq understands.
def transform_from_messages(self, message: Message):
return message.model_dump(mode="json")

# Transform Groq message (dict) to a format that the framework Message understands.
def transform_to_message(self, message_dict: dict):
return Message(**message_dict)
"""
Makes a request to the Groq chat completions endpoint using the official client.
"""
try:
# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)

response = self.client.chat.completions.create(
model=model,
messages=transformed_messages,
**kwargs, # Pass any additional arguments to the Groq API
)
return self.transformer.convert_response(response.model_dump())
except Exception as e:
raise LLMError(f"An error occurred: {e}")
57 changes: 57 additions & 0 deletions aisuite/providers/message_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall


class OpenAICompliantMessageConverter:
"""
Base class for message converters that are compatible with OpenAI's API.
"""

# Class variable that derived classes can override
tool_results_as_strings = False

@staticmethod
def convert_request(messages):
"""Convert messages to OpenAI-compatible format."""
transformed_messages = []
for message in messages:
tmsg = None
if isinstance(message, Message):
message_dict = message.model_dump(mode="json")
message_dict.pop("refusal", None) # Remove refusal field if present
tmsg = message_dict
else:
tmsg = message
if tmsg["role"] == "tool":
if OpenAICompliantMessageConverter.tool_results_as_strings:
tmsg["content"] = str(tmsg["content"])

transformed_messages.append(tmsg)
return transformed_messages

@staticmethod
def convert_response(response_data) -> ChatCompletionResponse:
"""Normalize the response to match OpenAI's response format."""
print(response_data)
completion_response = ChatCompletionResponse()
choice = response_data["choices"][0]
message = choice["message"]

# Set basic message content
completion_response.choices[0].message.content = message["content"]
completion_response.choices[0].message.role = message.get("role", "assistant")

# Handle tool calls if present
if "tool_calls" in message and message["tool_calls"] is not None:
tool_calls = []
for tool_call in message["tool_calls"]:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.get("id"),
type="function", # Always set to "function" as it's the only valid value
function=tool_call.get("function"),
)
)
completion_response.choices[0].message.tool_calls = tool_calls

return completion_response
50 changes: 20 additions & 30 deletions aisuite/providers/mistral_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from mistralai import Mistral
from aisuite.framework.message import Message
from aisuite.framework import ChatCompletionResponse
from aisuite.provider import Provider, LLMError
from aisuite.providers.message_converter import OpenAICompliantMessageConverter


# Implementation of Mistral provider.
Expand All @@ -12,36 +12,19 @@
# https://docs.mistral.ai/capabilities/function_calling/


class MistralMessageConverter:
@staticmethod
def convert_request(messages):
"""Convert messages to Mistral format."""
transformed_messages = []
for message in messages:
if isinstance(message, Message):
message_dict = message.model_dump(mode="json")
message_dict.pop("refusal", None) # Remove refusal field if present
transformed_messages.append(message_dict)
else:
transformed_messages.append(message)
return transformed_messages
class MistralMessageConverter(OpenAICompliantMessageConverter):
"""
Mistral-specific message converter
"""

@staticmethod
def convert_response(response) -> ChatCompletionResponse:
"""Normalize the response from Mistral to match OpenAI's response format."""
completion_response = ChatCompletionResponse()
choice = response.choices[0]
message = choice.message

# Set basic message content
completion_response.choices[0].message.content = message.content
completion_response.choices[0].message.role = message.role

# Handle tool calls if present
if hasattr(message, "tool_calls") and message.tool_calls:
completion_response.choices[0].message.tool_calls = message.tool_calls

return completion_response
def convert_response(response_data) -> ChatCompletionResponse:
"""Convert Mistral's response to our standard format."""
# Convert Mistral's response object to dict format
response_dict = response_data.model_dump()
return super(MistralMessageConverter, MistralMessageConverter).convert_response(
response_dict
)


# Function calling is available for the following models:
Expand All @@ -55,6 +38,10 @@ def convert_response(response) -> ChatCompletionResponse:
# Mixtral 8x22B
# Mistral Nemo
class MistralProvider(Provider):
"""
Mistral AI Provider using the official Mistral client.
"""

def __init__(self, **config):
"""
Initialize the Mistral provider with the given configuration.
Expand All @@ -64,12 +51,15 @@ def __init__(self, **config):
config.setdefault("api_key", os.getenv("MISTRAL_API_KEY"))
if not config["api_key"]:
raise ValueError(
" API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable."
"Mistral API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable."
)
self.client = Mistral(**config)
self.transformer = MistralMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to Mistral using the official client.
"""
try:
# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)
Expand Down
Loading

0 comments on commit ca67c72

Please sign in to comment.