From 37cbad2a15559348bb609d1fc74a260cadce9e7c Mon Sep 17 00:00:00 2001 From: Stanislav Kapulkin Date: Thu, 16 Jan 2025 19:31:45 +0200 Subject: [PATCH 1/4] [+] async client version with provider interface extension. Several provides implementations. --- aisuite/async_client.py | 54 +++++++++++++++++++++++++ aisuite/provider.py | 4 ++ aisuite/providers/anthropic_provider.py | 17 ++++++++ aisuite/providers/fireworks_provider.py | 30 ++++++++++++++ aisuite/providers/mistral_provider.py | 3 ++ aisuite/providers/openai_provider.py | 11 +++++ 6 files changed, 119 insertions(+) create mode 100644 aisuite/async_client.py diff --git a/aisuite/async_client.py b/aisuite/async_client.py new file mode 100644 index 00000000..dfff7ad2 --- /dev/null +++ b/aisuite/async_client.py @@ -0,0 +1,54 @@ +from .client import Client, Chat, Completions +from .provider import ProviderFactory + + +class AsyncClient(Client): + @property + def chat(self): + """Return the async chat API interface.""" + if not self._chat: + self._chat = AsyncChat(self) + return self._chat + + +class AsyncChat(Chat): + def __init__(self, client: "AsyncClient"): + self.client = client + self._completions = AsyncCompletions(self.client) + + +class AsyncCompletions(Completions): + async def create(self, model: str, messages: list, **kwargs): + """ + Create async chat completion based on the model, messages, and any extra arguments. + """ + # Check that correct format is used + if ":" not in model: + raise ValueError( + f"Invalid model format. Expected 'provider:model', got '{model}'" + ) + + # Extract the provider key from the model identifier, e.g., "google:gemini-xx" + provider_key, model_name = model.split(":", 1) + + # Validate if the provider is supported + supported_providers = ProviderFactory.get_supported_providers() + if provider_key not in supported_providers: + raise ValueError( + f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." + ) + + # Initialize provider if not already initialized + if provider_key not in self.client.providers: + config = self.client.provider_configs.get(provider_key, {}) + self.client.providers[provider_key] = ProviderFactory.create_provider( + provider_key, config + ) + + provider = self.client.providers.get(provider_key) + if not provider: + raise ValueError(f"Could not load provider for '{provider_key}'.") + + # Delegate the chat completion to the correct provider's async implementation + return await provider.chat_completions_create_async(model_name, messages, **kwargs) \ No newline at end of file diff --git a/aisuite/provider.py b/aisuite/provider.py index f53afe27..3ff6f2db 100644 --- a/aisuite/provider.py +++ b/aisuite/provider.py @@ -18,6 +18,10 @@ def chat_completions_create(self, model, messages): """Abstract method for chat completion calls, to be implemented by each provider.""" pass + @abstractmethod + def chat_completions_create_async(self, model, messages): + """Abstract method for async chat completion calls, to be implemented by each provider.""" + pass class ProviderFactory: """Factory to dynamically load provider instances based on naming conventions.""" diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index f63c054c..0855a8f6 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -14,6 +14,7 @@ def __init__(self, **config): """ self.client = anthropic.Anthropic(**config) + self.async_client = anthropic.AsyncAnthropic(**config) def chat_completions_create(self, model, messages, **kwargs): # Check if the fist message is a system message @@ -33,6 +34,22 @@ def chat_completions_create(self, model, messages, **kwargs): ) ) + async def chat_completions_create_async(self, model, messages, **kwargs): + # Check if the first message is a system message + if messages[0]["role"] == "system": + system_message = messages[0]["content"] + messages = messages[1:] + else: + system_message = [] + + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = DEFAULT_MAX_TOKENS + + response = await self.async_client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) + return self.normalize_response(response) + def normalize_response(self, response): """Normalize the response from the Anthropic API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() diff --git a/aisuite/providers/fireworks_provider.py b/aisuite/providers/fireworks_provider.py index 183b0d10..87fe7fb8 100644 --- a/aisuite/providers/fireworks_provider.py +++ b/aisuite/providers/fireworks_provider.py @@ -54,6 +54,36 @@ def chat_completions_create(self, model, messages, **kwargs): # Return the normalized response return self._normalize_response(response.json()) + async def chat_completions_create_async(self, model, messages, **kwargs): + """ + Makes an async request to the Fireworks AI chat completions endpoint. + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + data = { + "model": model, + "messages": messages, + **kwargs, # Pass any additional arguments to the API + } + + async with httpx.AsyncClient() as client: + try: + # Make the async request to Fireworks AI endpoint + response = await client.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + except httpx.HTTPStatusError as http_err: + raise LLMError(f"Fireworks AI request failed: {http_err}") + except Exception as e: + raise LLMError(f"An error occurred: {e}") + + # Return the normalized response + return self._normalize_response(response.json()) + def _normalize_response(self, response_data): """ Normalize the response to a common format (ChatCompletionResponse). diff --git a/aisuite/providers/mistral_provider.py b/aisuite/providers/mistral_provider.py index 6e96f1cf..df8d7964 100644 --- a/aisuite/providers/mistral_provider.py +++ b/aisuite/providers/mistral_provider.py @@ -21,3 +21,6 @@ def __init__(self, **config): def chat_completions_create(self, model, messages, **kwargs): return self.client.chat.complete(model=model, messages=messages, **kwargs) + + async def chat_completions_create_async(self, model, messages, **kwargs): + return await self.client.chat.complete_async(model=model, messages=messages, **kwargs) \ No newline at end of file diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index b8fe2fb8..4dda9551 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -22,6 +22,7 @@ def __init__(self, **config): # Pass the entire config to the OpenAI client constructor self.client = openai.OpenAI(**config) + self.async_client = openai.AsyncOpenAI(**config) def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by OpenAI will be returned to the caller. @@ -31,3 +32,13 @@ def chat_completions_create(self, model, messages, **kwargs): messages=messages, **kwargs # Pass any additional arguments to the OpenAI API ) + + async def chat_completions_create_async(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return await self.async_client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) + From d00c07865e65e95b1d654b1d36c18b5c1e8060b4 Mon Sep 17 00:00:00 2001 From: Stanislav Kapulkin Date: Thu, 16 Jan 2025 19:41:58 +0200 Subject: [PATCH 2/4] [+] tests for async client with providers, that have async implementation --- tests/client/test_client.py | 80 +++++++++++++++++++++++- tests/client/test_prerelease.py | 48 ++++++++++++++ tests/providers/test_mistral_provider.py | 34 ++++++++++ 3 files changed, 161 insertions(+), 1 deletion(-) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index a94b139d..d5e1bf4a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,7 +2,7 @@ import pytest -from aisuite import Client +from aisuite import Client, AsyncClient @pytest.fixture(scope="module") @@ -162,3 +162,81 @@ def test_invalid_model_format_in_create(monkeypatch): ValueError, match=r"Invalid model format. Expected 'provider:model'" ): client.chat.completions.create(invalid_model, messages=messages) + + +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create_async", + "openai", + "gpt-4o", + ), + ( + "aisuite.providers.mistral_provider.MistralProvider.chat_completions_create_async", + "mistral", + "mistral-model", + ), + ( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create_async", + "anthropic", + "anthropic-model", + ), + ( + "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create_async", + "fireworks", + "fireworks-model", + ) + ], +) +@pytest.mark.asyncio +async def test_async_client_chat_completions( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response + client = AsyncClient() + client.configure(provider_configs) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ] + + model_str = f"{provider}:{model}" + model_response = await client.chat.completions.create(model_str, messages=messages) + assert model_response == expected_response + + +@pytest.mark.asyncio +async def test_invalid_model_format_in_async_create(monkeypatch): + from aisuite.providers.openai_provider import OpenaiProvider + + monkeypatch.setattr( + target=OpenaiProvider, + name="chat_completions_create_async", + value=Mock(), + ) + + # Valid provider configurations + provider_configs = { + "openai": {"api_key": "test_openai_api_key"}, + } + + # Initialize the client with valid provider + client = AsyncClient() + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format and verify message + with pytest.raises( + ValueError, match=r"Invalid model format. Expected 'provider:model'" + ): + await client.chat.completions.create(invalid_model, messages=messages) diff --git a/tests/client/test_prerelease.py b/tests/client/test_prerelease.py index bb5f3285..a1d90409 100644 --- a/tests/client/test_prerelease.py +++ b/tests/client/test_prerelease.py @@ -13,6 +13,12 @@ def setup_client() -> ai.Client: return ai.Client() +def setup_async_client() -> ai.AsyncClient: + """Initialize the async AI client with environment variables.""" + load_dotenv(find_dotenv()) + return ai.AsyncClient() + + def get_test_models() -> List[str]: """Return a list of model identifiers to test.""" return [ @@ -26,6 +32,14 @@ def get_test_models() -> List[str]: ] +def get_test_async_models() -> List[str]: + """Return a list of model identifiers to test.""" + return [ + "anthropic:claude-3-5-sonnet-20240620", + "mistral:open-mistral-7b", + "openai:gpt-3.5-turbo" + ] + def get_test_messages() -> List[Dict[str, str]]: """Return the test messages to send to each model.""" return [ @@ -70,5 +84,39 @@ def test_model_pirate_response(model_id: str): pytest.fail(f"Error testing model {model_id}: {str(e)}") +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("model_id", get_test_models()) +async def test_async_model_pirate_response(model_id: str): + """ + Test that each model responds appropriately to the pirate prompt using async client. + + Args: + model_id: The provider:model identifier to test + """ + client = setup_async_client() + messages = get_test_messages() + + try: + response = await client.chat.completions.create( + model=model_id, messages=messages, temperature=0.75 + ) + + content = response.choices[0].message.content.lower() + + # Check if either version of the required phrase is present + assert any( + phrase in content for phrase in ["no rum no fun", "no rum, no fun"] + ), f"Model {model_id} did not include required phrase 'No rum No fun'" + + assert len(content) > 0, f"Model {model_id} returned empty response" + assert isinstance( + content, str + ), f"Model {model_id} returned non-string response" + + except Exception as e: + pytest.fail(f"Error testing model {model_id}: {str(e)}") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/providers/test_mistral_provider.py b/tests/providers/test_mistral_provider.py index 8ecc51e9..fb6d1200 100644 --- a/tests/providers/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -41,3 +41,37 @@ def test_mistral_provider(): ) assert response.choices[0].message.content == response_text_content + + +@pytest.mark.asyncio +async def test_mistral_provider_async(): + """High-level test that the provider handles async chat completions successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = MistralProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat, "complete_async", return_value=mock_response + ) as mock_create: + response = await provider.chat_completions_create_async( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content \ No newline at end of file From 949bb366015b4cbb38cf6af967af963c34f1067e Mon Sep 17 00:00:00 2001 From: Stanislav Kapulkin Date: Tue, 4 Feb 2025 16:43:57 +0200 Subject: [PATCH 3/4] [*] outdated comment is removed --- aisuite/providers/anthropic_provider.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index a7e84de0..7cdb8a43 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -216,7 +216,6 @@ def chat_completions_create(self, model, messages, **kwargs): async def chat_completions_create_async(self, model, messages, **kwargs): """Create a chat completion using the async Anthropic API.""" - # Check if the first message is a system message kwargs = self._prepare_kwargs(kwargs) system_message, converted_messages = self.converter.convert_request(messages) From 0e242a0efebe5581eaa534e19cfa0eed079e15d4 Mon Sep 17 00:00:00 2001 From: Stanislav Kapulkin Date: Tue, 4 Feb 2025 16:58:47 +0200 Subject: [PATCH 4/4] [*] openai code style polishing --- aisuite/providers/openai_provider.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index 7acfea94..4aa3a5d8 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -32,14 +32,16 @@ def chat_completions_create(self, model, messages, **kwargs): messages=messages, **kwargs # Pass any additional arguments to the OpenAI API ) + return response async def chat_completions_create_async(self, model, messages, **kwargs): # Any exception raised by OpenAI will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. - return await self.async_client.chat.completions.create( + response = await self.async_client.chat.completions.create( model=model, messages=messages, **kwargs # Pass any additional arguments to the OpenAI API ) + return response