From ff9afa61d56f81cb92de8b056de54ed877ef6821 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Tue, 10 Dec 2024 16:32:29 -0600 Subject: [PATCH 01/22] initial working e2e --- .../llama_index/core/agent/multi_agent/BUILD | 1 + .../core/agent/multi_agent/__init__.py | 0 .../core/agent/multi_agent/agent_config.py | 35 ++ .../agent/multi_agent/multi_agent_workflow.py | 577 ++++++++++++++++++ .../core/agent/multi_agent/workflow_events.py | 70 +++ .../llama_index/core/workflow/__init__.py | 2 + .../core/workflow/function_context_tool.py | 131 ++++ 7 files changed, 816 insertions(+) create mode 100644 llama-index-core/llama_index/core/agent/multi_agent/BUILD create mode 100644 llama-index-core/llama_index/core/agent/multi_agent/__init__.py create mode 100644 llama-index-core/llama_index/core/agent/multi_agent/agent_config.py create mode 100644 llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py create mode 100644 llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py create mode 100644 llama-index-core/llama_index/core/workflow/function_context_tool.py diff --git a/llama-index-core/llama_index/core/agent/multi_agent/BUILD b/llama-index-core/llama_index/core/agent/multi_agent/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-core/llama_index/core/agent/multi_agent/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-core/llama_index/core/agent/multi_agent/__init__.py b/llama-index-core/llama_index/core/agent/multi_agent/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py new file mode 100644 index 0000000000000..4c1fb311917f9 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py @@ -0,0 +1,35 @@ +from enum import Enum +from typing import List, Optional + +from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict +from llama_index.core.llms import LLM +from llama_index.core.objects import ObjectRetriever +from llama_index.core.tools import BaseTool + + +class AgentMode(str, Enum): + """Agent mode.""" + + DEFAULT = "default" + REACT = "react" + FUNCTION = "function" + + +class AgentConfig(BaseModel): + """Configuration for a single agent in the multi-agent system.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str + description: str + system_prompt: Optional[str] = None + tools: Optional[List[BaseTool]] = None + tool_retriever: Optional[ObjectRetriever] = None + tools_requiring_human_confirmation: Optional[List[str]] = Field( + default_factory=list + ) + can_handoff_to: Optional[List[str]] = Field(default=None) + handoff_prompt_template: Optional[str] = None + llm: Optional[LLM] = None + is_root_agent: bool = False + mode: AgentMode = AgentMode.DEFAULT diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py new file mode 100644 index 0000000000000..9ffe3a04474a8 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -0,0 +1,577 @@ +import asyncio +from typing import Any, Dict, List, Optional, Union, cast + +from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode +from llama_index.core.agent.multi_agent.workflow_events import ( + HandoffEvent, + ToolCall, + AgentInput, + AgentSetup, + AgentStream, + AgentOutput, +) +from llama_index.core.agent.react.output_parser import ReActOutputParser +from llama_index.core.agent.react.formatter import ReActChatFormatter +from llama_index.core.agent.react.types import ( + ActionReasoningStep, + ObservationReasoningStep, + ResponseReasoningStep, +) +from llama_index.core.llms import ChatMessage, LLM +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.memory import BaseMemory, ChatMemoryBuffer +from llama_index.core.prompts import BasePromptTemplate, PromptTemplate +from llama_index.core.tools import ( + BaseTool, + AsyncBaseTool, + FunctionTool, + ToolOutput, + adapt_to_async_tool, +) +from llama_index.core.workflow import ( + Context, + FunctionToolWithContext, + StartEvent, + StopEvent, + Workflow, + step, +) +from llama_index.core.settings import Settings + + +DEFAULT_HANDOFF_PROMPT = """Useful for handing off to another agent. +If you are currently not equipped to handle the user's request, please hand off to the appropriate agent. + +Currently available agents: +{agent_info} +""" + + +async def handoff(to_agent: str, reason: str) -> HandoffEvent: + """Handoff to the given agent.""" + return f"Handed off to {to_agent} because: {reason}" + + +class MultiAgentWorkflow(Workflow): + """A workflow for managing multiple agents with handoffs.""" + + def __init__( + self, + agent_configs: List[AgentConfig], + initial_state: Optional[Dict] = None, + memory: Optional[BaseMemory] = None, + handoff_prompt: Optional[Union[str, BasePromptTemplate]] = None, + state_prompt: Optional[Union[str, BasePromptTemplate]] = None, + timeout: Optional[float] = None, + **workflow_kwargs: Any, + ): + super().__init__(timeout=timeout, **workflow_kwargs) + if not agent_configs: + raise ValueError("At least one agent config must be provided") + + self.agent_configs = {cfg.name: cfg for cfg in agent_configs} + only_one_root_agent = sum(cfg.is_root_agent for cfg in agent_configs) == 1 + if not only_one_root_agent: + raise ValueError("Exactly one root agent must be provided") + + self.root_agent = next(cfg.name for cfg in agent_configs if cfg.is_root_agent) + + self.initial_state = initial_state or {} + self.memory = memory or ChatMemoryBuffer.from_defaults( + llm=agent_configs[0].llm or Settings.llm + ) + + self.handoff_prompt = handoff_prompt or DEFAULT_HANDOFF_PROMPT + if isinstance(self.handoff_prompt, str): + self.handoff_prompt = PromptTemplate(self.handoff_prompt) + if "{agent_info}" not in self.handoff_prompt.template: + raise ValueError("Handoff prompt must contain {agent_info}") + + self.state_prompt = state_prompt + if isinstance(self.state_prompt, str): + self.state_prompt = PromptTemplate(self.state_prompt) + if ( + "{state}" not in self.state_prompt.template + or "{msg}" not in self.state_prompt.template + ): + raise ValueError("State prompt must contain {state} and {msg}") + + def _ensure_tools_are_async(self, tools: List[BaseTool]) -> List[AsyncBaseTool]: + """Ensure all tools are async.""" + return [adapt_to_async_tool(tool) for tool in tools] + + def _get_handoff_tool(self, current_agent_config: AgentConfig) -> AsyncBaseTool: + """Creates a handoff tool for the given agent.""" + agent_info = {cfg.name: cfg.description for cfg in self.agent_configs.values()} + + # Filter out agents that the current agent cannot handoff to + configs_to_remove = [] + for name in agent_info: + if name == current_agent_config.name: + configs_to_remove.append(name) + elif ( + current_agent_config.can_handoff_to is not None + and name not in current_agent_config.can_handoff_to + ): + configs_to_remove.append(name) + + for name in configs_to_remove: + agent_info.pop(name) + + fn_tool_prompt = self.handoff_prompt.format(agent_info=str(agent_info)) + return FunctionTool.from_defaults(async_fn=handoff, description=fn_tool_prompt) + + async def _init_context(self, ctx: Context) -> None: + """Initialize the context once, if needed.""" + if not await ctx.get("memory", default=None): + await ctx.set("memory", self.memory) + if not await ctx.get("agent_configs", default=None): + await ctx.set("agent_configs", self.agent_configs) + if not await ctx.get("current_state", default=None): + await ctx.set("current_state", self.initial_state) + if not await ctx.get("current_agent", default=None): + await ctx.set("current_agent", self.root_agent) + + async def _call_tool( + self, ctx: Context, tool: AsyncBaseTool, tool_input: dict + ) -> ToolOutput: + """Call the given tool with the given input.""" + try: + if isinstance(tool, FunctionToolWithContext): + tool_output = await tool.acall(ctx=ctx, **tool_input) + else: + tool_output = await tool.acall(**tool_input) + except Exception as e: + tool_output = ToolOutput( + content=str(e), + tool_name=tool.metadata.name, + raw_input=tool_input, + raw_output=str(e), + is_error=True, + ) + + ctx.write_event_to_stream( + ToolCall( + tool_name=tool.metadata.name, + tool_kwargs=tool_input, + tool_output=tool_output.content, + ) + ) + + return tool_output + + async def _call_function_calling_agent( + self, + ctx: Context, + llm: FunctionCallingLLM, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + ) -> AgentOutput: + """Call the LLM as a function calling agent.""" + memory: BaseMemory = await ctx.get("memory") + tools_by_name = {tool.metadata.name: tool for tool in tools} + + current_llm_input = [*llm_input] + response = await llm.astream_chat_with_tools( + tools, chat_history=current_llm_input, allow_parallel_tool_calls=True + ) + async for r in response: + tool_calls = llm.get_tool_calls_from_response( + r, error_on_no_tool_call=False + ) + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + tool_calls=tool_calls or [], + raw_response=r.raw, + ) + ) + + current_llm_input.append(r.message) + await memory.aput(r.message) + tool_calls = llm.get_tool_calls_from_response(r, error_on_no_tool_call=False) + + all_tool_results = [] + while tool_calls: + tool_results: List[ToolOutput] = [] + tool_ids: List[str] = [] + jobs = [] + should_return_direct = False + for tool_call in tool_calls: + tool_ids.append(tool_call.tool_call_id) + if tool_call.tool_name not in tools_by_name: + tool_results.append( + ToolOutput( + content=f"Tool {tool_call.tool_name} not found. Please select a tool that is available.", + tool_name=tool_call.tool_name, + raw_input=tool_call.tool_kwargs, + raw_output=None, + is_error=True, + ) + ) + else: + tool = tools_by_name[tool_call.tool_name] + if tool.metadata.return_direct: + should_return_direct = True + + job = self._call_tool(ctx, tool, tool_call.tool_kwargs) + jobs.append(job) + + tool_results.extend(await asyncio.gather(*jobs)) + all_tool_results.extend(tool_results) + tool_messages = [ + ChatMessage( + role="tool", + content=str(result), + additional_kwargs={"tool_call_id": tool_id}, + ) + for result, tool_id in zip(tool_results, tool_ids) + ] + + for tool_message in tool_messages: + await memory.aput(tool_message) + + if should_return_direct: + return AgentOutput( + response=tool_results[0].content, + tool_outputs=all_tool_results, + raw_response=None, + ) + + current_llm_input.extend(tool_messages) + response = await llm.astream_chat_with_tools( + tools, chat_history=current_llm_input, allow_parallel_tool_calls=True + ) + async for r in response: + tool_calls = llm.get_tool_calls_from_response( + r, error_on_no_tool_call=False + ) + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + tool_calls=tool_calls or [], + raw_response=r.raw, + ) + ) + + current_llm_input.append(r.message) + await memory.aput(r.message) + tool_calls = llm.get_tool_calls_from_response( + r, error_on_no_tool_call=False + ) + + return AgentOutput( + response=r.message.content, + tool_outputs=all_tool_results, + raw_response=r.raw, + ) + + async def _call_react_agent( + self, + ctx: Context, + llm: LLM, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + ) -> AgentOutput: + """Call the LLM as a react agent.""" + memory: BaseMemory = await ctx.get("memory") + + # remove system prompt, since the react prompt will be combined with it + if llm_input[0].role == "system": + system_prompt = llm_input[0].content or "" + llm_input = llm_input[1:] + else: + system_prompt = "" + + output_parser = ReActOutputParser() + react_chat_formatter = ReActChatFormatter(context=system_prompt) + + # Format initial chat input + current_reasoning = [] + input_chat = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + + # Initial LLM call + response = await llm.astream_chat(input_chat) + async for r in response: + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + tool_calls=[], + raw_response=r.raw, + ) + ) + + await memory.aput(r.message) + + # Parse reasoning step and check if done + message_content = r.message.content + if not message_content: + raise ValueError("Got empty message") + + try: + reasoning_step = output_parser.parse(message_content, is_streaming=False) + except ValueError as e: + # If we can't parse the output, return an error message + error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" + return AgentOutput( + response=error_msg, + tool_outputs=[], + raw_response=r.raw, + ) + + # If response step, we're done + all_tool_outputs = [] + if reasoning_step.is_done: + current_reasoning.append(reasoning_step) + + latest_react_messages = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + for msg in latest_react_messages: + await memory.aput(msg) + + response = ( + reasoning_step.response + if hasattr(reasoning_step, "response") + else reasoning_step.get_content() + ) + return AgentOutput( + response=response, + tool_outputs=all_tool_outputs, + raw_response=r.raw, + ) + + # Otherwise process action step + while True: + current_reasoning.append(reasoning_step) + + reasoning_step = cast(ActionReasoningStep, reasoning_step) + if not isinstance(reasoning_step, ActionReasoningStep): + raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") + + # Call tool + tools_by_name = {tool.metadata.name: tool for tool in tools} + tool = None + if reasoning_step.action not in tools_by_name: + tool_output = ToolOutput( + content=f"Error: No such tool named `{reasoning_step.action}`.", + tool_name=reasoning_step.action, + raw_input={"kwargs": reasoning_step.action_input}, + raw_output=None, + is_error=True, + ) + else: + tool = tools_by_name[reasoning_step.action] + tool_output = await self._call_tool( + ctx, tool, reasoning_step.action_input + ) + all_tool_outputs.append(tool_output) + + # Add observation to chat history + current_reasoning.append( + ObservationReasoningStep( + observation=str(tool_output), + return_direct=tool.metadata.return_direct, + ) + ) + + if tool and tool.metadata.return_direct: + current_reasoning.append( + ResponseReasoningStep(response=tool_output.content) + ) + latest_react_messages = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + for msg in latest_react_messages: + await memory.aput(msg) + + return AgentOutput( + response=tool_output.content, + tool_outputs=all_tool_outputs, + raw_response=r.raw, + ) + + # Get next action from LLM + input_chat = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + response = await llm.astream_chat(input_chat) + async for r in response: + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + tool_calls=[], + raw_response=r.raw, + ) + ) + + await memory.aput(r.message) + + # Parse next reasoning step + message_content = r.message.content + if not message_content: + raise ValueError("Got empty message") + + try: + reasoning_step = output_parser.parse( + message_content, is_streaming=False + ) + except ValueError as e: + # If we can't parse the output, return an error message + error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" + + current_reasoning.append(ResponseReasoningStep(response=error_msg)) + + latest_react_messages = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + for msg in latest_react_messages: + await memory.aput(msg) + + return AgentOutput( + response=error_msg, + tool_outputs=all_tool_outputs, + raw_response=r.raw, + ) + + # If response step, we're done + if reasoning_step.is_done: + latest_react_messages = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + for msg in latest_react_messages: + await memory.aput(msg) + + return AgentOutput( + response=reasoning_step.response, + tool_outputs=all_tool_outputs, + raw_response=r.raw, + ) + + async def _call_llm( + self, + ctx: Context, + llm: LLM, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + mode: AgentMode, + ) -> AgentOutput: + """Call the LLM with the given input and tools.""" + if mode == AgentMode.DEFAULT: + if llm.metadata.is_function_calling_model: + return await self._call_function_calling_agent( + ctx, llm, llm_input, tools + ) + else: + return await self._call_react_agent(ctx, llm, llm_input, tools) + elif mode == AgentMode.REACT: + return await self._call_react_agent(ctx, llm, llm_input, tools) + elif mode == AgentMode.FUNCTION_CALLING: + return await self._call_function_calling_agent(ctx, llm, llm_input, tools) + else: + raise ValueError(f"Invalid agent mode: {mode}") + + @step + async def init_run(self, ctx: Context, ev: StartEvent | AgentOutput) -> AgentInput: + """Sets up the workflow and validates inputs.""" + if isinstance(ev, StartEvent): + await self._init_context(ctx) + + user_msg = ev.get("user_msg") + chat_history = ev.get("chat_history") + if user_msg and chat_history: + raise ValueError("Cannot provide both user_msg and chat_history") + + if isinstance(user_msg, str): + user_msg = ChatMessage(role="user", content=user_msg) + + await ctx.set("user_msg_str", user_msg.content) + + # Add messages to memory + memory: BaseMemory = await ctx.get("memory") + if user_msg: + await memory.aput(user_msg) + input_messages = memory.get(input=user_msg.content) + + # Add the state to the user message if it exists and if requested + current_state = await ctx.get("current_state") + if self.state_prompt and current_state: + user_msg.content = self.state_prompt.format( + state=current_state, msg=user_msg.content + ) + + await memory.aput(user_msg) + else: + memory.set(chat_history) + input_messages = memory.get() + else: + user_msg_str = await ctx.get("user_msg_str") + memory: BaseMemory = await ctx.get("memory") + input_messages = memory.get(input=user_msg_str) + + # send to the current agent + current_agent = await ctx.get("current_agent") + return AgentInput(input=input_messages, current_agent=current_agent) + + @step + async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: + """Main agent handling logic.""" + agent_config: AgentConfig = (await ctx.get("agent_configs"))[ev.current_agent] + llm_input = ev.input + + # Setup the tools + tools = list(agent_config.tools or []) + if agent_config.tool_retriever: + retrieved_tools = await agent_config.tool_retriever.aretrieve( + llm_input[-1].content or str(llm_input) + ) + tools.extend(retrieved_tools) + + if agent_config.can_handoff_to: + handoff_tool = self._get_handoff_tool(agent_config) + tools.append(handoff_tool) + + tools = self._ensure_tools_are_async(tools) + + ctx.write_event_to_stream( + AgentInput(input=llm_input, current_agent=ev.current_agent) + ) + + if agent_config.system_prompt: + llm_input = [ + ChatMessage(role="system", content=agent_config.system_prompt), + *llm_input, + ] + + return AgentSetup(input=llm_input, current_agent=ev.current_agent, tools=tools) + + @step + async def run_agent(self, ctx: Context, ev: AgentSetup) -> AgentOutput | StopEvent: + """Run the agent.""" + current_agent = ev.current_agent + agent_config: AgentConfig = (await ctx.get("agent_configs"))[current_agent] + llm = agent_config.llm or Settings.llm + + agent_output: AgentOutput = await self._call_llm( + ctx, llm, ev.input, ev.tools, agent_config.mode + ) + ctx.write_event_to_stream(agent_output) + if agent_output.tool_outputs: + ctx.write_event_to_stream(agent_output) + return agent_output + else: + return StopEvent(result=agent_output.response) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py new file mode 100644 index 0000000000000..cec908fe65d71 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py @@ -0,0 +1,70 @@ +from typing import Any, Optional + +from llama_index.core.tools import AsyncBaseTool, ToolSelection, ToolOutput +from llama_index.core.llms import ChatMessage +from llama_index.core.workflow import Event + + +class ToolApprovalNeeded(Event): + """Emitted when a tool call needs approval.""" + + id: str + tool_name: str + tool_kwargs: dict + + +class ApproveTool(Event): + """Required to approve a tool.""" + + id: str + tool_name: str + tool_kwargs: dict + approved: bool + reason: Optional[str] = None + + +class AgentInput(Event): + """LLM input.""" + + input: list[ChatMessage] + current_agent: str + + +class AgentSetup(Event): + """Agent setup.""" + + input: list[ChatMessage] + current_agent: str + tools: list[AsyncBaseTool] + + +class AgentStream(Event): + """Agent stream.""" + + delta: str + tool_calls: list[ToolSelection] + raw_response: Any + + +class AgentOutput(Event): + """LLM output.""" + + response: str + tool_outputs: list[ToolOutput] + raw_response: Any + + +class ToolCall(Event): + """All tool calls are surfaced.""" + + tool_name: str + tool_kwargs: dict + tool_output: Any + + +class HandoffEvent(Event): + """Internal event for agent handoffs.""" + + from_agent: str + to_agent: str + reason: str diff --git a/llama-index-core/llama_index/core/workflow/__init__.py b/llama-index-core/llama_index/core/workflow/__init__.py index 54dfd2660540a..ae1c7cb1c2319 100644 --- a/llama-index-core/llama_index/core/workflow/__init__.py +++ b/llama-index-core/llama_index/core/workflow/__init__.py @@ -16,6 +16,7 @@ InputRequiredEvent, HumanResponseEvent, ) +from llama_index.core.workflow.function_context_tool import FunctionToolWithContext from llama_index.core.workflow.workflow import Workflow from llama_index.core.workflow.context import Context from llama_index.core.workflow.context_serializers import ( @@ -46,4 +47,5 @@ "JsonSerializer", "WorkflowCheckpointer", "Checkpoint", + "FunctionToolWithContext", ] diff --git a/llama-index-core/llama_index/core/workflow/function_context_tool.py b/llama-index-core/llama_index/core/workflow/function_context_tool.py new file mode 100644 index 0000000000000..df4907338dadf --- /dev/null +++ b/llama-index-core/llama_index/core/workflow/function_context_tool.py @@ -0,0 +1,131 @@ +from inspect import signature +from typing import Any, Awaitable, Optional, Callable, Type, List, Tuple, Union, cast + +from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model +from llama_index.core.tools import ( + FunctionTool, + ToolOutput, + ToolMetadata, +) +from llama_index.core.workflow import ( + Context, +) + +AsyncCallable = Callable[..., Awaitable[Any]] + + +def create_schema_from_function( + name: str, + func: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], + additional_fields: Optional[ + List[Union[Tuple[str, Type, Any], Tuple[str, Type]]] + ] = None, +) -> Type[BaseModel]: + """Create schema from function.""" + fields = {} + params = signature(func).parameters + for param_name in params: + # TODO: Very hacky way to remove the ctx parameter from the signature + if param_name == "ctx": + continue + + param_type = params[param_name].annotation + param_default = params[param_name].default + + if param_type is params[param_name].empty: + param_type = Any + + if param_default is params[param_name].empty: + # Required field + fields[param_name] = (param_type, FieldInfo()) + elif isinstance(param_default, FieldInfo): + # Field with pydantic.Field as default value + fields[param_name] = (param_type, param_default) + else: + fields[param_name] = (param_type, FieldInfo(default=param_default)) + + additional_fields = additional_fields or [] + for field_info in additional_fields: + if len(field_info) == 3: + field_info = cast(Tuple[str, Type, Any], field_info) + field_name, field_type, field_default = field_info + fields[field_name] = (field_type, FieldInfo(default=field_default)) + elif len(field_info) == 2: + # Required field has no default value + field_info = cast(Tuple[str, Type], field_info) + field_name, field_type = field_info + fields[field_name] = (field_type, FieldInfo()) + else: + raise ValueError( + f"Invalid additional field info: {field_info}. " + "Must be a tuple of length 2 or 3." + ) + + return create_model(name, **fields) # type: ignore + + +class FunctionToolWithContext(FunctionTool): + """ + A function tool that also includes passing in workflow context. + + Only overrides the call methods to include the context. + """ + + @classmethod + def from_defaults( + cls, + fn: Optional[Callable[..., Any]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + fn_schema: Optional[Type[BaseModel]] = None, + async_fn: Optional[AsyncCallable] = None, + tool_metadata: Optional[ToolMetadata] = None, + ) -> "FunctionToolWithContext": + if tool_metadata is None: + fn_to_parse = fn or async_fn + assert fn_to_parse is not None, "fn or async_fn must be provided." + name = name or fn_to_parse.__name__ + docstring = fn_to_parse.__doc__ + + # TODO: Very hacky way to remove the ctx parameter from the signature + signature_str = str(signature(fn_to_parse)) + signature_str = signature_str.replace( + "ctx: llama_index.core.workflow.context.Context, ", "" + ) + signature_str = signature_str.replace( + "ctx: llama_index.core.workflow.context.Context", "" + ) + + description = description or f"{name}{signature_str}\n{docstring}" + if fn_schema is None: + fn_schema = create_schema_from_function( + f"{name}", fn_to_parse, additional_fields=None + ) + tool_metadata = ToolMetadata( + name=name, + description=description, + fn_schema=fn_schema, + return_direct=return_direct, + ) + return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) + + def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: + """Call.""" + tool_output = self._fn(ctx, *args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) + + async def acall(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: + """Call.""" + tool_output = await self._async_fn(ctx, *args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) From 3f2574f8845b33cbf92cf96239b38fd804954bdb Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Wed, 18 Dec 2024 16:24:16 -0600 Subject: [PATCH 02/22] e2e --- .../agent/multi_agent/multi_agent_workflow.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index 9ffe3a04474a8..7ef85f12e73e1 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -24,7 +24,6 @@ from llama_index.core.tools import ( BaseTool, AsyncBaseTool, - FunctionTool, ToolOutput, adapt_to_async_tool, ) @@ -40,15 +39,22 @@ DEFAULT_HANDOFF_PROMPT = """Useful for handing off to another agent. -If you are currently not equipped to handle the user's request, please hand off to the appropriate agent. +If you are currently not equipped to handle the user's request, or another agent is better suited to handle the request, please hand off to the appropriate agent. Currently available agents: {agent_info} """ -async def handoff(to_agent: str, reason: str) -> HandoffEvent: - """Handoff to the given agent.""" +async def handoff(ctx: Context, to_agent: str, reason: str) -> HandoffEvent: + """Handoff control of that chat to the given agent.""" + agent_configs = await ctx.get("agent_configs") + current_agent = await ctx.get("current_agent") + if to_agent not in agent_configs: + valid_agents = ", ".join([x for x in agent_configs if x != current_agent]) + return f"Agent {to_agent} not found. Please select a valid agent to hand off to. Valid agents: {valid_agents}" + + await ctx.set("current_agent", to_agent) return f"Handed off to {to_agent} because: {reason}" @@ -119,7 +125,9 @@ def _get_handoff_tool(self, current_agent_config: AgentConfig) -> AsyncBaseTool: agent_info.pop(name) fn_tool_prompt = self.handoff_prompt.format(agent_info=str(agent_info)) - return FunctionTool.from_defaults(async_fn=handoff, description=fn_tool_prompt) + return FunctionToolWithContext.from_defaults( + async_fn=handoff, description=fn_tool_prompt, return_direct=True + ) async def _init_context(self, ctx: Context) -> None: """Initialize the context once, if needed.""" @@ -198,7 +206,7 @@ async def _call_function_calling_agent( jobs = [] should_return_direct = False for tool_call in tool_calls: - tool_ids.append(tool_call.tool_call_id) + tool_ids.append(tool_call.tool_id) if tool_call.tool_name not in tools_by_name: tool_results.append( ToolOutput( @@ -480,7 +488,7 @@ async def _call_llm( return await self._call_react_agent(ctx, llm, llm_input, tools) elif mode == AgentMode.REACT: return await self._call_react_agent(ctx, llm, llm_input, tools) - elif mode == AgentMode.FUNCTION_CALLING: + elif mode == AgentMode.FUNCTION: return await self._call_function_calling_agent(ctx, llm, llm_input, tools) else: raise ValueError(f"Invalid agent mode: {mode}") From c992dd93ab190626311be2967bf281ea85c795ab Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 19 Dec 2024 08:27:35 -0600 Subject: [PATCH 03/22] move file --- .../llama_index/core/workflow/tools.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 llama-index-core/llama_index/core/workflow/tools.py diff --git a/llama-index-core/llama_index/core/workflow/tools.py b/llama-index-core/llama_index/core/workflow/tools.py new file mode 100644 index 0000000000000..df4907338dadf --- /dev/null +++ b/llama-index-core/llama_index/core/workflow/tools.py @@ -0,0 +1,131 @@ +from inspect import signature +from typing import Any, Awaitable, Optional, Callable, Type, List, Tuple, Union, cast + +from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model +from llama_index.core.tools import ( + FunctionTool, + ToolOutput, + ToolMetadata, +) +from llama_index.core.workflow import ( + Context, +) + +AsyncCallable = Callable[..., Awaitable[Any]] + + +def create_schema_from_function( + name: str, + func: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], + additional_fields: Optional[ + List[Union[Tuple[str, Type, Any], Tuple[str, Type]]] + ] = None, +) -> Type[BaseModel]: + """Create schema from function.""" + fields = {} + params = signature(func).parameters + for param_name in params: + # TODO: Very hacky way to remove the ctx parameter from the signature + if param_name == "ctx": + continue + + param_type = params[param_name].annotation + param_default = params[param_name].default + + if param_type is params[param_name].empty: + param_type = Any + + if param_default is params[param_name].empty: + # Required field + fields[param_name] = (param_type, FieldInfo()) + elif isinstance(param_default, FieldInfo): + # Field with pydantic.Field as default value + fields[param_name] = (param_type, param_default) + else: + fields[param_name] = (param_type, FieldInfo(default=param_default)) + + additional_fields = additional_fields or [] + for field_info in additional_fields: + if len(field_info) == 3: + field_info = cast(Tuple[str, Type, Any], field_info) + field_name, field_type, field_default = field_info + fields[field_name] = (field_type, FieldInfo(default=field_default)) + elif len(field_info) == 2: + # Required field has no default value + field_info = cast(Tuple[str, Type], field_info) + field_name, field_type = field_info + fields[field_name] = (field_type, FieldInfo()) + else: + raise ValueError( + f"Invalid additional field info: {field_info}. " + "Must be a tuple of length 2 or 3." + ) + + return create_model(name, **fields) # type: ignore + + +class FunctionToolWithContext(FunctionTool): + """ + A function tool that also includes passing in workflow context. + + Only overrides the call methods to include the context. + """ + + @classmethod + def from_defaults( + cls, + fn: Optional[Callable[..., Any]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + fn_schema: Optional[Type[BaseModel]] = None, + async_fn: Optional[AsyncCallable] = None, + tool_metadata: Optional[ToolMetadata] = None, + ) -> "FunctionToolWithContext": + if tool_metadata is None: + fn_to_parse = fn or async_fn + assert fn_to_parse is not None, "fn or async_fn must be provided." + name = name or fn_to_parse.__name__ + docstring = fn_to_parse.__doc__ + + # TODO: Very hacky way to remove the ctx parameter from the signature + signature_str = str(signature(fn_to_parse)) + signature_str = signature_str.replace( + "ctx: llama_index.core.workflow.context.Context, ", "" + ) + signature_str = signature_str.replace( + "ctx: llama_index.core.workflow.context.Context", "" + ) + + description = description or f"{name}{signature_str}\n{docstring}" + if fn_schema is None: + fn_schema = create_schema_from_function( + f"{name}", fn_to_parse, additional_fields=None + ) + tool_metadata = ToolMetadata( + name=name, + description=description, + fn_schema=fn_schema, + return_direct=return_direct, + ) + return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) + + def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: + """Call.""" + tool_output = self._fn(ctx, *args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) + + async def acall(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: + """Call.""" + tool_output = await self._async_fn(ctx, *args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) From 969b64a3acd0af781199e1f7c6ded14485fa8d26 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Tue, 24 Dec 2024 12:47:56 -0600 Subject: [PATCH 04/22] remove human confirmation concept, add wait_for_event --- .../agent/multi_agent/multi_agent_workflow.py | 13 ++++++-- .../llama_index/core/workflow/context.py | 31 ++++++++++++++++++- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index 7ef85f12e73e1..4ba2f402d363f 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -1,4 +1,5 @@ import asyncio +import uuid from typing import Any, Dict, List, Optional, Union, cast from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode @@ -141,9 +142,13 @@ async def _init_context(self, ctx: Context) -> None: await ctx.set("current_agent", self.root_agent) async def _call_tool( - self, ctx: Context, tool: AsyncBaseTool, tool_input: dict + self, ctx: Context, tool: AsyncBaseTool, tool_input: dict, tool_id: str ) -> ToolOutput: """Call the given tool with the given input.""" + current_agent = await ctx.get("current_agent") + agent_config = self.agent_configs[current_agent] + + # Call the tool once approved try: if isinstance(tool, FunctionToolWithContext): tool_output = await tool.acall(ctx=ctx, **tool_input) @@ -222,7 +227,9 @@ async def _call_function_calling_agent( if tool.metadata.return_direct: should_return_direct = True - job = self._call_tool(ctx, tool, tool_call.tool_kwargs) + job = await self._call_tool( + ctx, tool, tool_call.tool_kwargs, tool_call.tool_id + ) jobs.append(job) tool_results.extend(await asyncio.gather(*jobs)) @@ -377,7 +384,7 @@ async def _call_react_agent( else: tool = tools_by_name[reasoning_step.action] tool_output = await self._call_tool( - ctx, tool, reasoning_step.action_input + ctx, tool, reasoning_step.action_input, tool_id=uuid.uuid4().hex[:8] ) all_tool_outputs.append(tool_output) diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index c829de6409fa4..edd004bcacd4f 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -1,8 +1,9 @@ import asyncio import json import warnings +import uuid from collections import defaultdict -from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple +from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple, TypeVar from .context_serializers import BaseSerializer, JsonSerializer from .decorators import StepConfig @@ -12,6 +13,8 @@ if TYPE_CHECKING: # pragma: no cover from .workflow import Workflow +T = TypeVar("T", bound=Event) + class Context: """A global object representing a context for a given workflow run. @@ -281,6 +284,32 @@ def send_event(self, message: Event, step: Optional[str] = None) -> None: self._broker_log.append(message) + async def wait_for_event( + self, event_type: Type[T], requirements: Optional[Dict[str, Any]] = None + ) -> T: + """Asynchronously wait for a specific event type to be received. + + Returns: + The event type that was requested. + """ + requirements = requirements or {} + waiter_id = uuid.uuid4() + self._queues[waiter_id] = asyncio.Queue() + + try: + while True: + event = await self._queues[waiter_id].get() + if isinstance(event, event_type): + if all( + event.get(k, default=None) == v for k, v in requirements.items() + ): + return event + else: + continue + finally: + # Ensure queue cleanup happens even if cancelled + del self._queues[waiter_id] + def write_event_to_stream(self, ev: Optional[Event]) -> None: self._streaming_queue.put_nowait(ev) From ae072b8449c6553786c32108fd0e98647a2b7a4d Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 26 Dec 2024 10:55:43 -0600 Subject: [PATCH 05/22] add tests --- .../core/agent/multi_agent/agent_config.py | 5 +- .../agent/multi_agent/multi_agent_workflow.py | 71 +++-- .../core/agent/multi_agent/workflow_events.py | 2 + .../llama_index/core/tools/types.py | 3 +- .../llama_index/core/workflow/__init__.py | 2 +- .../core/workflow/function_context_tool.py | 131 -------- .../tests/agent/multi/__init__.py | 0 .../tests/agent/multi/test_multi_agent.py | 282 ++++++++++++++++++ 8 files changed, 335 insertions(+), 161 deletions(-) delete mode 100644 llama-index-core/llama_index/core/workflow/function_context_tool.py create mode 100644 llama-index-core/tests/agent/multi/__init__.py create mode 100644 llama-index-core/tests/agent/multi/test_multi_agent.py diff --git a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py index 4c1fb311917f9..1c37d36e26dc2 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py @@ -25,11 +25,8 @@ class AgentConfig(BaseModel): system_prompt: Optional[str] = None tools: Optional[List[BaseTool]] = None tool_retriever: Optional[ObjectRetriever] = None - tools_requiring_human_confirmation: Optional[List[str]] = Field( - default_factory=list - ) can_handoff_to: Optional[List[str]] = Field(default=None) handoff_prompt_template: Optional[str] = None llm: Optional[LLM] = None - is_root_agent: bool = False + is_entrypoint_agent: bool = False mode: AgentMode = AgentMode.DEFAULT diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index 4ba2f402d363f..eb707de82ed79 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -1,5 +1,4 @@ import asyncio -import uuid from typing import Any, Dict, List, Optional, Union, cast from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode @@ -66,7 +65,6 @@ def __init__( self, agent_configs: List[AgentConfig], initial_state: Optional[Dict] = None, - memory: Optional[BaseMemory] = None, handoff_prompt: Optional[Union[str, BasePromptTemplate]] = None, state_prompt: Optional[Union[str, BasePromptTemplate]] = None, timeout: Optional[float] = None, @@ -77,16 +75,15 @@ def __init__( raise ValueError("At least one agent config must be provided") self.agent_configs = {cfg.name: cfg for cfg in agent_configs} - only_one_root_agent = sum(cfg.is_root_agent for cfg in agent_configs) == 1 + only_one_root_agent = sum(cfg.is_entrypoint_agent for cfg in agent_configs) == 1 if not only_one_root_agent: raise ValueError("Exactly one root agent must be provided") - self.root_agent = next(cfg.name for cfg in agent_configs if cfg.is_root_agent) + self.root_agent = next( + cfg.name for cfg in agent_configs if cfg.is_entrypoint_agent + ) self.initial_state = initial_state or {} - self.memory = memory or ChatMemoryBuffer.from_defaults( - llm=agent_configs[0].llm or Settings.llm - ) self.handoff_prompt = handoff_prompt or DEFAULT_HANDOFF_PROMPT if isinstance(self.handoff_prompt, str): @@ -130,10 +127,14 @@ def _get_handoff_tool(self, current_agent_config: AgentConfig) -> AsyncBaseTool: async_fn=handoff, description=fn_tool_prompt, return_direct=True ) - async def _init_context(self, ctx: Context) -> None: + async def _init_context(self, ctx: Context, ev: StartEvent) -> None: """Initialize the context once, if needed.""" if not await ctx.get("memory", default=None): - await ctx.set("memory", self.memory) + default_memory = ev.get("memory", default=None) + default_memory = default_memory or ChatMemoryBuffer.from_defaults( + llm=self.agent_configs[self.root_agent].llm or Settings.llm + ) + await ctx.set("memory", default_memory) if not await ctx.get("agent_configs", default=None): await ctx.set("agent_configs", self.agent_configs) if not await ctx.get("current_state", default=None): @@ -142,13 +143,12 @@ async def _init_context(self, ctx: Context) -> None: await ctx.set("current_agent", self.root_agent) async def _call_tool( - self, ctx: Context, tool: AsyncBaseTool, tool_input: dict, tool_id: str + self, + ctx: Context, + tool: AsyncBaseTool, + tool_input: dict, ) -> ToolOutput: """Call the given tool with the given input.""" - current_agent = await ctx.get("current_agent") - agent_config = self.agent_configs[current_agent] - - # Call the tool once approved try: if isinstance(tool, FunctionToolWithContext): tool_output = await tool.acall(ctx=ctx, **tool_input) @@ -182,6 +182,7 @@ async def _call_function_calling_agent( ) -> AgentOutput: """Call the LLM as a function calling agent.""" memory: BaseMemory = await ctx.get("memory") + current_agent = await ctx.get("current_agent") tools_by_name = {tool.metadata.name: tool for tool in tools} current_llm_input = [*llm_input] @@ -197,6 +198,7 @@ async def _call_function_calling_agent( delta=r.delta or "", tool_calls=tool_calls or [], raw_response=r.raw, + current_agent=current_agent, ) ) @@ -227,10 +229,13 @@ async def _call_function_calling_agent( if tool.metadata.return_direct: should_return_direct = True - job = await self._call_tool( - ctx, tool, tool_call.tool_kwargs, tool_call.tool_id + jobs.append( + self._call_tool( + ctx, + tool, + tool_call.tool_kwargs, + ) ) - jobs.append(job) tool_results.extend(await asyncio.gather(*jobs)) all_tool_results.extend(tool_results) @@ -251,6 +256,7 @@ async def _call_function_calling_agent( response=tool_results[0].content, tool_outputs=all_tool_results, raw_response=None, + current_agent=current_agent, ) current_llm_input.extend(tool_messages) @@ -266,6 +272,7 @@ async def _call_function_calling_agent( delta=r.delta or "", tool_calls=tool_calls or [], raw_response=r.raw, + current_agent=current_agent, ) ) @@ -279,6 +286,7 @@ async def _call_function_calling_agent( response=r.message.content, tool_outputs=all_tool_results, raw_response=r.raw, + current_agent=current_agent, ) async def _call_react_agent( @@ -290,6 +298,7 @@ async def _call_react_agent( ) -> AgentOutput: """Call the LLM as a react agent.""" memory: BaseMemory = await ctx.get("memory") + current_agent = await ctx.get("current_agent") # remove system prompt, since the react prompt will be combined with it if llm_input[0].role == "system": @@ -317,6 +326,7 @@ async def _call_react_agent( delta=r.delta or "", tool_calls=[], raw_response=r.raw, + current_agent=current_agent, ) ) @@ -336,6 +346,7 @@ async def _call_react_agent( response=error_msg, tool_outputs=[], raw_response=r.raw, + current_agent=current_agent, ) # If response step, we're done @@ -360,6 +371,7 @@ async def _call_react_agent( response=response, tool_outputs=all_tool_outputs, raw_response=r.raw, + current_agent=current_agent, ) # Otherwise process action step @@ -384,7 +396,9 @@ async def _call_react_agent( else: tool = tools_by_name[reasoning_step.action] tool_output = await self._call_tool( - ctx, tool, reasoning_step.action_input, tool_id=uuid.uuid4().hex[:8] + ctx, + tool, + reasoning_step.action_input, ) all_tool_outputs.append(tool_output) @@ -412,6 +426,7 @@ async def _call_react_agent( response=tool_output.content, tool_outputs=all_tool_outputs, raw_response=r.raw, + current_agent=current_agent, ) # Get next action from LLM @@ -426,6 +441,7 @@ async def _call_react_agent( AgentStream( delta=r.delta or "", tool_calls=[], + current_agent=current_agent, raw_response=r.raw, ) ) @@ -445,7 +461,9 @@ async def _call_react_agent( # If we can't parse the output, return an error message error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" - current_reasoning.append(ResponseReasoningStep(response=error_msg)) + current_reasoning.append( + ObservationReasoningStep(observation=error_msg) + ) latest_react_messages = react_chat_formatter.format( tools, @@ -459,6 +477,7 @@ async def _call_react_agent( response=error_msg, tool_outputs=all_tool_outputs, raw_response=r.raw, + current_agent=current_agent, ) # If response step, we're done @@ -475,6 +494,7 @@ async def _call_react_agent( response=reasoning_step.response, tool_outputs=all_tool_outputs, raw_response=r.raw, + current_agent=current_agent, ) async def _call_llm( @@ -504,7 +524,7 @@ async def _call_llm( async def init_run(self, ctx: Context, ev: StartEvent | AgentOutput) -> AgentInput: """Sets up the workflow and validates inputs.""" if isinstance(ev, StartEvent): - await self._init_context(ctx) + await self._init_context(ctx, ev) user_msg = ev.get("user_msg") chat_history = ev.get("chat_history") @@ -556,7 +576,7 @@ async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: ) tools.extend(retrieved_tools) - if agent_config.can_handoff_to: + if agent_config.can_handoff_to or agent_config.can_handoff_to is None: handoff_tool = self._get_handoff_tool(agent_config) tools.append(handoff_tool) @@ -585,8 +605,11 @@ async def run_agent(self, ctx: Context, ev: AgentSetup) -> AgentOutput | StopEve ctx, llm, ev.input, ev.tools, agent_config.mode ) ctx.write_event_to_stream(agent_output) - if agent_output.tool_outputs: - ctx.write_event_to_stream(agent_output) + + if any( + tool_output.tool_name == "handoff" + for tool_output in agent_output.tool_outputs + ): return agent_output else: - return StopEvent(result=agent_output.response) + return StopEvent(result=agent_output) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py index cec908fe65d71..29a0f0d6914e4 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py @@ -42,6 +42,7 @@ class AgentStream(Event): """Agent stream.""" delta: str + current_agent: str tool_calls: list[ToolSelection] raw_response: Any @@ -52,6 +53,7 @@ class AgentOutput(Event): response: str tool_outputs: list[ToolOutput] raw_response: Any + current_agent: str class ToolCall(Event): diff --git a/llama-index-core/llama_index/core/tools/types.py b/llama-index-core/llama_index/core/tools/types.py index a35a990af7708..941fa9cfa9f0e 100644 --- a/llama-index-core/llama_index/core/tools/types.py +++ b/llama-index-core/llama_index/core/tools/types.py @@ -1,3 +1,4 @@ +import asyncio import json from abc import abstractmethod from dataclasses import dataclass @@ -195,7 +196,7 @@ def call(self, input: Any) -> ToolOutput: return self.base_tool(input) async def acall(self, input: Any) -> ToolOutput: - return self.call(input) + return await asyncio.to_thread(self.call, input) def adapt_to_async_tool(tool: BaseTool) -> AsyncBaseTool: diff --git a/llama-index-core/llama_index/core/workflow/__init__.py b/llama-index-core/llama_index/core/workflow/__init__.py index ae1c7cb1c2319..f842cd20d7499 100644 --- a/llama-index-core/llama_index/core/workflow/__init__.py +++ b/llama-index-core/llama_index/core/workflow/__init__.py @@ -16,7 +16,7 @@ InputRequiredEvent, HumanResponseEvent, ) -from llama_index.core.workflow.function_context_tool import FunctionToolWithContext +from llama_index.core.workflow.tools import FunctionToolWithContext from llama_index.core.workflow.workflow import Workflow from llama_index.core.workflow.context import Context from llama_index.core.workflow.context_serializers import ( diff --git a/llama-index-core/llama_index/core/workflow/function_context_tool.py b/llama-index-core/llama_index/core/workflow/function_context_tool.py deleted file mode 100644 index df4907338dadf..0000000000000 --- a/llama-index-core/llama_index/core/workflow/function_context_tool.py +++ /dev/null @@ -1,131 +0,0 @@ -from inspect import signature -from typing import Any, Awaitable, Optional, Callable, Type, List, Tuple, Union, cast - -from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model -from llama_index.core.tools import ( - FunctionTool, - ToolOutput, - ToolMetadata, -) -from llama_index.core.workflow import ( - Context, -) - -AsyncCallable = Callable[..., Awaitable[Any]] - - -def create_schema_from_function( - name: str, - func: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], - additional_fields: Optional[ - List[Union[Tuple[str, Type, Any], Tuple[str, Type]]] - ] = None, -) -> Type[BaseModel]: - """Create schema from function.""" - fields = {} - params = signature(func).parameters - for param_name in params: - # TODO: Very hacky way to remove the ctx parameter from the signature - if param_name == "ctx": - continue - - param_type = params[param_name].annotation - param_default = params[param_name].default - - if param_type is params[param_name].empty: - param_type = Any - - if param_default is params[param_name].empty: - # Required field - fields[param_name] = (param_type, FieldInfo()) - elif isinstance(param_default, FieldInfo): - # Field with pydantic.Field as default value - fields[param_name] = (param_type, param_default) - else: - fields[param_name] = (param_type, FieldInfo(default=param_default)) - - additional_fields = additional_fields or [] - for field_info in additional_fields: - if len(field_info) == 3: - field_info = cast(Tuple[str, Type, Any], field_info) - field_name, field_type, field_default = field_info - fields[field_name] = (field_type, FieldInfo(default=field_default)) - elif len(field_info) == 2: - # Required field has no default value - field_info = cast(Tuple[str, Type], field_info) - field_name, field_type = field_info - fields[field_name] = (field_type, FieldInfo()) - else: - raise ValueError( - f"Invalid additional field info: {field_info}. " - "Must be a tuple of length 2 or 3." - ) - - return create_model(name, **fields) # type: ignore - - -class FunctionToolWithContext(FunctionTool): - """ - A function tool that also includes passing in workflow context. - - Only overrides the call methods to include the context. - """ - - @classmethod - def from_defaults( - cls, - fn: Optional[Callable[..., Any]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - return_direct: bool = False, - fn_schema: Optional[Type[BaseModel]] = None, - async_fn: Optional[AsyncCallable] = None, - tool_metadata: Optional[ToolMetadata] = None, - ) -> "FunctionToolWithContext": - if tool_metadata is None: - fn_to_parse = fn or async_fn - assert fn_to_parse is not None, "fn or async_fn must be provided." - name = name or fn_to_parse.__name__ - docstring = fn_to_parse.__doc__ - - # TODO: Very hacky way to remove the ctx parameter from the signature - signature_str = str(signature(fn_to_parse)) - signature_str = signature_str.replace( - "ctx: llama_index.core.workflow.context.Context, ", "" - ) - signature_str = signature_str.replace( - "ctx: llama_index.core.workflow.context.Context", "" - ) - - description = description or f"{name}{signature_str}\n{docstring}" - if fn_schema is None: - fn_schema = create_schema_from_function( - f"{name}", fn_to_parse, additional_fields=None - ) - tool_metadata = ToolMetadata( - name=name, - description=description, - fn_schema=fn_schema, - return_direct=return_direct, - ) - return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) - - def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: - """Call.""" - tool_output = self._fn(ctx, *args, **kwargs) - return ToolOutput( - content=str(tool_output), - tool_name=self.metadata.name, - raw_input={"args": args, "kwargs": kwargs}, - raw_output=tool_output, - ) - - async def acall(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: - """Call.""" - tool_output = await self._async_fn(ctx, *args, **kwargs) - return ToolOutput( - content=str(tool_output), - tool_name=self.metadata.name, - raw_input={"args": args, "kwargs": kwargs}, - raw_output=tool_output, - ) diff --git a/llama-index-core/tests/agent/multi/__init__.py b/llama-index-core/tests/agent/multi/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-core/tests/agent/multi/test_multi_agent.py b/llama-index-core/tests/agent/multi/test_multi_agent.py new file mode 100644 index 0000000000000..57433140d85e1 --- /dev/null +++ b/llama-index-core/tests/agent/multi/test_multi_agent.py @@ -0,0 +1,282 @@ +from typing import Any, List +import pytest + +from llama_index.core.llms import MockLLM +from llama_index.core.agent.multi_agent.multi_agent_workflow import MultiAgentWorkflow +from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode +from llama_index.core.llms import ( + ChatMessage, + ChatResponse, + MessageRole, + ChatResponseAsyncGen, + LLMMetadata, +) +from llama_index.core.tools import FunctionTool, ToolSelection +from llama_index.core.memory import ChatMemoryBuffer + + +class MockLLM(MockLLM): + def __init__(self, responses: List[ChatMessage], is_function_calling: bool = False): + super().__init__() + self._responses = responses + self._response_index = 0 + self._is_function_calling = is_function_calling + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata(is_function_calling_model=self._is_function_calling) + + async def astream_chat( + self, messages: List[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + response_msg = self._responses[self._response_index] + self._response_index = (self._response_index + 1) % len(self._responses) + + async def _gen(): + yield ChatResponse( + message=response_msg, + delta=response_msg.content, + raw={"content": response_msg.content}, + ) + + return _gen() + + async def astream_chat_with_tools( + self, tools: List[Any], chat_history: List[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + response_msg = self._responses[self._response_index] + self._response_index = (self._response_index + 1) % len(self._responses) + + async def _gen(): + yield ChatResponse( + message=response_msg, + delta=response_msg.content, + raw={"content": response_msg.content}, + ) + + return _gen() + + def get_tool_calls_from_response( + self, response: ChatResponse, **kwargs: Any + ) -> List[ToolSelection]: + return response.message.additional_kwargs.get("tool_calls", []) + + +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +def subtract(a: int, b: int) -> int: + """Subtract two numbers.""" + return a - b + + +@pytest.fixture() +def calculator_agent(): + return AgentConfig( + name="calculator", + description="Performs basic arithmetic operations", + system_prompt="You are a calculator assistant.", + mode=AgentMode.REACT, + tools=[ + FunctionTool.from_defaults(fn=add), + FunctionTool.from_defaults(fn=subtract), + ], + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, + content='Thought: I need to add these numbers\nAction: add\nAction Input: {"a": 5, "b": 3}\n', + ), + ChatMessage( + role=MessageRole.ASSISTANT, + content=r"Thought: The result is 8\Answer: The sum is 8", + ), + ] + ), + ) + + +@pytest.fixture() +def retriever_agent(): + return AgentConfig( + name="retriever", + description="Manages data retrieval", + system_prompt="You are a retrieval assistant.", + is_entrypoint_agent=True, + mode=AgentMode.FUNCTION, + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, + content="Let me help you with that calculation. I'll hand this off to the calculator.", + additional_kwargs={ + "tool_calls": [ + ToolSelection( + tool_id="one", + tool_name="handoff", + tool_kwargs={ + "to_agent": "calculator", + "reason": "This requires arithmetic operations.", + }, + ) + ] + }, + ), + ChatMessage( + role=MessageRole.ASSISTANT, + content="handoff calculator Because this requires arithmetic operations.", + additional_kwargs={ + "tool_calls": [ + ToolSelection( + tool_id="one", + tool_name="handoff", + tool_kwargs={ + "to_agent": "calculator", + "reason": "This requires arithmetic operations.", + }, + ) + ] + }, + ), + ], + is_function_calling=True, + ), + ) + + +@pytest.mark.asyncio() +async def test_basic_workflow(calculator_agent, retriever_agent): + """Test basic workflow initialization and validation.""" + workflow = MultiAgentWorkflow( + agent_configs=[calculator_agent, retriever_agent], + ) + + assert workflow.root_agent == "retriever" + assert len(workflow.agent_configs) == 2 + assert "calculator" in workflow.agent_configs + assert "retriever" in workflow.agent_configs + + +@pytest.mark.asyncio() +async def test_workflow_requires_root_agent(): + """Test that workflow requires exactly one root agent.""" + with pytest.raises(ValueError, match="Exactly one root agent must be provided"): + MultiAgentWorkflow( + agent_configs=[ + AgentConfig( + name="agent1", + description="test", + is_entrypoint_agent=True, + ), + AgentConfig( + name="agent2", + description="test", + is_entrypoint_agent=True, + ), + ] + ) + + +@pytest.mark.asyncio() +async def test_workflow_execution(calculator_agent, retriever_agent): + """Test basic workflow execution with agent handoff.""" + workflow = MultiAgentWorkflow( + agent_configs=[calculator_agent, retriever_agent], + ) + + memory = ChatMemoryBuffer.from_defaults() + handler = workflow.run(user_msg="Can you add 5 and 3?", memory=memory) + + events = [] + async for event in handler.stream_events(): + events.append(event) + + response = await handler + + # Verify we got events indicating handoff and calculation + assert any( + ev.current_agent == "retriever" if hasattr(ev, "current_agent") else False + for ev in events + ) + assert any( + ev.current_agent == "calculator" if hasattr(ev, "current_agent") else False + for ev in events + ) + assert "8" in response.response + + +@pytest.mark.asyncio() +async def test_invalid_handoff(): + """Test handling of invalid agent handoff.""" + agent1 = AgentConfig( + name="agent1", + description="test", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, + content="handoff invalid_agent Because reasons", + additional_kwargs={ + "tool_calls": [ + ToolSelection( + tool_id="one", + tool_name="handoff", + tool_kwargs={ + "to_agent": "invalid_agent", + "reason": "Because reasons", + }, + ) + ] + }, + ), + ChatMessage(role=MessageRole.ASSISTANT, content="guess im stuck here"), + ], + is_function_calling=True, + ), + ) + + workflow = MultiAgentWorkflow( + agent_configs=[agent1], + ) + + handler = workflow.run(user_msg="test") + events = [] + async for event in handler.stream_events(): + events.append(event) + + response = await handler + assert "Agent invalid_agent not found" in str(events) + + +@pytest.mark.asyncio() +async def test_workflow_with_state(): + """Test workflow with state management.""" + agent = AgentConfig( + name="agent", + description="test", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, content="Current state processed" + ) + ], + is_function_calling=True, + ), + ) + + workflow = MultiAgentWorkflow( + agent_configs=[agent], + initial_state={"counter": 0}, + state_prompt="Current state: {state}. User message: {msg}", + ) + + handler = workflow.run(user_msg="test") + async for _ in handler.stream_events(): + pass + + response = await handler + assert response is not None From be035d097f69cd0a9567a96726d0f764acf91ee9 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Fri, 27 Dec 2024 12:30:34 -0600 Subject: [PATCH 06/22] wip --- .../agent/multi_agent/multi_agent_workflow.py | 556 ++++++++++-------- .../core/agent/multi_agent/workflow_events.py | 6 +- .../llama_index/core/workflow/context.py | 26 + 3 files changed, 331 insertions(+), 257 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index eb707de82ed79..67de69becdbe3 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -1,4 +1,4 @@ -import asyncio +import uuid from typing import Any, Dict, List, Optional, Union, cast from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode @@ -14,11 +14,11 @@ from llama_index.core.agent.react.formatter import ReActChatFormatter from llama_index.core.agent.react.types import ( ActionReasoningStep, - ObservationReasoningStep, - ResponseReasoningStep, + BaseReasoningStep, ) from llama_index.core.llms import ChatMessage, LLM from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.llms.llm import ToolSelection from llama_index.core.memory import BaseMemory, ChatMemoryBuffer from llama_index.core.prompts import BasePromptTemplate, PromptTemplate from llama_index.core.tools import ( @@ -206,89 +206,96 @@ async def _call_function_calling_agent( await memory.aput(r.message) tool_calls = llm.get_tool_calls_from_response(r, error_on_no_tool_call=False) - all_tool_results = [] - while tool_calls: - tool_results: List[ToolOutput] = [] - tool_ids: List[str] = [] - jobs = [] - should_return_direct = False - for tool_call in tool_calls: - tool_ids.append(tool_call.tool_id) - if tool_call.tool_name not in tools_by_name: - tool_results.append( - ToolOutput( - content=f"Tool {tool_call.tool_name} not found. Please select a tool that is available.", - tool_name=tool_call.tool_name, - raw_input=tool_call.tool_kwargs, - raw_output=None, - is_error=True, - ) - ) - else: - tool = tools_by_name[tool_call.tool_name] - if tool.metadata.return_direct: - should_return_direct = True - - jobs.append( - self._call_tool( - ctx, - tool, - tool_call.tool_kwargs, - ) - ) - - tool_results.extend(await asyncio.gather(*jobs)) - all_tool_results.extend(tool_results) - tool_messages = [ - ChatMessage( - role="tool", - content=str(result), - additional_kwargs={"tool_call_id": tool_id}, - ) - for result, tool_id in zip(tool_results, tool_ids) - ] - - for tool_message in tool_messages: - await memory.aput(tool_message) - - if should_return_direct: - return AgentOutput( - response=tool_results[0].content, - tool_outputs=all_tool_results, - raw_response=None, - current_agent=current_agent, - ) - - current_llm_input.extend(tool_messages) - response = await llm.astream_chat_with_tools( - tools, chat_history=current_llm_input, allow_parallel_tool_calls=True - ) - async for r in response: - tool_calls = llm.get_tool_calls_from_response( - r, error_on_no_tool_call=False - ) - ctx.write_event_to_stream( - AgentStream( - delta=r.delta or "", - tool_calls=tool_calls or [], - raw_response=r.raw, - current_agent=current_agent, - ) - ) - - current_llm_input.append(r.message) - await memory.aput(r.message) - tool_calls = llm.get_tool_calls_from_response( - r, error_on_no_tool_call=False - ) - return AgentOutput( response=r.message.content, - tool_outputs=all_tool_results, + tool_calls=tool_calls or [], raw_response=r.raw, current_agent=current_agent, ) + # all_tool_results = [] + # while tool_calls: + # tool_results: List[ToolOutput] = [] + # tool_ids: List[str] = [] + # jobs = [] + # should_return_direct = False + # for tool_call in tool_calls: + # tool_ids.append(tool_call.tool_id) + # if tool_call.tool_name not in tools_by_name: + # tool_results.append( + # ToolOutput( + # content=f"Tool {tool_call.tool_name} not found. Please select a tool that is available.", + # tool_name=tool_call.tool_name, + # raw_input=tool_call.tool_kwargs, + # raw_output=None, + # is_error=True, + # ) + # ) + # else: + # tool = tools_by_name[tool_call.tool_name] + # if tool.metadata.return_direct: + # should_return_direct = True + + # jobs.append( + # self._call_tool( + # ctx, + # tool, + # tool_call.tool_kwargs, + # ) + # ) + + # tool_results.extend(await asyncio.gather(*jobs)) + # all_tool_results.extend(tool_results) + # tool_messages = [ + # ChatMessage( + # role="tool", + # content=str(result), + # additional_kwargs={"tool_call_id": tool_id}, + # ) + # for result, tool_id in zip(tool_results, tool_ids) + # ] + + # for tool_message in tool_messages: + # await memory.aput(tool_message) + + # if should_return_direct: + # return AgentOutput( + # response=tool_results[0].content, + # tool_outputs=all_tool_results, + # raw_response=None, + # current_agent=current_agent, + # ) + + # current_llm_input.extend(tool_messages) + # response = await llm.astream_chat_with_tools( + # tools, chat_history=current_llm_input, allow_parallel_tool_calls=True + # ) + # async for r in response: + # tool_calls = llm.get_tool_calls_from_response( + # r, error_on_no_tool_call=False + # ) + # ctx.write_event_to_stream( + # AgentStream( + # delta=r.delta or "", + # tool_calls=tool_calls or [], + # raw_response=r.raw, + # current_agent=current_agent, + # ) + # ) + + # current_llm_input.append(r.message) + # await memory.aput(r.message) + # tool_calls = llm.get_tool_calls_from_response( + # r, error_on_no_tool_call=False + # ) + + # return AgentOutput( + # response=r.message.content, + # tool_outputs=all_tool_results, + # raw_response=r.raw, + # current_agent=current_agent, + # ) + async def _call_react_agent( self, ctx: Context, @@ -311,7 +318,9 @@ async def _call_react_agent( react_chat_formatter = ReActChatFormatter(context=system_prompt) # Format initial chat input - current_reasoning = [] + current_reasoning: list[BaseReasoningStep] = await ctx.get( + "current_reasoning", default=[] + ) input_chat = react_chat_formatter.format( tools, chat_history=llm_input, @@ -330,8 +339,6 @@ async def _call_react_agent( ) ) - await memory.aput(r.message) - # Parse reasoning step and check if done message_content = r.message.content if not message_content: @@ -342,25 +349,25 @@ async def _call_react_agent( except ValueError as e: # If we can't parse the output, return an error message error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" + + await memory.aput(r.message.content) + await memory.aput(ChatMessage(role="user", content=error_msg)) + return AgentOutput( - response=error_msg, - tool_outputs=[], + response=r.message.content, + tool_calls=[], raw_response=r.raw, current_agent=current_agent, ) + current_reasoning.append(reasoning_step) + await ctx.set("current_reasoning", current_reasoning) + # If response step, we're done - all_tool_outputs = [] if reasoning_step.is_done: - current_reasoning.append(reasoning_step) - - latest_react_messages = react_chat_formatter.format( - tools, - chat_history=llm_input, - current_reasoning=current_reasoning, - ) - for msg in latest_react_messages: - await memory.aput(msg) + reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) + reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) + await memory.aput(reasoning_msg) response = ( reasoning_step.response @@ -369,133 +376,144 @@ async def _call_react_agent( ) return AgentOutput( response=response, - tool_outputs=all_tool_outputs, + tool_calls=[], raw_response=r.raw, current_agent=current_agent, ) - # Otherwise process action step - while True: - current_reasoning.append(reasoning_step) - - reasoning_step = cast(ActionReasoningStep, reasoning_step) - if not isinstance(reasoning_step, ActionReasoningStep): - raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") - - # Call tool - tools_by_name = {tool.metadata.name: tool for tool in tools} - tool = None - if reasoning_step.action not in tools_by_name: - tool_output = ToolOutput( - content=f"Error: No such tool named `{reasoning_step.action}`.", - tool_name=reasoning_step.action, - raw_input={"kwargs": reasoning_step.action_input}, - raw_output=None, - is_error=True, - ) - else: - tool = tools_by_name[reasoning_step.action] - tool_output = await self._call_tool( - ctx, - tool, - reasoning_step.action_input, - ) - all_tool_outputs.append(tool_output) + reasoning_step = cast(ActionReasoningStep, reasoning_step) + if not isinstance(reasoning_step, ActionReasoningStep): + raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") - # Add observation to chat history - current_reasoning.append( - ObservationReasoningStep( - observation=str(tool_output), - return_direct=tool.metadata.return_direct, - ) + # Call tool + tool_calls = [ + ToolSelection( + tool_id=str(uuid.uuid4()), + tool_name=reasoning_step.action, + tool_kwargs=reasoning_step.action_input, ) + ] - if tool and tool.metadata.return_direct: - current_reasoning.append( - ResponseReasoningStep(response=tool_output.content) - ) - latest_react_messages = react_chat_formatter.format( - tools, - chat_history=llm_input, - current_reasoning=current_reasoning, - ) - for msg in latest_react_messages: - await memory.aput(msg) - - return AgentOutput( - response=tool_output.content, - tool_outputs=all_tool_outputs, - raw_response=r.raw, - current_agent=current_agent, - ) - - # Get next action from LLM - input_chat = react_chat_formatter.format( - tools, - chat_history=llm_input, - current_reasoning=current_reasoning, - ) - response = await llm.astream_chat(input_chat) - async for r in response: - ctx.write_event_to_stream( - AgentStream( - delta=r.delta or "", - tool_calls=[], - current_agent=current_agent, - raw_response=r.raw, - ) - ) - - await memory.aput(r.message) - - # Parse next reasoning step - message_content = r.message.content - if not message_content: - raise ValueError("Got empty message") - - try: - reasoning_step = output_parser.parse( - message_content, is_streaming=False - ) - except ValueError as e: - # If we can't parse the output, return an error message - error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" - - current_reasoning.append( - ObservationReasoningStep(observation=error_msg) - ) - - latest_react_messages = react_chat_formatter.format( - tools, - chat_history=llm_input, - current_reasoning=current_reasoning, - ) - for msg in latest_react_messages: - await memory.aput(msg) - - return AgentOutput( - response=error_msg, - tool_outputs=all_tool_outputs, - raw_response=r.raw, - current_agent=current_agent, - ) - - # If response step, we're done - if reasoning_step.is_done: - latest_react_messages = react_chat_formatter.format( - tools, - chat_history=llm_input, - current_reasoning=current_reasoning, - ) - for msg in latest_react_messages: - await memory.aput(msg) + return AgentOutput( + response=r.message.content, + tool_calls=tool_calls, + raw_response=r.raw, + current_agent=current_agent, + ) - return AgentOutput( - response=reasoning_step.response, - tool_outputs=all_tool_outputs, - raw_response=r.raw, - current_agent=current_agent, - ) + # tools_by_name = {tool.metadata.name: tool for tool in tools} + # tool = None + # if reasoning_step.action not in tools_by_name: + # tool_output = ToolOutput( + # content=f"Error: No such tool named `{reasoning_step.action}`.", + # tool_name=reasoning_step.action, + # raw_input={"kwargs": reasoning_step.action_input}, + # raw_output=None, + # is_error=True, + # ) + # else: + # tool = tools_by_name[reasoning_step.action] + # tool_output = await self._call_tool( + # ctx, + # tool, + # reasoning_step.action_input, + # ) + # all_tool_outputs.append(tool_output) + + # # Add observation to chat history + # current_reasoning.append( + # ObservationReasoningStep( + # observation=str(tool_output), + # return_direct=tool.metadata.return_direct, + # ) + # ) + + # if tool and tool.metadata.return_direct: + # current_reasoning.append( + # ResponseReasoningStep(response=tool_output.content) + # ) + # latest_react_messages = react_chat_formatter.format( + # tools, + # chat_history=llm_input, + # current_reasoning=current_reasoning, + # ) + # for msg in latest_react_messages: + # await memory.aput(msg) + + # return AgentOutput( + # response=tool_output.content, + # tool_outputs=all_tool_outputs, + # raw_response=r.raw, + # current_agent=current_agent, + # ) + + # # Get next action from LLM + # input_chat = react_chat_formatter.format( + # tools, + # chat_history=llm_input, + # current_reasoning=current_reasoning, + # ) + # response = await llm.astream_chat(input_chat) + # async for r in response: + # ctx.write_event_to_stream( + # AgentStream( + # delta=r.delta or "", + # tool_calls=[], + # current_agent=current_agent, + # raw_response=r.raw, + # ) + # ) + + # await memory.aput(r.message) + + # # Parse next reasoning step + # message_content = r.message.content + # if not message_content: + # raise ValueError("Got empty message") + + # try: + # reasoning_step = output_parser.parse( + # message_content, is_streaming=False + # ) + # except ValueError as e: + # # If we can't parse the output, return an error message + # error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" + + # current_reasoning.append( + # ObservationReasoningStep(observation=error_msg) + # ) + + # latest_react_messages = react_chat_formatter.format( + # tools, + # chat_history=llm_input, + # current_reasoning=current_reasoning, + # ) + # for msg in latest_react_messages: + # await memory.aput(msg) + + # return AgentOutput( + # response=error_msg, + # tool_outputs=all_tool_outputs, + # raw_response=r.raw, + # current_agent=current_agent, + # ) + + # # If response step, we're done + # if reasoning_step.is_done: + # latest_react_messages = react_chat_formatter.format( + # tools, + # chat_history=llm_input, + # current_reasoning=current_reasoning, + # ) + # for msg in latest_react_messages: + # await memory.aput(msg) + + # return AgentOutput( + # response=reasoning_step.response, + # tool_outputs=all_tool_outputs, + # raw_response=r.raw, + # current_agent=current_agent, + # ) async def _call_llm( self, @@ -521,42 +539,37 @@ async def _call_llm( raise ValueError(f"Invalid agent mode: {mode}") @step - async def init_run(self, ctx: Context, ev: StartEvent | AgentOutput) -> AgentInput: + async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: """Sets up the workflow and validates inputs.""" - if isinstance(ev, StartEvent): - await self._init_context(ctx, ev) + await self._init_context(ctx, ev) - user_msg = ev.get("user_msg") - chat_history = ev.get("chat_history") - if user_msg and chat_history: - raise ValueError("Cannot provide both user_msg and chat_history") + user_msg = ev.get("user_msg") + chat_history = ev.get("chat_history") + if user_msg and chat_history: + raise ValueError("Cannot provide both user_msg and chat_history") - if isinstance(user_msg, str): - user_msg = ChatMessage(role="user", content=user_msg) + if isinstance(user_msg, str): + user_msg = ChatMessage(role="user", content=user_msg) - await ctx.set("user_msg_str", user_msg.content) + await ctx.set("user_msg_str", user_msg.content) - # Add messages to memory - memory: BaseMemory = await ctx.get("memory") - if user_msg: - await memory.aput(user_msg) - input_messages = memory.get(input=user_msg.content) - - # Add the state to the user message if it exists and if requested - current_state = await ctx.get("current_state") - if self.state_prompt and current_state: - user_msg.content = self.state_prompt.format( - state=current_state, msg=user_msg.content - ) + # Add messages to memory + memory: BaseMemory = await ctx.get("memory") + if user_msg: + await memory.aput(user_msg) + input_messages = memory.get(input=user_msg.content) + + # Add the state to the user message if it exists and if requested + current_state = await ctx.get("current_state") + if self.state_prompt and current_state: + user_msg.content = self.state_prompt.format( + state=current_state, msg=user_msg.content + ) - await memory.aput(user_msg) - else: - memory.set(chat_history) - input_messages = memory.get() + await memory.aput(user_msg) else: - user_msg_str = await ctx.get("user_msg_str") - memory: BaseMemory = await ctx.get("memory") - input_messages = memory.get(input=user_msg_str) + memory.set(chat_history) + input_messages = memory.get() # send to the current agent current_agent = await ctx.get("current_agent") @@ -592,24 +605,57 @@ async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: *llm_input, ] - return AgentSetup(input=llm_input, current_agent=ev.current_agent, tools=tools) + return AgentSetup( + input=llm_input, + current_agent=ev.current_agent, + current_config=agent_config, + tools=tools, + ) @step - async def run_agent(self, ctx: Context, ev: AgentSetup) -> AgentOutput | StopEvent: + async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: """Run the agent.""" - current_agent = ev.current_agent - agent_config: AgentConfig = (await ctx.get("agent_configs"))[current_agent] + agent_config = ev.current_config llm = agent_config.llm or Settings.llm - agent_output: AgentOutput = await self._call_llm( + agent_output = await self._call_llm( ctx, llm, ev.input, ev.tools, agent_config.mode ) ctx.write_event_to_stream(agent_output) - if any( - tool_output.tool_name == "handoff" - for tool_output in agent_output.tool_outputs - ): - return agent_output - else: - return StopEvent(result=agent_output) + return agent_output + + @step + async def parse_agent_output( + self, ctx: Context, ev: AgentOutput + ) -> StopEvent | ToolCall: + pass + + def _add_tool_call_result_to_memory(self, ctx: Context, ev: ToolCallResult) -> None: + """Either adds to memory or adds to the react reasoning list.""" + + @step + async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: + pass + + @step + async def aggregate_tool_results( + self, ctx: Context, ev: ToolCallResult + ) -> AgentInput | None: + """Aggregate tool results and return the next agent input.""" + num_tool_calls = await ctx.get("num_tool_calls", default=0) + if num_tool_calls == 0: + raise ValueError("No tool calls found, cannot aggregate results.") + + tool_call_results = ctx.collect_events( + ev, expected=[ToolCallResult] * num_tool_calls + ) + if not tool_call_results: + return None + + user_msg_str = await ctx.get("user_msg_str") + memory: BaseMemory = await ctx.get("memory") + input_messages = memory.get(input=user_msg_str) + current_agent = await ctx.get("current_agent") + + return AgentInput(input=input_messages, current_agent=current_agent) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py index 29a0f0d6914e4..b5a0a58fd2b8a 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py @@ -1,8 +1,9 @@ from typing import Any, Optional -from llama_index.core.tools import AsyncBaseTool, ToolSelection, ToolOutput +from llama_index.core.tools import AsyncBaseTool, ToolSelection from llama_index.core.llms import ChatMessage from llama_index.core.workflow import Event +from llama_index.core.agent.multi_agent.agent_config import AgentConfig class ToolApprovalNeeded(Event): @@ -35,6 +36,7 @@ class AgentSetup(Event): input: list[ChatMessage] current_agent: str + current_config: AgentConfig tools: list[AsyncBaseTool] @@ -51,7 +53,7 @@ class AgentOutput(Event): """LLM output.""" response: str - tool_outputs: list[ToolOutput] + tool_calls: list[ToolSelection] raw_response: Any current_agent: str diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index edd004bcacd4f..06cc66d369fa5 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -310,6 +310,32 @@ async def wait_for_event( # Ensure queue cleanup happens even if cancelled del self._queues[waiter_id] + async def wait_for_event( + self, event_type: Type[T], requirements: Optional[Dict[str, Any]] = None + ) -> T: + """Asynchronously wait for a specific event type to be received. + + Returns: + The event type that was requested. + """ + requirements = requirements or {} + waiter_id = uuid.uuid4() + self._queues[waiter_id] = asyncio.Queue() + + try: + while True: + event = await self._queues[waiter_id].get() + if isinstance(event, event_type): + if all( + event.get(k, default=None) == v for k, v in requirements.items() + ): + return event + else: + continue + finally: + # Ensure queue cleanup happens even if cancelled + del self._queues[waiter_id] + def write_event_to_stream(self, ev: Optional[Event]) -> None: self._streaming_queue.put_nowait(ev) From 6974a9d72b501e097afb58a1197fe3ef32f3d408 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sat, 28 Dec 2024 20:17:32 -0600 Subject: [PATCH 07/22] refactor --- .../core/agent/multi_agent/agent_config.py | 11 + .../agent/multi_agent/multi_agent_workflow.py | 423 ++++++++---------- .../core/agent/multi_agent/workflow_events.py | 11 +- 3 files changed, 208 insertions(+), 237 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py index 1c37d36e26dc2..4a9102b85f399 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py @@ -30,3 +30,14 @@ class AgentConfig(BaseModel): llm: Optional[LLM] = None is_entrypoint_agent: bool = False mode: AgentMode = AgentMode.DEFAULT + + def get_mode(self) -> AgentMode: + """Resolve the mode of the agent.""" + if self.mode == AgentMode.DEFAULT: + return ( + AgentMode.FUNCTION + if self.llm.metadata.is_function_calling_model + else AgentMode.REACT + ) + + return self.mode diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index 67de69becdbe3..7495a491b1aab 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -5,6 +5,7 @@ from llama_index.core.agent.multi_agent.workflow_events import ( HandoffEvent, ToolCall, + ToolCallResult, AgentInput, AgentSetup, AgentStream, @@ -15,6 +16,8 @@ from llama_index.core.agent.react.types import ( ActionReasoningStep, BaseReasoningStep, + ObservationReasoningStep, + ResponseReasoningStep, ) from llama_index.core.llms import ChatMessage, LLM from llama_index.core.llms.function_calling import FunctionCallingLLM @@ -163,15 +166,40 @@ async def _call_tool( is_error=True, ) - ctx.write_event_to_stream( - ToolCall( - tool_name=tool.metadata.name, - tool_kwargs=tool_input, - tool_output=tool_output.content, - ) + return tool_output + + async def _handle_react_tool_call( + self, ctx: Context, results: List[ToolCallResult] + ) -> None: + """Adds to the react reasoning list.""" + current_reasoning: list[BaseReasoningStep] = await ctx.get( + "current_reasoning", default=[] ) + for tool_call_result in results: + current_reasoning.append( + ObservationReasoningStep( + observation=str(tool_call_result.tool_output.content), + return_direct=tool_call_result.return_direct, + ) + ) - return tool_output + await ctx.set("current_reasoning", current_reasoning) + + async def _handle_function_tool_call( + self, ctx: Context, results: List[ToolCallResult] + ) -> None: + """Adds to memory.""" + memory: BaseMemory = await ctx.get("memory") + for tool_call_result in results: + await memory.aput( + ChatMessage( + role="tool", + content=str(tool_call_result.tool_output.content), + additional_kwargs={"tool_call_id": tool_call_result.tool_id}, + ) + ) + + await ctx.set("memory", memory) async def _call_function_calling_agent( self, @@ -183,7 +211,6 @@ async def _call_function_calling_agent( """Call the LLM as a function calling agent.""" memory: BaseMemory = await ctx.get("memory") current_agent = await ctx.get("current_agent") - tools_by_name = {tool.metadata.name: tool for tool in tools} current_llm_input = [*llm_input] response = await llm.astream_chat_with_tools( @@ -204,6 +231,8 @@ async def _call_function_calling_agent( current_llm_input.append(r.message) await memory.aput(r.message) + await ctx.set("memory", memory) + tool_calls = llm.get_tool_calls_from_response(r, error_on_no_tool_call=False) return AgentOutput( @@ -213,89 +242,6 @@ async def _call_function_calling_agent( current_agent=current_agent, ) - # all_tool_results = [] - # while tool_calls: - # tool_results: List[ToolOutput] = [] - # tool_ids: List[str] = [] - # jobs = [] - # should_return_direct = False - # for tool_call in tool_calls: - # tool_ids.append(tool_call.tool_id) - # if tool_call.tool_name not in tools_by_name: - # tool_results.append( - # ToolOutput( - # content=f"Tool {tool_call.tool_name} not found. Please select a tool that is available.", - # tool_name=tool_call.tool_name, - # raw_input=tool_call.tool_kwargs, - # raw_output=None, - # is_error=True, - # ) - # ) - # else: - # tool = tools_by_name[tool_call.tool_name] - # if tool.metadata.return_direct: - # should_return_direct = True - - # jobs.append( - # self._call_tool( - # ctx, - # tool, - # tool_call.tool_kwargs, - # ) - # ) - - # tool_results.extend(await asyncio.gather(*jobs)) - # all_tool_results.extend(tool_results) - # tool_messages = [ - # ChatMessage( - # role="tool", - # content=str(result), - # additional_kwargs={"tool_call_id": tool_id}, - # ) - # for result, tool_id in zip(tool_results, tool_ids) - # ] - - # for tool_message in tool_messages: - # await memory.aput(tool_message) - - # if should_return_direct: - # return AgentOutput( - # response=tool_results[0].content, - # tool_outputs=all_tool_results, - # raw_response=None, - # current_agent=current_agent, - # ) - - # current_llm_input.extend(tool_messages) - # response = await llm.astream_chat_with_tools( - # tools, chat_history=current_llm_input, allow_parallel_tool_calls=True - # ) - # async for r in response: - # tool_calls = llm.get_tool_calls_from_response( - # r, error_on_no_tool_call=False - # ) - # ctx.write_event_to_stream( - # AgentStream( - # delta=r.delta or "", - # tool_calls=tool_calls or [], - # raw_response=r.raw, - # current_agent=current_agent, - # ) - # ) - - # current_llm_input.append(r.message) - # await memory.aput(r.message) - # tool_calls = llm.get_tool_calls_from_response( - # r, error_on_no_tool_call=False - # ) - - # return AgentOutput( - # response=r.message.content, - # tool_outputs=all_tool_results, - # raw_response=r.raw, - # current_agent=current_agent, - # ) - async def _call_react_agent( self, ctx: Context, @@ -352,6 +298,7 @@ async def _call_react_agent( await memory.aput(r.message.content) await memory.aput(ChatMessage(role="user", content=error_msg)) + await ctx.set("memory", memory) return AgentOutput( response=r.message.content, @@ -365,17 +312,8 @@ async def _call_react_agent( # If response step, we're done if reasoning_step.is_done: - reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) - reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) - await memory.aput(reasoning_msg) - - response = ( - reasoning_step.response - if hasattr(reasoning_step, "response") - else reasoning_step.get_content() - ) return AgentOutput( - response=response, + response=r.message.content, tool_calls=[], raw_response=r.raw, current_agent=current_agent, @@ -385,7 +323,7 @@ async def _call_react_agent( if not isinstance(reasoning_step, ActionReasoningStep): raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") - # Call tool + # Create tool call tool_calls = [ ToolSelection( tool_id=str(uuid.uuid4()), @@ -401,142 +339,55 @@ async def _call_react_agent( current_agent=current_agent, ) - # tools_by_name = {tool.metadata.name: tool for tool in tools} - # tool = None - # if reasoning_step.action not in tools_by_name: - # tool_output = ToolOutput( - # content=f"Error: No such tool named `{reasoning_step.action}`.", - # tool_name=reasoning_step.action, - # raw_input={"kwargs": reasoning_step.action_input}, - # raw_output=None, - # is_error=True, - # ) - # else: - # tool = tools_by_name[reasoning_step.action] - # tool_output = await self._call_tool( - # ctx, - # tool, - # reasoning_step.action_input, - # ) - # all_tool_outputs.append(tool_output) - - # # Add observation to chat history - # current_reasoning.append( - # ObservationReasoningStep( - # observation=str(tool_output), - # return_direct=tool.metadata.return_direct, - # ) - # ) - - # if tool and tool.metadata.return_direct: - # current_reasoning.append( - # ResponseReasoningStep(response=tool_output.content) - # ) - # latest_react_messages = react_chat_formatter.format( - # tools, - # chat_history=llm_input, - # current_reasoning=current_reasoning, - # ) - # for msg in latest_react_messages: - # await memory.aput(msg) - - # return AgentOutput( - # response=tool_output.content, - # tool_outputs=all_tool_outputs, - # raw_response=r.raw, - # current_agent=current_agent, - # ) - - # # Get next action from LLM - # input_chat = react_chat_formatter.format( - # tools, - # chat_history=llm_input, - # current_reasoning=current_reasoning, - # ) - # response = await llm.astream_chat(input_chat) - # async for r in response: - # ctx.write_event_to_stream( - # AgentStream( - # delta=r.delta or "", - # tool_calls=[], - # current_agent=current_agent, - # raw_response=r.raw, - # ) - # ) - - # await memory.aput(r.message) - - # # Parse next reasoning step - # message_content = r.message.content - # if not message_content: - # raise ValueError("Got empty message") - - # try: - # reasoning_step = output_parser.parse( - # message_content, is_streaming=False - # ) - # except ValueError as e: - # # If we can't parse the output, return an error message - # error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" - - # current_reasoning.append( - # ObservationReasoningStep(observation=error_msg) - # ) - - # latest_react_messages = react_chat_formatter.format( - # tools, - # chat_history=llm_input, - # current_reasoning=current_reasoning, - # ) - # for msg in latest_react_messages: - # await memory.aput(msg) - - # return AgentOutput( - # response=error_msg, - # tool_outputs=all_tool_outputs, - # raw_response=r.raw, - # current_agent=current_agent, - # ) - - # # If response step, we're done - # if reasoning_step.is_done: - # latest_react_messages = react_chat_formatter.format( - # tools, - # chat_history=llm_input, - # current_reasoning=current_reasoning, - # ) - # for msg in latest_react_messages: - # await memory.aput(msg) - - # return AgentOutput( - # response=reasoning_step.response, - # tool_outputs=all_tool_outputs, - # raw_response=r.raw, - # current_agent=current_agent, - # ) - async def _call_llm( self, ctx: Context, llm: LLM, llm_input: List[ChatMessage], tools: List[AsyncBaseTool], - mode: AgentMode, + agent_config: AgentConfig, ) -> AgentOutput: """Call the LLM with the given input and tools.""" - if mode == AgentMode.DEFAULT: - if llm.metadata.is_function_calling_model: - return await self._call_function_calling_agent( - ctx, llm, llm_input, tools - ) - else: - return await self._call_react_agent(ctx, llm, llm_input, tools) - elif mode == AgentMode.REACT: + if agent_config.get_mode() == AgentMode.REACT: return await self._call_react_agent(ctx, llm, llm_input, tools) - elif mode == AgentMode.FUNCTION: + elif agent_config.get_mode() == AgentMode.FUNCTION: return await self._call_function_calling_agent(ctx, llm, llm_input, tools) else: - raise ValueError(f"Invalid agent mode: {mode}") + raise ValueError(f"Invalid agent mode: {agent_config.get_mode()}") + + async def _finalize_function_calling_agent(self, ctx: Context) -> None: + """Finalizes the function calling agent. + + This is a no-op for the function calling agent, since we've been writing to memory as we go. + """ + + async def _finalize_react_agent(self, ctx: Context) -> None: + """Finalizes the react agent by writing the current reasoning to memory.""" + memory: BaseMemory = await ctx.get("memory") + current_reasoning: list[BaseReasoningStep] = await ctx.get( + "current_reasoning", default=[] + ) + + # if we returned a direct tool call, we should add the final reasoning to memory + if ( + len(current_reasoning) > 0 + and isinstance(current_reasoning[-1], ObservationReasoningStep) + and current_reasoning[-1].return_direct + ): + current_reasoning.append( + ResponseReasoningStep( + thought=current_reasoning[-1].observation, + response=current_reasoning[-1].observation, + is_streaming=False, + ) + ) + + reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) + reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) + + await memory.aput(reasoning_msg) + await ctx.set("memory", memory) + await ctx.set("current_reasoning", []) @step async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: @@ -578,7 +429,7 @@ async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: @step async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: """Main agent handling logic.""" - agent_config: AgentConfig = (await ctx.get("agent_configs"))[ev.current_agent] + agent_config: AgentConfig = self.agent_configs[ev.current_agent] llm_input = ev.input # Setup the tools @@ -605,6 +456,8 @@ async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: *llm_input, ] + await ctx.set("tools_by_name", {tool.metadata.name: tool for tool in tools}) + return AgentSetup( input=llm_input, current_agent=ev.current_agent, @@ -619,7 +472,11 @@ async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: llm = agent_config.llm or Settings.llm agent_output = await self._call_llm( - ctx, llm, ev.input, ev.tools, agent_config.mode + ctx, + llm, + ev.input, + ev.tools, + agent_config, ) ctx.write_event_to_stream(agent_output) @@ -628,15 +485,61 @@ async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: @step async def parse_agent_output( self, ctx: Context, ev: AgentOutput - ) -> StopEvent | ToolCall: - pass + ) -> StopEvent | ToolCall | None: + if not ev.tool_calls: + agent_configs = self.agent_configs + current_config = agent_configs[ev.current_agent] + if current_config.get_mode() == AgentMode.REACT: + await self._finalize_react_agent(ctx) + elif current_config.get_mode() == AgentMode.FUNCTION: + await self._finalize_function_calling_agent(ctx) + else: + raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") + + return StopEvent(result=ev) + + await ctx.set("num_tool_calls", len(ev.tool_calls)) + + for tool_call in ev.tool_calls: + ctx.send_event( + ToolCall( + tool_name=tool_call.tool_name, + tool_kwargs=tool_call.tool_kwargs, + tool_id=tool_call.tool_id, + ) + ) - def _add_tool_call_result_to_memory(self, ctx: Context, ev: ToolCallResult) -> None: - """Either adds to memory or adds to the react reasoning list.""" + return None @step async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: - pass + """Calls the tool and handles the result.""" + ctx.write_event_to_stream(ev) + + tools_by_name: dict[str, AsyncBaseTool] = await ctx.get("tools_by_name") + if ev.tool_name not in tools_by_name: + result = ToolOutput( + content=f"Tool {ev.tool_name} not found. Please select a tool that is available.", + tool_name=ev.tool_name, + raw_input=ev.tool_kwargs, + raw_output=None, + is_error=True, + return_direct=False, + ) + else: + tool = tools_by_name[ev.tool_name] + result = await self._call_tool(ctx, tool, ev.tool_kwargs) + + result_ev = ToolCallResult( + tool_name=ev.tool_name, + tool_kwargs=ev.tool_kwargs, + tool_id=ev.tool_id, + tool_output=result, + return_direct=tool.metadata.return_direct, + ) + + ctx.write_event_to_stream(result_ev) + return result_ev @step async def aggregate_tool_results( @@ -647,15 +550,65 @@ async def aggregate_tool_results( if num_tool_calls == 0: raise ValueError("No tool calls found, cannot aggregate results.") - tool_call_results = ctx.collect_events( + tool_call_results: list[ToolCallResult] = ctx.collect_events( ev, expected=[ToolCallResult] * num_tool_calls ) if not tool_call_results: return None + current_agent = await ctx.get("current_agent") + current_config: AgentConfig = self.agent_configs[current_agent] + + if current_config.get_mode() == AgentMode.REACT: + await self._handle_react_tool_call(ctx, tool_call_results) + elif current_config.get_mode() == AgentMode.FUNCTION: + await self._handle_function_tool_call(ctx, tool_call_results) + else: + raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") + + if any( + tool_call_result.return_direct for tool_call_result in tool_call_results + ): + # if any tool calls return directly, take the first one + return_direct_tool = next( + [ + tool_call_result + for tool_call_result in tool_call_results + if tool_call_result.return_direct + ] + ) + + # we don't want to finalize the agent if we're just handing off + if return_direct_tool.tool_name != "handoff": + current_config = self.agent_configs[current_agent] + if current_config.get_mode() == AgentMode.REACT: + await self._finalize_react_agent(ctx) + elif current_config.get_mode() == AgentMode.FUNCTION: + await self._finalize_function_calling_agent(ctx) + else: + raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") + + return StopEvent( + result=AgentOutput( + response=return_direct_tool.tool_output.content, + tool_calls=[ + ToolSelection( + tool_id=t.tool_id, + tool_name=t.tool_name, + tool_kwargs=t.tool_kwargs, + ) + for t in tool_call_results + ], + raw_response=return_direct_tool.tool_output.raw_output, + current_agent=current_agent, + ) + ) + user_msg_str = await ctx.get("user_msg_str") memory: BaseMemory = await ctx.get("memory") input_messages = memory.get(input=user_msg_str) + + # get this again, in case it changed current_agent = await ctx.get("current_agent") return AgentInput(input=input_messages, current_agent=current_agent) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py index b5a0a58fd2b8a..c5c4d367bc233 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from llama_index.core.tools import AsyncBaseTool, ToolSelection +from llama_index.core.tools import AsyncBaseTool, ToolSelection, ToolOutput from llama_index.core.llms import ChatMessage from llama_index.core.workflow import Event from llama_index.core.agent.multi_agent.agent_config import AgentConfig @@ -63,7 +63,14 @@ class ToolCall(Event): tool_name: str tool_kwargs: dict - tool_output: Any + tool_id: str + + +class ToolCallResult(ToolCall): + """Tool call result.""" + + tool_output: ToolOutput + return_direct: bool class HandoffEvent(Event): From d91f93ebacccceb5ebdd431f4d8d8c087df50c8f Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sat, 28 Dec 2024 20:20:44 -0600 Subject: [PATCH 08/22] refactor --- .../agent/multi_agent/multi_agent_workflow.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index 7495a491b1aab..b3830747aea77 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -183,6 +183,16 @@ async def _handle_react_tool_call( ) ) + if tool_call_result.return_direct: + current_reasoning.append( + ResponseReasoningStep( + thought=current_reasoning[-1].observation, + response=current_reasoning[-1].observation, + is_streaming=False, + ) + ) + break + await ctx.set("current_reasoning", current_reasoning) async def _handle_function_tool_call( @@ -199,6 +209,16 @@ async def _handle_function_tool_call( ) ) + if tool_call_result.return_direct: + await memory.aput( + ChatMessage( + role="assistant", + content=str(tool_call_result.tool_output.content), + additional_kwargs={"tool_call_id": tool_call_result.tool_id}, + ) + ) + break + await ctx.set("memory", memory) async def _call_function_calling_agent( @@ -368,20 +388,6 @@ async def _finalize_react_agent(self, ctx: Context) -> None: "current_reasoning", default=[] ) - # if we returned a direct tool call, we should add the final reasoning to memory - if ( - len(current_reasoning) > 0 - and isinstance(current_reasoning[-1], ObservationReasoningStep) - and current_reasoning[-1].return_direct - ): - current_reasoning.append( - ResponseReasoningStep( - thought=current_reasoning[-1].observation, - response=current_reasoning[-1].observation, - is_streaming=False, - ) - ) - reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) @@ -571,11 +577,9 @@ async def aggregate_tool_results( ): # if any tool calls return directly, take the first one return_direct_tool = next( - [ - tool_call_result - for tool_call_result in tool_call_results - if tool_call_result.return_direct - ] + tool_call_result + for tool_call_result in tool_call_results + if tool_call_result.return_direct ) # we don't want to finalize the agent if we're just handing off From 85b89afcdc9ff77547d9b0dd22593c437d044171 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sun, 29 Dec 2024 12:42:41 -0600 Subject: [PATCH 09/22] finish refactor --- .../agent/multi_agent/multi_agent_workflow.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index b3830747aea77..535e2194ab667 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -176,6 +176,10 @@ async def _handle_react_tool_call( "current_reasoning", default=[] ) for tool_call_result in results: + # don't add handoff tool calls to reasoning + if tool_call_result.tool_name == "handoff": + continue + current_reasoning.append( ObservationReasoningStep( observation=str(tool_call_result.tool_output.content), @@ -201,6 +205,10 @@ async def _handle_function_tool_call( """Adds to memory.""" memory: BaseMemory = await ctx.get("memory") for tool_call_result in results: + # don't add handoff tool calls to memory + if tool_call_result.tool_name == "handoff": + continue + await memory.aput( ChatMessage( role="tool", @@ -249,12 +257,14 @@ async def _call_function_calling_agent( ) ) - current_llm_input.append(r.message) - await memory.aput(r.message) - await ctx.set("memory", memory) - tool_calls = llm.get_tool_calls_from_response(r, error_on_no_tool_call=False) + # only add to memory if we didn't select the handoff tool + if not any(tool_call.tool_name == "handoff" for tool_call in tool_calls): + current_llm_input.append(r.message) + await memory.aput(r.message) + await ctx.set("memory", memory) + return AgentOutput( response=r.message.content, tool_calls=tool_calls or [], @@ -327,8 +337,10 @@ async def _call_react_agent( current_agent=current_agent, ) - current_reasoning.append(reasoning_step) - await ctx.set("current_reasoning", current_reasoning) + # add to reasoning if not a handoff + if hasattr(reasoning_step, "action") and reasoning_step.action != "handoff": + current_reasoning.append(reasoning_step) + await ctx.set("current_reasoning", current_reasoning) # If response step, we're done if reasoning_step.is_done: @@ -550,7 +562,7 @@ async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: @step async def aggregate_tool_results( self, ctx: Context, ev: ToolCallResult - ) -> AgentInput | None: + ) -> AgentInput | StopEvent | None: """Aggregate tool results and return the next agent input.""" num_tool_calls = await ctx.get("num_tool_calls", default=0) if num_tool_calls == 0: @@ -582,16 +594,16 @@ async def aggregate_tool_results( if tool_call_result.return_direct ) - # we don't want to finalize the agent if we're just handing off - if return_direct_tool.tool_name != "handoff": - current_config = self.agent_configs[current_agent] - if current_config.get_mode() == AgentMode.REACT: - await self._finalize_react_agent(ctx) - elif current_config.get_mode() == AgentMode.FUNCTION: - await self._finalize_function_calling_agent(ctx) - else: - raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") + current_config = self.agent_configs[current_agent] + if current_config.get_mode() == AgentMode.REACT: + await self._finalize_react_agent(ctx) + elif current_config.get_mode() == AgentMode.FUNCTION: + await self._finalize_function_calling_agent(ctx) + else: + raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") + # we don't want to stop the system if we're just handing off + if return_direct_tool.tool_name != "handoff": return StopEvent( result=AgentOutput( response=return_direct_tool.tool_output.content, From faabe0d236a061dc823e4bdddd1b60310553ff8c Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sun, 29 Dec 2024 13:16:05 -0600 Subject: [PATCH 10/22] add docs --- docs/docs/understanding/agent/multi_agents.md | 250 ++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 docs/docs/understanding/agent/multi_agents.md diff --git a/docs/docs/understanding/agent/multi_agents.md b/docs/docs/understanding/agent/multi_agents.md new file mode 100644 index 0000000000000..7f7b8c3525b01 --- /dev/null +++ b/docs/docs/understanding/agent/multi_agents.md @@ -0,0 +1,250 @@ +# Multi-Agent Workflows + +The MultiAgentWorkflow allows you to create a system of multiple agents that can collaborate and hand off tasks to each other based on their specialized capabilities. This enables building more complex agent systems where different agents handle different aspects of a task. + +## Quick Start + +Here's a simple example of setting up a multi-agent workflow with a calculator agent and a retriever agent: + +```python +from llama_index.core.agent.multi_agent import ( + MultiAgentWorkflow, + AgentConfig, + AgentMode, +) +from llama_index.core.tools import FunctionTool +from llama_index.core.workflow import FunctionToolWithContext + + +# Define some tools +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +def subtract(a: int, b: int) -> int: + """Subtract two numbers.""" + return a - b + + +# Create agent configs +calculator_agent = AgentConfig( + name="calculator", + description="Performs basic arithmetic operations", + system_prompt="You are a calculator assistant.", + mode=AgentMode.REACT, + tools=[ + FunctionTool.from_defaults(fn=add), + FunctionTool.from_defaults(fn=subtract), + ], + llm=OpenAI(model="gpt-4"), +) + +retriever_agent = AgentConfig( + name="retriever", + description="Manages data retrieval", + system_prompt="You are a retrieval assistant.", + mode=AgentMode.FUNCTION, + is_entrypoint_agent=True, + llm=OpenAI(model="gpt-4"), +) + +# Create and run the workflow +workflow = MultiAgentWorkflow( + agent_configs=[calculator_agent, retriever_agent] +) + +# Run the system +response = await workflow.run(user_msg="Can you add 5 and 3?") + +# Or stream the events +handler = workflow.run(user_msg="Can you add 5 and 3?") +async for event in handler.stream_events(): + if hasattr(event, "delta"): + print(event.delta, end="", flush=True) +``` + +## How It Works + +The MultiAgentWorkflow manages a collection of agents, each with their own specialized capabilities. One agent must be designated as the entry point agent (is_entrypoint_agent=True). + +When a user message comes in, it's first routed to the entry point agent. Each agent can then: + +1. Handle the request directly using their tools +2. Hand off to another agent better suited for the task +3. Return a response to the user + +Agents can be configured in two modes: +- REACT: Uses ReAct prompting for reasoning about tool usage +- FUNCTION: Uses OpenAI function calling style for tool usage + +## Configuration Options + +### Agent Config + +Each agent is configured with an `AgentConfig`: + +```python +AgentConfig( + # Unique name for the agent (str) + name="name", + # Description of agent's capabilities (str) + description="description", + # System prompt for the agent (str) + system_prompt="system_prompt", + # react or function -- defaults to function when possible. (str) + mode="function", + # Tools available to this agent (List[BaseTool]) + tools=[...], + # LLM to use for this agent. (BaseLLM) + llm=OpenAI(model="gpt-4"), + # Whether this is the entry point. (bool) + is_entrypoint_agent=True, + # List of agents this one can hand off to. Defaults to all agents. (List[str]) + can_handoff_to=[...], +) +``` + +### Workflow Options + +The MultiAgentWorkflow constructor accepts: + +```python +MultiAgentWorkflow( + # List of agent configs. (List[AgentConfig]) + agent_configs=[...], + # Initial state dict. (Optional[dict]) + initial_state=None, + # Custom prompt for handoffs. Should contain the `agent_info` string variable. (Optional[str]) + handoff_prompt=None, + # Custom prompt for state. Should contain the `state` and `msg` string variables. (Optional[str]) + state_prompt=None, +) +``` + +### State Management + +You can provide an initial state dict that will be available to all agents: + +```python +workflow = MultiAgentWorkflow( + agent_configs=[...], + initial_state={"counter": 0}, + state_prompt="Current state: {state}. User message: {msg}", +) +``` + +The state is stored in the `state` key of the workflow context. + +In order to persist state between runs, you can pass in the context from the previous run: + +```python +workflow = MultiAgentWorkflow(...) + +# Run the workflow +handler = workflow.run(user_msg="Can you add 5 and 3?") +response = await handler + +# Pass in the context from the previous run +response = await workflow.run(ctx=handler.ctx, user_msg="Can you add 5 and 3?") +``` + +As with normal workflows, the context is serializable: + +```python +from llama_index.core.workflow import ( + Context, + JsonSerializer, + JsonPickleSerializer, +) + +# the default serializer is JsonSerializer for safety +ctx_dict = handler.ctx.to_dict(serializer=JsonSerializer()) + +# then you can rehydrate the context +ctx = Context.from_dict(ctx_dict, serializer=JsonSerializer()) +``` + +## Streaming Events + +The workflow emits various events during execution that you can stream: + +```python +async for event in workflow.run(...).stream_events(): + if isinstance(event, AgentInput): + print(event.input) + print(event.current_agent) + elif isinstance(event, AgentStream): + # Agent thinking/tool calling response stream + print(event.delta) + print(event.current_agent) + elif isinstance(event, AgentOutput): + print(event.response) + print(event.tool_calls) + print(event.raw_response) + print(event.current_agent) + elif isinstance(event, ToolCall): + # Tool being called + print(event.tool_name) + print(event.tool_kwargs) + elif isinstance(event, ToolCallResult): + # Result of tool call + print(event.tool_output) +``` + +## Accessing Context in Tools + +The `FunctionToolWithContext` allows tools to access the workflow context: + +```python +from llama_index.core.workflow import FunctionToolWithContext + + +async def get_counter(ctx: Context) -> int: + """Get the current counter value.""" + return await ctx.get("counter", default=0) + + +counter_tool = FunctionToolWithContext.from_defaults( + async_fn=get_counter, description="Get the current counter value" +) +``` + +### Human in the Loop + +Using the context, you can implement a human in the loop pattern in your tools: + +```python +from llama_index.core.workflow import Event + + +class AskForConfirmationEvent(Event): + """Ask for confirmation event.""" + + confirmation_id: str + + +class ConfirmationEvent(Event): + """Confirmation event.""" + + confirmation: bool + confirmation_id: str + + +async def ask_for_confirmation(ctx: Context) -> bool: + """Ask the user for confirmation.""" + ctx.write_event_to_stream(AskForConfirmationEvent(confirmation_id="1234")) + + result = await ctx.wait_for_event( + ConfirmationEvent, requirements={"confirmation_id": "1234"} + ) + return result.confirmation +``` + +When this function is called, it will block the workflow execution until the user sends the required confirmation event. + +```python +handler.ctx.send_event( + ConfirmationEvent(confirmation=True, confirmation_id="1234") +) +``` From e9fde686165606769e7d55289ffa44f1abac1e34 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sun, 29 Dec 2024 13:17:50 -0600 Subject: [PATCH 11/22] nit --- .../core/agent/multi_agent/multi_agent_workflow.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py index 535e2194ab667..2b610ca6687fc 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py @@ -3,7 +3,6 @@ from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode from llama_index.core.agent.multi_agent.workflow_events import ( - HandoffEvent, ToolCall, ToolCallResult, AgentInput, @@ -49,7 +48,7 @@ """ -async def handoff(ctx: Context, to_agent: str, reason: str) -> HandoffEvent: +async def handoff(ctx: Context, to_agent: str, reason: str) -> str: """Handoff control of that chat to the given agent.""" agent_configs = await ctx.get("agent_configs") current_agent = await ctx.get("current_agent") @@ -140,8 +139,8 @@ async def _init_context(self, ctx: Context, ev: StartEvent) -> None: await ctx.set("memory", default_memory) if not await ctx.get("agent_configs", default=None): await ctx.set("agent_configs", self.agent_configs) - if not await ctx.get("current_state", default=None): - await ctx.set("current_state", self.initial_state) + if not await ctx.get("state", default=None): + await ctx.set("state", self.initial_state) if not await ctx.get("current_agent", default=None): await ctx.set("current_agent", self.root_agent) @@ -429,7 +428,7 @@ async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: input_messages = memory.get(input=user_msg.content) # Add the state to the user message if it exists and if requested - current_state = await ctx.get("current_state") + current_state = await ctx.get("state") if self.state_prompt and current_state: user_msg.content = self.state_prompt.format( state=current_state, msg=user_msg.content @@ -450,7 +449,7 @@ async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: agent_config: AgentConfig = self.agent_configs[ev.current_agent] llm_input = ev.input - # Setup the tools + # Set up the tools tools = list(agent_config.tools or []) if agent_config.tool_retriever: retrieved_tools = await agent_config.tool_retriever.aretrieve( From 81efcec015a1f2c4e0302d47814d1526cacbaecf Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Sun, 29 Dec 2024 13:21:31 -0600 Subject: [PATCH 12/22] add build file --- llama-index-core/tests/agent/multi/BUILD | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 llama-index-core/tests/agent/multi/BUILD diff --git a/llama-index-core/tests/agent/multi/BUILD b/llama-index-core/tests/agent/multi/BUILD new file mode 100644 index 0000000000000..57341b1358b56 --- /dev/null +++ b/llama-index-core/tests/agent/multi/BUILD @@ -0,0 +1,3 @@ +python_tests( + name="tests", +) From 2d4777199545d8dbe3a67bc165bc1be585db4be7 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Mon, 30 Dec 2024 10:14:35 -0600 Subject: [PATCH 13/22] remove unused event --- .../llama_index/core/agent/multi_agent/workflow_events.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py index c5c4d367bc233..5da011d03e581 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py +++ b/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py @@ -71,11 +71,3 @@ class ToolCallResult(ToolCall): tool_output: ToolOutput return_direct: bool - - -class HandoffEvent(Event): - """Internal event for agent handoffs.""" - - from_agent: str - to_agent: str - reason: str From 4718dce779a8cba10e7db4e03c3fde0f768bc161 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Wed, 1 Jan 2025 22:47:10 -0600 Subject: [PATCH 14/22] refactor --- .../core/agent/multi_agent/agent_config.py | 43 -- .../agent/multi_agent/multi_agent_workflow.py | 629 ------------------ .../agent/{multi_agent => workflow}/BUILD | 0 .../core/agent/workflow/__init__.py | 26 + .../core/agent/workflow/base_agent.py | 51 ++ .../core/agent/workflow/function_agent.py | 102 +++ .../agent/workflow/multi_agent_workflow.py | 356 ++++++++++ .../core/agent/workflow/react_agent.py | 176 +++++ .../workflow_events.py | 30 +- .../tests/agent/multi/__init__.py | 0 .../tests/agent/{multi => workflow}/BUILD | 0 .../agent/workflow}/__init__.py | 0 .../test_multi_agent_workflow.py} | 79 +-- 13 files changed, 752 insertions(+), 740 deletions(-) delete mode 100644 llama-index-core/llama_index/core/agent/multi_agent/agent_config.py delete mode 100644 llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py rename llama-index-core/llama_index/core/agent/{multi_agent => workflow}/BUILD (100%) create mode 100644 llama-index-core/llama_index/core/agent/workflow/__init__.py create mode 100644 llama-index-core/llama_index/core/agent/workflow/base_agent.py create mode 100644 llama-index-core/llama_index/core/agent/workflow/function_agent.py create mode 100644 llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py create mode 100644 llama-index-core/llama_index/core/agent/workflow/react_agent.py rename llama-index-core/llama_index/core/agent/{multi_agent => workflow}/workflow_events.py (61%) delete mode 100644 llama-index-core/tests/agent/multi/__init__.py rename llama-index-core/tests/agent/{multi => workflow}/BUILD (100%) rename llama-index-core/{llama_index/core/agent/multi_agent => tests/agent/workflow}/__init__.py (100%) rename llama-index-core/tests/agent/{multi/test_multi_agent.py => workflow/test_multi_agent_workflow.py} (78%) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py b/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py deleted file mode 100644 index 4a9102b85f399..0000000000000 --- a/llama-index-core/llama_index/core/agent/multi_agent/agent_config.py +++ /dev/null @@ -1,43 +0,0 @@ -from enum import Enum -from typing import List, Optional - -from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict -from llama_index.core.llms import LLM -from llama_index.core.objects import ObjectRetriever -from llama_index.core.tools import BaseTool - - -class AgentMode(str, Enum): - """Agent mode.""" - - DEFAULT = "default" - REACT = "react" - FUNCTION = "function" - - -class AgentConfig(BaseModel): - """Configuration for a single agent in the multi-agent system.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - name: str - description: str - system_prompt: Optional[str] = None - tools: Optional[List[BaseTool]] = None - tool_retriever: Optional[ObjectRetriever] = None - can_handoff_to: Optional[List[str]] = Field(default=None) - handoff_prompt_template: Optional[str] = None - llm: Optional[LLM] = None - is_entrypoint_agent: bool = False - mode: AgentMode = AgentMode.DEFAULT - - def get_mode(self) -> AgentMode: - """Resolve the mode of the agent.""" - if self.mode == AgentMode.DEFAULT: - return ( - AgentMode.FUNCTION - if self.llm.metadata.is_function_calling_model - else AgentMode.REACT - ) - - return self.mode diff --git a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py deleted file mode 100644 index 2b610ca6687fc..0000000000000 --- a/llama-index-core/llama_index/core/agent/multi_agent/multi_agent_workflow.py +++ /dev/null @@ -1,629 +0,0 @@ -import uuid -from typing import Any, Dict, List, Optional, Union, cast - -from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode -from llama_index.core.agent.multi_agent.workflow_events import ( - ToolCall, - ToolCallResult, - AgentInput, - AgentSetup, - AgentStream, - AgentOutput, -) -from llama_index.core.agent.react.output_parser import ReActOutputParser -from llama_index.core.agent.react.formatter import ReActChatFormatter -from llama_index.core.agent.react.types import ( - ActionReasoningStep, - BaseReasoningStep, - ObservationReasoningStep, - ResponseReasoningStep, -) -from llama_index.core.llms import ChatMessage, LLM -from llama_index.core.llms.function_calling import FunctionCallingLLM -from llama_index.core.llms.llm import ToolSelection -from llama_index.core.memory import BaseMemory, ChatMemoryBuffer -from llama_index.core.prompts import BasePromptTemplate, PromptTemplate -from llama_index.core.tools import ( - BaseTool, - AsyncBaseTool, - ToolOutput, - adapt_to_async_tool, -) -from llama_index.core.workflow import ( - Context, - FunctionToolWithContext, - StartEvent, - StopEvent, - Workflow, - step, -) -from llama_index.core.settings import Settings - - -DEFAULT_HANDOFF_PROMPT = """Useful for handing off to another agent. -If you are currently not equipped to handle the user's request, or another agent is better suited to handle the request, please hand off to the appropriate agent. - -Currently available agents: -{agent_info} -""" - - -async def handoff(ctx: Context, to_agent: str, reason: str) -> str: - """Handoff control of that chat to the given agent.""" - agent_configs = await ctx.get("agent_configs") - current_agent = await ctx.get("current_agent") - if to_agent not in agent_configs: - valid_agents = ", ".join([x for x in agent_configs if x != current_agent]) - return f"Agent {to_agent} not found. Please select a valid agent to hand off to. Valid agents: {valid_agents}" - - await ctx.set("current_agent", to_agent) - return f"Handed off to {to_agent} because: {reason}" - - -class MultiAgentWorkflow(Workflow): - """A workflow for managing multiple agents with handoffs.""" - - def __init__( - self, - agent_configs: List[AgentConfig], - initial_state: Optional[Dict] = None, - handoff_prompt: Optional[Union[str, BasePromptTemplate]] = None, - state_prompt: Optional[Union[str, BasePromptTemplate]] = None, - timeout: Optional[float] = None, - **workflow_kwargs: Any, - ): - super().__init__(timeout=timeout, **workflow_kwargs) - if not agent_configs: - raise ValueError("At least one agent config must be provided") - - self.agent_configs = {cfg.name: cfg for cfg in agent_configs} - only_one_root_agent = sum(cfg.is_entrypoint_agent for cfg in agent_configs) == 1 - if not only_one_root_agent: - raise ValueError("Exactly one root agent must be provided") - - self.root_agent = next( - cfg.name for cfg in agent_configs if cfg.is_entrypoint_agent - ) - - self.initial_state = initial_state or {} - - self.handoff_prompt = handoff_prompt or DEFAULT_HANDOFF_PROMPT - if isinstance(self.handoff_prompt, str): - self.handoff_prompt = PromptTemplate(self.handoff_prompt) - if "{agent_info}" not in self.handoff_prompt.template: - raise ValueError("Handoff prompt must contain {agent_info}") - - self.state_prompt = state_prompt - if isinstance(self.state_prompt, str): - self.state_prompt = PromptTemplate(self.state_prompt) - if ( - "{state}" not in self.state_prompt.template - or "{msg}" not in self.state_prompt.template - ): - raise ValueError("State prompt must contain {state} and {msg}") - - def _ensure_tools_are_async(self, tools: List[BaseTool]) -> List[AsyncBaseTool]: - """Ensure all tools are async.""" - return [adapt_to_async_tool(tool) for tool in tools] - - def _get_handoff_tool(self, current_agent_config: AgentConfig) -> AsyncBaseTool: - """Creates a handoff tool for the given agent.""" - agent_info = {cfg.name: cfg.description for cfg in self.agent_configs.values()} - - # Filter out agents that the current agent cannot handoff to - configs_to_remove = [] - for name in agent_info: - if name == current_agent_config.name: - configs_to_remove.append(name) - elif ( - current_agent_config.can_handoff_to is not None - and name not in current_agent_config.can_handoff_to - ): - configs_to_remove.append(name) - - for name in configs_to_remove: - agent_info.pop(name) - - fn_tool_prompt = self.handoff_prompt.format(agent_info=str(agent_info)) - return FunctionToolWithContext.from_defaults( - async_fn=handoff, description=fn_tool_prompt, return_direct=True - ) - - async def _init_context(self, ctx: Context, ev: StartEvent) -> None: - """Initialize the context once, if needed.""" - if not await ctx.get("memory", default=None): - default_memory = ev.get("memory", default=None) - default_memory = default_memory or ChatMemoryBuffer.from_defaults( - llm=self.agent_configs[self.root_agent].llm or Settings.llm - ) - await ctx.set("memory", default_memory) - if not await ctx.get("agent_configs", default=None): - await ctx.set("agent_configs", self.agent_configs) - if not await ctx.get("state", default=None): - await ctx.set("state", self.initial_state) - if not await ctx.get("current_agent", default=None): - await ctx.set("current_agent", self.root_agent) - - async def _call_tool( - self, - ctx: Context, - tool: AsyncBaseTool, - tool_input: dict, - ) -> ToolOutput: - """Call the given tool with the given input.""" - try: - if isinstance(tool, FunctionToolWithContext): - tool_output = await tool.acall(ctx=ctx, **tool_input) - else: - tool_output = await tool.acall(**tool_input) - except Exception as e: - tool_output = ToolOutput( - content=str(e), - tool_name=tool.metadata.name, - raw_input=tool_input, - raw_output=str(e), - is_error=True, - ) - - return tool_output - - async def _handle_react_tool_call( - self, ctx: Context, results: List[ToolCallResult] - ) -> None: - """Adds to the react reasoning list.""" - current_reasoning: list[BaseReasoningStep] = await ctx.get( - "current_reasoning", default=[] - ) - for tool_call_result in results: - # don't add handoff tool calls to reasoning - if tool_call_result.tool_name == "handoff": - continue - - current_reasoning.append( - ObservationReasoningStep( - observation=str(tool_call_result.tool_output.content), - return_direct=tool_call_result.return_direct, - ) - ) - - if tool_call_result.return_direct: - current_reasoning.append( - ResponseReasoningStep( - thought=current_reasoning[-1].observation, - response=current_reasoning[-1].observation, - is_streaming=False, - ) - ) - break - - await ctx.set("current_reasoning", current_reasoning) - - async def _handle_function_tool_call( - self, ctx: Context, results: List[ToolCallResult] - ) -> None: - """Adds to memory.""" - memory: BaseMemory = await ctx.get("memory") - for tool_call_result in results: - # don't add handoff tool calls to memory - if tool_call_result.tool_name == "handoff": - continue - - await memory.aput( - ChatMessage( - role="tool", - content=str(tool_call_result.tool_output.content), - additional_kwargs={"tool_call_id": tool_call_result.tool_id}, - ) - ) - - if tool_call_result.return_direct: - await memory.aput( - ChatMessage( - role="assistant", - content=str(tool_call_result.tool_output.content), - additional_kwargs={"tool_call_id": tool_call_result.tool_id}, - ) - ) - break - - await ctx.set("memory", memory) - - async def _call_function_calling_agent( - self, - ctx: Context, - llm: FunctionCallingLLM, - llm_input: List[ChatMessage], - tools: List[AsyncBaseTool], - ) -> AgentOutput: - """Call the LLM as a function calling agent.""" - memory: BaseMemory = await ctx.get("memory") - current_agent = await ctx.get("current_agent") - - current_llm_input = [*llm_input] - response = await llm.astream_chat_with_tools( - tools, chat_history=current_llm_input, allow_parallel_tool_calls=True - ) - async for r in response: - tool_calls = llm.get_tool_calls_from_response( - r, error_on_no_tool_call=False - ) - ctx.write_event_to_stream( - AgentStream( - delta=r.delta or "", - tool_calls=tool_calls or [], - raw_response=r.raw, - current_agent=current_agent, - ) - ) - - tool_calls = llm.get_tool_calls_from_response(r, error_on_no_tool_call=False) - - # only add to memory if we didn't select the handoff tool - if not any(tool_call.tool_name == "handoff" for tool_call in tool_calls): - current_llm_input.append(r.message) - await memory.aput(r.message) - await ctx.set("memory", memory) - - return AgentOutput( - response=r.message.content, - tool_calls=tool_calls or [], - raw_response=r.raw, - current_agent=current_agent, - ) - - async def _call_react_agent( - self, - ctx: Context, - llm: LLM, - llm_input: List[ChatMessage], - tools: List[AsyncBaseTool], - ) -> AgentOutput: - """Call the LLM as a react agent.""" - memory: BaseMemory = await ctx.get("memory") - current_agent = await ctx.get("current_agent") - - # remove system prompt, since the react prompt will be combined with it - if llm_input[0].role == "system": - system_prompt = llm_input[0].content or "" - llm_input = llm_input[1:] - else: - system_prompt = "" - - output_parser = ReActOutputParser() - react_chat_formatter = ReActChatFormatter(context=system_prompt) - - # Format initial chat input - current_reasoning: list[BaseReasoningStep] = await ctx.get( - "current_reasoning", default=[] - ) - input_chat = react_chat_formatter.format( - tools, - chat_history=llm_input, - current_reasoning=current_reasoning, - ) - - # Initial LLM call - response = await llm.astream_chat(input_chat) - async for r in response: - ctx.write_event_to_stream( - AgentStream( - delta=r.delta or "", - tool_calls=[], - raw_response=r.raw, - current_agent=current_agent, - ) - ) - - # Parse reasoning step and check if done - message_content = r.message.content - if not message_content: - raise ValueError("Got empty message") - - try: - reasoning_step = output_parser.parse(message_content, is_streaming=False) - except ValueError as e: - # If we can't parse the output, return an error message - error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" - - await memory.aput(r.message.content) - await memory.aput(ChatMessage(role="user", content=error_msg)) - await ctx.set("memory", memory) - - return AgentOutput( - response=r.message.content, - tool_calls=[], - raw_response=r.raw, - current_agent=current_agent, - ) - - # add to reasoning if not a handoff - if hasattr(reasoning_step, "action") and reasoning_step.action != "handoff": - current_reasoning.append(reasoning_step) - await ctx.set("current_reasoning", current_reasoning) - - # If response step, we're done - if reasoning_step.is_done: - return AgentOutput( - response=r.message.content, - tool_calls=[], - raw_response=r.raw, - current_agent=current_agent, - ) - - reasoning_step = cast(ActionReasoningStep, reasoning_step) - if not isinstance(reasoning_step, ActionReasoningStep): - raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") - - # Create tool call - tool_calls = [ - ToolSelection( - tool_id=str(uuid.uuid4()), - tool_name=reasoning_step.action, - tool_kwargs=reasoning_step.action_input, - ) - ] - - return AgentOutput( - response=r.message.content, - tool_calls=tool_calls, - raw_response=r.raw, - current_agent=current_agent, - ) - - async def _call_llm( - self, - ctx: Context, - llm: LLM, - llm_input: List[ChatMessage], - tools: List[AsyncBaseTool], - agent_config: AgentConfig, - ) -> AgentOutput: - """Call the LLM with the given input and tools.""" - if agent_config.get_mode() == AgentMode.REACT: - return await self._call_react_agent(ctx, llm, llm_input, tools) - elif agent_config.get_mode() == AgentMode.FUNCTION: - return await self._call_function_calling_agent(ctx, llm, llm_input, tools) - else: - raise ValueError(f"Invalid agent mode: {agent_config.get_mode()}") - - async def _finalize_function_calling_agent(self, ctx: Context) -> None: - """Finalizes the function calling agent. - - This is a no-op for the function calling agent, since we've been writing to memory as we go. - """ - - async def _finalize_react_agent(self, ctx: Context) -> None: - """Finalizes the react agent by writing the current reasoning to memory.""" - memory: BaseMemory = await ctx.get("memory") - current_reasoning: list[BaseReasoningStep] = await ctx.get( - "current_reasoning", default=[] - ) - - reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) - reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) - - await memory.aput(reasoning_msg) - await ctx.set("memory", memory) - await ctx.set("current_reasoning", []) - - @step - async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: - """Sets up the workflow and validates inputs.""" - await self._init_context(ctx, ev) - - user_msg = ev.get("user_msg") - chat_history = ev.get("chat_history") - if user_msg and chat_history: - raise ValueError("Cannot provide both user_msg and chat_history") - - if isinstance(user_msg, str): - user_msg = ChatMessage(role="user", content=user_msg) - - await ctx.set("user_msg_str", user_msg.content) - - # Add messages to memory - memory: BaseMemory = await ctx.get("memory") - if user_msg: - await memory.aput(user_msg) - input_messages = memory.get(input=user_msg.content) - - # Add the state to the user message if it exists and if requested - current_state = await ctx.get("state") - if self.state_prompt and current_state: - user_msg.content = self.state_prompt.format( - state=current_state, msg=user_msg.content - ) - - await memory.aput(user_msg) - else: - memory.set(chat_history) - input_messages = memory.get() - - # send to the current agent - current_agent = await ctx.get("current_agent") - return AgentInput(input=input_messages, current_agent=current_agent) - - @step - async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: - """Main agent handling logic.""" - agent_config: AgentConfig = self.agent_configs[ev.current_agent] - llm_input = ev.input - - # Set up the tools - tools = list(agent_config.tools or []) - if agent_config.tool_retriever: - retrieved_tools = await agent_config.tool_retriever.aretrieve( - llm_input[-1].content or str(llm_input) - ) - tools.extend(retrieved_tools) - - if agent_config.can_handoff_to or agent_config.can_handoff_to is None: - handoff_tool = self._get_handoff_tool(agent_config) - tools.append(handoff_tool) - - tools = self._ensure_tools_are_async(tools) - - ctx.write_event_to_stream( - AgentInput(input=llm_input, current_agent=ev.current_agent) - ) - - if agent_config.system_prompt: - llm_input = [ - ChatMessage(role="system", content=agent_config.system_prompt), - *llm_input, - ] - - await ctx.set("tools_by_name", {tool.metadata.name: tool for tool in tools}) - - return AgentSetup( - input=llm_input, - current_agent=ev.current_agent, - current_config=agent_config, - tools=tools, - ) - - @step - async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: - """Run the agent.""" - agent_config = ev.current_config - llm = agent_config.llm or Settings.llm - - agent_output = await self._call_llm( - ctx, - llm, - ev.input, - ev.tools, - agent_config, - ) - ctx.write_event_to_stream(agent_output) - - return agent_output - - @step - async def parse_agent_output( - self, ctx: Context, ev: AgentOutput - ) -> StopEvent | ToolCall | None: - if not ev.tool_calls: - agent_configs = self.agent_configs - current_config = agent_configs[ev.current_agent] - if current_config.get_mode() == AgentMode.REACT: - await self._finalize_react_agent(ctx) - elif current_config.get_mode() == AgentMode.FUNCTION: - await self._finalize_function_calling_agent(ctx) - else: - raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") - - return StopEvent(result=ev) - - await ctx.set("num_tool_calls", len(ev.tool_calls)) - - for tool_call in ev.tool_calls: - ctx.send_event( - ToolCall( - tool_name=tool_call.tool_name, - tool_kwargs=tool_call.tool_kwargs, - tool_id=tool_call.tool_id, - ) - ) - - return None - - @step - async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: - """Calls the tool and handles the result.""" - ctx.write_event_to_stream(ev) - - tools_by_name: dict[str, AsyncBaseTool] = await ctx.get("tools_by_name") - if ev.tool_name not in tools_by_name: - result = ToolOutput( - content=f"Tool {ev.tool_name} not found. Please select a tool that is available.", - tool_name=ev.tool_name, - raw_input=ev.tool_kwargs, - raw_output=None, - is_error=True, - return_direct=False, - ) - else: - tool = tools_by_name[ev.tool_name] - result = await self._call_tool(ctx, tool, ev.tool_kwargs) - - result_ev = ToolCallResult( - tool_name=ev.tool_name, - tool_kwargs=ev.tool_kwargs, - tool_id=ev.tool_id, - tool_output=result, - return_direct=tool.metadata.return_direct, - ) - - ctx.write_event_to_stream(result_ev) - return result_ev - - @step - async def aggregate_tool_results( - self, ctx: Context, ev: ToolCallResult - ) -> AgentInput | StopEvent | None: - """Aggregate tool results and return the next agent input.""" - num_tool_calls = await ctx.get("num_tool_calls", default=0) - if num_tool_calls == 0: - raise ValueError("No tool calls found, cannot aggregate results.") - - tool_call_results: list[ToolCallResult] = ctx.collect_events( - ev, expected=[ToolCallResult] * num_tool_calls - ) - if not tool_call_results: - return None - - current_agent = await ctx.get("current_agent") - current_config: AgentConfig = self.agent_configs[current_agent] - - if current_config.get_mode() == AgentMode.REACT: - await self._handle_react_tool_call(ctx, tool_call_results) - elif current_config.get_mode() == AgentMode.FUNCTION: - await self._handle_function_tool_call(ctx, tool_call_results) - else: - raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") - - if any( - tool_call_result.return_direct for tool_call_result in tool_call_results - ): - # if any tool calls return directly, take the first one - return_direct_tool = next( - tool_call_result - for tool_call_result in tool_call_results - if tool_call_result.return_direct - ) - - current_config = self.agent_configs[current_agent] - if current_config.get_mode() == AgentMode.REACT: - await self._finalize_react_agent(ctx) - elif current_config.get_mode() == AgentMode.FUNCTION: - await self._finalize_function_calling_agent(ctx) - else: - raise ValueError(f"Invalid agent mode: {current_config.get_mode()}") - - # we don't want to stop the system if we're just handing off - if return_direct_tool.tool_name != "handoff": - return StopEvent( - result=AgentOutput( - response=return_direct_tool.tool_output.content, - tool_calls=[ - ToolSelection( - tool_id=t.tool_id, - tool_name=t.tool_name, - tool_kwargs=t.tool_kwargs, - ) - for t in tool_call_results - ], - raw_response=return_direct_tool.tool_output.raw_output, - current_agent=current_agent, - ) - ) - - user_msg_str = await ctx.get("user_msg_str") - memory: BaseMemory = await ctx.get("memory") - input_messages = memory.get(input=user_msg_str) - - # get this again, in case it changed - current_agent = await ctx.get("current_agent") - - return AgentInput(input=input_messages, current_agent=current_agent) diff --git a/llama-index-core/llama_index/core/agent/multi_agent/BUILD b/llama-index-core/llama_index/core/agent/workflow/BUILD similarity index 100% rename from llama-index-core/llama_index/core/agent/multi_agent/BUILD rename to llama-index-core/llama_index/core/agent/workflow/BUILD diff --git a/llama-index-core/llama_index/core/agent/workflow/__init__.py b/llama-index-core/llama_index/core/agent/workflow/__init__.py new file mode 100644 index 0000000000000..aba7b324bcbcc --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/__init__.py @@ -0,0 +1,26 @@ +from llama_index.core.agent.workflow.multi_agent_workflow import MultiAgentWorkflow +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.function_agent import FunctionAgent +from llama_index.core.agent.workflow.react_agent import ReactAgent +from llama_index.core.agent.workflow.workflow_events import ( + AgentInput, + AgentSetup, + AgentStream, + AgentOutput, + ToolCall, + ToolCallResult, +) + + +__all__ = [ + "AgentInput", + "AgentSetup", + "AgentStream", + "AgentOutput", + "BaseWorkflowAgent", + "FunctionAgent", + "MultiAgentWorkflow", + "ReactAgent", + "ToolCall", + "ToolCallResult", +] diff --git a/llama-index-core/llama_index/core/agent/workflow/base_agent.py b/llama-index-core/llama_index/core/agent/workflow/base_agent.py new file mode 100644 index 0000000000000..de850ab26455c --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/base_agent.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from llama_index.core.agent.workflow.workflow_events import ( + AgentOutput, + ToolCallResult, +) +from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict +from llama_index.core.llms import ChatMessage, LLM +from llama_index.core.memory import BaseMemory +from llama_index.core.tools import BaseTool, AsyncBaseTool +from llama_index.core.workflow import Context +from llama_index.core.objects import ObjectRetriever + + +class BaseWorkflowAgent(BaseModel, ABC): + """Base class for all agents, combining config and logic.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str + description: str + system_prompt: Optional[str] = None + tools: Optional[List[BaseTool]] = None + tool_retriever: Optional[ObjectRetriever] = None + can_handoff_to: Optional[List[str]] = Field(default=None) + handoff_prompt_template: Optional[str] = None + llm: Optional[LLM] = None + is_entrypoint_agent: bool = False + + @abstractmethod + async def take_step( + self, + ctx: Context, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + memory: BaseMemory, + ) -> AgentOutput: + """Take a single step with the agent.""" + + @abstractmethod + async def handle_tool_call_results( + self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory + ) -> None: + """Handle tool call results.""" + + @abstractmethod + async def finalize( + self, ctx: Context, output: AgentOutput, memory: BaseMemory + ) -> AgentOutput: + """Finalize the agent's execution.""" diff --git a/llama-index-core/llama_index/core/agent/workflow/function_agent.py b/llama-index-core/llama_index/core/agent/workflow/function_agent.py new file mode 100644 index 0000000000000..f8ef3dda17b34 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/function_agent.py @@ -0,0 +1,102 @@ +from typing import List + +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.workflow_events import ( + AgentInput, + AgentOutput, + AgentStream, + ToolCallResult, +) +from llama_index.core.llms import ChatMessage +from llama_index.core.memory import BaseMemory +from llama_index.core.tools import AsyncBaseTool +from llama_index.core.workflow import Context + + +class FunctionAgent(BaseWorkflowAgent): + """Function calling agent implementation.""" + + async def take_step( + self, + ctx: Context, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + memory: BaseMemory, + ) -> AgentOutput: + """Take a single step with the function calling agent.""" + if not self.llm.metadata.is_function_calling_model: + raise ValueError("LLM must be a FunctionCallingLLM") + + current_llm_input = [*llm_input] + + ctx.write_event_to_stream( + AgentInput(input=current_llm_input, current_agent_name=self.name) + ) + + response = await self.llm.astream_chat_with_tools( + tools, chat_history=current_llm_input, allow_parallel_tool_calls=True + ) + async for r in response: + tool_calls = self.llm.get_tool_calls_from_response( + r, error_on_no_tool_call=False + ) + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + tool_calls=tool_calls or [], + raw_response=r.raw, + current_agent_name=self.name, + ) + ) + + tool_calls = self.llm.get_tool_calls_from_response( + r, error_on_no_tool_call=False + ) + + # only add to memory if we didn't select the handoff tool + if not any(tool_call.tool_name == "handoff" for tool_call in tool_calls): + current_llm_input.append(r.message) + await memory.aput(r.message) + + return AgentOutput( + response=r.message.content, + tool_calls=tool_calls or [], + raw_response=r.raw, + current_agent_name=self.name, + ) + + async def handle_tool_call_results( + self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory + ) -> None: + """Handle tool call results for function calling agent.""" + for tool_call_result in results: + # don't add handoff tool calls to memory + if tool_call_result.tool_name == "handoff": + continue + + await memory.aput( + ChatMessage( + role="tool", + content=str(tool_call_result.tool_output.content), + additional_kwargs={"tool_call_id": tool_call_result.tool_id}, + ) + ) + + if tool_call_result.return_direct: + await memory.aput( + ChatMessage( + role="assistant", + content=str(tool_call_result.tool_output.content), + additional_kwargs={"tool_call_id": tool_call_result.tool_id}, + ) + ) + break + + async def finalize( + self, ctx: Context, output: AgentOutput, memory: BaseMemory + ) -> AgentOutput: + """Finalize the function calling agent. + + This is a no-op for function calling agents since we write to memory as we go. + """ + return output diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py new file mode 100644 index 0000000000000..16d462b51c90f --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py @@ -0,0 +1,356 @@ +from typing import Any, Dict, List, Optional, Union + +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.workflow_events import ( + ToolCall, + ToolCallResult, + AgentInput, + AgentSetup, + AgentOutput, +) +from llama_index.core.llms import ChatMessage +from llama_index.core.llms.llm import ToolSelection +from llama_index.core.memory import BaseMemory, ChatMemoryBuffer +from llama_index.core.prompts import BasePromptTemplate, PromptTemplate +from llama_index.core.tools import ( + BaseTool, + AsyncBaseTool, + ToolOutput, + adapt_to_async_tool, +) +from llama_index.core.workflow import ( + Context, + FunctionToolWithContext, + StartEvent, + StopEvent, + Workflow, + step, +) +from llama_index.core.settings import Settings + + +DEFAULT_HANDOFF_PROMPT = """Useful for handing off to another agent. +If you are currently not equipped to handle the user's request, or another agent is better suited to handle the request, please hand off to the appropriate agent. + +Currently available agents: +{agent_info} +""" + + +async def handoff(ctx: Context, to_agent: str, reason: str) -> str: + """Handoff control of that chat to the given agent.""" + agents: dict[str, BaseWorkflowAgent] = await ctx.get("agents") + current_agent: BaseWorkflowAgent = await ctx.get("current_agent") + if to_agent not in agents: + valid_agents = ", ".join([x for x in agents if x != current_agent.name]) + return f"Agent {to_agent} not found. Please select a valid agent to hand off to. Valid agents: {valid_agents}" + + await ctx.set("next_agent", agents[to_agent].name) + return f"Handed off to {to_agent} because: {reason}" + + +class MultiAgentWorkflow(Workflow): + """A workflow for managing multiple agents with handoffs.""" + + def __init__( + self, + agents: List[BaseWorkflowAgent], + initial_state: Optional[Dict] = None, + handoff_prompt: Optional[Union[str, BasePromptTemplate]] = None, + state_prompt: Optional[Union[str, BasePromptTemplate]] = None, + timeout: Optional[float] = None, + **workflow_kwargs: Any, + ): + super().__init__(timeout=timeout, **workflow_kwargs) + if not agents: + raise ValueError("At least one agent must be provided") + + self.agents = {cfg.name: cfg for cfg in agents} + only_one_root_agent = sum(cfg.is_entrypoint_agent for cfg in agents) == 1 + if not only_one_root_agent: + raise ValueError("Exactly one root agent must be provided") + + self.root_agent = next(agent for agent in agents if agent.is_entrypoint_agent) + + self.initial_state = initial_state or {} + + self.handoff_prompt = handoff_prompt or DEFAULT_HANDOFF_PROMPT + if isinstance(self.handoff_prompt, str): + self.handoff_prompt = PromptTemplate(self.handoff_prompt) + if "{agent_info}" not in self.handoff_prompt.template: + raise ValueError("Handoff prompt must contain {agent_info}") + + self.state_prompt = state_prompt + if isinstance(self.state_prompt, str): + self.state_prompt = PromptTemplate(self.state_prompt) + if ( + "{state}" not in self.state_prompt.template + or "{msg}" not in self.state_prompt.template + ): + raise ValueError("State prompt must contain {state} and {msg}") + + def _ensure_tools_are_async(self, tools: List[BaseTool]) -> List[AsyncBaseTool]: + """Ensure all tools are async.""" + return [adapt_to_async_tool(tool) for tool in tools] + + def _get_handoff_tool(self, current_agent: BaseWorkflowAgent) -> AsyncBaseTool: + """Creates a handoff tool for the given agent.""" + agent_info = {cfg.name: cfg.description for cfg in self.agents.values()} + + # Filter out agents that the current agent cannot handoff to + configs_to_remove = [] + for name in agent_info: + if name == current_agent.name: + configs_to_remove.append(name) + elif ( + current_agent.can_handoff_to is not None + and name not in current_agent.can_handoff_to + ): + configs_to_remove.append(name) + + for name in configs_to_remove: + agent_info.pop(name) + + fn_tool_prompt = self.handoff_prompt.format(agent_info=str(agent_info)) + return FunctionToolWithContext.from_defaults( + async_fn=handoff, description=fn_tool_prompt, return_direct=True + ) + + async def _init_context(self, ctx: Context, ev: StartEvent) -> None: + """Initialize the context once, if needed.""" + if not await ctx.get("memory", default=None): + default_memory = ev.get("memory", default=None) + default_memory = default_memory or ChatMemoryBuffer.from_defaults( + llm=self.agents[self.root_agent.name].llm or Settings.llm + ) + await ctx.set("memory", default_memory) + if not await ctx.get("agents", default=None): + await ctx.set("agents", self.agents) + if not await ctx.get("state", default=None): + await ctx.set("state", self.initial_state) + if not await ctx.get("current_agent", default=None): + await ctx.set("current_agent", self.root_agent) + + async def _call_tool( + self, + ctx: Context, + tool: AsyncBaseTool, + tool_input: dict, + ) -> ToolOutput: + """Call the given tool with the given input.""" + try: + if isinstance(tool, FunctionToolWithContext): + tool_output = await tool.acall(ctx=ctx, **tool_input) + else: + tool_output = await tool.acall(**tool_input) + except Exception as e: + tool_output = ToolOutput( + content=str(e), + tool_name=tool.metadata.name, + raw_input=tool_input, + raw_output=str(e), + is_error=True, + ) + + return tool_output + + @step + async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: + """Sets up the workflow and validates inputs.""" + await self._init_context(ctx, ev) + + user_msg = ev.get("user_msg") + chat_history = ev.get("chat_history") + if user_msg and chat_history: + raise ValueError("Cannot provide both user_msg and chat_history") + + if isinstance(user_msg, str): + user_msg = ChatMessage(role="user", content=user_msg) + + await ctx.set("user_msg_str", user_msg.content) + + # Add messages to memory + memory: BaseMemory = await ctx.get("memory") + if user_msg: + # Add the state to the user message if it exists and if requested + current_state = await ctx.get("state") + if self.state_prompt and current_state: + user_msg.content = self.state_prompt.format( + state=current_state, msg=user_msg.content + ) + + await memory.aput(user_msg) + input_messages = memory.get(input=user_msg.content) + else: + memory.set(chat_history) + input_messages = memory.get() + + # send to the current agent + current_agent: BaseWorkflowAgent = await ctx.get("current_agent") + return AgentInput(input=input_messages, current_agent_name=current_agent.name) + + @step + async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: + """Main agent handling logic.""" + current_agent_name = ev.current_agent_name + agent = self.agents[current_agent_name] + llm_input = ev.input + + # Set up the tools + tools = agent.tools or [] + if agent.tool_retriever: + retrieved_tools = await agent.tool_retriever.aretrieve( + llm_input[-1].content or str(llm_input) + ) + tools.extend(retrieved_tools) + + if agent.can_handoff_to or agent.can_handoff_to is None: + handoff_tool = self._get_handoff_tool(agent) + tools.append(handoff_tool) + + tools = self._ensure_tools_are_async(tools) + + if agent.system_prompt: + llm_input = [ + ChatMessage(role="system", content=agent.system_prompt), + *llm_input, + ] + + await ctx.set("tools_by_name", {tool.metadata.name: tool for tool in tools}) + + return AgentSetup( + input=llm_input, + current_agent_name=ev.current_agent_name, + tools=tools, + ) + + @step + async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: + """Run the agent.""" + memory: BaseMemory = await ctx.get("memory") + agent = self.agents[ev.current_agent_name] + + return await agent.take_step( + ctx, + ev.input, + ev.tools, + memory, + ) + + @step + async def parse_agent_output( + self, ctx: Context, ev: AgentOutput + ) -> StopEvent | ToolCall | None: + if not ev.tool_calls: + agent = self.agents[ev.current_agent_name] + memory: BaseMemory = await ctx.get("memory") + output = await agent.finalize(ctx, ev, memory) + + return StopEvent(result=output) + + await ctx.set("num_tool_calls", len(ev.tool_calls)) + + for tool_call in ev.tool_calls: + ctx.send_event( + ToolCall( + tool_name=tool_call.tool_name, + tool_kwargs=tool_call.tool_kwargs, + tool_id=tool_call.tool_id, + ) + ) + + return None + + @step + async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: + """Calls the tool and handles the result.""" + ctx.write_event_to_stream(ev) + + tools_by_name: dict[str, AsyncBaseTool] = await ctx.get("tools_by_name") + if ev.tool_name not in tools_by_name: + tool = None + result = ToolOutput( + content=f"Tool {ev.tool_name} not found. Please select a tool that is available.", + tool_name=ev.tool_name, + raw_input=ev.tool_kwargs, + raw_output=None, + is_error=True, + ) + else: + tool = tools_by_name[ev.tool_name] + result = await self._call_tool(ctx, tool, ev.tool_kwargs) + + result_ev = ToolCallResult( + tool_name=ev.tool_name, + tool_kwargs=ev.tool_kwargs, + tool_id=ev.tool_id, + tool_output=result, + return_direct=tool.metadata.return_direct if tool else False, + ) + + ctx.write_event_to_stream(result_ev) + return result_ev + + @step + async def aggregate_tool_results( + self, ctx: Context, ev: ToolCallResult + ) -> AgentInput | StopEvent | None: + """Aggregate tool results and return the next agent input.""" + num_tool_calls = await ctx.get("num_tool_calls", default=0) + if num_tool_calls == 0: + raise ValueError("No tool calls found, cannot aggregate results.") + + tool_call_results: list[ToolCallResult] = ctx.collect_events( + ev, expected=[ToolCallResult] * num_tool_calls + ) + if not tool_call_results: + return None + + memory: BaseMemory = await ctx.get("memory") + agent: BaseWorkflowAgent = await ctx.get("current_agent") + + await agent.handle_tool_call_results(ctx, tool_call_results, memory) + + # set the next agent, if needed + # the handoff tool sets this + next_agent_name = await ctx.get("next_agent", default=None) + if next_agent_name: + await ctx.set("current_agent", self.agents[next_agent_name]) + + if any( + tool_call_result.return_direct for tool_call_result in tool_call_results + ): + # if any tool calls return directly, take the first one + return_direct_tool = next( + tool_call_result + for tool_call_result in tool_call_results + if tool_call_result.return_direct + ) + + # always finalize the agent, even if we're just handing off + result = AgentOutput( + response=return_direct_tool.tool_output.content, + tool_calls=[ + ToolSelection( + tool_id=t.tool_id, + tool_name=t.tool_name, + tool_kwargs=t.tool_kwargs, + ) + for t in tool_call_results + ], + raw_response=return_direct_tool.tool_output.raw_output, + current_agent_name=agent.name, + ) + result = await agent.finalize(ctx, result, memory) + + # we don't want to stop the system if we're just handing off + if return_direct_tool.tool_name != "handoff": + return StopEvent(result=result) + + user_msg_str = await ctx.get("user_msg_str") + input_messages = memory.get(input=user_msg_str) + + # get this again, in case it changed + agent: BaseWorkflowAgent = await ctx.get("current_agent") + + return AgentInput(input=input_messages, current_agent_name=agent.name) diff --git a/llama-index-core/llama_index/core/agent/workflow/react_agent.py b/llama-index-core/llama_index/core/agent/workflow/react_agent.py new file mode 100644 index 0000000000000..7e50dfc8e5e3b --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/react_agent.py @@ -0,0 +1,176 @@ +import uuid +from typing import List, cast + +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.workflow_events import ( + AgentInput, + AgentOutput, + AgentStream, + ToolCallResult, +) +from llama_index.core.agent.react.formatter import ReActChatFormatter +from llama_index.core.agent.react.output_parser import ReActOutputParser +from llama_index.core.agent.react.types import ( + ActionReasoningStep, + BaseReasoningStep, + ObservationReasoningStep, + ResponseReasoningStep, +) +from llama_index.core.llms import ChatMessage +from llama_index.core.llms.llm import ToolSelection +from llama_index.core.memory import BaseMemory +from llama_index.core.tools import AsyncBaseTool +from llama_index.core.workflow import Context + + +class ReactAgent(BaseWorkflowAgent): + """React agent implementation.""" + + async def take_step( + self, + ctx: Context, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + memory: BaseMemory, + ) -> AgentOutput: + """Take a single step with the React agent.""" + # remove system prompt, since the react prompt will be combined with it + if llm_input[0].role == "system": + system_prompt = llm_input[0].content or "" + llm_input = llm_input[1:] + else: + system_prompt = "" + + output_parser = ReActOutputParser() + react_chat_formatter = ReActChatFormatter(context=system_prompt) + + # Format initial chat input + current_reasoning: list[BaseReasoningStep] = await ctx.get( + "current_reasoning", default=[] + ) + input_chat = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + + ctx.write_event_to_stream( + AgentInput(input=input_chat, current_agent_name=self.name) + ) + + # Initial LLM call + response = await self.llm.astream_chat(input_chat) + async for r in response: + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + tool_calls=[], + raw_response=r.raw, + current_agent_name=self.name, + ) + ) + + # Parse reasoning step and check if done + message_content = r.message.content + if not message_content: + raise ValueError("Got empty message") + + try: + reasoning_step = output_parser.parse(message_content, is_streaming=False) + except ValueError as e: + error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" + await memory.aput(r.message) + await memory.aput(ChatMessage(role="user", content=error_msg)) + + return AgentOutput( + response=r.message.content, + tool_calls=[], + raw_response=r.raw, + current_agent_name=self.name, + ) + + # add to reasoning if not a handoff + if hasattr(reasoning_step, "action") and reasoning_step.action != "handoff": + current_reasoning.append(reasoning_step) + await ctx.set("current_reasoning", current_reasoning) + + # If response step, we're done + if reasoning_step.is_done: + return AgentOutput( + response=r.message.content, + tool_calls=[], + raw_response=r.raw, + current_agent_name=self.name, + ) + + reasoning_step = cast(ActionReasoningStep, reasoning_step) + if not isinstance(reasoning_step, ActionReasoningStep): + raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") + + # Create tool call + tool_calls = [ + ToolSelection( + tool_id=str(uuid.uuid4()), + tool_name=reasoning_step.action, + tool_kwargs=reasoning_step.action_input, + ) + ] + + return AgentOutput( + response=r.message.content, + tool_calls=tool_calls, + raw_response=r.raw, + current_agent_name=self.name, + ) + + async def handle_tool_call_results( + self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory + ) -> None: + """Handle tool call results for React agent.""" + current_reasoning: list[BaseReasoningStep] = await ctx.get( + "current_reasoning", default=[] + ) + for tool_call_result in results: + # don't add handoff tool calls to reasoning + if tool_call_result.tool_name == "handoff": + continue + + current_reasoning.append( + ObservationReasoningStep( + observation=str(tool_call_result.tool_output.content), + return_direct=tool_call_result.return_direct, + ) + ) + + if tool_call_result.return_direct: + current_reasoning.append( + ResponseReasoningStep( + thought=current_reasoning[-1].observation, + response=current_reasoning[-1].observation, + is_streaming=False, + ) + ) + break + + await ctx.set("current_reasoning", current_reasoning) + + async def finalize( + self, ctx: Context, output: AgentOutput, memory: BaseMemory + ) -> AgentOutput: + """Finalize the React agent.""" + current_reasoning: list[BaseReasoningStep] = await ctx.get( + "current_reasoning", default=[] + ) + + reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) + reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) + + await memory.aput(reasoning_msg) + await ctx.set("current_reasoning", []) + + # remove "Answer:" from the response + if output.response and "Answer:" in output.response: + start_idx = output.response.index("Answer:") + output.response = output.response[start_idx + len("Answer:") :].strip() + + return output diff --git a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py b/llama-index-core/llama_index/core/agent/workflow/workflow_events.py similarity index 61% rename from llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py rename to llama-index-core/llama_index/core/agent/workflow/workflow_events.py index 5da011d03e581..21f47eb9a57d4 100644 --- a/llama-index-core/llama_index/core/agent/multi_agent/workflow_events.py +++ b/llama-index-core/llama_index/core/agent/workflow/workflow_events.py @@ -1,42 +1,22 @@ -from typing import Any, Optional +from typing import Any from llama_index.core.tools import AsyncBaseTool, ToolSelection, ToolOutput from llama_index.core.llms import ChatMessage from llama_index.core.workflow import Event -from llama_index.core.agent.multi_agent.agent_config import AgentConfig - - -class ToolApprovalNeeded(Event): - """Emitted when a tool call needs approval.""" - - id: str - tool_name: str - tool_kwargs: dict - - -class ApproveTool(Event): - """Required to approve a tool.""" - - id: str - tool_name: str - tool_kwargs: dict - approved: bool - reason: Optional[str] = None class AgentInput(Event): """LLM input.""" input: list[ChatMessage] - current_agent: str + current_agent_name: str class AgentSetup(Event): """Agent setup.""" input: list[ChatMessage] - current_agent: str - current_config: AgentConfig + current_agent_name: str tools: list[AsyncBaseTool] @@ -44,7 +24,7 @@ class AgentStream(Event): """Agent stream.""" delta: str - current_agent: str + current_agent_name: str tool_calls: list[ToolSelection] raw_response: Any @@ -55,7 +35,7 @@ class AgentOutput(Event): response: str tool_calls: list[ToolSelection] raw_response: Any - current_agent: str + current_agent_name: str class ToolCall(Event): diff --git a/llama-index-core/tests/agent/multi/__init__.py b/llama-index-core/tests/agent/multi/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/llama-index-core/tests/agent/multi/BUILD b/llama-index-core/tests/agent/workflow/BUILD similarity index 100% rename from llama-index-core/tests/agent/multi/BUILD rename to llama-index-core/tests/agent/workflow/BUILD diff --git a/llama-index-core/llama_index/core/agent/multi_agent/__init__.py b/llama-index-core/tests/agent/workflow/__init__.py similarity index 100% rename from llama-index-core/llama_index/core/agent/multi_agent/__init__.py rename to llama-index-core/tests/agent/workflow/__init__.py diff --git a/llama-index-core/tests/agent/multi/test_multi_agent.py b/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py similarity index 78% rename from llama-index-core/tests/agent/multi/test_multi_agent.py rename to llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py index 57433140d85e1..340c98cdcdbc5 100644 --- a/llama-index-core/tests/agent/multi/test_multi_agent.py +++ b/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py @@ -2,8 +2,9 @@ import pytest from llama_index.core.llms import MockLLM -from llama_index.core.agent.multi_agent.multi_agent_workflow import MultiAgentWorkflow -from llama_index.core.agent.multi_agent.agent_config import AgentConfig, AgentMode +from llama_index.core.agent.workflow.multi_agent_workflow import MultiAgentWorkflow +from llama_index.core.agent.workflow.function_agent import FunctionAgent +from llama_index.core.agent.workflow.react_agent import ReactAgent from llama_index.core.llms import ( ChatMessage, ChatResponse, @@ -16,15 +17,14 @@ class MockLLM(MockLLM): - def __init__(self, responses: List[ChatMessage], is_function_calling: bool = False): + def __init__(self, responses: List[ChatMessage]): super().__init__() self._responses = responses self._response_index = 0 - self._is_function_calling = is_function_calling @property def metadata(self) -> LLMMetadata: - return LLMMetadata(is_function_calling_model=self._is_function_calling) + return LLMMetadata(is_function_calling_model=True) async def astream_chat( self, messages: List[ChatMessage], **kwargs: Any @@ -74,11 +74,10 @@ def subtract(a: int, b: int) -> int: @pytest.fixture() def calculator_agent(): - return AgentConfig( + return ReactAgent( name="calculator", description="Performs basic arithmetic operations", system_prompt="You are a calculator assistant.", - mode=AgentMode.REACT, tools=[ FunctionTool.from_defaults(fn=add), FunctionTool.from_defaults(fn=subtract), @@ -100,12 +99,11 @@ def calculator_agent(): @pytest.fixture() def retriever_agent(): - return AgentConfig( + return FunctionAgent( name="retriever", description="Manages data retrieval", system_prompt="You are a retrieval assistant.", is_entrypoint_agent=True, - mode=AgentMode.FUNCTION, llm=MockLLM( responses=[ ChatMessage( @@ -124,24 +122,7 @@ def retriever_agent(): ] }, ), - ChatMessage( - role=MessageRole.ASSISTANT, - content="handoff calculator Because this requires arithmetic operations.", - additional_kwargs={ - "tool_calls": [ - ToolSelection( - tool_id="one", - tool_name="handoff", - tool_kwargs={ - "to_agent": "calculator", - "reason": "This requires arithmetic operations.", - }, - ) - ] - }, - ), ], - is_function_calling=True, ), ) @@ -150,13 +131,13 @@ def retriever_agent(): async def test_basic_workflow(calculator_agent, retriever_agent): """Test basic workflow initialization and validation.""" workflow = MultiAgentWorkflow( - agent_configs=[calculator_agent, retriever_agent], + agents=[calculator_agent, retriever_agent], ) - assert workflow.root_agent == "retriever" - assert len(workflow.agent_configs) == 2 - assert "calculator" in workflow.agent_configs - assert "retriever" in workflow.agent_configs + assert workflow.root_agent == retriever_agent + assert len(workflow.agents) == 2 + assert "calculator" in workflow.agents + assert "retriever" in workflow.agents @pytest.mark.asyncio() @@ -164,16 +145,26 @@ async def test_workflow_requires_root_agent(): """Test that workflow requires exactly one root agent.""" with pytest.raises(ValueError, match="Exactly one root agent must be provided"): MultiAgentWorkflow( - agent_configs=[ - AgentConfig( + agents=[ + FunctionAgent( name="agent1", description="test", is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage(role=MessageRole.ASSISTANT, content="test"), + ] + ), ), - AgentConfig( + ReactAgent( name="agent2", description="test", is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage(role=MessageRole.ASSISTANT, content="test"), + ] + ), ), ] ) @@ -183,7 +174,7 @@ async def test_workflow_requires_root_agent(): async def test_workflow_execution(calculator_agent, retriever_agent): """Test basic workflow execution with agent handoff.""" workflow = MultiAgentWorkflow( - agent_configs=[calculator_agent, retriever_agent], + agents=[calculator_agent, retriever_agent], ) memory = ChatMemoryBuffer.from_defaults() @@ -197,11 +188,15 @@ async def test_workflow_execution(calculator_agent, retriever_agent): # Verify we got events indicating handoff and calculation assert any( - ev.current_agent == "retriever" if hasattr(ev, "current_agent") else False + ev.current_agent_name == "retriever" + if hasattr(ev, "current_agent_name") + else False for ev in events ) assert any( - ev.current_agent == "calculator" if hasattr(ev, "current_agent") else False + ev.current_agent_name == "calculator" + if hasattr(ev, "current_agent_name") + else False for ev in events ) assert "8" in response.response @@ -210,7 +205,7 @@ async def test_workflow_execution(calculator_agent, retriever_agent): @pytest.mark.asyncio() async def test_invalid_handoff(): """Test handling of invalid agent handoff.""" - agent1 = AgentConfig( + agent1 = FunctionAgent( name="agent1", description="test", is_entrypoint_agent=True, @@ -234,12 +229,11 @@ async def test_invalid_handoff(): ), ChatMessage(role=MessageRole.ASSISTANT, content="guess im stuck here"), ], - is_function_calling=True, ), ) workflow = MultiAgentWorkflow( - agent_configs=[agent1], + agents=[agent1], ) handler = workflow.run(user_msg="test") @@ -254,7 +248,7 @@ async def test_invalid_handoff(): @pytest.mark.asyncio() async def test_workflow_with_state(): """Test workflow with state management.""" - agent = AgentConfig( + agent = FunctionAgent( name="agent", description="test", is_entrypoint_agent=True, @@ -264,12 +258,11 @@ async def test_workflow_with_state(): role=MessageRole.ASSISTANT, content="Current state processed" ) ], - is_function_calling=True, ), ) workflow = MultiAgentWorkflow( - agent_configs=[agent], + agents=[agent], initial_state={"counter": 0}, state_prompt="Current state: {state}. User message: {msg}", ) From f2f2a1eca3a004e8c750c4ca63c12b94f421498a Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 09:15:10 -0600 Subject: [PATCH 15/22] fix tests --- .../llama_index/core/agent/workflow/function_agent.py | 3 ++- .../core/agent/workflow/multi_agent_workflow.py | 6 +++--- .../llama_index/core/agent/workflow/react_agent.py | 9 +++++---- .../llama_index/core/agent/workflow/workflow_events.py | 8 ++++++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/workflow/function_agent.py b/llama-index-core/llama_index/core/agent/workflow/function_agent.py index f8ef3dda17b34..88edcb140e444 100644 --- a/llama-index-core/llama_index/core/agent/workflow/function_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/function_agent.py @@ -43,6 +43,7 @@ async def take_step( ctx.write_event_to_stream( AgentStream( delta=r.delta or "", + response=r.message.content, tool_calls=tool_calls or [], raw_response=r.raw, current_agent_name=self.name, @@ -61,7 +62,7 @@ async def take_step( return AgentOutput( response=r.message.content, tool_calls=tool_calls or [], - raw_response=r.raw, + raw=r.raw, current_agent_name=self.name, ) diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py index 16d462b51c90f..26cdd12def039 100644 --- a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py @@ -240,7 +240,7 @@ async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: @step async def parse_agent_output( self, ctx: Context, ev: AgentOutput - ) -> StopEvent | ToolCall | None: + ) -> Union[StopEvent, ToolCall, None]: if not ev.tool_calls: agent = self.agents[ev.current_agent_name] memory: BaseMemory = await ctx.get("memory") @@ -294,7 +294,7 @@ async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: @step async def aggregate_tool_results( self, ctx: Context, ev: ToolCallResult - ) -> AgentInput | StopEvent | None: + ) -> Union[AgentInput, StopEvent, None]: """Aggregate tool results and return the next agent input.""" num_tool_calls = await ctx.get("num_tool_calls", default=0) if num_tool_calls == 0: @@ -338,7 +338,7 @@ async def aggregate_tool_results( ) for t in tool_call_results ], - raw_response=return_direct_tool.tool_output.raw_output, + raw=return_direct_tool.tool_output.raw_output, current_agent_name=agent.name, ) result = await agent.finalize(ctx, result, memory) diff --git a/llama-index-core/llama_index/core/agent/workflow/react_agent.py b/llama-index-core/llama_index/core/agent/workflow/react_agent.py index 7e50dfc8e5e3b..70ed8e1f303ff 100644 --- a/llama-index-core/llama_index/core/agent/workflow/react_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/react_agent.py @@ -64,8 +64,9 @@ async def take_step( ctx.write_event_to_stream( AgentStream( delta=r.delta or "", + response=r.message.content, tool_calls=[], - raw_response=r.raw, + raw=r.raw, current_agent_name=self.name, ) ) @@ -85,7 +86,7 @@ async def take_step( return AgentOutput( response=r.message.content, tool_calls=[], - raw_response=r.raw, + raw=r.raw, current_agent_name=self.name, ) @@ -99,7 +100,7 @@ async def take_step( return AgentOutput( response=r.message.content, tool_calls=[], - raw_response=r.raw, + raw=r.raw, current_agent_name=self.name, ) @@ -119,7 +120,7 @@ async def take_step( return AgentOutput( response=r.message.content, tool_calls=tool_calls, - raw_response=r.raw, + raw=r.raw, current_agent_name=self.name, ) diff --git a/llama-index-core/llama_index/core/agent/workflow/workflow_events.py b/llama-index-core/llama_index/core/agent/workflow/workflow_events.py index 21f47eb9a57d4..2a0532e79c229 100644 --- a/llama-index-core/llama_index/core/agent/workflow/workflow_events.py +++ b/llama-index-core/llama_index/core/agent/workflow/workflow_events.py @@ -24,9 +24,10 @@ class AgentStream(Event): """Agent stream.""" delta: str + response: str current_agent_name: str tool_calls: list[ToolSelection] - raw_response: Any + raw: Any class AgentOutput(Event): @@ -34,9 +35,12 @@ class AgentOutput(Event): response: str tool_calls: list[ToolSelection] - raw_response: Any + raw: Any current_agent_name: str + def __str__(self) -> str: + return str(self.response) + class ToolCall(Event): """All tool calls are surfaced.""" From 91516737d0dfb39fa253ea0fe6e1c12957477d3e Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 09:29:31 -0600 Subject: [PATCH 16/22] clean up types --- .../core/agent/workflow/base_agent.py | 41 +++++++++++---- .../core/agent/workflow/function_agent.py | 8 +-- .../agent/workflow/multi_agent_workflow.py | 14 ++--- .../core/agent/workflow/react_agent.py | 13 +++-- .../llama_index/core/workflow/context.py | 51 ++++--------------- .../llama_index/core/workflow/tools.py | 4 +- 6 files changed, 63 insertions(+), 68 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/workflow/base_agent.py b/llama-index-core/llama_index/core/agent/workflow/base_agent.py index de850ab26455c..9a4fcc62ef91f 100644 --- a/llama-index-core/llama_index/core/agent/workflow/base_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/base_agent.py @@ -11,6 +11,11 @@ from llama_index.core.tools import BaseTool, AsyncBaseTool from llama_index.core.workflow import Context from llama_index.core.objects import ObjectRetriever +from llama_index.core.settings import Settings + + +def get_default_llm() -> LLM: + return Settings.llm class BaseWorkflowAgent(BaseModel, ABC): @@ -18,15 +23,33 @@ class BaseWorkflowAgent(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) - name: str - description: str - system_prompt: Optional[str] = None - tools: Optional[List[BaseTool]] = None - tool_retriever: Optional[ObjectRetriever] = None - can_handoff_to: Optional[List[str]] = Field(default=None) - handoff_prompt_template: Optional[str] = None - llm: Optional[LLM] = None - is_entrypoint_agent: bool = False + name: str = Field(description="The name of the agent") + description: str = Field( + description="The description of what the agent does and is responsible for" + ) + system_prompt: Optional[str] = Field( + default=None, description="The system prompt for the agent" + ) + tools: Optional[List[BaseTool]] = Field( + default=None, description="The tools that the agent can use" + ) + tool_retriever: Optional[ObjectRetriever] = Field( + default=None, + description="The tool retriever for the agent, can be provided instead of tools", + ) + can_handoff_to: Optional[List[str]] = Field( + default=None, description="The agent names that this agent can hand off to" + ) + handoff_prompt_template: Optional[str] = Field( + default=None, description="The prompt template for an artificial handoff tool" + ) + llm: LLM = Field( + default_factory=get_default_llm, description="The LLM that the agent uses" + ) + is_entrypoint_agent: bool = Field( + default=False, + description="Whether the agent is the entrypoint agent in a multi-agent workflow", + ) @abstractmethod async def take_step( diff --git a/llama-index-core/llama_index/core/agent/workflow/function_agent.py b/llama-index-core/llama_index/core/agent/workflow/function_agent.py index 88edcb140e444..2b236361a42f6 100644 --- a/llama-index-core/llama_index/core/agent/workflow/function_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/function_agent.py @@ -33,11 +33,11 @@ async def take_step( AgentInput(input=current_llm_input, current_agent_name=self.name) ) - response = await self.llm.astream_chat_with_tools( + response = await self.llm.astream_chat_with_tools( # type: ignore tools, chat_history=current_llm_input, allow_parallel_tool_calls=True ) async for r in response: - tool_calls = self.llm.get_tool_calls_from_response( + tool_calls = self.llm.get_tool_calls_from_response( # type: ignore r, error_on_no_tool_call=False ) ctx.write_event_to_stream( @@ -45,12 +45,12 @@ async def take_step( delta=r.delta or "", response=r.message.content, tool_calls=tool_calls or [], - raw_response=r.raw, + raw=r.raw, current_agent_name=self.name, ) ) - tool_calls = self.llm.get_tool_calls_from_response( + tool_calls = self.llm.get_tool_calls_from_response( # type: ignore r, error_on_no_tool_call=False ) diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py index 26cdd12def039..1e1d0533fb764 100644 --- a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent from llama_index.core.agent.workflow.workflow_events import ( @@ -89,7 +89,9 @@ def __init__( ): raise ValueError("State prompt must contain {state} and {msg}") - def _ensure_tools_are_async(self, tools: List[BaseTool]) -> List[AsyncBaseTool]: + def _ensure_tools_are_async( + self, tools: Sequence[BaseTool] + ) -> Sequence[AsyncBaseTool]: """Ensure all tools are async.""" return [adapt_to_async_tool(tool) for tool in tools] @@ -208,7 +210,7 @@ async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: handoff_tool = self._get_handoff_tool(agent) tools.append(handoff_tool) - tools = self._ensure_tools_are_async(tools) + async_tools = self._ensure_tools_are_async(tools) if agent.system_prompt: llm_input = [ @@ -221,7 +223,7 @@ async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: return AgentSetup( input=llm_input, current_agent_name=ev.current_agent_name, - tools=tools, + tools=async_tools, ) @step @@ -300,7 +302,7 @@ async def aggregate_tool_results( if num_tool_calls == 0: raise ValueError("No tool calls found, cannot aggregate results.") - tool_call_results: list[ToolCallResult] = ctx.collect_events( + tool_call_results: list[ToolCallResult] = ctx.collect_events( # type: ignore ev, expected=[ToolCallResult] * num_tool_calls ) if not tool_call_results: @@ -351,6 +353,6 @@ async def aggregate_tool_results( input_messages = memory.get(input=user_msg_str) # get this again, in case it changed - agent: BaseWorkflowAgent = await ctx.get("current_agent") + agent = await ctx.get("current_agent") return AgentInput(input=input_messages, current_agent_name=agent.name) diff --git a/llama-index-core/llama_index/core/agent/workflow/react_agent.py b/llama-index-core/llama_index/core/agent/workflow/react_agent.py index 70ed8e1f303ff..5bdbe35aa6b7a 100644 --- a/llama-index-core/llama_index/core/agent/workflow/react_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/react_agent.py @@ -136,18 +136,17 @@ async def handle_tool_call_results( if tool_call_result.tool_name == "handoff": continue - current_reasoning.append( - ObservationReasoningStep( - observation=str(tool_call_result.tool_output.content), - return_direct=tool_call_result.return_direct, - ) + obs_step = ObservationReasoningStep( + observation=str(tool_call_result.tool_output.content), + return_direct=tool_call_result.return_direct, ) + current_reasoning.append(obs_step) if tool_call_result.return_direct: current_reasoning.append( ResponseReasoningStep( - thought=current_reasoning[-1].observation, - response=current_reasoning[-1].observation, + thought=obs_step.observation, + response=obs_step.observation, is_streaming=False, ) ) diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index 06cc66d369fa5..636f9c39f6ee0 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -293,48 +293,19 @@ async def wait_for_event( The event type that was requested. """ requirements = requirements or {} - waiter_id = uuid.uuid4() + waiter_id = str(uuid.uuid4()) self._queues[waiter_id] = asyncio.Queue() - try: - while True: - event = await self._queues[waiter_id].get() - if isinstance(event, event_type): - if all( - event.get(k, default=None) == v for k, v in requirements.items() - ): - return event - else: - continue - finally: - # Ensure queue cleanup happens even if cancelled - del self._queues[waiter_id] - - async def wait_for_event( - self, event_type: Type[T], requirements: Optional[Dict[str, Any]] = None - ) -> T: - """Asynchronously wait for a specific event type to be received. - - Returns: - The event type that was requested. - """ - requirements = requirements or {} - waiter_id = uuid.uuid4() - self._queues[waiter_id] = asyncio.Queue() - - try: - while True: - event = await self._queues[waiter_id].get() - if isinstance(event, event_type): - if all( - event.get(k, default=None) == v for k, v in requirements.items() - ): - return event - else: - continue - finally: - # Ensure queue cleanup happens even if cancelled - del self._queues[waiter_id] + while True: + event = await self._queues[waiter_id].get() + if isinstance(event, event_type): + if all( + event.get(k, default=None) == v for k, v in requirements.items() + ): + del self._queues[waiter_id] + return event + else: + continue def write_event_to_stream(self, ev: Optional[Event]) -> None: self._streaming_queue.put_nowait(ev) diff --git a/llama-index-core/llama_index/core/workflow/tools.py b/llama-index-core/llama_index/core/workflow/tools.py index df4907338dadf..de2a89ab568ca 100644 --- a/llama-index-core/llama_index/core/workflow/tools.py +++ b/llama-index-core/llama_index/core/workflow/tools.py @@ -110,7 +110,7 @@ def from_defaults( ) return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) - def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: + def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: # type: ignore """Call.""" tool_output = self._fn(ctx, *args, **kwargs) return ToolOutput( @@ -120,7 +120,7 @@ def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: raw_output=tool_output, ) - async def acall(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: + async def acall(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: # type: ignore """Call.""" tool_output = await self._async_fn(ctx, *args, **kwargs) return ToolOutput( From 935e4aeb9c0b97caec3b5103a1c8123210705429 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 09:37:43 -0600 Subject: [PATCH 17/22] add more tests --- .../tests/workflow/test_context.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/llama-index-core/tests/workflow/test_context.py b/llama-index-core/tests/workflow/test_context.py index 06e4aa3568010..1cb35c2379075 100644 --- a/llama-index-core/tests/workflow/test_context.py +++ b/llama-index-core/tests/workflow/test_context.py @@ -1,3 +1,4 @@ +import asyncio from unittest import mock from typing import Union, Optional @@ -130,3 +131,42 @@ async def test_empty_inprogress_when_workflow_done(workflow): # there shouldn't be any in progress events for inprogress_list in h.ctx._in_progress.values(): assert len(inprogress_list) == 0 + + +@pytest.mark.asyncio() +async def test_wait_for_event(ctx): + wait_job = asyncio.create_task(ctx.wait_for_event(Event)) + await asyncio.sleep(0.01) + ctx.send_event(Event(msg="foo")) + ev = await wait_job + assert ev.msg == "foo" + + +@pytest.mark.asyncio() +async def test_wait_for_event_with_requirements(ctx): + wait_job = asyncio.create_task(ctx.wait_for_event(Event, {"msg": "foo"})) + await asyncio.sleep(0.01) + ctx.send_event(Event(msg="bar")) + ctx.send_event(Event(msg="foo")) + ev = await wait_job + assert ev.msg == "foo" + + +@pytest.mark.asyncio() +async def test_wait_for_event_in_workflow(): + class TestWorkflow(Workflow): + @step + async def step1(self, ctx: Context, ev: StartEvent) -> StopEvent: + ctx.write_event_to_stream(Event(msg="foo")) + result = await ctx.wait_for_event(Event) + return StopEvent(result=result.msg) + + workflow = TestWorkflow() + handler = workflow.run() + async for ev in handler.stream_events(): + if isinstance(ev, Event) and ev.msg == "foo": + handler.ctx.send_event(Event(msg="bar")) + break + + result = await handler + assert result == "bar" From f527a05d06dd32f52691d735e865f8a3fec9e5e5 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 09:50:23 -0600 Subject: [PATCH 18/22] fix small bug in finalize for react agent --- .../llama_index/core/agent/workflow/react_agent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/workflow/react_agent.py b/llama-index-core/llama_index/core/agent/workflow/react_agent.py index 5bdbe35aa6b7a..befe39588af69 100644 --- a/llama-index-core/llama_index/core/agent/workflow/react_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/react_agent.py @@ -163,10 +163,11 @@ async def finalize( ) reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) - reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) - await memory.aput(reasoning_msg) - await ctx.set("current_reasoning", []) + if reasoning_str: + reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) + await memory.aput(reasoning_msg) + await ctx.set("current_reasoning", []) # remove "Answer:" from the response if output.response and "Answer:" in output.response: From d94d22fd1d21011460c8b2cdcd9e3f9bbabb4811 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 10:01:52 -0600 Subject: [PATCH 19/22] make react components configurable --- .../core/agent/workflow/react_agent.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/workflow/react_agent.py b/llama-index-core/llama_index/core/agent/workflow/react_agent.py index befe39588af69..1f51c2a539815 100644 --- a/llama-index-core/llama_index/core/agent/workflow/react_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/react_agent.py @@ -1,5 +1,5 @@ import uuid -from typing import List, cast +from typing import List, Optional, cast from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent from llama_index.core.agent.workflow.workflow_events import ( @@ -16,6 +16,7 @@ ObservationReasoningStep, ResponseReasoningStep, ) +from llama_index.core.bridge.pydantic import Field from llama_index.core.llms import ChatMessage from llama_index.core.llms.llm import ToolSelection from llama_index.core.memory import BaseMemory @@ -26,6 +27,15 @@ class ReactAgent(BaseWorkflowAgent): """React agent implementation.""" + reasoning_key: str = "current_reasoning" + output_parser: Optional[ReActOutputParser] = Field( + default=None, description="The react output parser" + ) + formatter: Optional[ReActChatFormatter] = Field( + default=None, + description="The react chat formatter to format the reasoning steps and chat history into an llm input.", + ) + async def take_step( self, ctx: Context, @@ -41,12 +51,14 @@ async def take_step( else: system_prompt = "" - output_parser = ReActOutputParser() - react_chat_formatter = ReActChatFormatter(context=system_prompt) + output_parser = self.output_parser or ReActOutputParser() + react_chat_formatter = self.formatter or ReActChatFormatter( + context=system_prompt + ) # Format initial chat input current_reasoning: list[BaseReasoningStep] = await ctx.get( - "current_reasoning", default=[] + self.reasoning_key, default=[] ) input_chat = react_chat_formatter.format( tools, @@ -93,7 +105,7 @@ async def take_step( # add to reasoning if not a handoff if hasattr(reasoning_step, "action") and reasoning_step.action != "handoff": current_reasoning.append(reasoning_step) - await ctx.set("current_reasoning", current_reasoning) + await ctx.set(self.reasoning_key, current_reasoning) # If response step, we're done if reasoning_step.is_done: @@ -129,7 +141,7 @@ async def handle_tool_call_results( ) -> None: """Handle tool call results for React agent.""" current_reasoning: list[BaseReasoningStep] = await ctx.get( - "current_reasoning", default=[] + self.reasoning_key, default=[] ) for tool_call_result in results: # don't add handoff tool calls to reasoning @@ -152,14 +164,14 @@ async def handle_tool_call_results( ) break - await ctx.set("current_reasoning", current_reasoning) + await ctx.set(self.reasoning_key, current_reasoning) async def finalize( self, ctx: Context, output: AgentOutput, memory: BaseMemory ) -> AgentOutput: """Finalize the React agent.""" current_reasoning: list[BaseReasoningStep] = await ctx.get( - "current_reasoning", default=[] + self.reasoning_key, default=[] ) reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) @@ -167,7 +179,7 @@ async def finalize( if reasoning_str: reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) await memory.aput(reasoning_msg) - await ctx.set("current_reasoning", []) + await ctx.set(self.reasoning_key, []) # remove "Answer:" from the response if output.response and "Answer:" in output.response: From 347bda1c5b0077942b326861975ec34e2df00513 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 10:02:12 -0600 Subject: [PATCH 20/22] make function agent use scratchpad --- .../core/agent/workflow/function_agent.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/workflow/function_agent.py b/llama-index-core/llama_index/core/agent/workflow/function_agent.py index 2b236361a42f6..dab790679d2cf 100644 --- a/llama-index-core/llama_index/core/agent/workflow/function_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/function_agent.py @@ -16,6 +16,8 @@ class FunctionAgent(BaseWorkflowAgent): """Function calling agent implementation.""" + scratchpad_key: str = "scratchpad" + async def take_step( self, ctx: Context, @@ -27,7 +29,8 @@ async def take_step( if not self.llm.metadata.is_function_calling_model: raise ValueError("LLM must be a FunctionCallingLLM") - current_llm_input = [*llm_input] + scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[]) + current_llm_input = [*llm_input, *scratchpad] ctx.write_event_to_stream( AgentInput(input=current_llm_input, current_agent_name=self.name) @@ -54,10 +57,10 @@ async def take_step( r, error_on_no_tool_call=False ) - # only add to memory if we didn't select the handoff tool + # only add to scratchpad if we didn't select the handoff tool if not any(tool_call.tool_name == "handoff" for tool_call in tool_calls): - current_llm_input.append(r.message) - await memory.aput(r.message) + scratchpad.append(r.message) + await ctx.set(self.scratchpad_key, scratchpad) return AgentOutput( response=r.message.content, @@ -70,12 +73,14 @@ async def handle_tool_call_results( self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory ) -> None: """Handle tool call results for function calling agent.""" + scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[]) + for tool_call_result in results: # don't add handoff tool calls to memory if tool_call_result.tool_name == "handoff": continue - await memory.aput( + scratchpad.append( ChatMessage( role="tool", content=str(tool_call_result.tool_output.content), @@ -84,7 +89,7 @@ async def handle_tool_call_results( ) if tool_call_result.return_direct: - await memory.aput( + scratchpad.append( ChatMessage( role="assistant", content=str(tool_call_result.tool_output.content), @@ -93,11 +98,17 @@ async def handle_tool_call_results( ) break + await ctx.set(self.scratchpad_key, scratchpad) + async def finalize( self, ctx: Context, output: AgentOutput, memory: BaseMemory ) -> AgentOutput: """Finalize the function calling agent. - This is a no-op for function calling agents since we write to memory as we go. + Adds all in-progress messages to memory. """ + scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[]) + for msg in scratchpad: + await memory.aput(msg) + return output From aa555a2839d39b19970b37313a4663f83ce6b1ea Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 10:14:54 -0600 Subject: [PATCH 21/22] update docs --- docs/docs/api_reference/agent/workflow.md | 12 ++++ docs/docs/understanding/agent/multi_agents.md | 67 ++++++++++--------- .../core/agent/workflow/base_agent.py | 3 - 3 files changed, 49 insertions(+), 33 deletions(-) create mode 100644 docs/docs/api_reference/agent/workflow.md diff --git a/docs/docs/api_reference/agent/workflow.md b/docs/docs/api_reference/agent/workflow.md new file mode 100644 index 0000000000000..aefb274d845cf --- /dev/null +++ b/docs/docs/api_reference/agent/workflow.md @@ -0,0 +1,12 @@ +::: llama_index.core.agent.workflow + options: + members: + - MultiAgentWorkflow + - BaseWorkflowAgent + - FunctionAgent + - ReactAgent + - AgentInput + - AgentStream + - AgentOutput + - ToolCall + - ToolCallResult diff --git a/docs/docs/understanding/agent/multi_agents.md b/docs/docs/understanding/agent/multi_agents.md index 7f7b8c3525b01..3d1a299bcea17 100644 --- a/docs/docs/understanding/agent/multi_agents.md +++ b/docs/docs/understanding/agent/multi_agents.md @@ -1,19 +1,18 @@ # Multi-Agent Workflows -The MultiAgentWorkflow allows you to create a system of multiple agents that can collaborate and hand off tasks to each other based on their specialized capabilities. This enables building more complex agent systems where different agents handle different aspects of a task. +The MultiAgentWorkflow uses Workflow Agents to allow you to create a system of multiple agents that can collaborate and hand off tasks to each other based on their specialized capabilities. This enables building more complex agent systems where different agents handle different aspects of a task. ## Quick Start Here's a simple example of setting up a multi-agent workflow with a calculator agent and a retriever agent: ```python -from llama_index.core.agent.multi_agent import ( +from llama_index.core.agent.workflow import ( MultiAgentWorkflow, - AgentConfig, - AgentMode, + FunctionAgent, + ReactAgent, ) from llama_index.core.tools import FunctionTool -from llama_index.core.workflow import FunctionToolWithContext # Define some tools @@ -28,11 +27,13 @@ def subtract(a: int, b: int) -> int: # Create agent configs -calculator_agent = AgentConfig( +# NOTE: we can use FunctionAgent or ReactAgent here. +# FunctionAgent works for LLMs with a function calling API. +# ReactAgent works for any LLM. +calculator_agent = FunctionAgent( name="calculator", description="Performs basic arithmetic operations", system_prompt="You are a calculator assistant.", - mode=AgentMode.REACT, tools=[ FunctionTool.from_defaults(fn=add), FunctionTool.from_defaults(fn=subtract), @@ -40,11 +41,10 @@ calculator_agent = AgentConfig( llm=OpenAI(model="gpt-4"), ) -retriever_agent = AgentConfig( +retriever_agent = FunctionAgent( name="retriever", description="Manages data retrieval", system_prompt="You are a retrieval assistant.", - mode=AgentMode.FUNCTION, is_entrypoint_agent=True, llm=OpenAI(model="gpt-4"), ) @@ -66,7 +66,7 @@ async for event in handler.stream_events(): ## How It Works -The MultiAgentWorkflow manages a collection of agents, each with their own specialized capabilities. One agent must be designated as the entry point agent (is_entrypoint_agent=True). +The MultiAgentWorkflow manages a collection of agents, each with their own specialized capabilities. One agent must be designated as the entry point agent (`is_entrypoint_agent=True`). When a user message comes in, it's first routed to the entry point agent. Each agent can then: @@ -74,26 +74,20 @@ When a user message comes in, it's first routed to the entry point agent. Each a 2. Hand off to another agent better suited for the task 3. Return a response to the user -Agents can be configured in two modes: -- REACT: Uses ReAct prompting for reasoning about tool usage -- FUNCTION: Uses OpenAI function calling style for tool usage - ## Configuration Options ### Agent Config -Each agent is configured with an `AgentConfig`: +Each agent holds a certain set of configuration options. Whether you use `FunctionAgent` or `ReactAgent`, the core options are the same. ```python -AgentConfig( +FunctionAgent( # Unique name for the agent (str) name="name", # Description of agent's capabilities (str) description="description", # System prompt for the agent (str) system_prompt="system_prompt", - # react or function -- defaults to function when possible. (str) - mode="function", # Tools available to this agent (List[BaseTool]) tools=[...], # LLM to use for this agent. (BaseLLM) @@ -111,8 +105,8 @@ The MultiAgentWorkflow constructor accepts: ```python MultiAgentWorkflow( - # List of agent configs. (List[AgentConfig]) - agent_configs=[...], + # List of agent configs. (List[BaseWorkflowAgent]) + agents=[...], # Initial state dict. (Optional[dict]) initial_state=None, # Custom prompt for handoffs. Should contain the `agent_info` string variable. (Optional[str]) @@ -124,11 +118,13 @@ MultiAgentWorkflow( ### State Management +#### Initial Global State + You can provide an initial state dict that will be available to all agents: ```python workflow = MultiAgentWorkflow( - agent_configs=[...], + agents=[...], initial_state={"counter": 0}, state_prompt="Current state: {state}. User message: {msg}", ) @@ -136,6 +132,8 @@ workflow = MultiAgentWorkflow( The state is stored in the `state` key of the workflow context. +#### Persisting State Between Runs + In order to persist state between runs, you can pass in the context from the previous run: ```python @@ -146,9 +144,12 @@ handler = workflow.run(user_msg="Can you add 5 and 3?") response = await handler # Pass in the context from the previous run -response = await workflow.run(ctx=handler.ctx, user_msg="Can you add 5 and 3?") +handler = workflow.run(ctx=handler.ctx, user_msg="Can you add 5 and 3?") +response = await handler ``` +#### Serializing Context / State + As with normal workflows, the context is serializable: ```python @@ -173,16 +174,16 @@ The workflow emits various events during execution that you can stream: async for event in workflow.run(...).stream_events(): if isinstance(event, AgentInput): print(event.input) - print(event.current_agent) + print(event.current_agent_name) elif isinstance(event, AgentStream): # Agent thinking/tool calling response stream print(event.delta) - print(event.current_agent) + print(event.current_agent_name) elif isinstance(event, AgentOutput): print(event.response) print(event.tool_calls) - print(event.raw_response) - print(event.current_agent) + print(event.raw) + print(event.current_agent_name) elif isinstance(event, ToolCall): # Tool being called print(event.tool_name) @@ -210,7 +211,7 @@ counter_tool = FunctionToolWithContext.from_defaults( ) ``` -### Human in the Loop +## Human in the Loop Using the context, you can implement a human in the loop pattern in your tools: @@ -244,7 +245,13 @@ async def ask_for_confirmation(ctx: Context) -> bool: When this function is called, it will block the workflow execution until the user sends the required confirmation event. ```python -handler.ctx.send_event( - ConfirmationEvent(confirmation=True, confirmation_id="1234") -) +handler = workflow.run(user_msg="Can you add 5 and 3?") + +async for event in handler.stream_events(): + if isinstance(event, AskForConfirmationEvent): + print(event.confirmation_id) + handler.ctx.send_event( + ConfirmationEvent(confirmation=True, confirmation_id="1234") + ) + ... ``` diff --git a/llama-index-core/llama_index/core/agent/workflow/base_agent.py b/llama-index-core/llama_index/core/agent/workflow/base_agent.py index 9a4fcc62ef91f..ea20220fcf54c 100644 --- a/llama-index-core/llama_index/core/agent/workflow/base_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/base_agent.py @@ -40,9 +40,6 @@ class BaseWorkflowAgent(BaseModel, ABC): can_handoff_to: Optional[List[str]] = Field( default=None, description="The agent names that this agent can hand off to" ) - handoff_prompt_template: Optional[str] = Field( - default=None, description="The prompt template for an artificial handoff tool" - ) llm: LLM = Field( default_factory=get_default_llm, description="The LLM that the agent uses" ) From ee70508c551dfda1c2487715120b3f45f8556e36 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 2 Jan 2025 10:17:21 -0600 Subject: [PATCH 22/22] add to nav --- docs/mkdocs.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 0f5b103a290af..c5f5ee60ea269 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -50,6 +50,7 @@ nav: - Enhancing with LlamaParse: ./understanding/agent/llamaparse.md - Memory: ./understanding/agent/memory.md - Adding other tools: ./understanding/agent/tools.md + - Multi-agent workflows: ./understanding/agent/multi_agents.md - Building Workflows: - Introduction to workflows: ./understanding/workflows/index.md - A basic workflow: ./understanding/workflows/basic_flow.md @@ -852,6 +853,7 @@ nav: - ./api_reference/agent/openai.md - ./api_reference/agent/openai_legacy.md - ./api_reference/agent/react.md + - ./api_reference/agent/workflow.md - Callbacks: - ./api_reference/callbacks/agentops.md - ./api_reference/callbacks/aim.md