From 5981dd363a7268622b410d0e03a4f715f04b9a62 Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Sun, 11 Feb 2024 11:45:36 -0800 Subject: [PATCH] agent + query pipeline cleanups (#10563) --- .../agent_runner/query_pipeline_agent.ipynb | 500 +++++++++--------- llama_index/agent/custom/pipeline_worker.py | 6 + llama_index/agent/custom/simple.py | 5 + llama_index/agent/openai/step.py | 5 + llama_index/agent/react/step.py | 5 + llama_index/agent/react_multimodal/step.py | 5 + llama_index/agent/runner/base.py | 29 +- llama_index/agent/types.py | 15 +- 8 files changed, 327 insertions(+), 243 deletions(-) diff --git a/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb b/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb index 2dcc394e386e5..17a038b2e8346 100644 --- a/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb +++ b/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb @@ -168,13 +168,25 @@ "Throughout this we'll use a variety of agent-specific query components. Unlike normal query pipelines, these are specifically designed for query pipelines that are used in a `QueryPipelineAgentWorker`:\n", "- An `AgentInputComponent` that allows you to convert the agent inputs (Task, state dictionary) into a set of inputs for the query pipeline.\n", "- An `AgentFnComponent`: a general processor that allows you to take in the current Task, state, as well as any arbitrary inputs, and returns an output. In this cookbook we define a function component to format the ReAct prompt. However, you can put this anywhere.\n", - "- An `CustomAgentComponent`: similar to `AgentFnComponent`, you can implement `_run_component` to define your own logic, with access to Task and state. It is more verbose but more flexible than `AgentFnComponent` (e.g. you can define init variables, and callbacks are in the base class).\n", + "- [Not used in this notebook] An `CustomAgentComponent`: similar to `AgentFnComponent`, you can implement `_run_component` to define your own logic, with access to Task and state. It is more verbose but more flexible than `AgentFnComponent` (e.g. you can define init variables, and callbacks are in the base class).\n", "\n", "Note that any function passed into `AgentFnComponent` and `AgentInputComponent` MUST include `task` and `state` as input variables, as these are inputs passed from the agent. \n", "\n", "Note that the output of an agentic query pipeline MUST be `Tuple[AgentChatResponse, bool]`. You'll see this below." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9884b8c0-0ec3-4b14-9538-db9f5be63466", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.query_pipeline import QueryPipeline as QP\n", + "\n", + "qp = QP(verbose=True)" + ] + }, { "cell_type": "markdown", "id": "a024201f-a39b-4e23-a567-9737026dd771", @@ -202,8 +214,8 @@ " AgentInputComponent,\n", " AgentFnComponent,\n", " CustomAgentComponent,\n", - " ToolRunnerComponent,\n", " QueryComponent,\n", + " ToolRunnerComponent,\n", ")\n", "from llama_index.llms import MessageRole\n", "from typing import Dict, Any, Optional, Tuple, List, cast\n", @@ -284,7 +296,7 @@ "1. If an answer is given, then we're done. Process the output\n", "2. If an action is given, we need to execute the specified tool with the specified args, and then process the output.\n", "\n", - "Tool calling can be done via the `ToolRunnerComponent` module. This is a standalone module that takes in a list of tools, and can be \"executed\" with the specified tool name (every tool has a name) and tool action.\n", + "Tool calling can be done via the `ToolRunnerComponent` module. This is a simple wrapper module that takes in a list of tools, and can be \"executed\" with the specified tool name (every tool has a name) and tool action.\n", "\n", "We implement this overall module `OutputAgentComponent` that subclasses `CustomAgentComponent`.\n", "\n", @@ -300,115 +312,72 @@ "source": [ "from typing import Set, Optional\n", "from llama_index.agent.react.output_parser import ReActOutputParser\n", + "from llama_index.llms import ChatResponse\n", + "from llama_index.agent.types import Task\n", "\n", "\n", - "## Agent Output Component\n", - "## Process reasoning step/tool outputs, and return agent response\n", - "def finalize_fn(\n", - " task: Task,\n", - " state: Dict[str, Any],\n", - " reasoning_step: Any,\n", - " is_done: bool = False,\n", - " tool_output: Optional[Any] = None,\n", - ") -> Tuple[AgentChatResponse, bool]:\n", - " \"\"\"Finalize function.\n", + "def parse_react_output_fn(\n", + " task: Task, state: Dict[str, Any], chat_response: ChatResponse\n", + "):\n", + " \"\"\"Parse ReAct output into a reasoning step.\"\"\"\n", + " output_parser = ReActOutputParser()\n", + " reasoning_step = output_parser.parse(chat_response.message.content)\n", + " return {\"done\": reasoning_step.is_done, \"reasoning_step\": reasoning_step}\n", "\n", - " Here we take the latest reasoning step, and a tool output (if provided),\n", - " and return the agent output (and decide if agent is done).\n", "\n", - " This function returns an `AgentChatResponse` and `is_done` tuple. and\n", - " is the last component of the query pipeline. This is the expected\n", - " return type for any query pipeline passed to `QueryPipelineAgentWorker`.\n", + "parse_react_output = AgentFnComponent(fn=parse_react_output_fn)\n", "\n", - " \"\"\"\n", - " current_reasoning = state[\"current_reasoning\"]\n", - " current_reasoning.append(reasoning_step)\n", - " # if tool_output is not None, add to current reasoning\n", - " if tool_output is not None:\n", - " observation_step = ObservationReasoningStep(\n", - " observation=str(tool_output)\n", - " )\n", - " current_reasoning.append(observation_step)\n", - " if isinstance(current_reasoning[-1], ResponseReasoningStep):\n", - " response_step = cast(ResponseReasoningStep, current_reasoning[-1])\n", - " response_str = response_step.response\n", - " else:\n", - " response_str = current_reasoning[-1].get_content()\n", - "\n", - " # if is_done, add to memory\n", - " # NOTE: memory is a reserved keyword in `state`, but you can add your own too\n", - " if is_done:\n", - " state[\"memory\"].put(\n", - " ChatMessage(content=task.input, role=MessageRole.USER)\n", - " )\n", - " state[\"memory\"].put(\n", - " ChatMessage(content=response_str, role=MessageRole.ASSISTANT)\n", - " )\n", - "\n", - " return AgentChatResponse(response=response_str), is_done\n", - "\n", - "\n", - "class OutputAgentComponent(CustomAgentComponent):\n", - " \"\"\"Output agent component.\"\"\"\n", - "\n", - " tool_runner_component: ToolRunnerComponent\n", - " output_parser: ReActOutputParser\n", - "\n", - " def __init__(self, tools, **kwargs):\n", - " tool_runner_component = ToolRunnerComponent(tools)\n", - " super().__init__(\n", - " tool_runner_component=tool_runner_component,\n", - " output_parser=ReActOutputParser(),\n", - " **kwargs\n", - " )\n", - "\n", - " def _run_component(self, **kwargs: Any) -> Any:\n", - " \"\"\"Run component.\"\"\"\n", - " chat_response = kwargs[\"chat_response\"]\n", - " task = kwargs[\"task\"]\n", - " state = kwargs[\"state\"]\n", - " reasoning_step = self.output_parser.parse(\n", - " chat_response.message.content\n", - " )\n", - " if reasoning_step.is_done:\n", - " return {\n", - " \"output\": finalize_fn(\n", - " task, state, reasoning_step, is_done=True\n", - " )\n", - " }\n", - " else:\n", - " tool_output = self.tool_runner_component.run_component(\n", - " tool_name=reasoning_step.action,\n", - " tool_input=reasoning_step.action_input,\n", - " )\n", - " return {\n", - " \"output\": finalize_fn(\n", - " task,\n", - " state,\n", - " reasoning_step,\n", - " is_done=False,\n", - " tool_output=tool_output,\n", - " )\n", - " }\n", - "\n", - " @property\n", - " def _input_keys(self) -> Set[str]:\n", - " return {\"chat_response\"}\n", - "\n", - " @property\n", - " def _optional_input_keys(self) -> Set[str]:\n", - " return {\"is_done\", \"tool_output\"}\n", - "\n", - " @property\n", - " def _output_keys(self) -> Set[str]:\n", - " return {\"output\"}\n", - "\n", - " @property\n", - " def sub_query_components(self) -> List[QueryComponent]:\n", - " return [self.tool_runner_component]\n", - "\n", - "\n", - "react_output_component = OutputAgentComponent([sql_tool])" + "\n", + "def run_tool_fn(\n", + " task: Task, state: Dict[str, Any], reasoning_step: ActionReasoningStep\n", + "):\n", + " \"\"\"Run tool and process tool output.\"\"\"\n", + " tool_runner_component = ToolRunnerComponent(\n", + " [sql_tool], callback_manager=task.callback_manager\n", + " )\n", + " tool_output = tool_runner_component.run_component(\n", + " tool_name=reasoning_step.action,\n", + " tool_input=reasoning_step.action_input,\n", + " )\n", + " observation_step = ObservationReasoningStep(observation=str(tool_output))\n", + " state[\"current_reasoning\"].append(observation_step)\n", + " # TODO: get output\n", + "\n", + " return {\"response_str\": observation_step.get_content(), \"is_done\": False}\n", + "\n", + "\n", + "run_tool = AgentFnComponent(fn=run_tool_fn)\n", + "\n", + "\n", + "def process_response_fn(\n", + " task: Task, state: Dict[str, Any], response_step: ResponseReasoningStep\n", + "):\n", + " \"\"\"Process response.\"\"\"\n", + " state[\"current_reasoning\"].append(response_step)\n", + " response_str = response_step.response\n", + " # Now that we're done with this step, put into memory\n", + " state[\"memory\"].put(ChatMessage(content=task.input, role=MessageRole.USER))\n", + " state[\"memory\"].put(\n", + " ChatMessage(content=response_str, role=MessageRole.ASSISTANT)\n", + " )\n", + "\n", + " return {\"response_str\": response_str, \"is_done\": True}\n", + "\n", + "\n", + "process_response = AgentFnComponent(fn=process_response_fn)\n", + "\n", + "\n", + "def process_agent_response_fn(\n", + " task: Task, state: Dict[str, Any], response_dict: dict\n", + "):\n", + " \"\"\"Process agent response.\"\"\"\n", + " return (\n", + " AgentChatResponse(response_dict[\"response_str\"]),\n", + " response_dict[\"is_done\"],\n", + " )\n", + "\n", + "\n", + "process_agent_response = AgentFnComponent(fn=process_agent_response_fn)" ] }, { @@ -433,16 +402,47 @@ "from llama_index.query_pipeline import QueryPipeline as QP\n", "from llama_index.llms import OpenAI\n", "\n", - "qp = QP(\n", - " modules={\n", + "qp.add_modules(\n", + " {\n", " \"agent_input\": agent_input_component,\n", " \"react_prompt\": react_prompt_component,\n", " \"llm\": OpenAI(model=\"gpt-4-1106-preview\"),\n", - " \"react_output\": react_output_component,\n", - " },\n", - " verbose=True,\n", + " \"react_output_parser\": parse_react_output,\n", + " \"run_tool\": run_tool,\n", + " \"process_response\": process_response,\n", + " \"process_agent_response\": process_agent_response,\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be90408b-0893-4cab-a7c7-32206979eecd", + "metadata": {}, + "outputs": [], + "source": [ + "# link input to react prompt to parsed out response (either tool action/input or observation)\n", + "qp.add_chain([\"agent_input\", \"react_prompt\", \"llm\", \"react_output_parser\"])\n", + "\n", + "# add conditional link from react output to tool call (if not done)\n", + "qp.add_link(\n", + " \"react_output_parser\",\n", + " \"run_tool\",\n", + " condition_fn=lambda x: not x[\"done\"],\n", + " input_fn=lambda x: x[\"reasoning_step\"],\n", + ")\n", + "# add conditional link from react output to final response processing (if done)\n", + "qp.add_link(\n", + " \"react_output_parser\",\n", + " \"process_response\",\n", + " condition_fn=lambda x: x[\"done\"],\n", + " input_fn=lambda x: x[\"reasoning_step\"],\n", ")\n", - "qp.add_chain([\"agent_input\", \"react_prompt\", \"llm\", \"react_output\"])" + "\n", + "# whether response processing or tool output processing, add link to final agent response\n", + "qp.add_link(\"process_response\", \"process_agent_response\")\n", + "qp.add_link(\"run_tool\", \"process_agent_response\")" ] }, { @@ -481,7 +481,7 @@ " " ], "text/plain": [ - "" + "" ] }, "execution_count": null, @@ -493,7 +493,7 @@ "from pyvis.network import Network\n", "\n", "net = Network(notebook=True, cdn_resources=\"in_line\", directed=True)\n", - "net.from_nx(qp.dag)\n", + "net.from_nx(qp.clean_dag)\n", "net.show(\"agent_dag.html\")" ] }, @@ -507,42 +507,6 @@ "This is our way to setup an agent around a text-to-SQL Query Pipeline" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "73d6d178-97e9-4525-82ba-e14030506691", - "metadata": {}, - "outputs": [], - "source": [ - "from llama_index.agent import QueryPipelineAgentWorker, AgentRunner\n", - "from llama_index.callbacks import CallbackManager\n", - "\n", - "agent_worker = QueryPipelineAgentWorker(qp)\n", - "agent = AgentRunner(agent_worker, callback_manager=CallbackManager([]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f927ac2e-ef69-4b15-8a10-f1e57ba4aa03", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[AgentFnComponent(partial_dict={'tools': []}, fn=, async_fn=None),\n", - " OutputAgentComponent(partial_dict={}, callback_manager=, tool_runner_component=ToolRunnerComponent(partial_dict={}, tool_dict={'sql_tool': }, callback_manager=), output_parser=)]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_worker.agent_components" - ] - }, { "cell_type": "code", "execution_count": null, @@ -554,29 +518,9 @@ "from llama_index.callbacks import CallbackManager\n", "\n", "agent_worker = QueryPipelineAgentWorker(qp)\n", - "agent = AgentRunner(agent_worker, callback_manager=CallbackManager([]))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bdd8835d-1044-4cca-a6aa-59f7635096ca", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[AgentFnComponent(partial_dict={'tools': [], 'task': Task(task_id='fe91205c-62ca-4a6f-96dd-c72685b6fa27', input='What are some tracks from the artist AC/DC? Limit it to 3', memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatStore(store={}), chat_store_key='chat_history'), extra_state={}), 'state': {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatStore(store={}), chat_store_key='chat_history'), 'current_reasoning': [ObservationReasoningStep(observation='What are some tracks from the artist AC/DC? Limit it to 3'), ActionReasoningStep(thought='I need to use a tool to help me answer the question.', action='sql_tool', action_input={'input': \"SELECT track_name FROM tracks WHERE artist_name = 'AC/DC' LIMIT 3\"}), ObservationReasoningStep(observation='{\\'output\\': ToolOutput(content=\\'Some of AC/DC\\\\\\'s popular tracks are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s Get It Up\".\\', tool_name=\\'sql_tool\\', raw_input={\\'input\\': \"SELECT track_name FROM tracks WHERE artist_name = \\'AC/DC\\' LIMIT 3\"}, raw_output=Response(response=\\'Some of AC/DC\\\\\\'s popular tracks are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s Get It Up\".\\', source_nodes=[NodeWithScore(node=TextNode(id_=\\'0b519588-2fc1-49d9-ac43-d744654ba1dc\\', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text=\\'[(\\\\\\'For Those About To Rock (We Salute You)\\\\\\',), (\\\\\\'Put The Finger On You\\\\\\',), (\"Let\\\\\\'s Get It Up\",)]\\', start_char_idx=None, end_char_idx=None, text_template=\\'{metadata_str}\\\\n\\\\n{content}\\', metadata_template=\\'{key}: {value}\\', metadata_seperator=\\'\\\\n\\'), score=None)], metadata={\\'0b519588-2fc1-49d9-ac43-d744654ba1dc\\': {}, \\'sql_query\\': \"SELECT tracks.Name AS track_name\\\\nFROM tracks\\\\nJOIN albums ON tracks.AlbumId = albums.AlbumId\\\\nJOIN artists ON albums.ArtistId = artists.ArtistId\\\\nWHERE artists.Name = \\'AC/DC\\'\\\\nLIMIT 3\", \\'result\\': [(\\'For Those About To Rock (We Salute You)\\',), (\\'Put The Finger On You\\',), (\"Let\\'s Get It Up\",)], \\'col_keys\\': [\\'track_name\\']}))}')]}}, fn=, async_fn=None),\n", - " OutputAgentComponent(partial_dict={}, callback_manager=, tool_runner_component=ToolRunnerComponent(partial_dict={}, tool_dict={'sql_tool': }, callback_manager=), output_parser=)]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_worker.agent_components" + "agent = AgentRunner(\n", + " agent_worker, callback_manager=CallbackManager([]), verbose=True\n", + ")" ] }, { @@ -612,9 +556,10 @@ "name": "stdout", "output_type": "stream", "text": [ + "> Running step 1778ae52-0a31-4199-be65-af574fbf70f1. Step input: What are some tracks from the artist AC/DC? Limit it to 3\n", "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='79a8d443-5707-4632-82b2-51fd253cd294' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", "input: What are some tracks from the artist AC/DC? Limit it to 3\n", @@ -622,10 +567,16 @@ "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", "\n", - "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output with input: \n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", "chat_response: assistant: Thought: I need to use a tool to help me answer the question.\n", "Action: sql_tool\n", - "Action Input: {\"input\": \"Select track_name from music_database where artist_name = 'AC/DC' limit 3\"}\n", + "Action Input: {\"input\": \"What are some tracks from the artist AC/DC? Limit it to 3\"}\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module run_tool with input: \n", + "reasoning_step: thought='I need to use a tool to help me answer the question.' action='sql_tool' action_input={'input': 'What are some tracks from the artist AC/DC? Limit it to 3'}\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", + "response_dict: {'response_str': 'Observation: {\\'output\\': ToolOutput(content=\\'Some tracks from the artist AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s Get It Up\".\\', ...\n", "\n", "\u001b[0m" ] @@ -645,9 +596,10 @@ "name": "stdout", "output_type": "stream", "text": [ + "> Running step 07ac733e-d234-4751-bcad-002a655fb2e3. Step input: None\n", "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='79a8d443-5707-4632-82b2-51fd253cd294' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", "input: What are some tracks from the artist AC/DC? Limit it to 3\n", @@ -655,10 +607,16 @@ "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", "\n", - "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output with input: \n", - "chat_response: assistant: Thought: The user has repeated the question, but I have already provided the answer using the tool. I will restate the answer.\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", + "chat_response: assistant: Thought: The user has repeated the request, but it seems they might not have noticed the previous response. I will reiterate the information provided by the tool.\n", "\n", - "Answer: The top 3 tracks by AC/DC are \"For Those About To Roc...\n", + "Answer: Some tracks from ...\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_response with input: \n", + "response_step: thought='The user has repeated the request, but it seems they might not have noticed the previous response. I will reiterate the information provided by the tool.' response='Some tracks from the artis...\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", + "response_dict: {'response_str': 'Some tracks from the artist AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\'s Get It Up\".', 'is_done': True}\n", "\n", "\u001b[0m" ] @@ -704,6 +662,88 @@ "execution_count": null, "id": "7ba842bd-cc86-422b-8c17-2ff87c1962b6", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Some tracks from the artist AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let's Get It Up\".\n" + ] + } + ], + "source": [ + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6ef125b-ccd7-4937-820a-e3b0a8b1f825", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> Running step 0b0ed367-3246-442e-a8be-34ee4ef320c4. Step input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", + "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", + "task: task_id='d699cfb9-9632-4ea0-ae43-0fe86f623a6e' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", + "input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", + "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", + "chat_response: assistant: Thought: I need to use a tool to help me answer the question.\n", + "Action: sql_tool\n", + "Action Input: {\"input\": \"SELECT track_name FROM tracks WHERE artist_name = 'AC/DC' LIMIT 3\"}\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module run_tool with input: \n", + "reasoning_step: thought='I need to use a tool to help me answer the question.' action='sql_tool' action_input={'input': \"SELECT track_name FROM tracks WHERE artist_name = 'AC/DC' LIMIT 3\"}\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", + "response_dict: {'response_str': 'Observation: {\\'output\\': ToolOutput(content=\\'The top 3 tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s Get It Up\".\\', tool_nam...\n", + "\n", + "\u001b[0m> Running step 2d670abc-4d8f-4cd6-8ad3-36152340ec8b. Step input: None\n", + "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", + "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", + "task: task_id='d699cfb9-9632-4ea0-ae43-0fe86f623a6e' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", + "input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", + "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", + "chat_response: assistant: Thought: The user has repeated the request for tracks from the artist AC/DC, limited to 3, despite having already received an answer. It's possible that the user did not see the previous re...\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_response with input: \n", + "response_step: thought=\"The user has repeated the request for tracks from the artist AC/DC, limited to 3, despite having already received an answer. It's possible that the user did not see the previous response or t...\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", + "response_dict: {'response_str': 'The top 3 tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\'s Get It Up\".', 'is_done': True}\n", + "\n", + "\u001b[0m" + ] + } + ], + "source": [ + "# run this e2e\n", + "agent.reset()\n", + "response = agent.chat(\n", + " \"What are some tracks from the artist AC/DC? Limit it to 3\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de05be63-319f-48e3-bc16-6155b23b13a0", + "metadata": {}, "outputs": [ { "name": "stdout", @@ -932,7 +972,7 @@ " " ], "text/plain": [ - "" + "" ] }, "execution_count": null, @@ -966,8 +1006,11 @@ "from llama_index.agent import QueryPipelineAgentWorker, AgentRunner\n", "from llama_index.callbacks import CallbackManager\n", "\n", + "# callback manager is passed from query pipeline to agent worker/agent\n", "agent_worker = QueryPipelineAgentWorker(qp)\n", - "agent = AgentRunner(agent_worker, callback_manager=CallbackManager([]))" + "agent = AgentRunner(\n", + " agent_worker, callback_manager=CallbackManager(), verbose=False\n", + ")" ] }, { @@ -982,7 +1025,7 @@ "text": [ "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='741c0d59-fa40-44a2-acab-cc4c36fdf0c7' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", + "task: task_id='0952d9f0-458d-4bca-8531-56fd8f880d47' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module retry_prompt with input: \n", "input: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", @@ -995,67 +1038,35 @@ "will convert the query to a SQL statement. I...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module sql_query_engine with input: \n", - "input: assistant: Given the user input and the requirement that the answer should be non-zero, the natural language query should be specific enough to avoid ambiguity and errors when converted to SQL. The pr...\n", + "input: assistant: Given the user's input and the requirement that the answer should be non-zero, a proper natural language query that could be interpreted by a text-to-SQL agent might be:\n", "\n", - "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module output_component with input: \n", - "output: The count of albums released by the artist credited with writing the song titled 'Restless and Wild', excluding any counts of zero, is 1.\n", - "\n", - "\u001b[0m> Inferred SQL Query: SELECT COUNT(*) \n", - "FROM albums \n", - "JOIN artists ON albums.ArtistId = artists.ArtistId \n", - "JOIN tracks ON albums.AlbumId = tracks.AlbumId \n", - "WHERE tracks.Name = 'Restless and Wild' \n", - "AND albums.ArtistId IS NOT NULL \n", - "AND albums.AlbumId IS NOT NULL \n", - "AND albums.Title IS NOT NULL \n", - "AND artists.Name IS NOT NULL \n", - "AND tracks.TrackId IS NOT NULL \n", - "AND tracks.MediaTypeId IS NOT NULL \n", - "AND tracks.GenreId IS NOT NULL \n", - "AND tracks.Composer IS NOT NULL \n", - "AND tracks.Milliseconds IS NOT NULL \n", - "AND tracks.Bytes IS NOT NULL \n", - "AND tracks.UnitPrice IS NOT NULL\n", - "> SQL Response: The count of albums released by the artist credited with writing the song titled 'Restless and Wild', excluding any counts of zero, is 1.\n", - "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='741c0d59-fa40-44a2-acab-cc4c36fdf0c7' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", - "\n", - "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module retry_prompt with input: \n", - "input: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", - "convo_history: User: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", - "Assistant (inferred SQL query): SELECT COUNT(*) \n", - "FROM albums \n", - "JOIN artists ON albums.ArtistId =...\n", - "\n", - "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", - "messages: You are trying to generate a proper natural language query given a user input.\n", - "\n", - "This query will then be interpreted by a downstream text-to-SQL agent which\n", - "will convert the query to a SQL statement. I...\n", - "\n", - "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module sql_query_engine with input: \n", - "input: assistant: Given the previous failed attempt, it seems that the SQL query was overly complex and included unnecessary conditions. The query should focus on finding the artist who wrote 'Restless and W...\n", + "\"Find the total nu...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module output_component with input: \n", - "output: The number of albums released by the artist who composed the track 'Restless and Wild' is 347.\n", + "output: The total number of albums released by the artist credited with writing the song 'Restless and Wild' is 1.\n", "\n", - "\u001b[0m> Inferred SQL Query: SELECT COUNT(*) \n", - "FROM albums \n", - "WHERE ArtistId = (SELECT ArtistId \n", - " FROM tracks \n", - " WHERE Name = 'Restless and Wild')\n", - "> SQL Response: The number of albums released by the artist who composed the track 'Restless and Wild' is 347.\n", + "\u001b[0m> Inferred SQL Query: SELECT COUNT(DISTINCT albums.AlbumId) AS TotalAlbums\n", + "FROM albums\n", + "JOIN tracks ON albums.AlbumId = tracks.AlbumId\n", + "WHERE tracks.Name = 'Restless and Wild'\n", + "AND albums.ArtistId = (\n", + " SELECT ArtistId\n", + " FROM tracks\n", + " JOIN albums ON tracks.AlbumId = albums.AlbumId\n", + " WHERE tracks.Name = 'Restless and Wild'\n", + " LIMIT 1\n", + ")\n", + "HAVING TotalAlbums > 0;\n", + "> SQL Response: The total number of albums released by the artist credited with writing the song 'Restless and Wild' is 1.\n", "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='741c0d59-fa40-44a2-acab-cc4c36fdf0c7' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", + "task: task_id='0952d9f0-458d-4bca-8531-56fd8f880d47' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module retry_prompt with input: \n", "input: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", "convo_history: User: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", - "Assistant (inferred SQL query): SELECT COUNT(*) \n", - "FROM albums \n", - "JOIN artists ON albums.ArtistId =...\n", + "Assistant (inferred SQL query): SELECT COUNT(DISTINCT albums.AlbumId) AS TotalAlbums\n", + "FROM album...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", "messages: You are trying to generate a proper natural language query given a user input.\n", @@ -1064,19 +1075,26 @@ "will convert the query to a SQL statement. I...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module sql_query_engine with input: \n", - "input: assistant: Given the previous failed attempts and the user's insistence on a non-zero answer, it seems that the SQL queries might have been correct in structure but the responses provided were not sat...\n", + "input: assistant: Given the conversation history, it seems that the previous query was successful in fetching a non-zero count of albums released by the artist who wrote 'Restless and Wild'. However, the use...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module output_component with input: \n", - "output: The number of albums released by the composer of the track 'Restless and Wild' is 1.\n", + "output: Based on the revised query, it seems that there are no distinct albums released by the composer who wrote the song 'Restless and Wild'.\n", "\n", "\u001b[0m> Inferred SQL Query: SELECT COUNT(DISTINCT albums.AlbumId) AS AlbumCount\n", "FROM albums\n", "JOIN tracks ON albums.AlbumId = tracks.AlbumId\n", - "WHERE tracks.Composer IN (SELECT Composer FROM tracks WHERE Name = 'Restless and Wild')\n", - "AND albums.AlbumId IS NOT NULL\n", + "WHERE tracks.Composer = 'Restless and Wild'\n", "HAVING AlbumCount > 0;\n", - "> SQL Response: The number of albums released by the composer of the track 'Restless and Wild' is 1.\n", - "The number of albums released by the composer of the track 'Restless and Wild' is 1.\n" + "> SQL Response: Based on the revised query, it seems that there are no distinct albums released by the composer who wrote the song 'Restless and Wild'.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "KeyboardInterrupt\n", + "\n" ] } ], diff --git a/llama_index/agent/custom/pipeline_worker.py b/llama_index/agent/custom/pipeline_worker.py index 03685b45c1b0c..808037e530362 100644 --- a/llama_index/agent/custom/pipeline_worker.py +++ b/llama_index/agent/custom/pipeline_worker.py @@ -191,3 +191,9 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None: task.memory.set(task.memory.get() + task.extra_state["memory"].get_all()) # reset new memory task.extra_state["memory"].reset() + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: make this abstractmethod (right now will break some agent impls) + self.callback_manager = callback_manager + self.pipeline.set_callback_manager(callback_manager) diff --git a/llama_index/agent/custom/simple.py b/llama_index/agent/custom/simple.py index e735ac66b9761..bbab9a13344fe 100644 --- a/llama_index/agent/custom/simple.py +++ b/llama_index/agent/custom/simple.py @@ -254,3 +254,8 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None: # reset new memory task.extra_state["memory"].reset() self._finalize_task(task.extra_state, **kwargs) + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: make this abstractmethod (right now will break some agent impls) + self.callback_manager = callback_manager diff --git a/llama_index/agent/openai/step.py b/llama_index/agent/openai/step.py index 7e1b712c57502..4fd046116bfaf 100644 --- a/llama_index/agent/openai/step.py +++ b/llama_index/agent/openai/step.py @@ -637,3 +637,8 @@ def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]: # # break # # while cast(AgentChatResponse, last_step_output.output).response != + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: make this abstractmethod (right now will break some agent impls) + self.callback_manager = callback_manager diff --git a/llama_index/agent/react/step.py b/llama_index/agent/react/step.py index 381f4bb361730..d3396b0a0f59a 100644 --- a/llama_index/agent/react/step.py +++ b/llama_index/agent/react/step.py @@ -633,3 +633,8 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None: task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all()) # reset new memory task.extra_state["new_memory"].reset() + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: make this abstractmethod (right now will break some agent impls) + self.callback_manager = callback_manager diff --git a/llama_index/agent/react_multimodal/step.py b/llama_index/agent/react_multimodal/step.py index c961540ffae53..5df795221f269 100644 --- a/llama_index/agent/react_multimodal/step.py +++ b/llama_index/agent/react_multimodal/step.py @@ -472,3 +472,8 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None: task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all()) # reset new memory task.extra_state["new_memory"].reset() + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: make this abstractmethod (right now will break some agent impls) + self.callback_manager = callback_manager diff --git a/llama_index/agent/runner/base.py b/llama_index/agent/runner/base.py index bd05cdbaa4551..d352ae1169abe 100644 --- a/llama_index/agent/runner/base.py +++ b/llama_index/agent/runner/base.py @@ -208,15 +208,33 @@ def __init__( init_task_state_kwargs: Optional[dict] = None, delete_task_on_finish: bool = False, default_tool_choice: str = "auto", + verbose: bool = False, ) -> None: """Initialize.""" self.agent_worker = agent_worker self.state = state or AgentState() self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm) - self.callback_manager = callback_manager or CallbackManager([]) + + # get and set callback manager + if callback_manager is not None: + self.agent_worker.set_callback_manager(callback_manager) + self.callback_manager = callback_manager + else: + # TODO: This is *temporary* + # Stopgap before having a callback on the BaseAgentWorker interface. + # Doing that requires a bit more refactoring to make sure existing code + # doesn't break. + if hasattr(self.agent_worker, "callback_manager"): + self.callback_manager = ( + self.agent_worker.callback_manager or CallbackManager() + ) + else: + self.callback_manager = CallbackManager() + self.init_task_state_kwargs = init_task_state_kwargs or {} self.delete_task_on_finish = delete_task_on_finish self.default_tool_choice = default_tool_choice + self.verbose = verbose @staticmethod def from_llm( @@ -263,10 +281,13 @@ def create_task(self, input: str, **kwargs: Any) -> Task: ) else: extra_state = self.init_task_state_kwargs + + callback_manager = kwargs.pop("callback_manager", self.callback_manager) task = Task( input=input, memory=self.memory, extra_state=extra_state, + callback_manager=callback_manager, **kwargs, ) # # put input into memory @@ -325,6 +346,9 @@ def _run_step( if input is not None: step.input = input + if self.verbose: + print(f"> Running step {step.step_id}. Step input: {step.input}") + # TODO: figure out if you can dynamically swap in different step executors # not clear when you would do that by theoretically possible @@ -359,6 +383,9 @@ async def _arun_step( if input is not None: step.input = input + if self.verbose: + print(f"> Running step {step.step_id}. Step input: {step.input}") + # TODO: figure out if you can dynamically swap in different step executors # not clear when you would do that by theoretically possible if mode == ChatResponseMode.WAIT: diff --git a/llama_index/agent/types.py b/llama_index/agent/types.py index 7e59448d895a8..a523e49502b8f 100644 --- a/llama_index/agent/types.py +++ b/llama_index/agent/types.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from llama_index.bridge.pydantic import BaseModel, Field -from llama_index.callbacks import trace_method +from llama_index.callbacks import CallbackManager, trace_method from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.core.llms.types import ChatMessage @@ -147,6 +147,9 @@ class Task(BaseModel): """ + class Config: + arbitrary_types_allowed = True + task_id: str = Field( default_factory=lambda: str(uuid.uuid4()), type=str, description="Task ID" ) @@ -161,6 +164,12 @@ class Task(BaseModel): ), ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, + exclude=True, + description="Callback manager for the task.", + ) + extra_state: Dict[str, Any] = Field( default_factory=dict, description=( @@ -220,3 +229,7 @@ async def astream_step( @abstractmethod def finalize_task(self, task: Task, **kwargs: Any) -> None: """Finalize task, after all the steps are completed.""" + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: make this abstractmethod (right now will break some agent impls)