diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py index f63c054c..26066721 100644 --- a/aisuite/providers/anthropic_provider.py +++ b/aisuite/providers/anthropic_provider.py @@ -27,14 +27,31 @@ def chat_completions_create(self, model, messages, **kwargs): if "max_tokens" not in kwargs: kwargs["max_tokens"] = DEFAULT_MAX_TOKENS - return self.normalize_response( - self.client.messages.create( - model=model, system=system_message, messages=messages, **kwargs + stream = kwargs.pop("stream", False) + if stream: + return self.handle_streaming_response( + self.client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) + ) + else: + return self.normalize_response( + self.client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) ) - ) def normalize_response(self, response): """Normalize the response from the Anthropic API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() normalized_response.choices[0].message.content = response.content[0].text return normalized_response + + def handle_streaming_response(self, response): + """Handle streaming responses from the Anthropic API.""" + normalized_response = ChatCompletionResponse() + content = "" + for chunk in response: + content += chunk.text + normalized_response.choices[0].message.content = content + return normalized_response diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e1949ac..d2aaf66b 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -139,6 +139,16 @@ def test_client_chat_completions( next_compl_instance = client.chat.completions assert compl_instance is next_compl_instance + # Test streaming response for Anthropic model + stream_anthropic_model = "anthropic" + ":" + "anthropic-model" + stream_anthropic_response = client.chat.completions.create( + stream_anthropic_model, messages=messages, stream=True + ) + self.assertEqual(stream_anthropic_response, "Anthropic Response") + mock_anthropic.assert_called_with( + "anthropic-model", messages, stream=True + ) + def test_invalid_provider_in_client_config(self): # Testing an invalid provider name in the configuration invalid_provider_configs = {