diff --git a/aisuite/async_client.py b/aisuite/async_client.py new file mode 100644 index 0000000..dfff7ad --- /dev/null +++ b/aisuite/async_client.py @@ -0,0 +1,54 @@ +from .client import Client, Chat, Completions +from .provider import ProviderFactory + + +class AsyncClient(Client): + @property + def chat(self): + """Return the async chat API interface.""" + if not self._chat: + self._chat = AsyncChat(self) + return self._chat + + +class AsyncChat(Chat): + def __init__(self, client: "AsyncClient"): + self.client = client + self._completions = AsyncCompletions(self.client) + + +class AsyncCompletions(Completions): + async def create(self, model: str, messages: list, **kwargs): + """ + Create async chat completion based on the model, messages, and any extra arguments. + """ + # Check that correct format is used + if ":" not in model: + raise ValueError( + f"Invalid model format. Expected 'provider:model', got '{model}'" + ) + + # Extract the provider key from the model identifier, e.g., "google:gemini-xx" + provider_key, model_name = model.split(":", 1) + + # Validate if the provider is supported + supported_providers = ProviderFactory.get_supported_providers() + if provider_key not in supported_providers: + raise ValueError( + f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." + ) + + # Initialize provider if not already initialized + if provider_key not in self.client.providers: + config = self.client.provider_configs.get(provider_key, {}) + self.client.providers[provider_key] = ProviderFactory.create_provider( + provider_key, config + ) + + provider = self.client.providers.get(provider_key) + if not provider: + raise ValueError(f"Could not load provider for '{provider_key}'.") + + # Delegate the chat completion to the correct provider's async implementation + return await provider.chat_completions_create_async(model_name, messages, **kwargs) \ No newline at end of file diff --git a/aisuite/provider.py b/aisuite/provider.py index f53afe2..3ff6f2d 100644 --- a/aisuite/provider.py +++ b/aisuite/provider.py @@ -18,6 +18,10 @@ def chat_completions_create(self, model, messages): """Abstract method for chat completion calls, to be implemented by each provider.""" pass + @abstractmethod + def chat_completions_create_async(self, model, messages): + """Abstract method for async chat completion calls, to be implemented by each provider.""" + pass class ProviderFactory: """Factory to dynamically load provider instances based on naming conventions.""" diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index b7edf71..7cdb8a4 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -201,6 +201,7 @@ class AnthropicProvider(Provider): def __init__(self, **config): """Initialize the Anthropic provider with the given configuration.""" self.client = anthropic.Anthropic(**config) + self.async_client = anthropic.AsyncAnthropic(**config) self.converter = AnthropicMessageConverter() def chat_completions_create(self, model, messages, **kwargs): @@ -213,6 +214,16 @@ def chat_completions_create(self, model, messages, **kwargs): ) return self.converter.convert_response(response) + async def chat_completions_create_async(self, model, messages, **kwargs): + """Create a chat completion using the async Anthropic API.""" + kwargs = self._prepare_kwargs(kwargs) + system_message, converted_messages = self.converter.convert_request(messages) + + response = await self.async_client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) + return self.converter.convert_response(response) + def _prepare_kwargs(self, kwargs): """Prepare kwargs for the API call.""" kwargs = kwargs.copy() diff --git a/aisuite/providers/fireworks_provider.py b/aisuite/providers/fireworks_provider.py index 10bea19..e3f383a 100644 --- a/aisuite/providers/fireworks_provider.py +++ b/aisuite/providers/fireworks_provider.py @@ -130,6 +130,58 @@ def chat_completions_create(self, model, messages, **kwargs): except Exception as e: raise LLMError(f"An error occurred: {e}") + async def chat_completions_create_async(self, model, messages, **kwargs): + """ + Makes an async request to the Fireworks AI chat completions endpoint. + """ + # Remove 'stream' from kwargs if present + kwargs.pop("stream", None) + + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + # Prepare the request payload + data = { + "model": model, + "messages": transformed_messages, + } + + # Add tools if provided + if "tools" in kwargs: + data["tools"] = kwargs["tools"] + kwargs.pop("tools") + + # Add tool_choice if provided + if "tool_choice" in kwargs: + data["tool_choice"] = kwargs["tool_choice"] + kwargs.pop("tool_choice") + + # Add remaining kwargs + data.update(kwargs) + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient() as client: + try: + # Make the async request to Fireworks AI endpoint + response = await client.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + return self.transformer.convert_response(response.json()) + except httpx.HTTPStatusError as error: + error_message = ( + f"The request failed with status code: {error.status_code}\n" + ) + error_message += f"Headers: {error.headers}\n" + error_message += error.response.text + raise LLMError(error_message) + except Exception as e: + raise LLMError(f"An error occurred: {e}") + def _normalize_response(self, response_data): """ Normalize the response to a common format (ChatCompletionResponse). diff --git a/aisuite/providers/mistral_provider.py b/aisuite/providers/mistral_provider.py index 4fc28fa..694e148 100644 --- a/aisuite/providers/mistral_provider.py +++ b/aisuite/providers/mistral_provider.py @@ -72,3 +72,19 @@ def chat_completions_create(self, model, messages, **kwargs): return self.transformer.convert_response(response) except Exception as e: raise LLMError(f"An error occurred: {e}") + + async def chat_completions_create_async(self, model, messages, **kwargs): + """ + Makes a request to Mistral using the official client. + """ + try: + # Transform messages using converter + transformed_messages = self.transformer.convert_request(messages) + + response = await self.client.chat.complete_async( + model=model, messages=messages, **kwargs + ) + + return self.transformer.convert_response(response) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index fdd32d4..4aa3a5d 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -22,6 +22,7 @@ def __init__(self, **config): # Pass the entire config to the OpenAI client constructor self.client = openai.OpenAI(**config) + self.async_client = openai.AsyncOpenAI(**config) def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by OpenAI will be returned to the caller. @@ -33,3 +34,14 @@ def chat_completions_create(self, model, messages, **kwargs): ) return response + + async def chat_completions_create_async(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + response = await self.async_client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) + + return response diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 04d1243..ee58417 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,7 +2,7 @@ import pytest -from aisuite import Client +from aisuite import Client, AsyncClient @pytest.fixture(scope="module") @@ -152,3 +152,81 @@ def test_invalid_model_format_in_create(monkeypatch): ValueError, match=r"Invalid model format. Expected 'provider:model'" ): client.chat.completions.create(invalid_model, messages=messages) + + +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create_async", + "openai", + "gpt-4o", + ), + ( + "aisuite.providers.mistral_provider.MistralProvider.chat_completions_create_async", + "mistral", + "mistral-model", + ), + ( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create_async", + "anthropic", + "anthropic-model", + ), + ( + "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create_async", + "fireworks", + "fireworks-model", + ) + ], +) +@pytest.mark.asyncio +async def test_async_client_chat_completions( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response + client = AsyncClient() + client.configure(provider_configs) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ] + + model_str = f"{provider}:{model}" + model_response = await client.chat.completions.create(model_str, messages=messages) + assert model_response == expected_response + + +@pytest.mark.asyncio +async def test_invalid_model_format_in_async_create(monkeypatch): + from aisuite.providers.openai_provider import OpenaiProvider + + monkeypatch.setattr( + target=OpenaiProvider, + name="chat_completions_create_async", + value=Mock(), + ) + + # Valid provider configurations + provider_configs = { + "openai": {"api_key": "test_openai_api_key"}, + } + + # Initialize the client with valid provider + client = AsyncClient() + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format and verify message + with pytest.raises( + ValueError, match=r"Invalid model format. Expected 'provider:model'" + ): + await client.chat.completions.create(invalid_model, messages=messages) diff --git a/tests/client/test_prerelease.py b/tests/client/test_prerelease.py index bb5f328..a1d9040 100644 --- a/tests/client/test_prerelease.py +++ b/tests/client/test_prerelease.py @@ -13,6 +13,12 @@ def setup_client() -> ai.Client: return ai.Client() +def setup_async_client() -> ai.AsyncClient: + """Initialize the async AI client with environment variables.""" + load_dotenv(find_dotenv()) + return ai.AsyncClient() + + def get_test_models() -> List[str]: """Return a list of model identifiers to test.""" return [ @@ -26,6 +32,14 @@ def get_test_models() -> List[str]: ] +def get_test_async_models() -> List[str]: + """Return a list of model identifiers to test.""" + return [ + "anthropic:claude-3-5-sonnet-20240620", + "mistral:open-mistral-7b", + "openai:gpt-3.5-turbo" + ] + def get_test_messages() -> List[Dict[str, str]]: """Return the test messages to send to each model.""" return [ @@ -70,5 +84,39 @@ def test_model_pirate_response(model_id: str): pytest.fail(f"Error testing model {model_id}: {str(e)}") +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("model_id", get_test_models()) +async def test_async_model_pirate_response(model_id: str): + """ + Test that each model responds appropriately to the pirate prompt using async client. + + Args: + model_id: The provider:model identifier to test + """ + client = setup_async_client() + messages = get_test_messages() + + try: + response = await client.chat.completions.create( + model=model_id, messages=messages, temperature=0.75 + ) + + content = response.choices[0].message.content.lower() + + # Check if either version of the required phrase is present + assert any( + phrase in content for phrase in ["no rum no fun", "no rum, no fun"] + ), f"Model {model_id} did not include required phrase 'No rum No fun'" + + assert len(content) > 0, f"Model {model_id} returned empty response" + assert isinstance( + content, str + ), f"Model {model_id} returned non-string response" + + except Exception as e: + pytest.fail(f"Error testing model {model_id}: {str(e)}") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/providers/test_mistral_provider.py b/tests/providers/test_mistral_provider.py index 937da0e..7056a67 100644 --- a/tests/providers/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -41,3 +41,37 @@ def test_mistral_provider(): ) assert response.choices[0].message.content == response_text_content + + +@pytest.mark.asyncio +async def test_mistral_provider_async(): + """High-level test that the provider handles async chat completions successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = MistralProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat, "complete_async", return_value=mock_response + ) as mock_create: + response = await provider.chat_completions_create_async( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content \ No newline at end of file