Skip to content

Commit

Permalink
Fix save
Browse files Browse the repository at this point in the history
  • Loading branch information
Théophilus Homawoo committed Feb 16, 2024
1 parent 2e7ba38 commit d8e1424
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable

from llama_index.agent.openai.utils import get_function_by_name
from llama_index.core.agent.types import BaseAgent
Expand Down Expand Up @@ -59,7 +59,8 @@ def from_openai_thread_messages(thread_messages: List[Any]) -> List[ChatMessage]


def call_function(
tools: List[BaseTool], fn_obj: Any, verbose: bool = False, get_tool_by: function = get_function_by_name
tools: List[BaseTool], fn_obj: Any, verbose: bool = False,
get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
from openai.types.beta.threads.required_action_function_tool_call import Function
Expand All @@ -71,7 +72,7 @@ def call_function(
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
get_tool_by(tools, name)
tool = get_tool_by(tools, name)
argument_dict = json.loads(arguments_str)
output = tool(**argument_dict)
if verbose:
Expand All @@ -90,7 +91,8 @@ def call_function(


async def acall_function(
tools: List[BaseTool], fn_obj: Any, verbose: bool = False, get_tool_by: function = get_function_by_name
tools: List[BaseTool], fn_obj: Any, verbose: bool = False,
get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name
) -> Tuple[ChatMessage, ToolOutput]:
"""Call an async function and return the output as a string."""
from openai.types.beta.threads.required_action_function_tool_call import Function
Expand All @@ -102,7 +104,7 @@ async def acall_function(
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
get_tool_by(tools, name)
tool = get_tool_by(tools, name)
argument_dict = json.loads(arguments_str)
async_tool = adapt_to_async_tool(tool)
output = await async_tool.acall(**argument_dict)
Expand Down Expand Up @@ -169,7 +171,7 @@ def __init__(
self._run_retrieve_sleep_time = run_retrieve_sleep_time
self._verbose = verbose
self.file_dict = file_dict
self._get_tool_by = get_tool_by
self._get_tool_fn = get_tool_by

self.callback_manager = callback_manager or CallbackManager([])

Expand Down Expand Up @@ -358,7 +360,8 @@ def _run_function_calling(self, run: Any) -> List[ToolOutput]:
tool_output_objs: List[ToolOutput] = []
for tool_call in tool_calls:
fn_obj = tool_call.function
_, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose, get_tool_by=self._get_tool_by)
_, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose,
get_tool_by=self._get_tool_fn)
tool_output_dicts.append(
{"tool_call_id": tool_call.id, "output": str(tool_output)}
)
Expand All @@ -381,7 +384,7 @@ async def _arun_function_calling(self, run: Any) -> List[ToolOutput]:
for tool_call in tool_calls:
fn_obj = tool_call.function
_, tool_output = await acall_function(
self._tools, fn_obj, verbose=self._verbose, get_tool_by=self._get_tool_by
self._tools, fn_obj, verbose=self._verbose, get_tool_by=self._get_tool_fn
)
tool_output_dicts.append(
{"tool_call_id": tool_call.id, "output": str(tool_output)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def call_function(
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
tool = get_tool_by(tools, name)
argument_dict = json.loads(arguments_str)

# Call tool
Expand All @@ -119,7 +119,8 @@ def call_function(


async def acall_function(
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False,
get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name,
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
# validations to get passed mypy
Expand All @@ -135,7 +136,7 @@ async def acall_function(
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
tool = get_tool_by(tools, name)
async_tool = adapt_to_async_tool(tool)
argument_dict = json.loads(arguments_str)
output = await async_tool.acall(**argument_dict)
Expand Down Expand Up @@ -167,12 +168,14 @@ def __init__(
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name,
):
self._llm = llm
self._verbose = verbose
self._max_function_calls = max_function_calls
self.prefix_messages = prefix_messages
self.callback_manager = callback_manager or self._llm.callback_manager
self._get_tools_fn = get_tool_by

if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
Expand All @@ -196,6 +199,7 @@ def from_tools(
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
get_tool_by: Callable[[List[BaseTool], str], BaseTool] = get_function_by_name,
**kwargs: Any,
) -> "OpenAIAgentWorker":
"""Create an OpenAIAgent from a list of tools.
Expand Down Expand Up @@ -236,6 +240,7 @@ def from_tools(
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
get_tool_by=get_tool_by,
)

def get_all_messages(self, task: Task) -> List[ChatMessage]:
Expand Down Expand Up @@ -354,13 +359,13 @@ def _call_function(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: function_call.arguments,
EventPayload.TOOL: get_function_by_name(
EventPayload.TOOL: self._get_tools_fn(
tools, function_call.name
).metadata,
},
) as event:
function_message, tool_output = call_function(
tools, tool_call, verbose=self._verbose
tools, tool_call, verbose=self._verbose, get_tool_by=self._get_tools_fn
)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
sources.append(tool_output)
Expand All @@ -383,13 +388,13 @@ async def _acall_function(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: function_call.arguments,
EventPayload.TOOL: get_function_by_name(
EventPayload.TOOL: self._get_tools_fn(
tools, function_call.name
).metadata,
},
) as event:
function_message, tool_output = await acall_function(
tools, tool_call, verbose=self._verbose
tools, tool_call, verbose=self._verbose, get_tool_by=self._get_tools_fn
)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
sources.append(tool_output)
Expand Down

0 comments on commit d8e1424

Please sign in to comment.