Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async version of the client with corresponding providers support #179

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions aisuite/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from .client import Client, Chat, Completions
from .provider import ProviderFactory


class AsyncClient(Client):
@property
def chat(self):
"""Return the async chat API interface."""
if not self._chat:
self._chat = AsyncChat(self)
return self._chat


class AsyncChat(Chat):
def __init__(self, client: "AsyncClient"):
self.client = client
self._completions = AsyncCompletions(self.client)


class AsyncCompletions(Completions):
async def create(self, model: str, messages: list, **kwargs):
"""
Create async chat completion based on the model, messages, and any extra arguments.
"""
# Check that correct format is used
if ":" not in model:
raise ValueError(
f"Invalid model format. Expected 'provider:model', got '{model}'"
)

# Extract the provider key from the model identifier, e.g., "google:gemini-xx"
provider_key, model_name = model.split(":", 1)

# Validate if the provider is supported
supported_providers = ProviderFactory.get_supported_providers()
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

# Initialize provider if not already initialized
if provider_key not in self.client.providers:
config = self.client.provider_configs.get(provider_key, {})
self.client.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config
)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Delegate the chat completion to the correct provider's async implementation
return await provider.chat_completions_create_async(model_name, messages, **kwargs)
4 changes: 4 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def chat_completions_create(self, model, messages):
"""Abstract method for chat completion calls, to be implemented by each provider."""
pass

@abstractmethod
def chat_completions_create_async(self, model, messages):
"""Abstract method for async chat completion calls, to be implemented by each provider."""
pass

class ProviderFactory:
"""Factory to dynamically load provider instances based on naming conventions."""
Expand Down
11 changes: 11 additions & 0 deletions aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class AnthropicProvider(Provider):
def __init__(self, **config):
"""Initialize the Anthropic provider with the given configuration."""
self.client = anthropic.Anthropic(**config)
self.async_client = anthropic.AsyncAnthropic(**config)
self.converter = AnthropicMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
Expand All @@ -213,6 +214,16 @@ def chat_completions_create(self, model, messages, **kwargs):
)
return self.converter.convert_response(response)

async def chat_completions_create_async(self, model, messages, **kwargs):
"""Create a chat completion using the async Anthropic API."""
kwargs = self._prepare_kwargs(kwargs)
system_message, converted_messages = self.converter.convert_request(messages)

response = await self.async_client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
return self.converter.convert_response(response)

def _prepare_kwargs(self, kwargs):
"""Prepare kwargs for the API call."""
kwargs = kwargs.copy()
Expand Down
52 changes: 52 additions & 0 deletions aisuite/providers/fireworks_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,58 @@ def chat_completions_create(self, model, messages, **kwargs):
except Exception as e:
raise LLMError(f"An error occurred: {e}")

async def chat_completions_create_async(self, model, messages, **kwargs):
"""
Makes an async request to the Fireworks AI chat completions endpoint.
"""
# Remove 'stream' from kwargs if present
kwargs.pop("stream", None)

# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)

# Prepare the request payload
data = {
"model": model,
"messages": transformed_messages,
}

# Add tools if provided
if "tools" in kwargs:
data["tools"] = kwargs["tools"]
kwargs.pop("tools")

# Add tool_choice if provided
if "tool_choice" in kwargs:
data["tool_choice"] = kwargs["tool_choice"]
kwargs.pop("tool_choice")

# Add remaining kwargs
data.update(kwargs)

headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

async with httpx.AsyncClient() as client:
try:
# Make the async request to Fireworks AI endpoint
response = await client.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
return self.transformer.convert_response(response.json())
except httpx.HTTPStatusError as error:
error_message = (
f"The request failed with status code: {error.status_code}\n"
)
error_message += f"Headers: {error.headers}\n"
error_message += error.response.text
raise LLMError(error_message)
except Exception as e:
raise LLMError(f"An error occurred: {e}")

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
Expand Down
16 changes: 16 additions & 0 deletions aisuite/providers/mistral_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,19 @@ def chat_completions_create(self, model, messages, **kwargs):
return self.transformer.convert_response(response)
except Exception as e:
raise LLMError(f"An error occurred: {e}")

async def chat_completions_create_async(self, model, messages, **kwargs):
"""
Makes a request to Mistral using the official client.
"""
try:
# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)

response = await self.client.chat.complete_async(
model=model, messages=messages, **kwargs
)

return self.transformer.convert_response(response)
except Exception as e:
raise LLMError(f"An error occurred: {e}")
12 changes: 12 additions & 0 deletions aisuite/providers/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, **config):

# Pass the entire config to the OpenAI client constructor
self.client = openai.OpenAI(**config)
self.async_client = openai.AsyncOpenAI(**config)

def chat_completions_create(self, model, messages, **kwargs):
# Any exception raised by OpenAI will be returned to the caller.
Expand All @@ -33,3 +34,14 @@ def chat_completions_create(self, model, messages, **kwargs):
)

return response

async def chat_completions_create_async(self, model, messages, **kwargs):
# Any exception raised by OpenAI will be returned to the caller.
# Maybe we should catch them and raise a custom LLMError.
response = await self.async_client.chat.completions.create(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the OpenAI API
)

return response
80 changes: 79 additions & 1 deletion tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from aisuite import Client
from aisuite import Client, AsyncClient


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -152,3 +152,81 @@ def test_invalid_model_format_in_create(monkeypatch):
ValueError, match=r"Invalid model format. Expected 'provider:model'"
):
client.chat.completions.create(invalid_model, messages=messages)


@pytest.mark.parametrize(
argnames=("patch_target", "provider", "model"),
argvalues=[
(
"aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create_async",
"openai",
"gpt-4o",
),
(
"aisuite.providers.mistral_provider.MistralProvider.chat_completions_create_async",
"mistral",
"mistral-model",
),
(
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create_async",
"anthropic",
"anthropic-model",
),
(
"aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create_async",
"fireworks",
"fireworks-model",
)
],
)
@pytest.mark.asyncio
async def test_async_client_chat_completions(
provider_configs: dict, patch_target: str, provider: str, model: str
):
expected_response = f"{patch_target}_{provider}_{model}"
with patch(patch_target) as mock_provider:
mock_provider.return_value = expected_response
client = AsyncClient()
client.configure(provider_configs)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
]

model_str = f"{provider}:{model}"
model_response = await client.chat.completions.create(model_str, messages=messages)
assert model_response == expected_response


@pytest.mark.asyncio
async def test_invalid_model_format_in_async_create(monkeypatch):
from aisuite.providers.openai_provider import OpenaiProvider

monkeypatch.setattr(
target=OpenaiProvider,
name="chat_completions_create_async",
value=Mock(),
)

# Valid provider configurations
provider_configs = {
"openai": {"api_key": "test_openai_api_key"},
}

# Initialize the client with valid provider
client = AsyncClient()
client.configure(provider_configs)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me a joke."},
]

# Invalid model format
invalid_model = "invalidmodel"

# Expect ValueError when calling create with invalid model format and verify message
with pytest.raises(
ValueError, match=r"Invalid model format. Expected 'provider:model'"
):
await client.chat.completions.create(invalid_model, messages=messages)
48 changes: 48 additions & 0 deletions tests/client/test_prerelease.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def setup_client() -> ai.Client:
return ai.Client()


def setup_async_client() -> ai.AsyncClient:
"""Initialize the async AI client with environment variables."""
load_dotenv(find_dotenv())
return ai.AsyncClient()


def get_test_models() -> List[str]:
"""Return a list of model identifiers to test."""
return [
Expand All @@ -26,6 +32,14 @@ def get_test_models() -> List[str]:
]


def get_test_async_models() -> List[str]:
"""Return a list of model identifiers to test."""
return [
"anthropic:claude-3-5-sonnet-20240620",
"mistral:open-mistral-7b",
"openai:gpt-3.5-turbo"
]

def get_test_messages() -> List[Dict[str, str]]:
"""Return the test messages to send to each model."""
return [
Expand Down Expand Up @@ -70,5 +84,39 @@ def test_model_pirate_response(model_id: str):
pytest.fail(f"Error testing model {model_id}: {str(e)}")


@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize("model_id", get_test_models())
async def test_async_model_pirate_response(model_id: str):
"""
Test that each model responds appropriately to the pirate prompt using async client.

Args:
model_id: The provider:model identifier to test
"""
client = setup_async_client()
messages = get_test_messages()

try:
response = await client.chat.completions.create(
model=model_id, messages=messages, temperature=0.75
)

content = response.choices[0].message.content.lower()

# Check if either version of the required phrase is present
assert any(
phrase in content for phrase in ["no rum no fun", "no rum, no fun"]
), f"Model {model_id} did not include required phrase 'No rum No fun'"

assert len(content) > 0, f"Model {model_id} returned empty response"
assert isinstance(
content, str
), f"Model {model_id} returned non-string response"

except Exception as e:
pytest.fail(f"Error testing model {model_id}: {str(e)}")


if __name__ == "__main__":
pytest.main([__file__, "-v"])
34 changes: 34 additions & 0 deletions tests/providers/test_mistral_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,37 @@ def test_mistral_provider():
)

assert response.choices[0].message.content == response_text_content


@pytest.mark.asyncio
async def test_mistral_provider_async():
"""High-level test that the provider handles async chat completions successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = MistralProvider()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
provider.client.chat, "complete_async", return_value=mock_response
) as mock_create:
response = await provider.chat_completions_create_async(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content