-
Notifications
You must be signed in to change notification settings - Fork 220
Add baseten integration #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hey @philipkiely-baseten, That said, we absolutely want to support your integration! We'd recommend one of these approaches:
Either way, we'd be happy to feature you on our documentation page as a supported model provider, giving you visibility to our community. |
src/strands/models/baseten.py
Outdated
from typing_extensions import Unpack, override | ||
|
||
from ..types.content import Messages | ||
from ..types.models import OpenAIModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import path is out of date.
src/strands/models/baseten.py
Outdated
return cast(BasetenModel.BasetenConfig, self.config) | ||
|
||
@override | ||
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here.
src/strands/models/baseten.py
Outdated
elif "base_url" in self.config: | ||
client_args["base_url"] = self.config["base_url"] | ||
|
||
self.client = openai.OpenAI(**client_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've migrated to AsyncOpenAI in our implementation. Please verify this change is properly reflected throughout the codebase in your PR. Also, ensure you've pulled the most recent code before proceeding with your review.
src/strands/models/baseten.py
Outdated
Returns: | ||
An iterable of response events from the Baseten model. | ||
""" | ||
response = self.client.chat.completions.create(**request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
async happens also here ^^
src/strands/models/baseten.py
Outdated
yield {"chunk_type": "metadata", "data": event.usage} | ||
|
||
@override | ||
def structured_output( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want to update async here
cb5560d
to
2676b19
Compare
2676b19
to
5dfb0e0
Compare
@@ -69,6 +70,16 @@ def __init__(self): | |||
max_tokens=512, | |||
), | |||
) | |||
baseten = ProviderInfo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you setup baseten providerInfo, you could reference other integration tests code style, using that pytestmark instead of pytest.skip()
"content": [cls.format_request_message_content(content) for content in contents], | ||
} | ||
|
||
def format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are defining this method to be an instance method but use it as static method in your tests.
def test_format_request_messages_simple(): | ||
"""Test formatting simple messages.""" | ||
messages = [{"role": "user", "content": [{"text": "Hello"}]}] | ||
result = BasetenModel.format_request_messages(messages) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format_request_messages as static method here. This test fails.
with unittest.mock.patch.object(strands.models.baseten.openai, "AsyncOpenAI") as mock_client_cls: | ||
yield mock_client_cls | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check your tests, 13 tests failed.
"api_key": os.getenv("BASETEN_API_KEY"), | ||
}, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 integration tests failed also.
... | ||
|
||
|
||
class BasetenModel(Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also reference https://strandsagents.com/latest/documentation/docs/user-guide/concepts/model-providers/cohere/
I think your previous revision was extending OpenAI model provider.
Description
Adds Baseten as a model provider
Related Issues
Documentation PR
strands-agents/docs#124
Type of Change
New feature
Testing
How have you tested the change? Verify that the changes do not break functionality or introduce warnings in consuming repositories: agents-docs, agents-tools, agents-cli
hatch run prepare
Checklist
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.