diff --git a/aisuite/providers/cohere_provider.py b/aisuite/providers/cohere_provider.py index 5886f24..9639a3e 100644 --- a/aisuite/providers/cohere_provider.py +++ b/aisuite/providers/cohere_provider.py @@ -1,8 +1,133 @@ import os import cohere - +import json from aisuite.framework import ChatCompletionResponse -from aisuite.provider import Provider +from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function +from aisuite.provider import Provider, LLMError + + +class CohereMessageConverter: + """ + Cohere-specific message converter + """ + + def convert_request(self, messages): + """Convert framework messages to Cohere format.""" + converted_messages = [] + + for message in messages: + if isinstance(message, dict): + role = message.get("role") + content = message.get("content") + tool_calls = message.get("tool_calls") + tool_plan = message.get("tool_plan") + else: + role = message.role + content = message.content + tool_calls = message.tool_calls + tool_plan = getattr(message, "tool_plan", None) + + # Convert to Cohere's format + if role == "tool": + # Handle tool response messages + converted_message = { + "role": role, + "tool_call_id": ( + message.get("tool_call_id") + if isinstance(message, dict) + else message.tool_call_id + ), + "content": self._convert_tool_content(content), + } + elif role == "assistant" and tool_calls: + # Handle assistant messages with tool calls + converted_message = { + "role": role, + "tool_calls": [ + { + "id": tc.id if not isinstance(tc, dict) else tc["id"], + "function": { + "name": ( + tc.function.name + if not isinstance(tc, dict) + else tc["function"]["name"] + ), + "arguments": ( + tc.function.arguments + if not isinstance(tc, dict) + else tc["function"]["arguments"] + ), + }, + "type": "function", + } + for tc in tool_calls + ], + "tool_plan": tool_plan, + } + if content: + converted_message["content"] = content + else: + # Handle regular messages + converted_message = {"role": role, "content": content} + + converted_messages.append(converted_message) + + return converted_messages + + def _convert_tool_content(self, content): + """Convert tool response content to Cohere's expected format.""" + if isinstance(content, str): + try: + # Try to parse as JSON first + data = json.loads(content) + return [{"type": "document", "document": {"data": json.dumps(data)}}] + except json.JSONDecodeError: + # If not JSON, return as plain text + return content + elif isinstance(content, list): + # If content is already in Cohere's format, return as is + return content + else: + # For other types, convert to string + return str(content) + + @staticmethod + def convert_response(response_data) -> ChatCompletionResponse: + """Convert Cohere's response to our standard format.""" + normalized_response = ChatCompletionResponse() + + # Set usage information + normalized_response.usage = { + "prompt_tokens": response_data.usage.tokens.input_tokens, + "completion_tokens": response_data.usage.tokens.output_tokens, + "total_tokens": response_data.usage.tokens.input_tokens + + response_data.usage.tokens.output_tokens, + } + + # Handle tool calls + if response_data.finish_reason == "TOOL_CALL": + tool_call = response_data.message.tool_calls[0] + function = Function( + name=tool_call.function.name, arguments=tool_call.function.arguments + ) + tool_call_obj = ChatCompletionMessageToolCall( + id=tool_call.id, function=function, type="function" + ) + normalized_response.choices[0].message = Message( + content=response_data.message.tool_plan, # Use tool_plan as content + tool_calls=[tool_call_obj], + role="assistant", + refusal=None, + ) + normalized_response.choices[0].finish_reason = "tool_calls" + else: + # Handle regular text response + normalized_response.choices[0].message.content = ( + response_data.message.content[0].text + ) + normalized_response.choices[0].finish_reason = "stop" + + return normalized_response class CohereProvider(Provider): @@ -15,23 +140,24 @@ def __init__(self, **config): config.setdefault("api_key", os.getenv("CO_API_KEY")) if not config["api_key"]: raise ValueError( - " API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." + "Cohere API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." ) self.client = cohere.ClientV2(**config) + self.transformer = CohereMessageConverter() def chat_completions_create(self, model, messages, **kwargs): - response = self.client.chat( - model=model, - messages=messages, - **kwargs # Pass any additional arguments to the Cohere API - ) + """ + Makes a request to Cohere using the official client. + """ + try: + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) - return self.normalize_response(response) + # Make the request to Cohere + response = self.client.chat( + model=model, messages=transformed_messages, **kwargs + ) - def normalize_response(self, response): - """Normalize the reponse from Cohere API to match OpenAI's response format.""" - normalized_response = ChatCompletionResponse() - normalized_response.choices[0].message.content = response.message.content[ - 0 - ].text - return normalized_response + return self.transformer.convert_response(response) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/aisuite/providers/groq_provider.py b/aisuite/providers/groq_provider.py index a25fcc5..07d29fd 100644 --- a/aisuite/providers/groq_provider.py +++ b/aisuite/providers/groq_provider.py @@ -1,8 +1,7 @@ import os - import groq -from aisuite.provider import Provider -from aisuite.framework.message import Message +from aisuite.provider import Provider, LLMError +from aisuite.providers.message_converter import OpenAICompliantMessageConverter # Implementation of Groq provider. # Groq's message format is same as OpenAI's. @@ -21,6 +20,14 @@ # gemma2-9b-it (parallel tool use not supported) +class GroqMessageConverter(OpenAICompliantMessageConverter): + """ + Groq-specific message converter if needed + """ + + pass + + class GroqProvider(Provider): def __init__(self, **config): """ @@ -28,30 +35,28 @@ def __init__(self, **config): Pass the entire configuration dictionary to the Groq client constructor. """ # Ensure API key is provided either in config or via environment variable - config.setdefault("api_key", os.getenv("GROQ_API_KEY")) - if not config["api_key"]: + self.api_key = config.get("api_key", os.getenv("GROQ_API_KEY")) + if not self.api_key: raise ValueError( - " API key is missing. Please provide it in the config or set the GROQ_API_KEY environment variable." + "Groq API key is missing. Please provide it in the config or set the GROQ_API_KEY environment variable." ) + config["api_key"] = self.api_key self.client = groq.Groq(**config) + self.transformer = GroqMessageConverter() def chat_completions_create(self, model, messages, **kwargs): - transformed_messages = [] - for message in messages: - if isinstance(message, Message): - transformed_messages.append(self.transform_from_messages(message)) - else: - transformed_messages.append(message) - return self.client.chat.completions.create( - model=model, - messages=transformed_messages, - **kwargs # Pass any additional arguments to the Groq API - ) - - # Transform framework Message to a format that Groq understands. - def transform_from_messages(self, message: Message): - return message.model_dump(mode="json") - - # Transform Groq message (dict) to a format that the framework Message understands. - def transform_to_message(self, message_dict: dict): - return Message(**message_dict) + """ + Makes a request to the Groq chat completions endpoint using the official client. + """ + try: + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + response = self.client.chat.completions.create( + model=model, + messages=transformed_messages, + **kwargs, # Pass any additional arguments to the Groq API + ) + return self.transformer.convert_response(response.model_dump()) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/aisuite/providers/message_converter.py b/aisuite/providers/message_converter.py new file mode 100644 index 0000000..2c5c9d2 --- /dev/null +++ b/aisuite/providers/message_converter.py @@ -0,0 +1,57 @@ +from aisuite.framework import ChatCompletionResponse +from aisuite.framework.message import Message, ChatCompletionMessageToolCall + + +class OpenAICompliantMessageConverter: + """ + Base class for message converters that are compatible with OpenAI's API. + """ + + # Class variable that derived classes can override + tool_results_as_strings = False + + @staticmethod + def convert_request(messages): + """Convert messages to OpenAI-compatible format.""" + transformed_messages = [] + for message in messages: + tmsg = None + if isinstance(message, Message): + message_dict = message.model_dump(mode="json") + message_dict.pop("refusal", None) # Remove refusal field if present + tmsg = message_dict + else: + tmsg = message + if tmsg["role"] == "tool": + if OpenAICompliantMessageConverter.tool_results_as_strings: + tmsg["content"] = str(tmsg["content"]) + + transformed_messages.append(tmsg) + return transformed_messages + + @staticmethod + def convert_response(response_data) -> ChatCompletionResponse: + """Normalize the response to match OpenAI's response format.""" + print(response_data) + completion_response = ChatCompletionResponse() + choice = response_data["choices"][0] + message = choice["message"] + + # Set basic message content + completion_response.choices[0].message.content = message["content"] + completion_response.choices[0].message.role = message.get("role", "assistant") + + # Handle tool calls if present + if "tool_calls" in message and message["tool_calls"] is not None: + tool_calls = [] + for tool_call in message["tool_calls"]: + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.get("id"), + type="function", # Always set to "function" as it's the only valid value + function=tool_call.get("function"), + ) + ) + completion_response.choices[0].message.tool_calls = tool_calls + + return completion_response diff --git a/aisuite/providers/mistral_provider.py b/aisuite/providers/mistral_provider.py index a9f6943..4fc28fa 100644 --- a/aisuite/providers/mistral_provider.py +++ b/aisuite/providers/mistral_provider.py @@ -1,9 +1,9 @@ import os - from mistralai import Mistral from aisuite.framework.message import Message from aisuite.framework import ChatCompletionResponse from aisuite.provider import Provider, LLMError +from aisuite.providers.message_converter import OpenAICompliantMessageConverter # Implementation of Mistral provider. @@ -12,36 +12,19 @@ # https://docs.mistral.ai/capabilities/function_calling/ -class MistralMessageConverter: - @staticmethod - def convert_request(messages): - """Convert messages to Mistral format.""" - transformed_messages = [] - for message in messages: - if isinstance(message, Message): - message_dict = message.model_dump(mode="json") - message_dict.pop("refusal", None) # Remove refusal field if present - transformed_messages.append(message_dict) - else: - transformed_messages.append(message) - return transformed_messages +class MistralMessageConverter(OpenAICompliantMessageConverter): + """ + Mistral-specific message converter + """ @staticmethod - def convert_response(response) -> ChatCompletionResponse: - """Normalize the response from Mistral to match OpenAI's response format.""" - completion_response = ChatCompletionResponse() - choice = response.choices[0] - message = choice.message - - # Set basic message content - completion_response.choices[0].message.content = message.content - completion_response.choices[0].message.role = message.role - - # Handle tool calls if present - if hasattr(message, "tool_calls") and message.tool_calls: - completion_response.choices[0].message.tool_calls = message.tool_calls - - return completion_response + def convert_response(response_data) -> ChatCompletionResponse: + """Convert Mistral's response to our standard format.""" + # Convert Mistral's response object to dict format + response_dict = response_data.model_dump() + return super(MistralMessageConverter, MistralMessageConverter).convert_response( + response_dict + ) # Function calling is available for the following models: @@ -55,6 +38,10 @@ def convert_response(response) -> ChatCompletionResponse: # Mixtral 8x22B # Mistral Nemo class MistralProvider(Provider): + """ + Mistral AI Provider using the official Mistral client. + """ + def __init__(self, **config): """ Initialize the Mistral provider with the given configuration. @@ -64,12 +51,15 @@ def __init__(self, **config): config.setdefault("api_key", os.getenv("MISTRAL_API_KEY")) if not config["api_key"]: raise ValueError( - " API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable." + "Mistral API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable." ) self.client = Mistral(**config) self.transformer = MistralMessageConverter() def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to Mistral using the official client. + """ try: # Transform messages using converter transformed_messages = self.transformer.convert_request(messages) diff --git a/aisuite/providers/nebius_provider.py b/aisuite/providers/nebius_provider.py index c558a9c..568cca8 100644 --- a/aisuite/providers/nebius_provider.py +++ b/aisuite/providers/nebius_provider.py @@ -6,6 +6,7 @@ BASE_URL = "https://api.studio.nebius.ai/v1" +# TODO(rohitcp): This needs to be added to our internal testbed. Tool calling not tested. class NebiusProvider(Provider): def __init__(self, **config): """ diff --git a/aisuite/providers/sambanova_provider.py b/aisuite/providers/sambanova_provider.py index 75a9731..850fb15 100644 --- a/aisuite/providers/sambanova_provider.py +++ b/aisuite/providers/sambanova_provider.py @@ -1,30 +1,53 @@ import os -from aisuite.provider import Provider +from aisuite.provider import Provider, LLMError from openai import OpenAI +from aisuite.providers.message_converter import OpenAICompliantMessageConverter + + +class SambanovaMessageConverter(OpenAICompliantMessageConverter): + """ + SambaNova-specific message converter. + """ + + pass class SambanovaProvider(Provider): + """ + SambaNova Provider using OpenAI client for API calls. + """ + def __init__(self, **config): """ Initialize the SambaNova provider with the given configuration. Pass the entire configuration dictionary to the OpenAI client constructor. """ # Ensure API key is provided either in config or via environment variable - config.setdefault("api_key", os.getenv("SAMBANOVA_API_KEY")) - if not config["api_key"]: + self.api_key = config.get("api_key", os.getenv("SAMBANOVA_API_KEY")) + if not self.api_key: raise ValueError( "Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." ) + config["api_key"] = self.api_key config["base_url"] = "https://api.sambanova.ai/v1/" # Pass the entire config to the OpenAI client constructor self.client = OpenAI(**config) + self.transformer = SambanovaMessageConverter() def chat_completions_create(self, model, messages, **kwargs): - # Any exception raised by Sambanova will be returned to the caller. - # Maybe we should catch them and raise a custom LLMError. - return self.client.chat.completions.create( - model=model, - messages=messages, - **kwargs # Pass any additional arguments to the Sambanova API - ) + """ + Makes a request to the SambaNova chat completions endpoint using the OpenAI client. + """ + try: + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + response = self.client.chat.completions.create( + model=model, + messages=transformed_messages, + **kwargs, # Pass any additional arguments to the Sambanova API + ) + return self.transformer.convert_response(response.model_dump()) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/aisuite/providers/together_provider.py b/aisuite/providers/together_provider.py index 45cbce0..fdfb02a 100644 --- a/aisuite/providers/together_provider.py +++ b/aisuite/providers/together_provider.py @@ -1,49 +1,15 @@ import os import httpx from aisuite.provider import Provider, LLMError -from aisuite.framework import ChatCompletionResponse -from aisuite.framework.message import Message, ChatCompletionMessageToolCall +from aisuite.providers.message_converter import OpenAICompliantMessageConverter -class TogetherMessageConverter: - @staticmethod - def convert_request(messages): - """Convert messages to Together format.""" - transformed_messages = [] - for message in messages: - if isinstance(message, Message): - message_dict = message.model_dump(mode="json") - message_dict.pop("refusal", None) # Remove refusal field if present - transformed_messages.append(message_dict) - else: - transformed_messages.append(message) - return transformed_messages - - @staticmethod - def convert_response(response_data) -> ChatCompletionResponse: - """Normalize the response from Together to match OpenAI's response format.""" - completion_response = ChatCompletionResponse() - choice = response_data["choices"][0] - message = choice["message"] - - # Set basic message content - completion_response.choices[0].message.content = message["content"] - completion_response.choices[0].message.role = message.get("role", "assistant") - - # Handle tool calls if present - if "tool_calls" in message and message["tool_calls"] is not None: - tool_calls = [] - for tool_call in message["tool_calls"]: - tool_calls.append( - ChatCompletionMessageToolCall( - id=tool_call.get("id"), - type=tool_call.get("type"), - function=tool_call.get("function"), - ) - ) - completion_response.choices[0].message.tool_calls = tool_calls +class TogetherMessageConverter(OpenAICompliantMessageConverter): + """ + Together-specific message converter if needed + """ - return completion_response + pass class TogetherProvider(Provider): diff --git a/aisuite/providers/xai_provider.py b/aisuite/providers/xai_provider.py index 53e8d83..23effb0 100644 --- a/aisuite/providers/xai_provider.py +++ b/aisuite/providers/xai_provider.py @@ -2,6 +2,15 @@ import httpx from aisuite.provider import Provider, LLMError from aisuite.framework import ChatCompletionResponse +from aisuite.providers.message_converter import OpenAICompliantMessageConverter + + +class XaiMessageConverter(OpenAICompliantMessageConverter): + """ + xAI-specific message converter if needed + """ + + pass class XaiProvider(Provider): @@ -24,11 +33,15 @@ def __init__(self, **config): # Optionally set a custom timeout (default to 30s) self.timeout = config.get("timeout", 30) + self.transformer = XaiMessageConverter() def chat_completions_create(self, model, messages, **kwargs): """ Makes a request to the xAI chat completions endpoint using httpx. """ + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", @@ -36,7 +49,7 @@ def chat_completions_create(self, model, messages, **kwargs): data = { "model": model, - "messages": messages, + "messages": transformed_messages, **kwargs, # Pass any additional arguments to the API } @@ -46,20 +59,8 @@ def chat_completions_create(self, model, messages, **kwargs): self.BASE_URL, json=data, headers=headers, timeout=self.timeout ) response.raise_for_status() + return self.transformer.convert_response(response.json()) except httpx.HTTPStatusError as http_err: raise LLMError(f"xAI request failed: {http_err}") except Exception as e: raise LLMError(f"An error occurred: {e}") - - # Return the normalized response - return self._normalize_response(response.json()) - - def _normalize_response(self, response_data): - """ - Normalize the response to a common format (ChatCompletionResponse). - """ - normalized_response = ChatCompletionResponse() - normalized_response.choices[0].message.content = response_data["choices"][0][ - "message" - ]["content"] - return normalized_response diff --git a/examples/simple_tool_calling.ipynb b/examples/simple_tool_calling.ipynb index 4487072..4a7e7cb 100644 --- a/examples/simple_tool_calling.ipynb +++ b/examples/simple_tool_calling.ipynb @@ -2,20 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import json\n", "import sys\n", @@ -37,20 +26,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "For model: xai:grok-2-latest\n", - "I'm here to help with your query! To provide the most accurate and up-to-date temperature for San Francisco, I would need to access real-time weather data. However, I don't have direct access to real-time data sources. \n", - "\n", - "I recommend checking a reliable weather website or app, such as AccuWeather, Weather.com, or the National Weather Service, to get the current temperature in San Francisco in Celsius. These platforms update frequently and will give you the most current information.\n" - ] - } - ], + "outputs": [], "source": [ "from aisuite import Client\n", "\n", @@ -111,7 +89,7 @@ "# Mock tool functions.\n", "def get_current_temperature(location: str, unit: str):\n", " # Simulate fetching temperature from an API\n", - " return {\"location\": location, \"unit\": unit, \"temperature\": 72}\n", + " return {\"temperature\": 72}\n", "\n", "def get_rain_probability(location: str):\n", " # Simulate fetching rain probability\n", @@ -197,23 +175,6 @@ " }" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = \"anthropic:claude-3-5-sonnet-20240620\"\n", - "model = \"huggingface:meta-llama/Llama-3.1-8B-Instruct\"\n", - "model = \"huggingface:meta-llama/Llama-3.3-70B-Instruct\"\n", - "# model = \"openai:gpt-4o\"\n", - "# model = mistral:mistral-large-latest\n", - "# model = \"aws:anthropic.claude-3-haiku-20240307-v1:0\"\n", - "# model = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n", - "# model = \"aws:meta.llama3-3-70b-instruct-v1:0\"\n", - "# model = \"groq:llama-3.1-70b-versatile\"" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -227,6 +188,16 @@ "metadata": {}, "outputs": [], "source": [ + "import json\n", + "import sys\n", + "from dotenv import load_dotenv, find_dotenv\n", + "import os\n", + "\n", + "sys.path.append('../../aisuite')\n", + "\n", + "# Load from .env file if available\n", + "load_dotenv(find_dotenv())\n", + "\n", "from aisuite import Client\n", "\n", "client = Client()\n", @@ -234,10 +205,23 @@ " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", " \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n", "}})\n", - "model = \"azure:aisuite-mistral-large-2407\"\n", - "model = \"fireworks:accounts/fireworks/models/llama-v3p1-405b-instruct\"\n", - "model = \"mistral:mistral-large-latest\"\n", - "model = \"together:meta-llama/Llama-3.3-70B-Instruct-Turbo\"\n", + "\n", + "# model = \"anthropic:claude-3-5-sonnet-20241022\"\n", + "# model = \"aws:mistral.mistral-7b-instruct-v0:2\"\n", + "# model = \"azure:aisuite-mistral-large\"\n", + "# model = \"cohere:command-r-plus\"\n", + "# model = \"deepseek:deepseek-chat\"\n", + "# model = \"fireworks:accounts/fireworks/models/llama-v3p1-405b-instruct\"\n", + "# model = \"google:gemini-1.5-pro-002\"\n", + "# model = \"groq:llama-3.3-70b-versatile\"\n", + "# model = \"huggingface:meta-llama/Llama-3.1-8B-Instruct\"\n", + "# model = \"mistral:mistral-large-latest\"\n", + "# model = \"nebius:\"\n", + "# model = \"ollama:\"\n", + "# model = \"sambanova:Meta-Llama-3.3-70B-Instruct\"\n", + "# model = \"together:meta-llama/Llama-3.3-70B-Instruct-Turbo\"\n", + "# model = \"watsonx:\"\n", + "model = \"xai:grok-2-latest\"\n", "\n", "messages = [{\n", " \"role\": \"user\",\n", @@ -269,16 +253,25 @@ "if response.choices[0].message.tool_calls:\n", " for tool_call in response.choices[0].message.tool_calls:\n", " tool_result = handle_tool_call(tool_call)\n", + " print(tool_result)\n", "\n", " messages.append(response.choices[0].message) # Model's function call message\n", " messages.append(create_tool_response_message(tool_call, tool_result))\n", " # Send the tool response back to the model\n", " final_response = client.chat.completions.create(\n", " model=model, messages=messages, tools=tools)\n", - "\n", + " print(final_response.choices[0].message)\n", + " \n", " # Output the final response from the model\n", " print(final_response.choices[0].message.content)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tests/providers/test_groq_provider.py b/tests/providers/test_groq_provider.py index 94e953f..e2e021e 100644 --- a/tests/providers/test_groq_provider.py +++ b/tests/providers/test_groq_provider.py @@ -22,9 +22,9 @@ def test_groq_provider(): provider = GroqProvider() mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message = MagicMock() - mock_response.choices[0].message.content = response_text_content + mock_response.model_dump.return_value = { + "choices": [{"message": {"content": response_text_content}}] + } with patch.object( provider.client.chat.completions, diff --git a/tests/providers/test_mistral_provider.py b/tests/providers/test_mistral_provider.py index 8ecc51e..937da0e 100644 --- a/tests/providers/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -21,9 +21,9 @@ def test_mistral_provider(): provider = MistralProvider() mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message = MagicMock() - mock_response.choices[0].message.content = response_text_content + mock_response.model_dump.return_value = { + "choices": [{"message": {"content": response_text_content}}] + } with patch.object( provider.client.chat, "complete", return_value=mock_response diff --git a/tests/providers/test_sambanova_provider.py b/tests/providers/test_sambanova_provider.py index b5c649e..01f5ace 100644 --- a/tests/providers/test_sambanova_provider.py +++ b/tests/providers/test_sambanova_provider.py @@ -22,9 +22,11 @@ def test_sambanova_provider(): provider = SambanovaProvider() mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message = MagicMock() - mock_response.choices[0].message.content = response_text_content + mock_response.model_dump.return_value = { + "choices": [ + {"message": {"content": response_text_content, "role": "assistant"}} + ] + } with patch.object( provider.client.chat.completions, diff --git a/tests/providers/test_watsonx_provider.py b/tests/providers/test_watsonx_provider.py index 8e7123a..710cb20 100644 --- a/tests/providers/test_watsonx_provider.py +++ b/tests/providers/test_watsonx_provider.py @@ -1,7 +1,11 @@ from unittest.mock import MagicMock, patch import pytest -from ibm_watsonx_ai.metanames import GenChatParamsMetaNames as GenChatParams + +try: + from ibm_watsonx_ai.metanames import GenChatParamsMetaNames as GenChatParams +except Exception as e: + pytest.skip(f"Skipping test due to import error: {e}", allow_module_level=True) from aisuite.providers.watsonx_provider import WatsonxProvider @@ -14,6 +18,7 @@ def set_api_key_env_var(monkeypatch): monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id") +@pytest.mark.skip(reason="Skipping due to version compatibility issue on python 3.11") def test_watsonx_provider(): """High-level test that the provider is initialized and chat completions are requested successfully."""