diff --git a/cookbook/models/azure/openai/async/__init__.py b/cookbook/models/azure/openai/async/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/cookbook/models/azure/openai/async/basic.py b/cookbook/models/azure/openai/async_basic.py similarity index 100% rename from cookbook/models/azure/openai/async/basic.py rename to cookbook/models/azure/openai/async_basic.py diff --git a/cookbook/models/azure/openai/async/basic_stream.py b/cookbook/models/azure/openai/async_basic_stream.py similarity index 100% rename from cookbook/models/azure/openai/async/basic_stream.py rename to cookbook/models/azure/openai/async_basic_stream.py diff --git a/cookbook/models/openai/async_tool_use.py b/cookbook/models/openai/async_tool_use.py new file mode 100644 index 0000000000..574c6e6f0c --- /dev/null +++ b/cookbook/models/openai/async_tool_use.py @@ -0,0 +1,14 @@ +"""Run `pip install duckduckgo-search` to install dependencies.""" + +import asyncio +from agno.agent import Agent +from agno.models.openai import OpenAIChat +from agno.tools.duckduckgo import DuckDuckGoTools + +agent = Agent( + model=OpenAIChat(id="gpt-4o"), + tools=[DuckDuckGoTools()], + show_tool_calls=True, + markdown=True, +) +asyncio.run(agent.aprint_response("Whats happening in France?", stream=True)) diff --git a/libs/agno/agno/agent/agent.py b/libs/agno/agno/agent/agent.py index f4a03bf873..b4efbd38fc 100644 --- a/libs/agno/agno/agent/agent.py +++ b/libs/agno/agno/agent/agent.py @@ -24,7 +24,7 @@ from pydantic import BaseModel from agno.agent.metrics import SessionMetrics -from agno.exceptions import AgentRunException, StopAgentRun +from agno.exceptions import AgentRunException, ModelProviderError, StopAgentRun from agno.knowledge.agent import AgentKnowledge from agno.media import Audio, AudioArtifact, Image, ImageArtifact, Video, VideoArtifact from agno.memory.agent import AgentMemory, AgentRun @@ -867,7 +867,7 @@ def run( **kwargs, ) return next(resp) - except Exception as e: + except ModelProviderError as e: logger.warning(f"Attempt {attempt + 1}/{num_attempts} failed: {str(e)}") if isinstance(e, StopAgentRun): raise e @@ -1267,7 +1267,7 @@ async def arun( **kwargs, ) return await resp.__anext__() - except Exception as e: + except ModelProviderError as e: logger.warning(f"Attempt {attempt + 1}/{num_attempts} failed: {str(e)}") if isinstance(e, StopAgentRun): raise e diff --git a/libs/agno/agno/exceptions.py b/libs/agno/agno/exceptions.py index fbb8b00baf..40645d721f 100644 --- a/libs/agno/agno/exceptions.py +++ b/libs/agno/agno/exceptions.py @@ -36,3 +36,11 @@ def __init__( super().__init__( exc, user_message=user_message, agent_message=agent_message, messages=messages, stop_execution=True ) + +class ModelProviderError(Exception): + """Exception raised when a model provider returns an error.""" + + def __init__(self, exc, model_name: str, model_id: str): + super().__init__(exc) + self.model_name = model_name + self.model_id = model_id \ No newline at end of file diff --git a/libs/agno/agno/models/anthropic/claude.py b/libs/agno/agno/models/anthropic/claude.py index d6f485a2a8..bb016dc160 100644 --- a/libs/agno/agno/models/anthropic/claude.py +++ b/libs/agno/agno/models/anthropic/claude.py @@ -3,6 +3,7 @@ from os import getenv from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from agno.exceptions import ModelProviderError from agno.media import Image from agno.models.base import Model from agno.models.message import Message @@ -309,16 +310,16 @@ def invoke(self, messages: List[Message]) -> AnthropicMessage: ) except APIConnectionError as e: logger.error(f"Connection error while calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except RateLimitError as e: logger.warning(f"Rate limit exceeded: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except APIStatusError as e: logger.error(f"Claude API error (status {e.status_code}): {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except Exception as e: logger.error(f"Unexpected error calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e def invoke_stream(self, messages: List[Message]) -> Any: """ @@ -341,16 +342,16 @@ def invoke_stream(self, messages: List[Message]) -> Any: ).__enter__() except APIConnectionError as e: logger.error(f"Connection error while calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except RateLimitError as e: logger.warning(f"Rate limit exceeded: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except APIStatusError as e: logger.error(f"Claude API error (status {e.status_code}): {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except Exception as e: logger.error(f"Unexpected error calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e async def ainvoke(self, messages: List[Message]) -> AnthropicMessage: """ @@ -378,16 +379,16 @@ async def ainvoke(self, messages: List[Message]) -> AnthropicMessage: ) except APIConnectionError as e: logger.error(f"Connection error while calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except RateLimitError as e: logger.warning(f"Rate limit exceeded: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except APIStatusError as e: logger.error(f"Claude API error (status {e.status_code}): {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except Exception as e: logger.error(f"Unexpected error calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e async def ainvoke_stream(self, messages: List[Message]) -> Any: """ @@ -410,16 +411,16 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any: ).__aenter__() except APIConnectionError as e: logger.error(f"Connection error while calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except RateLimitError as e: logger.warning(f"Rate limit exceeded: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except APIStatusError as e: logger.error(f"Claude API error (status {e.status_code}): {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e except Exception as e: logger.error(f"Unexpected error calling Claude API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e # Overwrite the default from the base model def format_function_call_results( diff --git a/libs/agno/agno/models/aws/bedrock.py b/libs/agno/agno/models/aws/bedrock.py index 8bd09376d1..dd7bb9c404 100644 --- a/libs/agno/agno/models/aws/bedrock.py +++ b/libs/agno/agno/models/aws/bedrock.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Iterator, List, Optional from agno.aws.api_client import AwsApiClient # type: ignore +from agno.exceptions import ModelProviderError from agno.models.base import Model from agno.models.message import Message from agno.models.response import ProviderResponse @@ -92,7 +93,7 @@ def invoke(self, messages: List[Message]) -> Dict[str, Any]: return self.get_client().converse(**body) except Exception as e: logger.error(f"Unexpected error calling Bedrock API: {str(e)}") - raise + raise ModelProviderError(e, self.name, self.id) from e def invoke_stream(self, messages: List[Message]) -> Iterator[Dict[str, Any]]: """ @@ -105,11 +106,15 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[Dict[str, Any]]: Iterator[Dict[str, Any]]: The streamed response. """ body = self.format_messages(messages) - response = self.get_client().converse_stream(**body) - stream = response.get("stream") - if stream: - for event in stream: - yield event + try: + response = self.get_client().converse_stream(**body) + stream = response.get("stream") + if stream: + for event in stream: + yield event + except Exception as e: + logger.error(f"Unexpected error calling Bedrock API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e @abstractmethod def format_messages(self, messages: List[Message]) -> Dict[str, Any]: diff --git a/libs/agno/agno/models/base.py b/libs/agno/agno/models/base.py index e322eeb149..6d31c3025c 100644 --- a/libs/agno/agno/models/base.py +++ b/libs/agno/agno/models/base.py @@ -161,32 +161,6 @@ def set_functions(self, functions: Dict[str, Function]) -> None: if len(functions) > 0: self._functions = functions - # @staticmethod - # def _update_assistant_message_metrics(assistant_message: Message, metrics_for_run: Metrics = Metrics()) -> None: - # assistant_message.metrics["time"] = metrics_for_run.response_timer.elapsed - # if metrics_for_run.input_tokens is not None: - # assistant_message.metrics["input_tokens"] = metrics_for_run.input_tokens - # if metrics_for_run.output_tokens is not None: - # assistant_message.metrics["output_tokens"] = metrics_for_run.output_tokens - # if metrics_for_run.total_tokens is not None: - # assistant_message.metrics["total_tokens"] = metrics_for_run.total_tokens - # if metrics_for_run.time_to_first_token is not None: - # assistant_message.metrics["time_to_first_token"] = metrics_for_run.time_to_first_token - - # def _update_model_metrics( - # self, - # metrics_for_run: Metrics = Metrics(), - # ) -> None: - # self.metrics.setdefault("response_times", []).append(metrics_for_run.response_timer.elapsed) - # if metrics_for_run.input_tokens is not None: - # self.metrics["input_tokens"] = self.metrics.get("input_tokens", 0) + metrics_for_run.input_tokens - # if metrics_for_run.output_tokens is not None: - # self.metrics["output_tokens"] = self.metrics.get("output_tokens", 0) + metrics_for_run.output_tokens - # if metrics_for_run.total_tokens is not None: - # self.metrics["total_tokens"] = self.metrics.get("total_tokens", 0) + metrics_for_run.total_tokens - # if metrics_for_run.time_to_first_token is not None: - # self.metrics.setdefault("time_to_first_token", []).append(metrics_for_run.time_to_first_token) - def parse_tool_calls(self, tool_calls_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Parse the tool calls from the model provider into a list of tool calls. @@ -478,12 +452,24 @@ async def arun_function_calls( if additional_messages: function_call_results.extend(additional_messages) + def _show_tool_calls(self, function_calls_to_run: List[FunctionCall], model_response: ModelResponse): + """ + Show tool calls in the model response. + """ + if len(function_calls_to_run) == 1: + model_response.content += f" - Running: {function_calls_to_run[0].get_call_str()}\n\n" + elif len(function_calls_to_run) > 1: + model_response.content += "Running:" + for _f in function_calls_to_run: + model_response.content += f"\n - {_f.get_call_str()}" + model_response.content += "\n\n" + def _prepare_function_calls( self, assistant_message: Message, messages: List[Message], model_response: ModelResponse, - ) -> Tuple[List[FunctionCall], List[Message]]: + ) -> List[FunctionCall]: """ Prepare function calls from tool calls in the assistant message. @@ -492,26 +478,17 @@ def _prepare_function_calls( messages (List[Message]): The list of messages to append tool responses to model_response (ModelResponse): The model response to update Returns: - Tuple[List[FunctionCall], List[Message]]: Tuple of function calls to run and function call results + List[FunctionCall]: The function calls to run """ if model_response.content is None: model_response.content = "" if model_response.tool_calls is None: model_response.tool_calls = [] - function_call_results: List[Message] = [] function_calls_to_run: List[FunctionCall] = self.get_function_calls_to_run(assistant_message, messages) - if self.show_tool_calls: - if len(function_calls_to_run) == 1: - model_response.content += f" - Running: {function_calls_to_run[0].get_call_str()}\n\n" - elif len(function_calls_to_run) > 1: - model_response.content += "Running:" - for _f in function_calls_to_run: - model_response.content += f"\n - {_f.get_call_str()}" - model_response.content += "\n\n" - - return function_calls_to_run, function_call_results + self._show_tool_calls(function_calls_to_run, model_response) + return function_calls_to_run def format_function_call_results(self, messages: List[Message], function_call_results: List[Message], **kwargs) -> None: """ @@ -520,284 +497,14 @@ def format_function_call_results(self, messages: List[Message], function_call_re if len(function_call_results) > 0: messages.extend(function_call_results) - def handle_tool_calls( - self, - assistant_message: Message, - messages: List[Message], - model_response: ModelResponse, - **kwargs, - ) -> Optional[ModelResponse]: - """ - Handle tool calls in the assistant message. - - Args: - assistant_message (Message): The assistant message. - messages (List[Message]): The list of messages. - model_response (ModelResponse): The model response. - - Returns: - Optional[ModelResponse]: The model response after handling tool calls. - """ - if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: - function_calls_to_run, function_call_results = self._prepare_function_calls( - assistant_message=assistant_message, - messages=messages, - model_response=model_response, - ) - - for function_call_response in self.run_function_calls( - function_calls=function_calls_to_run, function_call_results=function_call_results - ): - if ( - function_call_response.event == ModelResponseEvent.tool_call_completed.value - and function_call_response.tool_calls is not None - ): - model_response.tool_calls.extend(function_call_response.tool_calls) # type: ignore # model_response.tool_calls are initialized before calling this method - - self.format_function_call_results(messages=messages, function_call_results=function_call_results, **kwargs) - - return model_response - return None - - async def ahandle_tool_calls( - self, - assistant_message: Message, - messages: List[Message], - model_response: ModelResponse, - **kwargs, - ) -> Optional[ModelResponse]: - """ - Handle tool calls in the assistant message. - Args: - assistant_message (Message): The assistant message. - messages (List[Message]): The list of messages. - model_response (ModelResponse): The model response. - tool_role (str): The role of the tool call. Defaults to "tool". - - Returns: - Optional[ModelResponse]: The model response after handling tool calls. - """ - if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: - function_calls_to_run, function_call_results = self._prepare_function_calls( - assistant_message=assistant_message, - messages=messages, - model_response=model_response, - ) - - async for function_call_response in self.arun_function_calls( - function_calls=function_calls_to_run, function_call_results=function_call_results - ): - if ( - function_call_response.event == ModelResponseEvent.tool_call_completed.value - and function_call_response.tool_calls is not None - ): - model_response.tool_calls.extend(function_call_response.tool_calls) # type: ignore # model_response.tool_calls are initialized before calling this method - - self.format_function_call_results(messages=messages, function_call_results=function_call_results, **kwargs) - - return model_response - return None - - def _prepare_stream_tool_calls( - self, - assistant_message: Message, - messages: List[Message], - ) -> Tuple[List[FunctionCall], List[Message]]: - """ - Prepare function calls from tool calls in the assistant message for streaming. - - Args: - assistant_message (Message): The assistant message containing tool calls - messages (List[Message]): The list of messages to append tool responses to - - Returns: - Tuple[List[FunctionCall], List[Message]]: Tuple of function calls to run and function call results - """ - function_calls_to_run: List[FunctionCall] = [] - function_call_results: List[Message] = [] - - for tool_call in assistant_message.tool_calls: # type: ignore # assistant_message.tool_calls are checked before calling this method - _tool_call_id = tool_call.get("id") - _function_call = get_function_call_for_tool_call(tool_call, self._functions) - if _function_call is None: - messages.append( - Message( - role=self.tool_message_role, - tool_call_id=_tool_call_id, - content="Could not find function to call.", - ) - ) - continue - if _function_call.error is not None: - messages.append( - Message( - role=self.tool_message_role, - tool_call_id=_tool_call_id, - content=_function_call.error, - ) - ) - continue - function_calls_to_run.append(_function_call) - - return function_calls_to_run, function_call_results - - def handle_stream_tool_calls( - self, - assistant_message: Message, - messages: List[Message], - **kwargs, - ) -> Iterator[ModelResponse]: - """ - Handle tool calls for response stream. - - Args: - assistant_message (Message): The assistant message. - messages (List[Message]): The list of messages. - - Returns: - Iterator[ModelResponse]: An iterator of the model response. - """ - if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: - - yield ModelResponse(content="\n\n") - - function_calls_to_run, function_call_results = self._prepare_stream_tool_calls( - assistant_message=assistant_message, - messages=messages, - ) - - if self.show_tool_calls: - if len(function_calls_to_run) == 1: - yield ModelResponse(content=f" - Running: {function_calls_to_run[0].get_call_str()}\n\n") - else: - yield ModelResponse(content="\nRunning:") - for _f in function_calls_to_run: - yield ModelResponse(content=f"\n - {_f.get_call_str()}") - yield ModelResponse(content="\n\n") - - for function_call_response in self.run_function_calls( - function_calls=function_calls_to_run, function_call_results=function_call_results - ): - yield function_call_response - - self.format_function_call_results(messages=messages, function_call_results=function_call_results, **kwargs) - - async def ahandle_stream_tool_calls( - self, - assistant_message: Message, - messages: List[Message], - **kwargs, - ) -> AsyncIterator[ModelResponse]: - """ - Handle tool calls for response stream. - - Args: - assistant_message (Message): The assistant message. - messages (List[Message]): The list of messages. - tool_role (str): The role of the tool call. Defaults to "tool". - - Returns: - Iterator[ModelResponse]: An iterator of the model response. - """ - if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: - yield ModelResponse(content="\n\n") - - function_calls_to_run, function_call_results = self._prepare_stream_tool_calls( - assistant_message=assistant_message, - messages=messages, - ) - - if self.show_tool_calls: - if len(function_calls_to_run) == 1: - yield ModelResponse(content=f" - Running: {function_calls_to_run[0].get_call_str()}\n\n") - else: - yield ModelResponse(content="\nRunning:") - for _f in function_calls_to_run: - yield ModelResponse(content=f"\n - {_f.get_call_str()}") - yield ModelResponse(content="\n\n") - - async for function_call_response in self.arun_function_calls( - function_calls=function_calls_to_run, function_call_results=function_call_results - ): - yield function_call_response - - self.format_function_call_results(messages=messages, function_call_results=function_call_results, **kwargs) - - - def _handle_response_after_tool_calls( - self, response_after_tool_calls: ModelResponse, model_response: ModelResponse - ): - if response_after_tool_calls.content is not None: - if model_response.content is None: - model_response.content = "" - model_response.content += response_after_tool_calls.content - if response_after_tool_calls.parsed is not None: - # bubble up the parsed object, so that the final response has the parsed object - # that is visible to the agent - model_response.parsed = response_after_tool_calls.parsed - if response_after_tool_calls.audio is not None: - # bubble up the audio, so that the final response has the audio - # that is visible to the agent - model_response.audio = response_after_tool_calls.audio - - def _handle_stop_after_tool_calls(self, last_message: Message, model_response: ModelResponse): - logger.debug("Stopping execution as stop_after_tool_call=True") - if ( - last_message.role == "assistant" - and last_message.content is not None - and isinstance(last_message.content, str) - ): - if model_response.content is None: - model_response.content = "" - model_response.content += last_message.content - - def handle_post_tool_call_messages(self, messages: List[Message], model_response: ModelResponse) -> ModelResponse: - last_message = messages[-1] - if last_message.stop_after_tool_call: - self._handle_stop_after_tool_calls(last_message, model_response) - else: - response_after_tool_calls = self.response(messages=messages) - self._handle_response_after_tool_calls(response_after_tool_calls, model_response) - return model_response - - async def ahandle_post_tool_call_messages( - self, messages: List[Message], model_response: ModelResponse - ) -> ModelResponse: - last_message = messages[-1] - if last_message.stop_after_tool_call: - self._handle_stop_after_tool_calls(last_message, model_response) - else: - response_after_tool_calls = await self.aresponse(messages=messages) - self._handle_response_after_tool_calls(response_after_tool_calls, model_response) - return model_response - - def handle_post_tool_call_messages_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: - last_message = messages[-1] - if last_message.stop_after_tool_call: - logger.debug("Stopping execution as stop_after_tool_call=True") - if ( - last_message.role == "assistant" - and last_message.content is not None - and isinstance(last_message.content, str) - ): - yield ModelResponse(content=last_message.content) + def _show_stream_tool_calls(self, function_calls_to_run: List[FunctionCall]) -> Iterator[ModelResponse]: + if len(function_calls_to_run) == 1: + yield ModelResponse(content=f" - Running: {function_calls_to_run[0].get_call_str()}\n\n") else: - yield from self.response_stream(messages=messages) - - async def ahandle_post_tool_call_messages_stream(self, messages: List[Message]) -> Any: - last_message = messages[-1] - if last_message.stop_after_tool_call: - logger.debug("Stopping execution as stop_after_tool_call=True") - if ( - last_message.role == "assistant" - and last_message.content is not None - and isinstance(last_message.content, str) - ): - yield ModelResponse(content=last_message.content) - else: - async for model_response in self.aresponse_stream(messages=messages): # type: ignore - yield model_response - + yield ModelResponse(content="\nRunning:") + for _f in function_calls_to_run: + yield ModelResponse(content=f"\n - {_f.get_call_str()}") + yield ModelResponse(content="\n\n") def get_system_message_for_model(self) -> Optional[str]: return self.system_prompt @@ -927,6 +634,10 @@ def populate_assistant_message( if provider_response.audio is not None: assistant_message.audio_output = provider_response.audio + # Add reasoning content to assistant message + if provider_response.reasoning_content is not None: + assistant_message.reasoning_content = provider_response.reasoning_content + # Add usage metrics if provided if provider_response.response_usage is not None: self.add_usage_metrics_to_assistant_message( @@ -936,20 +647,17 @@ def populate_assistant_message( return assistant_message - def response(self, messages: List[Message]) -> ModelResponse: + def _process_model_response( + self, + messages: List[Message], + model_response: ModelResponse, + ) -> Tuple[Message, bool]: """ - Generate a response from the model. - - Args: - messages: List of messages in the conversation + Process a single model response and return the assistant message and whether to continue. Returns: - ModelResponse: The model's response + Tuple[Message, bool]: (assistant_message, should_continue) """ - logger.debug(f"---------- {self.get_provider()} Response Start ----------") - self._log_messages(messages) - model_response = ModelResponse() - # Create assistant message assistant_message = Message(role=self.assistant_message_role) @@ -961,7 +669,7 @@ def response(self, messages: List[Message]) -> ModelResponse: # Parse provider response provider_response = self.parse_provider_response(response) - # Add parsed data to assistant message + # Add parsed data to model response if provider_response.parsed is not None: model_response.parsed = provider_response.parsed @@ -982,33 +690,85 @@ def response(self, messages: List[Message]) -> ModelResponse: model_response.content = assistant_message.get_content_string() if assistant_message.audio_output is not None: model_response.audio = assistant_message.audio_output + if provider_response.extra is not None: + model_response.extra.update(provider_response.extra) - # Handle tool calls - if ( - self.handle_tool_calls( - assistant_message=assistant_message, + return assistant_message, bool(assistant_message.tool_calls) + + def response(self, messages: List[Message]) -> ModelResponse: + """ + Generate a response from the model. + + Args: + messages: List of messages in the conversation + + Returns: + ModelResponse: The model's response + """ + logger.debug(f"---------- {self.get_provider()} Response Start ----------") + self._log_messages(messages) + model_response = ModelResponse() + + while True: + # Get response from model + assistant_message, has_tool_calls = self._process_model_response( messages=messages, model_response=model_response, - **provider_response.extra, # Any other values set on the provider response is passed here ) - is not None - ): - return self.handle_post_tool_call_messages(messages=messages, model_response=model_response) + + # Handle tool calls if present + if has_tool_calls: + # Prepare function calls + function_calls_to_run = self._prepare_function_calls( + assistant_message=assistant_message, + messages=messages, + model_response=model_response, + ) + function_call_results: List[Message] = [] + + # Execute function calls + for function_call_response in self.run_function_calls( + function_calls=function_calls_to_run, + function_call_results=function_call_results + ): + if ( + function_call_response.event == ModelResponseEvent.tool_call_completed.value + and function_call_response.tool_calls is not None + ): + model_response.tool_calls.extend(function_call_response.tool_calls) + + # Format and add results to messages + self.format_function_call_results( + messages=messages, + function_call_results=function_call_results, + **model_response.extra + ) + + # Check if we should stop after tool calls + if any(m.stop_after_tool_call for m in function_call_results): + break + + # Continue loop to get next response + continue + + # No tool calls or finished processing them + break + logger.debug(f"---------- {self.get_provider()} Response End ----------") return model_response - async def aresponse(self, messages: List[Message]) -> ModelResponse: + async def _aprocess_model_response( + self, + messages: List[Message], + model_response: ModelResponse, + ) -> Tuple[Message, bool]: """ - Generate an asynchronous response from the model. + Process a single async model response and return the assistant message and whether to continue. - Args: - messages: List of messages in the conversation + Returns: + Tuple[Message, bool]: (assistant_message, should_continue) """ - logger.debug(f"---------- {self.get_provider()} Async Response Start ----------") - self._log_messages(messages) - model_response = ModelResponse() - # Create assistant message assistant_message = Message(role=self.assistant_message_role) @@ -1020,7 +780,7 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse: # Parse provider response provider_response = self.parse_provider_response(response) - # Add parsed data to assistant message + # Add parsed data to model response if provider_response.parsed is not None: model_response.parsed = provider_response.parsed @@ -1034,25 +794,76 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse: messages.append(assistant_message) # Log response and metrics - assistant_message.log() + assistant_message.log(metrics=True) # Update model response with assistant message content and audio if assistant_message.content is not None: - model_response.content = assistant_message.get_content_string() + model_response.content += assistant_message.get_content_string() if assistant_message.audio_output is not None: model_response.audio = assistant_message.audio_output + if provider_response.extra is not None: + model_response.extra.update(provider_response.extra) - # -*- Handle tool calls - if ( - await self.ahandle_tool_calls( - assistant_message=assistant_message, + return assistant_message, bool(assistant_message.tool_calls) + + async def aresponse(self, messages: List[Message]) -> ModelResponse: + """ + Generate an asynchronous response from the model. + + Args: + messages: List of messages in the conversation + + Returns: + ModelResponse: The model's response + """ + logger.debug(f"---------- {self.get_provider()} Async Response Start ----------") + self._log_messages(messages) + model_response = ModelResponse() + + while True: + # Get response from model + assistant_message, has_tool_calls = await self._aprocess_model_response( messages=messages, model_response=model_response, - **provider_response.extra, # Any other values set on the provider response is passed here ) - is not None - ): - return await self.ahandle_post_tool_call_messages(messages=messages, model_response=model_response) + + # Handle tool calls if present + if has_tool_calls: + # Prepare function calls + function_calls_to_run = self._prepare_function_calls( + assistant_message=assistant_message, + messages=messages, + model_response=model_response, + ) + function_call_results: List[Message] = [] + + # Execute function calls + async for function_call_response in self.arun_function_calls( + function_calls=function_calls_to_run, + function_call_results=function_call_results + ): + if ( + function_call_response.event == ModelResponseEvent.tool_call_completed.value + and function_call_response.tool_calls is not None + ): + model_response.tool_calls.extend(function_call_response.tool_calls) + + # Format and add results to messages + self.format_function_call_results( + messages=messages, + function_call_results=function_call_results, + **model_response.extra + ) + + # Check if we should stop after tool calls + if any(m.stop_after_tool_call for m in function_call_results): + break + + # Continue loop to get next response + continue + + # No tool calls or finished processing them + break logger.debug(f"---------- {self.get_provider()} Async Response End ----------") return model_response @@ -1062,7 +873,6 @@ def process_response_stream(self, messages: List[Message], assistant_message: Me Process a streaming response from the model. """ for response_delta in self.invoke_stream(messages=messages): - provider_response_delta = self.parse_provider_response_delta(response_delta) if provider_response_delta: # Update metrics @@ -1071,7 +881,7 @@ def process_response_stream(self, messages: List[Message], assistant_message: Me assistant_message.metrics.set_time_to_first_token() if provider_response_delta.content is not None: - # Update stream data + # Update stream data and yield content stream_data.response_content += provider_response_delta.content yield ModelResponse(content=provider_response_delta.content) @@ -1082,7 +892,7 @@ def process_response_stream(self, messages: List[Message], assistant_message: Me stream_data.response_tool_calls.extend(provider_response_delta.tool_calls) if provider_response_delta.audio is not None: - # Update stream data + # Update stream data and yield audio stream_data.response_audio = provider_response_delta.audio yield ModelResponse(audio=provider_response_delta.audio) @@ -1109,43 +919,70 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: """ logger.debug(f"---------- {self.get_provider()} Response Stream Start ----------") self._log_messages(messages) - stream_data: MessageData = MessageData() - - # Create assistant message - assistant_message = Message(role=self.assistant_message_role) - # Generate response - assistant_message.metrics.start_timer() - yield from self.process_response_stream(messages=messages, assistant_message=assistant_message, stream_data=stream_data) - assistant_message.metrics.stop_timer() - - # Add response content and audio to assistant message - if stream_data.response_content != "": - assistant_message.content = stream_data.response_content + while True: + # Create assistant message and stream data + assistant_message = Message(role=self.assistant_message_role) + stream_data = MessageData() + + # Generate response + assistant_message.metrics.start_timer() + yield from self.process_response_stream( + messages=messages, + assistant_message=assistant_message, + stream_data=stream_data + ) + assistant_message.metrics.stop_timer() + + # Populate assistant message from stream data + if stream_data.response_content: + assistant_message.content = stream_data.response_content + if stream_data.response_audio: + assistant_message.audio_output = stream_data.response_audio + if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0: + assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls) + + # Add assistant message to messages + messages.append(assistant_message) + assistant_message.log(metrics=True) + + # Handle tool calls if present + if assistant_message.tool_calls: + yield ModelResponse(content="\n\n") + + # Prepare function calls + function_calls_to_run: List[FunctionCall] = self.get_function_calls_to_run(assistant_message, messages) + function_call_results: List[Message] = [] + + # Show tool calls if enabled + if self.show_tool_calls: + yield from self._show_stream_tool_calls( + function_calls_to_run=function_calls_to_run + ) - if stream_data.response_audio is not None: - assistant_message.audio_output = stream_data.response_audio + # Execute function calls + for function_call_response in self.run_function_calls( + function_calls=function_calls_to_run, + function_call_results=function_call_results + ): + yield function_call_response - # Add tool calls to assistant message - if stream_data.response_tool_calls is not None and len(stream_data.response_tool_calls) > 0: - parsed_tool_calls = self.parse_tool_calls(stream_data.response_tool_calls) - if len(parsed_tool_calls) > 0: - assistant_message.tool_calls = parsed_tool_calls + # Format and add results to messages + self.format_function_call_results( + messages=messages, + function_call_results=function_call_results, + **stream_data.extra + ) - # Add assistant message to messages - messages.append(assistant_message) + # Check if we should stop after tool calls + if any(m.stop_after_tool_call for m in function_call_results): + break - # Log response and metrics - assistant_message.log(metrics=True) + # Continue loop to get next response + continue - # Handle tool calls - if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: - yield from self.handle_stream_tool_calls( - assistant_message=assistant_message, - messages=messages, - **stream_data.extra, # Any other values set on the provider response is passed here - ) - yield from self.handle_post_tool_call_messages_stream(messages=messages) + # No tool calls or finished processing them + break logger.debug(f"---------- {self.get_provider()} Response Stream End ----------") @@ -1153,8 +990,7 @@ async def aprocess_response_stream(self, messages: List[Message], assistant_mess """ Process a streaming response from the model. """ - async for response_delta in self.ainvoke_stream(messages=messages): - + async for response_delta in await self.ainvoke_stream(messages=messages): provider_response_delta = self.parse_provider_response_delta(response_delta) if provider_response_delta: # Update metrics @@ -1163,7 +999,7 @@ async def aprocess_response_stream(self, messages: List[Message], assistant_mess assistant_message.metrics.set_time_to_first_token() if provider_response_delta.content is not None: - # Update stream data + # Update stream data and yield content stream_data.response_content += provider_response_delta.content yield ModelResponse(content=provider_response_delta.content) @@ -1174,7 +1010,7 @@ async def aprocess_response_stream(self, messages: List[Message], assistant_mess stream_data.response_tool_calls.extend(provider_response_delta.tool_calls) if provider_response_delta.audio is not None: - # Update stream data + # Update stream data and yield audio stream_data.response_audio = provider_response_delta.audio yield ModelResponse(audio=provider_response_delta.audio) @@ -1189,7 +1025,7 @@ async def aprocess_response_stream(self, messages: List[Message], assistant_mess response_usage=provider_response_delta.response_usage ) - async def aresponse_stream(self, messages: List[Message]) -> Any: + async def aresponse_stream(self, messages: List[Message]) -> AsyncIterator[ModelResponse]: """ Generate an asynchronous streaming response from the model. @@ -1197,49 +1033,73 @@ async def aresponse_stream(self, messages: List[Message]) -> Any: messages: List of messages in the conversation Returns: - Any: Async iterator of model responses + AsyncIterator[ModelResponse]: Async iterator of model responses """ logger.debug(f"---------- {self.get_provider()} Async Response Stream Start ----------") self._log_messages(messages) - stream_data = MessageData() - - # Create assistant message - assistant_message = Message(role=self.assistant_message_role) - # Generate response - assistant_message.metrics.start_timer() - async for response in self.aprocess_response_stream(messages=messages, assistant_message=assistant_message, stream_data=stream_data): - yield response - assistant_message.metrics.stop_timer() + while True: + # Create assistant message and stream data + assistant_message = Message(role=self.assistant_message_role) + stream_data = MessageData() - # Add response content and audio to assistant message - if stream_data.response_content != "": - assistant_message.content = stream_data.response_content + # Generate response + assistant_message.metrics.start_timer() + async for response in self.aprocess_response_stream( + messages=messages, + assistant_message=assistant_message, + stream_data=stream_data + ): + yield response + assistant_message.metrics.stop_timer() + + # Populate assistant message from stream data + if stream_data.response_content: + assistant_message.content = stream_data.response_content + if stream_data.response_audio: + assistant_message.audio_output = stream_data.response_audio + if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0: + assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls) + + # Add assistant message to messages + messages.append(assistant_message) + assistant_message.log(metrics=True) + + # Handle tool calls if present + if assistant_message.tool_calls: + yield ModelResponse(content="\n\n") + + # Prepare function calls + function_calls_to_run: List[FunctionCall] = self.get_function_calls_to_run(assistant_message, messages) + function_call_results: List[Message] = [] + + # Show tool calls if enabled + if self.show_tool_calls: + for model_response in self._show_stream_tool_calls(function_calls_to_run): + yield model_response + + # Execute function calls + async for function_call_response in self.arun_function_calls( + function_calls=function_calls_to_run, + function_call_results=function_call_results + ): + yield function_call_response - if stream_data.response_audio is not None: - assistant_message.audio_output = stream_data.response_audio + # Format and add results to messages + self.format_function_call_results( + messages=messages, + function_call_results=function_call_results, + **stream_data.extra + ) - # Add tool calls to assistant message - if stream_data.response_tool_calls is not None and len(stream_data.response_tool_calls) > 0: - parsed_tool_calls = self.parse_tool_calls(stream_data.response_tool_calls) - if len(parsed_tool_calls) > 0: - assistant_message.tool_calls = parsed_tool_calls + # Check if we should stop after tool calls + if any(m.stop_after_tool_call for m in function_call_results): + break - # Add assistant message to messages - messages.append(assistant_message) - - # Log response and metrics - assistant_message.log(metrics=True) + # Continue loop to get next response + continue - # Handle tool calls - if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0: - async for tool_call_response in self.ahandle_stream_tool_calls( - assistant_message=assistant_message, - messages=messages, - **stream_data.extra - ): - yield tool_call_response - async for post_tool_call_response in self.ahandle_post_tool_call_messages_stream(messages=messages): - yield post_tool_call_response + # No tool calls or finished processing them + break logger.debug(f"---------- {self.get_provider()} Async Response Stream End ----------") diff --git a/libs/agno/agno/models/cohere/chat.py b/libs/agno/agno/models/cohere/chat.py index 791543cdae..c59050df7c 100644 --- a/libs/agno/agno/models/cohere/chat.py +++ b/libs/agno/agno/models/cohere/chat.py @@ -3,6 +3,7 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple +from agno.exceptions import ModelProviderError from agno.models.base import MessageData, Model from agno.models.message import Message from agno.models.response import ModelResponse, ProviderResponse @@ -140,8 +141,12 @@ def invoke( """ request_kwargs = self.request_kwargs - - return self.get_client().chat(model=self.id, messages=self._format_messages(messages), **request_kwargs) + + try: + return self.get_client().chat(model=self.id, messages=self._format_messages(messages), **request_kwargs) + except Exception as e: + logger.error(f"Unexpected error calling Cohere API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e def invoke_stream( self, messages: List[Message] @@ -156,7 +161,12 @@ def invoke_stream( Iterator[StreamedChatResponseV2]: An iterator of streamed chat responses. """ request_kwargs = self.request_kwargs - return self.get_client().chat_stream(model=self.id, messages=self._format_messages(messages), **request_kwargs) + + try: + return self.get_client().chat_stream(model=self.id, messages=self._format_messages(messages), **request_kwargs) + except Exception as e: + logger.error(f"Unexpected error calling Cohere API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e async def ainvoke( self, messages: List[Message] @@ -172,7 +182,11 @@ async def ainvoke( """ request_kwargs = self.request_kwargs - return await self.get_async_client().chat(model=self.id, messages=self._format_messages(messages), **request_kwargs) + try: + return await self.get_async_client().chat(model=self.id, messages=self._format_messages(messages), **request_kwargs) + except Exception as e: + logger.error(f"Unexpected error calling Cohere API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e async def ainvoke_stream( self, messages: List[Message] @@ -187,9 +201,13 @@ async def ainvoke_stream( AsyncIterator[StreamedChatResponseV2]: An async iterator of streamed chat responses. """ request_kwargs = self.request_kwargs - - async for response in self.get_async_client().chat_stream(model=self.id, messages=self._format_messages(messages), **request_kwargs): - yield response + + try: + async for response in self.get_async_client().chat_stream(model=self.id, messages=self._format_messages(messages), **request_kwargs): + yield response + except Exception as e: + logger.error(f"Unexpected error calling Cohere API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e def parse_provider_response(self, response: ChatResponse) -> ProviderResponse: """ diff --git a/libs/agno/agno/models/deepseek/deepseek.py b/libs/agno/agno/models/deepseek/deepseek.py index 985b28d93f..a6bd9bf6f1 100644 --- a/libs/agno/agno/models/deepseek/deepseek.py +++ b/libs/agno/agno/models/deepseek/deepseek.py @@ -2,17 +2,7 @@ from os import getenv from typing import Optional -from agno.media import AudioOutput -from agno.models.base import Metrics -from agno.models.message import Message from agno.models.openai.like import OpenAILike -from agno.utils.log import logger - -try: - from openai.types.chat.chat_completion_message import ChatCompletionMessage - from openai.types.completion_usage import CompletionUsage -except ModuleNotFoundError: - raise ImportError("`openai` not installed. Please install using `pip install openai`") @dataclass @@ -29,47 +19,3 @@ class DeepSeek(OpenAILike): api_key: Optional[str] = getenv("DEEPSEEK_API_KEY", None) base_url: str = "https://api.deepseek.com" - - def create_assistant_message( - self, - response_message: ChatCompletionMessage, - metrics: Metrics, - response_usage: Optional[CompletionUsage], - ) -> Message: - """ - Create an assistant message from the response. - - Args: - response_message (ChatCompletionMessage): The response message. - metrics (Metrics): The metrics. - response_usage (Optional[CompletionUsage]): The response usage. - - Returns: - Message: The assistant message. - """ - assistant_message = Message( - role=response_message.role or "assistant", - content=response_message.content, - reasoning_content=response_message.reasoning_content - if hasattr(response_message, "reasoning_content") - else None, - ) - if response_message.tool_calls is not None and len(response_message.tool_calls) > 0: - try: - assistant_message.tool_calls = [t.model_dump() for t in response_message.tool_calls] - except Exception as e: - logger.warning(f"Error processing tool calls: {e}") - if hasattr(response_message, "audio") and response_message.audio is not None: - try: - assistant_message.audio_output = AudioOutput( - id=response_message.audio.id, - content=response_message.audio.data, - expires_at=response_message.audio.expires_at, - transcript=response_message.audio.transcript, - ) - except Exception as e: - logger.warning(f"Error processing audio: {e}") - - # Update metrics - self.update_usage_metrics(assistant_message, metrics, response_usage) - return assistant_message diff --git a/libs/agno/agno/models/fireworks/fireworks.py b/libs/agno/agno/models/fireworks/fireworks.py index d4c209761c..5cf0c0d0a3 100644 --- a/libs/agno/agno/models/fireworks/fireworks.py +++ b/libs/agno/agno/models/fireworks/fireworks.py @@ -27,20 +27,3 @@ class Fireworks(OpenAILike): api_key: Optional[str] = getenv("FIREWORKS_API_KEY", None) base_url: str = "https://api.fireworks.ai/inference/v1" - - def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]: - """ - Send a streaming chat completion request to the Fireworks API. - - Args: - messages (List[Message]): A list of message objects representing the conversation. - - Returns: - Iterator[ChatCompletionChunk]: An iterator of chat completion chunks. - """ - yield from self.get_client().chat.completions.create( - model=self.id, - messages=[m.to_dict() for m in messages], # type: ignore - stream=True, - **self.request_kwargs, - ) # type: ignore diff --git a/libs/agno/agno/models/groq/groq.py b/libs/agno/agno/models/groq/groq.py index 74dc64fbc0..a0a6f3b53d 100644 --- a/libs/agno/agno/models/groq/groq.py +++ b/libs/agno/agno/models/groq/groq.py @@ -4,6 +4,7 @@ import httpx +from agno.exceptions import ModelProviderError from agno.models.base import Model from agno.models.message import Message from agno.models.response import ProviderResponse @@ -221,11 +222,16 @@ def invoke(self, messages: List[Message]) -> ChatCompletion: Returns: ChatCompletion: The chat completion response from the API. """ - return self.get_client().chat.completions.create( - model=self.id, - messages=[format_message(m) for m in messages], # type: ignore - **self.request_kwargs, - ) + try: + return self.get_client().chat.completions.create( + model=self.id, + messages=[format_message(m) for m in messages], # type: ignore + **self.request_kwargs, + ) + except Exception as e: + logger.error(f"Unexpected error calling Groq API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e + async def ainvoke(self, messages: List[Message]) -> ChatCompletion: """ @@ -237,11 +243,16 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletion: Returns: ChatCompletion: The chat completion response from the API. """ - return await self.get_async_client().chat.completions.create( - model=self.id, - messages=[format_message(m) for m in messages], # type: ignore - **self.request_kwargs, - ) + try: + return await self.get_async_client().chat.completions.create( + model=self.id, + messages=[format_message(m) for m in messages], # type: ignore + **self.request_kwargs, + ) + except Exception as e: + logger.error(f"Unexpected error calling Groq API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e + def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]: """ @@ -253,12 +264,17 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk Returns: Iterator[ChatCompletionChunk]: An iterator of chat completion chunks. """ - return self.get_client().chat.completions.create( - model=self.id, - messages=[format_message(m) for m in messages], # type: ignore - stream=True, - **self.request_kwargs, - ) + try: + return self.get_client().chat.completions.create( + model=self.id, + messages=[format_message(m) for m in messages], # type: ignore + stream=True, + **self.request_kwargs, + ) + except Exception as e: + logger.error(f"Unexpected error calling Groq API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e + async def ainvoke_stream(self, messages: List[Message]) -> Any: """ @@ -269,13 +285,17 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any: Returns: Any: An asynchronous iterator of chat completion chunks. - """ - return await self.get_async_client().chat.completions.create( - model=self.id, - messages=[format_message(m) for m in messages], # type: ignore - stream=True, - **self.request_kwargs, - ) + """ + try: + return await self.get_async_client().chat.completions.create( + model=self.id, + messages=[format_message(m) for m in messages], # type: ignore + stream=True, + **self.request_kwargs, + ) + except Exception as e: + logger.error(f"Unexpected error calling Groq API: {str(e)}") + raise ModelProviderError(e, self.name, self.id) from e # Override base method diff --git a/libs/agno/agno/models/openai/chat.py b/libs/agno/agno/models/openai/chat.py index c44ce86dc5..8f7449ab45 100644 --- a/libs/agno/agno/models/openai/chat.py +++ b/libs/agno/agno/models/openai/chat.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncIterator from dataclasses import dataclass from os import getenv from typing import Any, Dict, Iterator, List, Optional, Union @@ -5,6 +6,7 @@ import httpx from pydantic import BaseModel +from agno.exceptions import ModelProviderError from agno.media import AudioOutput from agno.models.base import Model from agno.models.message import Message @@ -15,6 +17,7 @@ try: from openai import AsyncOpenAI as AsyncOpenAIClient from openai import OpenAI as OpenAIClient + from openai import RateLimitError, APIConnectionError, APIStatusError from openai.types.chat import ChatCompletionAudio from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ( @@ -266,8 +269,8 @@ def invoke(self, messages: List[Message]) -> Union[ChatCompletion, ParsedChatCom Returns: ChatCompletion: The chat completion response from the API. """ - if self.response_format is not None and self.structured_outputs: - try: + try: + if self.response_format is not None and self.structured_outputs: if isinstance(self.response_format, type) and issubclass(self.response_format, BaseModel): return self.get_client().beta.chat.completions.parse( model=self.id, @@ -276,14 +279,24 @@ def invoke(self, messages: List[Message]) -> Union[ChatCompletion, ParsedChatCom ) else: raise ValueError("response_format must be a subclass of BaseModel if structured_outputs=True") - except Exception as e: - logger.error(f"Error from OpenAI API: {e}") - return self.get_client().chat.completions.create( - model=self.id, - messages=[self._format_message(m) for m in messages], # type: ignore - **self.request_kwargs, - ) + return self.get_client().chat.completions.create( + model=self.id, + messages=[self._format_message(m) for m in messages], # type: ignore + **self.request_kwargs, + ) + except RateLimitError as e: + logger.error(f"Rate limit error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIConnectionError as e: + logger.error(f"API connection error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIStatusError as e: + logger.error(f"API status error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except Exception as e: + logger.error(f"Error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e async def ainvoke(self, messages: List[Message]) -> Union[ChatCompletion, ParsedChatCompletion]: """ @@ -295,8 +308,8 @@ async def ainvoke(self, messages: List[Message]) -> Union[ChatCompletion, Parsed Returns: ChatCompletion: The chat completion response from the API. """ - if self.response_format is not None and self.structured_outputs: - try: + try: + if self.response_format is not None and self.structured_outputs: if isinstance(self.response_format, type) and issubclass(self.response_format, BaseModel): return await self.get_async_client().beta.chat.completions.parse( model=self.id, @@ -305,14 +318,23 @@ async def ainvoke(self, messages: List[Message]) -> Union[ChatCompletion, Parsed ) else: raise ValueError("response_format must be a subclass of BaseModel if structured_outputs=True") - except Exception as e: - logger.error(f"Error from OpenAI API: {e}") - - return await self.get_async_client().chat.completions.create( - model=self.id, - messages=[self._format_message(m) for m in messages], # type: ignore - **self.request_kwargs, - ) + return await self.get_async_client().chat.completions.create( + model=self.id, + messages=[self._format_message(m) for m in messages], # type: ignore + **self.request_kwargs, + ) + except RateLimitError as e: + logger.error(f"Rate limit error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIConnectionError as e: + logger.error(f"API connection error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIStatusError as e: + logger.error(f"API status error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except Exception as e: + logger.error(f"Error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk]: """ @@ -324,15 +346,28 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk Returns: Iterator[ChatCompletionChunk]: An iterator of chat completion chunks. """ - yield from self.get_client().chat.completions.create( - model=self.id, - messages=[self._format_message(m) for m in messages], # type: ignore - stream=True, - stream_options={"include_usage": True}, - **self.request_kwargs, - ) # type: ignore - - async def ainvoke_stream(self, messages: List[Message]) -> Any: + try: + yield from self.get_client().chat.completions.create( + model=self.id, + messages=[self._format_message(m) for m in messages], # type: ignore + stream=True, + stream_options={"include_usage": True}, + **self.request_kwargs, + ) # type: ignore + except RateLimitError as e: + logger.error(f"Rate limit error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIConnectionError as e: + logger.error(f"API connection error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIStatusError as e: + logger.error(f"API status error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except Exception as e: + logger.error(f"Error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + + async def ainvoke_stream(self, messages: List[Message]) -> AsyncIterator[ChatCompletionChunk]: """ Sends an asynchronous streaming chat completion request to the OpenAI API. @@ -342,15 +377,27 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any: Returns: Any: An asynchronous iterator of chat completion chunks. """ - async_stream = await self.get_async_client().chat.completions.create( - model=self.id, - messages=[self._format_message(m) for m in messages], # type: ignore - stream=True, - stream_options={"include_usage": True}, - **self.request_kwargs, - ) - async for chunk in async_stream: # type: ignore - yield chunk + try: + async_stream = await self.get_async_client().chat.completions.create( + model=self.id, + messages=[self._format_message(m) for m in messages], # type: ignore + stream=True, + stream_options={"include_usage": True}, + **self.request_kwargs, + ) + return async_stream + except RateLimitError as e: + logger.error(f"Rate limit error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIConnectionError as e: + logger.error(f"API connection error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except APIStatusError as e: + logger.error(f"API status error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e + except Exception as e: + logger.error(f"Error from OpenAI API: {e}") + raise ModelProviderError(e, self.name, self.id) from e # Override base method @staticmethod @@ -455,6 +502,9 @@ def parse_provider_response( except Exception as e: logger.warning(f"Error processing audio: {e}") + if hasattr(response_message, "reasoning_content") and response_message.reasoning_content is not None: + provider_response.reasoning_content = response_message.reasoning_content + if response.usage is not None: provider_response.response_usage = response.usage diff --git a/libs/agno/agno/models/response.py b/libs/agno/agno/models/response.py index 26ecf2cde2..95168051a6 100644 --- a/libs/agno/agno/models/response.py +++ b/libs/agno/agno/models/response.py @@ -25,6 +25,8 @@ class ModelResponse: event: str = ModelResponseEvent.assistant_response.value created_at: int = int(time()) + extra: Optional[Dict[str, Any]] = field(default_factory=dict) + @dataclass class ProviderResponse: """Response parsed from the response that the model provider returns""" @@ -33,6 +35,7 @@ class ProviderResponse: content: Optional[str] = None parsed: Optional[Any] = None audio: Optional[AudioOutput] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = field(default_factory=list) response_usage: Optional[Any] = None