diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/openai_assistant_agent.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/openai_assistant_agent.py index b0c6c669f5776..df81310d73daa 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/openai_assistant_agent.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/openai_assistant_agent.py @@ -3,7 +3,7 @@ import json import logging import time -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable from llama_index.agent.openai.utils import get_function_by_name from llama_index.core.agent.types import BaseAgent @@ -59,7 +59,8 @@ def from_openai_thread_messages(thread_messages: List[Any]) -> List[ChatMessage] def call_function( - tools: List[BaseTool], fn_obj: Any, verbose: bool = False, get_tool_by: function = get_function_by_name + tools: List[BaseTool], fn_obj: Any, verbose: bool = False, + get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name ) -> Tuple[ChatMessage, ToolOutput]: """Call a function and return the output as a string.""" from openai.types.beta.threads.required_action_function_tool_call import Function @@ -71,7 +72,7 @@ def call_function( if verbose: print("=== Calling Function ===") print(f"Calling function: {name} with args: {arguments_str}") - get_tool_by(tools, name) + tool = get_tool_by(tools, name) argument_dict = json.loads(arguments_str) output = tool(**argument_dict) if verbose: @@ -90,7 +91,8 @@ def call_function( async def acall_function( - tools: List[BaseTool], fn_obj: Any, verbose: bool = False, get_tool_by: function = get_function_by_name + tools: List[BaseTool], fn_obj: Any, verbose: bool = False, + get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name ) -> Tuple[ChatMessage, ToolOutput]: """Call an async function and return the output as a string.""" from openai.types.beta.threads.required_action_function_tool_call import Function @@ -102,7 +104,7 @@ async def acall_function( if verbose: print("=== Calling Function ===") print(f"Calling function: {name} with args: {arguments_str}") - get_tool_by(tools, name) + tool = get_tool_by(tools, name) argument_dict = json.loads(arguments_str) async_tool = adapt_to_async_tool(tool) output = await async_tool.acall(**argument_dict) @@ -169,7 +171,7 @@ def __init__( self._run_retrieve_sleep_time = run_retrieve_sleep_time self._verbose = verbose self.file_dict = file_dict - self._get_tool_by = get_tool_by + self._get_tool_fn = get_tool_by self.callback_manager = callback_manager or CallbackManager([]) @@ -358,7 +360,8 @@ def _run_function_calling(self, run: Any) -> List[ToolOutput]: tool_output_objs: List[ToolOutput] = [] for tool_call in tool_calls: fn_obj = tool_call.function - _, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose, get_tool_by=self._get_tool_by) + _, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose, + get_tool_by=self._get_tool_fn) tool_output_dicts.append( {"tool_call_id": tool_call.id, "output": str(tool_output)} ) @@ -381,7 +384,7 @@ async def _arun_function_calling(self, run: Any) -> List[ToolOutput]: for tool_call in tool_calls: fn_obj = tool_call.function _, tool_output = await acall_function( - self._tools, fn_obj, verbose=self._verbose, get_tool_by=self._get_tool_by + self._tools, fn_obj, verbose=self._verbose, get_tool_by=self._get_tool_fn ) tool_output_dicts.append( {"tool_call_id": tool_call.id, "output": str(tool_output)} diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py index 79912dc6c3e14..935d2db19f44c 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py @@ -96,7 +96,7 @@ def call_function( if verbose: print("=== Calling Function ===") print(f"Calling function: {name} with args: {arguments_str}") - tool = get_function_by_name(tools, name) + tool = get_tool_by(tools, name) argument_dict = json.loads(arguments_str) # Call tool @@ -119,7 +119,8 @@ def call_function( async def acall_function( - tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False + tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False, + get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name, ) -> Tuple[ChatMessage, ToolOutput]: """Call a function and return the output as a string.""" # validations to get passed mypy @@ -135,7 +136,7 @@ async def acall_function( if verbose: print("=== Calling Function ===") print(f"Calling function: {name} with args: {arguments_str}") - tool = get_function_by_name(tools, name) + tool = get_tool_by(tools, name) async_tool = adapt_to_async_tool(tool) argument_dict = json.loads(arguments_str) output = await async_tool.acall(**argument_dict) @@ -167,12 +168,14 @@ def __init__( max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, tool_retriever: Optional[ObjectRetriever[BaseTool]] = None, + get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name, ): self._llm = llm self._verbose = verbose self._max_function_calls = max_function_calls self.prefix_messages = prefix_messages self.callback_manager = callback_manager or self._llm.callback_manager + self._get_tools_fn = get_tool_by if len(tools) > 0 and tool_retriever is not None: raise ValueError("Cannot specify both tools and tool_retriever") @@ -196,6 +199,7 @@ def from_tools( callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, prefix_messages: Optional[List[ChatMessage]] = None, + get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name, **kwargs: Any, ) -> "OpenAIAgentWorker": """Create an OpenAIAgent from a list of tools. @@ -236,6 +240,7 @@ def from_tools( verbose=verbose, max_function_calls=max_function_calls, callback_manager=callback_manager, + get_tool_by=get_tool_by, ) def get_all_messages(self, task: Task) -> List[ChatMessage]: @@ -354,13 +359,13 @@ def _call_function( CBEventType.FUNCTION_CALL, payload={ EventPayload.FUNCTION_CALL: function_call.arguments, - EventPayload.TOOL: get_function_by_name( + EventPayload.TOOL: self._get_tools_fn( tools, function_call.name ).metadata, }, ) as event: function_message, tool_output = call_function( - tools, tool_call, verbose=self._verbose + tools, tool_call, verbose=self._verbose, get_tool_by=self._get_tools_fn ) event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) sources.append(tool_output) @@ -383,13 +388,13 @@ async def _acall_function( CBEventType.FUNCTION_CALL, payload={ EventPayload.FUNCTION_CALL: function_call.arguments, - EventPayload.TOOL: get_function_by_name( + EventPayload.TOOL: self._get_tools_fn( tools, function_call.name ).metadata, }, ) as event: function_message, tool_output = await acall_function( - tools, tool_call, verbose=self._verbose + tools, tool_call, verbose=self._verbose, get_tool_by=self._get_tools_fn ) event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) sources.append(tool_output)