Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: handle error when model doesn't support async #1641

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion phi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,10 @@ async def _arun(
self.model = cast(Model, self.model)
if stream and self.is_streamable:
model_response = ModelResponse(content="")
model_response_stream = self.model.aresponse_stream(messages=messages_for_model)
if hasattr(self.model, "aresponse_stream"):
model_response_stream = self.model.aresponse_stream(messages=messages_for_model)
else:
raise NotImplementedError(f"{self.model.id} does not support streaming")
async for model_response_chunk in model_response_stream: # type: ignore
if model_response_chunk.event == ModelResponseEvent.assistant_response.value:
if model_response_chunk.content is not None and model_response.content is not None:
Expand Down
3 changes: 0 additions & 3 deletions phi/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse:
def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
raise NotImplementedError

async def aresponse_stream(self, messages: List[Message]) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather keep this and do a try catch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dirkbrnd I tried doing that but couldn't catch the NotImplemented error. It's not bubbling up the correct way

raise NotImplementedError

def _log_messages(self, messages: List[Message]) -> None:
"""
Log messages for debugging.
Expand Down
Loading