From 0c4a219a6219e4c7aae70989855caeb4ebae8fd4 Mon Sep 17 00:00:00 2001 From: yusefes Date: Fri, 29 Nov 2024 17:59:36 +0330 Subject: [PATCH] Fix streaming support in client Fixes #103 Add support for streaming responses in the Anthropic provider. * Handle the `stream` parameter in the `chat_completions_create` method in `aisuite/providers/anthropic_provider.py`. * Use `handle_streaming_response` method to process streaming responses. * Add a test case for the `stream` parameter in `tests/client/test_client.py`. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/andrewyng/aisuite/issues/103?shareId=XXXX-XXXX-XXXX-XXXX). --- aisuite/providers/anthropic_provider.py | 25 +++++++++++++++++++++---- tests/client/test_client.py | 10 ++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index f63c054c..26066721 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -27,14 +27,31 @@ def chat_completions_create(self, model, messages, **kwargs): if "max_tokens" not in kwargs: kwargs["max_tokens"] = DEFAULT_MAX_TOKENS - return self.normalize_response( - self.client.messages.create( - model=model, system=system_message, messages=messages, **kwargs + stream = kwargs.pop("stream", False) + if stream: + return self.handle_streaming_response( + self.client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) + ) + else: + return self.normalize_response( + self.client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) ) - ) def normalize_response(self, response): """Normalize the response from the Anthropic API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() normalized_response.choices[0].message.content = response.content[0].text return normalized_response + + def handle_streaming_response(self, response): + """Handle streaming responses from the Anthropic API.""" + normalized_response = ChatCompletionResponse() + content = "" + for chunk in response: + content += chunk.text + normalized_response.choices[0].message.content = content + return normalized_response diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e1949ac..d2aaf66b 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -139,6 +139,16 @@ def test_client_chat_completions( next_compl_instance = client.chat.completions assert compl_instance is next_compl_instance + # Test streaming response for Anthropic model + stream_anthropic_model = "anthropic" + ":" + "anthropic-model" + stream_anthropic_response = client.chat.completions.create( + stream_anthropic_model, messages=messages, stream=True + ) + self.assertEqual(stream_anthropic_response, "Anthropic Response") + mock_anthropic.assert_called_with( + "anthropic-model", messages, stream=True + ) + def test_invalid_provider_in_client_config(self): # Testing an invalid provider name in the configuration invalid_provider_configs = {