Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AssistantAgent Tool Call Behavior #4602

Merged
merged 24 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SystemMessage,
UserMessage,
)
from autogen_core.components.models._types import CreateResult
husseinmozannar marked this conversation as resolved.
Show resolved Hide resolved
from autogen_core.components.tools import FunctionTool, Tool
from typing_extensions import deprecated

Expand Down Expand Up @@ -54,6 +55,7 @@ class AssistantAgent(BaseChatAgent):
The assistant agent is not thread-safe or coroutine-safe.
It should not be shared between multiple tasks or coroutines, and it should
not call its methods concurrently.
If multiple handoffs are detected, only the first handoff is executed.
```

Args:
Expand All @@ -66,11 +68,13 @@ class AssistantAgent(BaseChatAgent):
If a handoff is a string, it should represent the target agent's name.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
max_tool_call_iterations (int, optional): The maximum number of attempts to run the model until the response is not a list of tool calls but a string.

Raises:
ValueError: If tool names are not unique.
ValueError: If handoff names are not unique.
ValueError: If handoff names are not unique from tool names.
ValueError: If maximum number of tool iterations is less than 1.

Examples:

Expand Down Expand Up @@ -181,9 +185,13 @@ def __init__(
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
max_tool_call_iterations: int = 1,
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(name=name, description=description)
self._model_client = model_client
if max_tool_call_iterations < 1:
raise ValueError("The maximum number of tool iterations must be at least 1.")
self._max_tool_call_iterations = max_tool_call_iterations
if system_message is None:
self._system_messages = []
else:
Expand Down Expand Up @@ -257,63 +265,82 @@ async def on_messages_stream(

# Inner messages.
inner_messages: List[AgentMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
return

# Model response holder and tool call messages.
result: CreateResult | None = None
tool_call_msg: ToolCallMessage | None = None
tool_call_result_msg: ToolCallResultMessage | None = None

# call the model for _max_tool_call_iterations times or until the response is a string
tool_call_iteration = 0
while tool_call_iteration < self._max_tool_call_iterations:
tool_call_iteration += 1
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
# check if the response is a list of tool calls and run the tool calls.
if isinstance(result.content, str):
break
elif isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg
husseinmozannar marked this conversation as resolved.
Show resolved Hide resolved

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
# show warning if multiple handoffs detected
warnings.warn(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}",
stacklevel=2,
)
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
return

assert result is not None
# if last model response is a list of tool calls
if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
tool_call_summary = "Tool calls:"
assert isinstance(tool_call_msg, ToolCallMessage)
assert isinstance(tool_call_result_msg, ToolCallResultMessage)
for i in range(len(tool_call_msg.content)):
tool_call_summary += f"\n{tool_call_msg.content[i].name}({tool_call_msg.content[i].arguments}) = {tool_call_result_msg.content[i].content}"
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
# if last model response is a string
else:
assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
26 changes: 26 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
Expand All @@ -128,6 +129,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == 'Tool calls:\n_pass_function({"input": "task"}) = pass'
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10
Expand All @@ -152,6 +154,30 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
# Test with max tool call iterations = 2
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
max_tool_call_iterations=2,
)
result = await agent.run(task="task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "Hello"
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10


@pytest.mark.asyncio
Expand Down
30 changes: 29 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,28 +306,30 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
# Test with repeat tool calls once
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool = FunctionTool(_pass_function, name="pass", description="pass function")
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[tool],
max_tool_call_iterations=2,
)
echo_agent = _EchoAgent("echo_agent", description="echo agent")
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent], termination_condition=termination)
result = await team.run(
task="Write a program that prints 'Hello, world!'",
)

assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response

assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

context = tool_use_agent._model_context # pyright: ignore
Expand Down Expand Up @@ -363,6 +365,32 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result
# Test with no tool call repeat
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool = FunctionTool(_pass_function, name="pass", description="pass function")
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[tool],
)
echo_agent = _EchoAgent("echo_agent", description="echo agent")
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent], termination_condition=termination)
result = await team.run(
task="Write a program that prints 'Hello, world!'",
)
assert len(result.messages) == 8
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
assert isinstance(result.messages[6], TextMessage) # echo agent response
assert isinstance(result.messages[7], TextMessage) # tool use agent response, that has TERMINATE

assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"


@pytest.mark.asyncio
Expand Down
Loading
Loading