Skip to content

Commit

Permalink
Fix streaming support in client
Browse files Browse the repository at this point in the history
Fixes andrewyng#103

Add support for streaming responses in the Anthropic provider.

* Handle the `stream` parameter in the `chat_completions_create` method in `aisuite/providers/anthropic_provider.py`.
* Use `handle_streaming_response` method to process streaming responses.
* Add a test case for the `stream` parameter in `tests/client/test_client.py`.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/andrewyng/aisuite/issues/103?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
yusefes committed Nov 29, 2024
1 parent 1b5da0e commit 0c4a219
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
25 changes: 21 additions & 4 deletions aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,31 @@ def chat_completions_create(self, model, messages, **kwargs):
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS

return self.normalize_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
stream = kwargs.pop("stream", False)
if stream:
return self.handle_streaming_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
)
else:
return self.normalize_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
)
)

def normalize_response(self, response):
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.content[0].text
return normalized_response

def handle_streaming_response(self, response):
"""Handle streaming responses from the Anthropic API."""
normalized_response = ChatCompletionResponse()
content = ""
for chunk in response:
content += chunk.text
normalized_response.choices[0].message.content = content
return normalized_response
10 changes: 10 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ def test_client_chat_completions(
next_compl_instance = client.chat.completions
assert compl_instance is next_compl_instance

# Test streaming response for Anthropic model
stream_anthropic_model = "anthropic" + ":" + "anthropic-model"
stream_anthropic_response = client.chat.completions.create(
stream_anthropic_model, messages=messages, stream=True
)
self.assertEqual(stream_anthropic_response, "Anthropic Response")
mock_anthropic.assert_called_with(
"anthropic-model", messages, stream=True
)

def test_invalid_provider_in_client_config(self):
# Testing an invalid provider name in the configuration
invalid_provider_configs = {
Expand Down

0 comments on commit 0c4a219

Please sign in to comment.