From ff74654aa6fc8970fbff79d72dd56fadd1bf2817 Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Sat, 29 Jun 2024 17:28:55 -0700 Subject: [PATCH] Add stateful and loop components (#14235) --- .../agent_runner/query_pipeline_agent.ipynb | 301 ++++++++---------- .../core/agent/custom/pipeline_worker.py | 85 ++++- .../core/query_pipeline/__init__.py | 5 + .../query_pipeline/components/__init__.py | 4 + .../core/query_pipeline/components/agent.py | 10 +- .../core/query_pipeline/components/loop.py | 86 +++++ .../query_pipeline/components/stateful.py | 91 ++++++ .../llama_index/core/query_pipeline/query.py | 68 ++++ .../tests/agent/custom/test_pipeline.py | 164 ++++++++++ .../tests/query_pipeline/test_components.py | 54 +++- 10 files changed, 679 insertions(+), 189 deletions(-) create mode 100644 llama-index-core/llama_index/core/query_pipeline/components/loop.py create mode 100644 llama-index-core/llama_index/core/query_pipeline/components/stateful.py create mode 100644 llama-index-core/tests/agent/custom/test_pipeline.py diff --git a/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb b/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb index 599228c74be47..ab9af2188bba3 100644 --- a/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb +++ b/docs/docs/examples/agent/agent_runner/query_pipeline_agent.ipynb @@ -88,7 +88,7 @@ "text": [ " % Total % Received % Xferd Average Speed Time Time Time Current\n", " Dload Upload Total Spent Left Speed\n", - "100 298k 100 298k 0 0 2327k 0 --:--:-- --:--:-- --:--:-- 2387k\n", + "100 298k 100 298k 0 0 3751k 0 --:--:-- --:--:-- --:--:-- 3926k\n", "curl: (6) Could not resolve host: .\n", "Archive: ./chinook.zip\n", " inflating: chinook.db \n" @@ -130,7 +130,17 @@ "execution_count": null, "id": "754f11d3-f053-46f7-acb4-ae8ee7d3fe07", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🌍 To view the Phoenix app in your browser, visit http://localhost:6006/\n", + "📺 To view the Phoenix app in a notebook, run `px.active_session().view()`\n", + "📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix\n" + ] + } + ], "source": [ "# setup Arize Phoenix for logging/observability\n", "import phoenix as px\n", @@ -187,12 +197,9 @@ "3. If tool/action is selected, call tool pipeline to execute tool + collect response.\n", "4. If response is generated, get response.\n", "\n", - "Throughout this we'll use a variety of agent-specific query components. Unlike normal query pipelines, these are specifically designed for query pipelines that are used in a `QueryPipelineAgentWorker`:\n", - "- An `AgentInputComponent` that allows you to convert the agent inputs (Task, state dictionary) into a set of inputs for the query pipeline.\n", - "- An `AgentFnComponent`: a general processor that allows you to take in the current Task, state, as well as any arbitrary inputs, and returns an output. In this cookbook we define a function component to format the ReAct prompt. However, you can put this anywhere.\n", - "- [Not used in this notebook] An `CustomAgentComponent`: similar to `AgentFnComponent`, you can implement `_run_component` to define your own logic, with access to Task and state. It is more verbose but more flexible than `AgentFnComponent` (e.g. you can define init variables, and callbacks are in the base class).\n", - "\n", - "Note that any function passed into `AgentFnComponent` and `AgentInputComponent` MUST include `task` and `state` as input variables, as these are inputs passed from the agent. \n", + "Throughout this we'll build a **stateful** agent pipeline. It contains the following components:\n", + "- A `QueryPipelineAgentWorker` - this is the agent that wraps a stateful query\n", + "- A `StatefulFnComponent` - these track global state over query pipeline executions. They track the agent `task` and `step_state` as special keys by default.\n", "\n", "Note that the output of an agentic query pipeline MUST be `Tuple[AgentChatResponse, bool]`. You'll see this below." ] @@ -233,9 +240,7 @@ ")\n", "from llama_index.core.agent import Task, AgentChatResponse\n", "from llama_index.core.query_pipeline import (\n", - " AgentInputComponent,\n", - " AgentFnComponent,\n", - " CustomAgentComponent,\n", + " StatefulFnComponent,\n", " QueryComponent,\n", " ToolRunnerComponent,\n", ")\n", @@ -243,10 +248,10 @@ "from typing import Dict, Any, Optional, Tuple, List, cast\n", "\n", "\n", - "## Agent Input Component\n", + "# Input Component\n", "## This is the component that produces agent inputs to the rest of the components\n", "## Can also put initialization logic here.\n", - "def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict[str, Any]:\n", + "def agent_input_fn(state: Dict[str, Any]) -> str:\n", " \"\"\"Agent input function.\n", "\n", " Returns:\n", @@ -255,15 +260,17 @@ " components, make sure the src_key matches the specified output_key.\n", "\n", " \"\"\"\n", + " step_state = state[\"step_state\"]\n", + " task = state[\"task\"]\n", " # initialize current_reasoning\n", - " if \"current_reasoning\" not in state:\n", - " state[\"current_reasoning\"] = []\n", + " if \"current_reasoning\" not in step_state:\n", + " step_state[\"current_reasoning\"] = []\n", " reasoning_step = ObservationReasoningStep(observation=task.input)\n", - " state[\"current_reasoning\"].append(reasoning_step)\n", - " return {\"input\": task.input}\n", + " step_state[\"current_reasoning\"].append(reasoning_step)\n", + " return task.input\n", "\n", "\n", - "agent_input_component = AgentInputComponent(fn=agent_input_fn)" + "agent_input_component = StatefulFnComponent(fn=agent_input_fn)" ] }, { @@ -291,18 +298,19 @@ "\n", "## define prompt function\n", "def react_prompt_fn(\n", - " task: Task, state: Dict[str, Any], input: str, tools: List[BaseTool]\n", + " state: Dict[str, Any], input: str, tools: List[BaseTool]\n", ") -> List[ChatMessage]:\n", + " task, step_state = state[\"task\"], state[\"step_state\"]\n", " # Add input to reasoning\n", " chat_formatter = ReActChatFormatter()\n", " return chat_formatter.format(\n", " tools,\n", - " chat_history=task.memory.get() + state[\"memory\"].get_all(),\n", - " current_reasoning=state[\"current_reasoning\"],\n", + " chat_history=task.memory.get() + step_state[\"memory\"].get_all(),\n", + " current_reasoning=step_state[\"current_reasoning\"],\n", " )\n", "\n", "\n", - "react_prompt_component = AgentFnComponent(\n", + "react_prompt_component = StatefulFnComponent(\n", " fn=react_prompt_fn, partial_dict={\"tools\": [sql_tool]}\n", ")" ] @@ -338,22 +346,19 @@ "from llama_index.core.agent.types import Task\n", "\n", "\n", - "def parse_react_output_fn(\n", - " task: Task, state: Dict[str, Any], chat_response: ChatResponse\n", - "):\n", + "def parse_react_output_fn(state: Dict[str, Any], chat_response: ChatResponse):\n", " \"\"\"Parse ReAct output into a reasoning step.\"\"\"\n", " output_parser = ReActOutputParser()\n", " reasoning_step = output_parser.parse(chat_response.message.content)\n", " return {\"done\": reasoning_step.is_done, \"reasoning_step\": reasoning_step}\n", "\n", "\n", - "parse_react_output = AgentFnComponent(fn=parse_react_output_fn)\n", + "parse_react_output = StatefulFnComponent(fn=parse_react_output_fn)\n", "\n", "\n", - "def run_tool_fn(\n", - " task: Task, state: Dict[str, Any], reasoning_step: ActionReasoningStep\n", - "):\n", + "def run_tool_fn(state: Dict[str, Any], reasoning_step: ActionReasoningStep):\n", " \"\"\"Run tool and process tool output.\"\"\"\n", + " task, step_state = state[\"task\"], state[\"step_state\"]\n", " tool_runner_component = ToolRunnerComponent(\n", " [sql_tool], callback_manager=task.callback_manager\n", " )\n", @@ -362,36 +367,37 @@ " tool_input=reasoning_step.action_input,\n", " )\n", " observation_step = ObservationReasoningStep(observation=str(tool_output))\n", - " state[\"current_reasoning\"].append(observation_step)\n", + " step_state[\"current_reasoning\"].append(observation_step)\n", " # TODO: get output\n", "\n", " return {\"response_str\": observation_step.get_content(), \"is_done\": False}\n", "\n", "\n", - "run_tool = AgentFnComponent(fn=run_tool_fn)\n", + "run_tool = StatefulFnComponent(fn=run_tool_fn)\n", "\n", "\n", "def process_response_fn(\n", - " task: Task, state: Dict[str, Any], response_step: ResponseReasoningStep\n", + " state: Dict[str, Any], response_step: ResponseReasoningStep\n", "):\n", " \"\"\"Process response.\"\"\"\n", - " state[\"current_reasoning\"].append(response_step)\n", + " task, step_state = state[\"task\"], state[\"step_state\"]\n", + " step_state[\"current_reasoning\"].append(response_step)\n", " response_str = response_step.response\n", " # Now that we're done with this step, put into memory\n", - " state[\"memory\"].put(ChatMessage(content=task.input, role=MessageRole.USER))\n", - " state[\"memory\"].put(\n", + " step_state[\"memory\"].put(\n", + " ChatMessage(content=task.input, role=MessageRole.USER)\n", + " )\n", + " step_state[\"memory\"].put(\n", " ChatMessage(content=response_str, role=MessageRole.ASSISTANT)\n", " )\n", "\n", " return {\"response_str\": response_str, \"is_done\": True}\n", "\n", "\n", - "process_response = AgentFnComponent(fn=process_response_fn)\n", + "process_response = StatefulFnComponent(fn=process_response_fn)\n", "\n", "\n", - "def process_agent_response_fn(\n", - " task: Task, state: Dict[str, Any], response_dict: dict\n", - "):\n", + "def process_agent_response_fn(state: Dict[str, Any], response_dict: dict):\n", " \"\"\"Process agent response.\"\"\"\n", " return (\n", " AgentChatResponse(response_dict[\"response_str\"]),\n", @@ -399,7 +405,7 @@ " )\n", "\n", "\n", - "process_agent_response = AgentFnComponent(fn=process_agent_response_fn)" + "process_agent_response = StatefulFnComponent(fn=process_agent_response_fn)" ] }, { @@ -503,7 +509,7 @@ " " ], "text/plain": [ - "" + "" ] }, "execution_count": null, @@ -578,24 +584,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "> Running step edb9926c-7290-4b8d-ac80-1421432a0ea6. Step input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "> Running step 01f68b00-b4b6-41a4-85cc-c1a37d1ec42d. Step input: What are some tracks from the artist AC/DC? Limit it to 3\n", "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='b9b747a7-880f-4e91-9eed-b64574cbb6d0' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", "input: What are some tracks from the artist AC/DC? Limit it to 3\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", - "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", + "messages: [ChatMessage(role=, content='You are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Tools\\n\\n...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", - "chat_response: assistant: Thought: I need to use a tool to help me answer the question.\n", + "chat_response: assistant: Thought: The current language of the user is English. I need to use a tool to help me answer the question.\n", "Action: sql_tool\n", - "Action Input: {\"input\": \"Select track_name from tracks where artist_name = 'AC/DC' limit 3\"}\n", + "Action Input: {\"input\": \"SELECT track_name FROM music_database WH...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module run_tool with input: \n", - "reasoning_step: thought='I need to use a tool to help me answer the question.' action='sql_tool' action_input={'input': \"Select track_name from tracks where artist_name = 'AC/DC' limit 3\"}\n", + "reasoning_step: thought='The current language of the user is English. I need to use a tool to help me answer the question.' action='sql_tool' action_input={'input': \"SELECT track_name FROM music_database WHERE artist...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", "response_dict: {'response_str': 'Observation: {\\'output\\': ToolOutput(content=\\'The top 3 tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s Get It Up\".\\', tool_nam...\n", @@ -618,24 +622,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "> Running step 37e2312b-540b-4c79-9261-15318d4796d9. Step input: None\n", + "> Running step 4e39c5b6-39fd-433a-a6b2-f22c4831421b. Step input: None\n", "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='b9b747a7-880f-4e91-9eed-b64574cbb6d0' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", "input: What are some tracks from the artist AC/DC? Limit it to 3\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", - "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", + "messages: [ChatMessage(role=, content='You are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Tools\\n\\n...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", - "chat_response: assistant: Thought: The user has repeated the request, possibly due to not noticing the previous response. I will provide the information again.\n", + "chat_response: assistant: Thought: The user has repeated the request, possibly not realizing that the answer has already been provided. I will reiterate the information given.\n", "\n", - "Answer: The top 3 tracks by AC/DC are \"For Those About...\n", + "Answer: The top 3 tracks by AC/DC are ...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_response with input: \n", - "response_step: thought='The user has repeated the request, possibly due to not noticing the previous response. I will provide the information again.' response='The top 3 tracks by AC/DC are \"For Those About To Rock ...\n", + "response_step: thought='The user has repeated the request, possibly not realizing that the answer has already been provided. I will reiterate the information given.' response='The top 3 tracks by AC/DC are \"For Thos...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", "response_dict: {'response_str': 'The top 3 tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\'s Get It Up\".', 'is_done': True}\n", @@ -707,47 +709,45 @@ "name": "stdout", "output_type": "stream", "text": [ - "> Running step 781d6e78-5bfe-4330-b8fc-3242deb6f64a. Step input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "> Running step caac93ec-8762-42b0-a0c3-120cd96c68ca. Step input: What are some tracks from the artist AC/DC? Limit it to 3\n", "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='c09dd358-19e8-4fcc-8b82-326783ba4af2' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", - "input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "input: {'input': 'What are some tracks from the artist AC/DC? Limit it to 3'}\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", - "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", + "messages: [ChatMessage(role=, content='You are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Tools\\n\\n...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", - "chat_response: assistant: Thought: I need to use a tool to help me answer the question.\n", + "chat_response: assistant: Thought: The current language of the user is English. I need to use a tool to help me answer the question.\n", "Action: sql_tool\n", - "Action Input: {\"input\": \"SELECT track_name FROM tracks WHERE artist_name = 'AC/DC' LIMIT 3\"}\n", + "Action Input: {\"input\": \"Select track_name from music_database wh...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module run_tool with input: \n", - "reasoning_step: thought='I need to use a tool to help me answer the question.' action='sql_tool' action_input={'input': \"SELECT track_name FROM tracks WHERE artist_name = 'AC/DC' LIMIT 3\"}\n", + "reasoning_step: thought='The current language of the user is English. I need to use a tool to help me answer the question.' action='sql_tool' action_input={'input': \"Select track_name from music_database where artist...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", - "response_dict: {'response_str': 'Observation: {\\'output\\': ToolOutput(content=\\'The top three tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s Get It Up\".\\', tool...\n", + "response_dict: {'response_str': 'Observation: {\\'output\\': ToolOutput(content=\\'The top 3 tracks by AC/DC in the music database are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\\\\\'s G...\n", "\n", - "\u001b[0m> Running step a65d44a6-7a98-49ec-86ce-eb4b3bcd6a48. Step input: None\n", + "\u001b[0m> Running step cfddbbed-c4f3-4f5a-8c43-5d50cbd63063. Step input: None\n", "\u001b[1;3;38;2;155;135;227m> Running module agent_input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='c09dd358-19e8-4fcc-8b82-326783ba4af2' input='What are some tracks from the artist AC/DC? Limit it to 3' memory=ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial( Running module react_prompt with input: \n", - "input: What are some tracks from the artist AC/DC? Limit it to 3\n", + "input: {'input': 'What are some tracks from the artist AC/DC? Limit it to 3'}\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", - "messages: [ChatMessage(role=, content='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Too...\n", + "messages: [ChatMessage(role=, content='You are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Tools\\n\\n...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module react_output_parser with input: \n", - "chat_response: assistant: Thought: The user has repeated the request for tracks from the artist AC/DC, limited to 3, despite having already received an answer. It's possible that the user did not see the previous re...\n", + "chat_response: assistant: Thought: The user has repeated the question, possibly not realizing that an answer has already been provided. I should reiterate the information given.\n", + "\n", + "Answer: The top 3 tracks by AC/DC in...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_response with input: \n", - "response_step: thought=\"The user has repeated the request for tracks from the artist AC/DC, limited to 3, despite having already received an answer. It's possible that the user did not see the previous response, or ...\n", + "response_step: thought='The user has repeated the question, possibly not realizing that an answer has already been provided. I should reiterate the information given.' response='The top 3 tracks by AC/DC in the musi...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module process_agent_response with input: \n", - "response_dict: {'response_str': 'The top three tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\'s Get It Up\".', 'is_done': True}\n", + "response_dict: {'response_str': 'The top 3 tracks by AC/DC in the music database are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let\\'s Get It Up\".', 'is_done': True}\n", "\n", "\u001b[0m" ] @@ -771,7 +771,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The top three tracks by AC/DC are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let's Get It Up\".\n" + "The top 3 tracks by AC/DC in the music database are \"For Those About To Rock (We Salute You)\", \"Put The Finger On You\", and \"Let's Get It Up\".\n" ] } ], @@ -825,24 +825,22 @@ "source": [ "from llama_index.core.agent import Task, AgentChatResponse\n", "from typing import Dict, Any\n", - "from llama_index.core.query_pipeline import (\n", - " AgentInputComponent,\n", - " AgentFnComponent,\n", - ")\n", + "from llama_index.core.query_pipeline import StatefulFnComponent\n", "\n", "\n", - "def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict:\n", + "def agent_input_fn(state: Dict[str, Any]) -> Dict:\n", " \"\"\"Agent input function.\"\"\"\n", + " task, step_state = state[\"task\"], state[\"step_state\"]\n", " # initialize current_reasoning\n", - " if \"convo_history\" not in state:\n", - " state[\"convo_history\"] = []\n", - " state[\"count\"] = 0\n", - " state[\"convo_history\"].append(f\"User: {task.input}\")\n", - " convo_history_str = \"\\n\".join(state[\"convo_history\"]) or \"None\"\n", + " if \"convo_history\" not in step_state:\n", + " step_state[\"convo_history\"] = []\n", + " step_state[\"count\"] = 0\n", + " step_state[\"convo_history\"].append(f\"User: {task.input}\")\n", + " convo_history_str = \"\\n\".join(step_state[\"convo_history\"]) or \"None\"\n", " return {\"input\": task.input, \"convo_history\": convo_history_str}\n", "\n", "\n", - "agent_input_component = AgentInputComponent(fn=agent_input_fn)" + "agent_input_component = StatefulFnComponent(fn=agent_input_fn)" ] }, { @@ -898,15 +896,16 @@ "\n", "\n", "def agent_output_fn(\n", - " task: Task, state: Dict[str, Any], output: Response\n", + " state: Dict[str, Any], output: Response\n", ") -> Tuple[AgentChatResponse, bool]:\n", " \"\"\"Agent output component.\"\"\"\n", + " task, step_state = state[\"task\"], state[\"step_state\"]\n", " print(f\"> Inferred SQL Query: {output.metadata['sql_query']}\")\n", " print(f\"> SQL Response: {str(output)}\")\n", - " state[\"convo_history\"].append(\n", + " step_state[\"convo_history\"].append(\n", " f\"Assistant (inferred SQL query): {output.metadata['sql_query']}\"\n", " )\n", - " state[\"convo_history\"].append(f\"Assistant (response): {str(output)}\")\n", + " step_state[\"convo_history\"].append(f\"Assistant (response): {str(output)}\")\n", "\n", " # run a mini chain to get response\n", " validate_prompt_partial = validate_prompt.as_query_component(\n", @@ -918,9 +917,9 @@ " qp = QP(chain=[validate_prompt_partial, llm])\n", " validate_output = qp.run(input=task.input)\n", "\n", - " state[\"count\"] += 1\n", + " step_state[\"count\"] += 1\n", " is_done = False\n", - " if state[\"count\"] >= MAX_ITER:\n", + " if step_state[\"count\"] >= MAX_ITER:\n", " is_done = True\n", " if \"YES\" in validate_output.message.content:\n", " is_done = True\n", @@ -928,7 +927,7 @@ " return AgentChatResponse(response=str(output)), is_done\n", "\n", "\n", - "agent_output_component = AgentFnComponent(fn=agent_output_fn)" + "agent_output_component = StatefulFnComponent(fn=agent_output_fn)" ] }, { @@ -954,9 +953,14 @@ " },\n", " verbose=True,\n", ")\n", - "qp.add_link(\"input\", \"retry_prompt\", src_key=\"input\", dest_key=\"input\")\n", "qp.add_link(\n", - " \"input\", \"retry_prompt\", src_key=\"convo_history\", dest_key=\"convo_history\"\n", + " \"input\", \"retry_prompt\", dest_key=\"input\", input_fn=lambda x: x[\"input\"]\n", + ")\n", + "qp.add_link(\n", + " \"input\",\n", + " \"retry_prompt\",\n", + " dest_key=\"convo_history\",\n", + " input_fn=lambda x: x[\"convo_history\"],\n", ")\n", "qp.add_chain([\"retry_prompt\", \"llm\", \"sql_query_engine\", \"output_component\"])" ] @@ -974,43 +978,13 @@ "execution_count": null, "id": "142dd638-578c-42c0-aa06-395052ca210a", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "agent_dag.html\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "from pyvis.network import Network\n", + "# from pyvis.network import Network\n", "\n", - "net = Network(notebook=True, cdn_resources=\"in_line\", directed=True)\n", - "net.from_nx(qp.dag)\n", - "net.show(\"agent_dag.html\")" + "# net = Network(notebook=True, cdn_resources=\"in_line\", directed=True)\n", + "# net.from_nx(qp.dag)\n", + "# net.show(\"agent_dag.html\")" ] }, { @@ -1049,8 +1023,6 @@ "output_type": "stream", "text": [ "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='2d8a63de-7410-4422-98f3-f0ca41884f58' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module retry_prompt with input: \n", "input: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", @@ -1063,28 +1035,26 @@ "will convert the query to a SQL statement. I...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module sql_query_engine with input: \n", - "input: assistant: Given the conversation history, it seems that the previous attempt to generate a SQL query from the user's question may have resulted in an error. To avoid the same problem, we need to reph...\n", + "input: assistant: Given the user's input and the requirement that the answer should be non-zero, a proper natural language query that could be interpreted by a text-to-SQL agent might be:\n", + "\n", + "\"Retrieve the coun...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module output_component with input: \n", - "output: I'm sorry, but there seems to be an error in the SQL query. The query you provided is invalid SQL. Please double-check the syntax and try again.\n", + "output: The count of albums released by the artist credited with writing the song 'Restless and Wild' is 1.\n", "\n", - "\u001b[0m> Inferred SQL Query: SELECT COUNT(albums.AlbumId) \n", - "FROM albums \n", - "JOIN tracks ON albums.AlbumId = tracks.AlbumId \n", - "WHERE tracks.Name = 'Restless and Wild' \n", - "AND albums.ArtistId = tracks.Composer \n", - "AND COUNT(albums.AlbumId) > 0\n", - "> SQL Response: I'm sorry, but there seems to be an error in the SQL query. The query you provided is invalid SQL. Please double-check the syntax and try again.\n", + "\u001b[0m> Inferred SQL Query: SELECT COUNT(a.AlbumId) AS AlbumCount\n", + "FROM albums a\n", + "JOIN tracks t ON a.AlbumId = t.AlbumId\n", + "WHERE t.Name = 'Restless and Wild';\n", + "> SQL Response: The count of albums released by the artist credited with writing the song 'Restless and Wild' is 1.\n", "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='2d8a63de-7410-4422-98f3-f0ca41884f58' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module retry_prompt with input: \n", "input: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", "convo_history: User: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", - "Assistant (inferred SQL query): SELECT COUNT(albums.AlbumId) \n", - "FROM albums \n", - "JOIN tracks ON album...\n", + "Assistant (inferred SQL query): SELECT COUNT(a.AlbumId) AS AlbumCount\n", + "FROM albums a\n", + "JOIN tracks...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", "messages: You are trying to generate a proper natural language query given a user input.\n", @@ -1093,29 +1063,25 @@ "will convert the query to a SQL statement. I...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module sql_query_engine with input: \n", - "input: assistant: Given the previous error, it seems that the SQL query was incorrect because it attempted to use an aggregate function (`COUNT`) in the `WHERE` clause, which is not allowed. Additionally, th...\n", + "input: assistant: Given the conversation history, it seems that the previous query did not correctly account for the fact that the user is asking for the number of albums released by the specific artist who ...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module output_component with input: \n", - "output: The number of albums released by the artist who composed the track 'Restless and Wild' is 1.\n", + "output: The revised SQL query has been executed to count the number of distinct albums released by the artist who wrote the song 'Restless and Wild'. The result of the query is 0, indicating that there are no...\n", "\n", "\u001b[0m> Inferred SQL Query: SELECT COUNT(DISTINCT albums.AlbumId) AS NumAlbums\n", - "FROM tracks\n", - "JOIN albums ON tracks.AlbumId = albums.AlbumId\n", + "FROM albums\n", + "JOIN tracks ON albums.AlbumId = tracks.AlbumId\n", "JOIN artists ON albums.ArtistId = artists.ArtistId\n", - "WHERE tracks.Name = 'Restless and Wild'\n", - "GROUP BY artists.ArtistId\n", - "HAVING NumAlbums > 0;\n", - "> SQL Response: The number of albums released by the artist who composed the track 'Restless and Wild' is 1.\n", + "WHERE tracks.Name = 'Restless and Wild' AND artists.Name = (SELECT artists.Name FROM tracks JOIN artists ON tracks.Composer = artists.Name WHERE tracks.Name = 'Restless and Wild');\n", + "> SQL Response: The revised SQL query has been executed to count the number of distinct albums released by the artist who wrote the song 'Restless and Wild'. The result of the query is 0, indicating that there are no albums in the database by the specific artist who wrote that song. This information should provide clarity on the number of albums released by the artist in question.\n", "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", - "state: {'sources': [], 'memory': ChatMemoryBuffer(token_limit=3000, tokenizer_fn=functools.partial(>, allowed_special='all'), chat_store=SimpleChatSto...\n", - "task: task_id='2d8a63de-7410-4422-98f3-f0ca41884f58' input=\"How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\" memory=ChatMemoryBuffer(token_limit=3000, toke...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module retry_prompt with input: \n", "input: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", "convo_history: User: How many albums did the artist who wrote 'Restless and Wild' release? (answer should be non-zero)?\n", - "Assistant (inferred SQL query): SELECT COUNT(albums.AlbumId) \n", - "FROM albums \n", - "JOIN tracks ON album...\n", + "Assistant (inferred SQL query): SELECT COUNT(a.AlbumId) AS AlbumCount\n", + "FROM albums a\n", + "JOIN tracks...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module llm with input: \n", "messages: You are trying to generate a proper natural language query given a user input.\n", @@ -1124,25 +1090,18 @@ "will convert the query to a SQL statement. I...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module sql_query_engine with input: \n", - "input: assistant: Given the conversation history, it seems that the previous SQL query was successful in retrieving the number of albums released by the artist who composed the track 'Restless and Wild'. How...\n", + "input: assistant: Given the previous failed attempts, it seems that the SQL queries were not correctly capturing the artist who wrote 'Restless and Wild'. To correct this, we need to ensure that we are joini...\n", "\n", "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module output_component with input: \n", - "output: I apologize, but there seems to be an error in the SQL query provided. Please double-check the syntax and try again.\n", + "output: The total number of distinct albums released by the composer of the song 'Restless and Wild' is 1.\n", "\n", "\u001b[0m> Inferred SQL Query: SELECT COUNT(DISTINCT albums.AlbumId) AS TotalAlbums\n", - "FROM albums\n", - "JOIN artists ON albums.ArtistId = artists.ArtistId\n", - "WHERE artists.ArtistId = (\n", - " SELECT artists.ArtistId\n", - " FROM tracks\n", - " JOIN albums ON tracks.AlbumId = albums.AlbumId\n", - " JOIN artists ON albums.ArtistId = artists.ArtistId\n", - " WHERE tracks.Name = 'Restless and Wild'\n", - " LIMIT 1\n", - ")\n", - "AND TotalAlbums > 0;\n", - "> SQL Response: I apologize, but there seems to be an error in the SQL query provided. Please double-check the syntax and try again.\n", - "I apologize, but there seems to be an error in the SQL query provided. Please double-check the syntax and try again.\n" + "FROM tracks\n", + "JOIN albums ON tracks.AlbumId = albums.AlbumId\n", + "WHERE tracks.Name = 'Restless and Wild'\n", + "GROUP BY albums.ArtistId;\n", + "> SQL Response: The total number of distinct albums released by the composer of the song 'Restless and Wild' is 1.\n", + "The total number of distinct albums released by the composer of the song 'Restless and Wild' is 1.\n" ] } ], diff --git a/llama-index-core/llama_index/core/agent/custom/pipeline_worker.py b/llama-index-core/llama_index/core/agent/custom/pipeline_worker.py index b60f30301d745..b943e947e4572 100644 --- a/llama-index-core/llama_index/core/agent/custom/pipeline_worker.py +++ b/llama-index-core/llama_index/core/agent/custom/pipeline_worker.py @@ -53,8 +53,15 @@ class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker): Barebones agent worker that takes in a query pipeline. - Assumes that the first component in the query pipeline is an - `AgentInputComponent` and last is `AgentFnComponent`. + **Default Workflow**: The default workflow assumes that you compose + a query pipeline with `StatefulFnComponent` objects. This allows you to store, update + and retrieve state throughout the executions of the query pipeline by the agent. + + The task and step state of the agent are stored in this `state` variable via a special key. + Of course you can choose to store other variables in this state as well. + + **Deprecated Workflow**: The deprecated workflow assumes that the first component in the + query pipeline is an `AgentInputComponent` and last is `AgentFnComponent`. Args: pipeline (QueryPipeline): Query pipeline @@ -63,6 +70,8 @@ class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker): pipeline: QueryPipeline = Field(..., description="Query pipeline") callback_manager: CallbackManager = Field(..., exclude=True) + task_key: str = Field("task", description="Key to store task in state") + step_state_key: str = Field("step_state", description="Key to store step in state") class Config: arbitrary_types_allowed = True @@ -71,6 +80,7 @@ def __init__( self, pipeline: QueryPipeline, callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, ) -> None: """Initialize.""" if callback_manager is not None: @@ -81,14 +91,19 @@ def __init__( super().__init__( pipeline=pipeline, callback_manager=callback_manager, + **kwargs, ) # validate query pipeline - self.agent_input_component + # self.agent_input_component self.agent_components @property def agent_input_component(self) -> AgentInputComponent: - """Get agent input component.""" + """Get agent input component. + + NOTE: This is deprecated and will be removed in the future. + + """ root_key = self.pipeline.get_root_keys()[0] if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent): raise ValueError( @@ -103,6 +118,26 @@ def agent_components(self) -> List[AgentFnComponent]: """Get agent output component.""" return _get_agent_components(self.pipeline) + def preprocess(self, task: Task, step: TaskStep) -> None: + """Preprocessing flow. + + This runs preprocessing to propagate the task and step as variables + to relevant components in the query pipeline. + + Contains deprecated flow of updating agent components. + But also contains main flow of updating StatefulFnComponent components. + + """ + # NOTE: this is deprecated + # partial agent output component with task and step + for agent_fn_component in self.agent_components: + agent_fn_component.partial(task=task, state=step.step_state) + + # update stateful components + self.pipeline.update_state( + {self.task_key: task, self.step_state_key: step.step_state} + ) + def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep: """Initialize step from task.""" sources: List[ToolOutput] = [] @@ -147,11 +182,21 @@ def _get_task_step_response( @trace_method("run_step") def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: """Run step.""" - # partial agent output component with task and step - for agent_fn_component in self.agent_components: - agent_fn_component.partial(task=task, state=step.step_state) - - agent_response, is_done = self.pipeline.run(state=step.step_state, task=task) + self.preprocess(task, step) + + # HACK: do a try/except for now. Fine since old agent components are deprecated + try: + self.agent_input_component + uses_deprecated = True + except ValueError: + uses_deprecated = False + + if uses_deprecated: + agent_response, is_done = self.pipeline.run( + state=step.step_state, task=task + ) + else: + agent_response, is_done = self.pipeline.run() response = self._get_task_step_response(agent_response, step, is_done) # sync step state with task state task.extra_state.update(step.step_state) @@ -162,13 +207,21 @@ async def arun_step( self, step: TaskStep, task: Task, **kwargs: Any ) -> TaskStepOutput: """Run step (async).""" - # partial agent output component with task and step - for agent_fn_component in self.agent_components: - agent_fn_component.partial(task=task, state=step.step_state) - - agent_response, is_done = await self.pipeline.arun( - state=step.step_state, task=task - ) + self.preprocess(task, step) + + # HACK: do a try/except for now. Fine since old agent components are deprecated + try: + self.agent_input_component + uses_deprecated = True + except ValueError: + uses_deprecated = False + + if uses_deprecated: + agent_response, is_done = await self.pipeline.arun( + state=step.step_state, task=task + ) + else: + agent_response, is_done = await self.pipeline.arun() response = self._get_task_step_response(agent_response, step, is_done) task.extra_state.update(step.step_state) return response diff --git a/llama-index-core/llama_index/core/query_pipeline/__init__.py b/llama-index-core/llama_index/core/query_pipeline/__init__.py index 109025f54058d..accea877d4dd9 100644 --- a/llama-index-core/llama_index/core/query_pipeline/__init__.py +++ b/llama-index-core/llama_index/core/query_pipeline/__init__.py @@ -20,6 +20,9 @@ ChainableMixin, QueryComponent, ) +from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent +from llama_index.core.query_pipeline.components.loop import LoopComponent + from llama_index.core.base.query_pipeline.query import ( CustomQueryComponent, ) @@ -40,4 +43,6 @@ "ChainableMixin", "QueryComponent", "CustomQueryComponent", + "StatefulFnComponent", + "LoopComponent", ] diff --git a/llama-index-core/llama_index/core/query_pipeline/components/__init__.py b/llama-index-core/llama_index/core/query_pipeline/components/__init__.py index 7a1d1d2493d38..445f7943f31a0 100644 --- a/llama-index-core/llama_index/core/query_pipeline/components/__init__.py +++ b/llama-index-core/llama_index/core/query_pipeline/components/__init__.py @@ -15,6 +15,8 @@ SelectorComponent, ) from llama_index.core.query_pipeline.components.tool_runner import ToolRunnerComponent +from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent +from llama_index.core.query_pipeline.components.loop import LoopComponent __all__ = [ "AgentFnComponent", @@ -28,4 +30,6 @@ "RouterComponent", "SelectorComponent", "ToolRunnerComponent", + "StatefulFnComponent", + "LoopComponent", ] diff --git a/llama-index-core/llama_index/core/query_pipeline/components/agent.py b/llama-index-core/llama_index/core/query_pipeline/components/agent.py index 8c2ea76ea5df5..384f5fa69d853 100644 --- a/llama-index-core/llama_index/core/query_pipeline/components/agent.py +++ b/llama-index-core/llama_index/core/query_pipeline/components/agent.py @@ -42,7 +42,11 @@ def default_agent_input_fn(task: Any, state: dict) -> dict: class AgentInputComponent(QueryComponent): - """Takes in agent inputs and transforms it into desired outputs.""" + """Takes in agent inputs and transforms it into desired outputs. + + NOTE: this is now deprecated in favor of using `StatefulFnComponent`. + + """ fn: Callable = Field(..., description="Function to run.") async_fn: Optional[Callable] = Field( @@ -149,6 +153,8 @@ class AgentFnComponent(BaseAgentComponent): Designed to let users easily modify state. + NOTE: this is now deprecated in favor of using `StatefulFnComponent`. + """ fn: Callable = Field(..., description="Function to run.") @@ -257,6 +263,8 @@ class CustomAgentComponent(BaseAgentComponent): Designed to let users easily modify state. + NOTE: this is now deprecated in favor of using `StatefulFnComponent`. + """ callback_manager: CallbackManager = Field( diff --git a/llama-index-core/llama_index/core/query_pipeline/components/loop.py b/llama-index-core/llama_index/core/query_pipeline/components/loop.py new file mode 100644 index 0000000000000..1e8f787372cda --- /dev/null +++ b/llama-index-core/llama_index/core/query_pipeline/components/loop.py @@ -0,0 +1,86 @@ +from llama_index.core.base.query_pipeline.query import ( + InputKeys, + OutputKeys, + QueryComponent, +) +from llama_index.core.query_pipeline.query import QueryPipeline +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks.base import CallbackManager +from typing import Any, Dict, Optional, Callable + + +class LoopComponent(QueryComponent): + """Loop component.""" + + pipeline: QueryPipeline = Field(..., description="Query pipeline") + should_exit_fn: Optional[Callable] = Field(..., description="Should exit function") + add_output_to_input_fn: Optional[Callable] = Field( + ..., + description="Add output to input function. If not provided, will reuse the original input for the next iteration. If provided, will call the function to combine the output into the input for the next iteration.", + ) + max_iterations: Optional[int] = Field(5, description="Max iterations") + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + pipeline: QueryPipeline, + should_exit_fn: Optional[Callable] = None, + add_output_to_input_fn: Optional[Callable] = None, + max_iterations: Optional[int] = 5, + ) -> None: + """Init params.""" + super().__init__( + pipeline=pipeline, + should_exit_fn=should_exit_fn, + add_output_to_input_fn=add_output_to_input_fn, + max_iterations=max_iterations, + ) + + def set_callback_manager(self, callback_manager: CallbackManager) -> None: + """Set callback manager.""" + # TODO: implement + + def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: + return input + + def _run_component(self, **kwargs: Any) -> Dict: + """Run component.""" + current_input = kwargs + for i in range(self.max_iterations): + output = self.pipeline.run_component(**current_input) + if self.should_exit_fn: + should_exit = self.should_exit_fn(output) + if should_exit: + break + + if self.add_output_to_input_fn: + current_input = self.add_output_to_input_fn(current_input, output) + + return output + + async def _arun_component(self, **kwargs: Any) -> Any: + """Run component (async).""" + current_input = kwargs + for i in range(self.max_iterations): + output = await self.pipeline.arun_component(**current_input) + if self.should_exit_fn: + should_exit = self.should_exit_fn(output) + if should_exit: + break + + if self.add_output_to_input_fn: + current_input = self.add_output_to_input_fn(current_input, output) + + return output + + @property + def input_keys(self) -> InputKeys: + """Input keys.""" + return self.pipeline.input_keys + + @property + def output_keys(self) -> OutputKeys: + """Output keys.""" + return self.pipeline.output_keys diff --git a/llama-index-core/llama_index/core/query_pipeline/components/stateful.py b/llama-index-core/llama_index/core/query_pipeline/components/stateful.py new file mode 100644 index 0000000000000..24044fa64af2f --- /dev/null +++ b/llama-index-core/llama_index/core/query_pipeline/components/stateful.py @@ -0,0 +1,91 @@ +"""Agent components.""" + +from typing import Any, Callable, Dict, Optional, Set + +from llama_index.core.base.query_pipeline.query import ( + QueryComponent, +) +from llama_index.core.bridge.pydantic import Field +from llama_index.core.query_pipeline.components.function import ( + FnComponent, + get_parameters, +) + +# from llama_index.core.query_pipeline.components.input import InputComponent + + +class BaseStatefulComponent(QueryComponent): + """Takes in agent inputs and transforms it into desired outputs.""" + + state: Dict[str, Any] = Field( + default_factory=dict, description="State of the pipeline." + ) + + def reset_state(self) -> None: + """Reset state.""" + self.state = {} + + +class StatefulFnComponent(BaseStatefulComponent, FnComponent): + """Query component that takes in an arbitrary function. + + Stateful version of `FnComponent`. Expects functions to have `state` as the first argument. + + """ + + def __init__( + self, + fn: Callable, + req_params: Optional[Set[str]] = None, + opt_params: Optional[Set[str]] = None, + state: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> None: + """Init params.""" + # determine parameters + default_req_params, default_opt_params = get_parameters(fn) + # make sure task and step are part of the list, and remove them from the list + if "state" not in default_req_params: + raise ValueError( + "StatefulFnComponent must have 'state' as required parameters" + ) + + default_req_params = default_req_params - {"state"} + default_opt_params = default_opt_params - {"state"} + + if req_params is None: + req_params = default_req_params + if opt_params is None: + opt_params = default_opt_params + + super().__init__( + fn=fn, + req_params=req_params, + opt_params=opt_params, + state=state or {}, + **kwargs + ) + + def _run_component(self, **kwargs: Any) -> Dict: + """Run component.""" + kwargs.update({"state": self.state}) + return super()._run_component(**kwargs) + + async def _arun_component(self, **kwargs: Any) -> Any: + """Async run component.""" + kwargs.update({"state": self.state}) + return await super()._arun_component(**kwargs) + + # @property + # def input_keys(self) -> InputKeys: + # """Input keys.""" + # return InputKeys.from_keys( + # required_keys={"state", *self._req_params}, + # optional_keys=self._opt_params, + # ) + + # @property + # def output_keys(self) -> OutputKeys: + # """Output keys.""" + # # output can be anything, overrode validate function + # return OutputKeys.from_keys({self.output_key}) diff --git a/llama-index-core/llama_index/core/query_pipeline/query.py b/llama-index-core/llama_index/core/query_pipeline/query.py index ec7c33005a4e5..3b7fe130984e7 100644 --- a/llama-index-core/llama_index/core/query_pipeline/query.py +++ b/llama-index-core/llama_index/core/query_pipeline/query.py @@ -32,6 +32,7 @@ ComponentIntermediates, ) from llama_index.core.utils import print_text +from llama_index.core.query_pipeline.components.stateful import BaseStatefulComponent # TODO: Make this (safely) pydantic? @@ -154,6 +155,43 @@ def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiD return graph_copy +def get_stateful_components( + query_component: QueryComponent, +) -> List[BaseStatefulComponent]: + """Get stateful components.""" + stateful_components: List[BaseStatefulComponent] = [] + for c in query_component.sub_query_components: + if isinstance(c, BaseStatefulComponent): + stateful_components.append(cast(BaseStatefulComponent, c)) + + if len(c.sub_query_components) > 0: + stateful_components.extend(get_stateful_components(c)) + + return stateful_components + + +def update_stateful_components( + stateful_components: List[BaseStatefulComponent], state: Dict[str, Any] +) -> None: + """Update stateful components.""" + for stateful_component in stateful_components: + # stateful_component.partial(state=state) + stateful_component.state = state + + +def get_and_update_stateful_components( + query_component: QueryComponent, state: Dict[str, Any] +) -> List[BaseStatefulComponent]: + """Get and update stateful components. + + Assign all stateful components in the query component with the state. + + """ + stateful_components = get_stateful_components(query_component) + update_stateful_components(stateful_components, state) + return stateful_components + + CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str] @@ -184,6 +222,9 @@ class QueryPipeline(QueryComponent): num_workers: int = Field( default=4, description="Number of workers to use (currently async only)." ) + state: Dict[str, Any] = Field( + default_factory=dict, description="State of the pipeline." + ) class Config: arbitrary_types_allowed = True @@ -194,14 +235,33 @@ def __init__( chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None, modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None, links: Optional[List[Link]] = None, + state: Optional[Dict[str, Any]] = None, **kwargs: Any, ): super().__init__( callback_manager=callback_manager or CallbackManager([]), + state=state or {}, **kwargs, ) self._init_graph(chain=chain, modules=modules, links=links) + # Pydantic validator isn't called for __init__ so we need to call it manually + get_and_update_stateful_components(self, state) + + def set_state(self, state: Dict[str, Any]) -> None: + """Set state.""" + self.state = state + get_and_update_stateful_components(self, state) + + def update_state(self, state: Dict[str, Any]) -> None: + """Update state.""" + self.state.update(state) + get_and_update_stateful_components(self, state) + + def reset_state(self) -> None: + """Reset state.""" + # use pydantic validator to update state + self.set_state({}) def _init_graph( self, @@ -243,6 +303,11 @@ def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None: for i in range(len(chain) - 1): self.add_link(src=module_keys[i], dest=module_keys[i + 1]) + @property + def stateful_components(self) -> List[BaseStatefulComponent]: + """Get stateful component.""" + return get_stateful_components(self) + def add_links( self, links: List[Link], @@ -272,6 +337,9 @@ def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None: self.module_dict[module_key] = cast(QueryComponent, module) self.dag.add_node(module_key) + # propagate state to new modules added + # TODO: there's more efficient ways to do this + get_and_update_stateful_components(self, self.state) def add_link( self, diff --git a/llama-index-core/tests/agent/custom/test_pipeline.py b/llama-index-core/tests/agent/custom/test_pipeline.py new file mode 100644 index 0000000000000..4d7ba56df6b42 --- /dev/null +++ b/llama-index-core/tests/agent/custom/test_pipeline.py @@ -0,0 +1,164 @@ +"""Test query pipeline worker.""" + +from typing import Any, Dict, Set, Tuple + +from llama_index.core.agent.custom.pipeline_worker import ( + QueryPipelineAgentWorker, +) +from llama_index.core.agent.runner.base import AgentRunner +from llama_index.core.agent.types import Task +from llama_index.core.bridge.pydantic import Field +from llama_index.core.chat_engine.types import AgentChatResponse +from llama_index.core.query_pipeline import FnComponent, QueryPipeline +from llama_index.core.query_pipeline.components.agent import ( + AgentFnComponent, + AgentInputComponent, + CustomAgentComponent, +) +from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent + + +def mock_fn(a: str) -> str: + """Mock function.""" + return a + "3" + + +def mock_agent_input_fn(task: Task, state: dict) -> dict: + """Mock agent input function.""" + if "count" not in state: + state["count"] = 0 + state["max_count"] = 2 + state["input"] = task.input + return {"a": state["input"]} + + +def mock_agent_output_fn( + task: Task, state: dict, output: str +) -> Tuple[AgentChatResponse, bool]: + state["count"] += 1 + state["input"] = output + is_done = state["count"] >= state["max_count"] + return AgentChatResponse(response=str(output)), is_done + + +def mock_agent_input_fn_stateful(state: Dict[str, Any]) -> str: + """Mock agent input function (for StatefulFnComponent).""" + d = mock_agent_input_fn(state["task"], state["step_state"]) + return d["a"] + + +def mock_agent_output_fn_stateful( + state: Dict[str, Any], output: str +) -> Tuple[AgentChatResponse, bool]: + """Mock agent output function (for StatefulFnComponent).""" + return mock_agent_output_fn(state["task"], state["step_state"], output) + + +def mock_agent_output_fn( + task: Task, state: dict, output: str +) -> Tuple[AgentChatResponse, bool]: + state["count"] += 1 + state["input"] = output + is_done = state["count"] >= state["max_count"] + return AgentChatResponse(response=str(output)), is_done + + +def test_qp_agent_fn() -> None: + """Test query pipeline agent. + + Implement via function components. + + """ + agent_input = AgentInputComponent(fn=mock_agent_input_fn) + fn_component = FnComponent(fn=mock_fn) + agent_output = AgentFnComponent(fn=mock_agent_output_fn) + qp = QueryPipeline(chain=[agent_input, fn_component, agent_output]) + + agent_worker = QueryPipelineAgentWorker(pipeline=qp) + agent_runner = AgentRunner(agent_worker=agent_worker) + + # test create_task + task = agent_runner.create_task("foo") + assert task.input == "foo" + + step_output = agent_runner.run_step(task.task_id) + assert str(step_output.output) == "foo3" + assert step_output.is_last is False + + step_output = agent_runner.run_step(task.task_id) + assert str(step_output.output) == "foo33" + assert step_output.is_last is True + + +class MyCustomAgentComponent(CustomAgentComponent): + """Custom agent component.""" + + separator: str = Field(default=":", description="Separator") + + def _run_component(self, **kwargs: Any) -> Dict[str, Any]: + """Run component.""" + return {"output": kwargs["a"] + self.separator + kwargs["a"]} + + @property + def _input_keys(self) -> Set[str]: + """Input keys.""" + return {"a"} + + @property + def _output_keys(self) -> Set[str]: + """Output keys.""" + return {"output"} + + +def test_qp_agent_custom() -> None: + """Test query pipeline agent. + + Implement via `AgentCustomQueryComponent` subclass. + + """ + agent_input = AgentInputComponent(fn=mock_agent_input_fn) + fn_component = MyCustomAgentComponent(separator="/") + agent_output = AgentFnComponent(fn=mock_agent_output_fn) + qp = QueryPipeline(chain=[agent_input, fn_component, agent_output]) + + agent_worker = QueryPipelineAgentWorker(pipeline=qp) + agent_runner = AgentRunner(agent_worker=agent_worker) + + # test create_task + task = agent_runner.create_task("foo") + assert task.input == "foo" + + step_output = agent_runner.run_step(task.task_id) + assert str(step_output.output) == "foo/foo" + assert step_output.is_last is False + + step_output = agent_runner.run_step(task.task_id) + assert str(step_output.output) == "foo/foo/foo/foo" + assert step_output.is_last is True + + +def test_qp_agent_stateful_fn() -> None: + """Test query pipeline agent with stateful components. + + The old flows of using `AgentInputComponent` and `AgentFnComponent` are deprecated. + + """ + agent_input = StatefulFnComponent(fn=mock_agent_input_fn_stateful) + fn_component = FnComponent(fn=mock_fn) + agent_output = StatefulFnComponent(fn=mock_agent_output_fn_stateful) + qp = QueryPipeline(chain=[agent_input, fn_component, agent_output]) + + agent_worker = QueryPipelineAgentWorker(pipeline=qp) + agent_runner = AgentRunner(agent_worker=agent_worker) + + # test create_task + task = agent_runner.create_task("foo") + assert task.input == "foo" + + step_output = agent_runner.run_step(task.task_id) + assert str(step_output.output) == "foo3" + assert step_output.is_last is False + + step_output = agent_runner.run_step(task.task_id) + assert str(step_output.output) == "foo33" + assert step_output.is_last is True diff --git a/llama-index-core/tests/query_pipeline/test_components.py b/llama-index-core/tests/query_pipeline/test_components.py index 2dfdca600c2c4..0e9bf9b6a3dcc 100644 --- a/llama-index-core/tests/query_pipeline/test_components.py +++ b/llama-index-core/tests/query_pipeline/test_components.py @@ -1,5 +1,5 @@ """Test components.""" -from typing import Any, List, Sequence +from typing import Any, List, Sequence, Dict import pytest from llama_index.core.base.base_selector import ( @@ -15,6 +15,8 @@ ) from llama_index.core.query_pipeline.components.function import FnComponent from llama_index.core.query_pipeline.components.input import InputComponent +from llama_index.core.query_pipeline.components.stateful import StatefulFnComponent +from llama_index.core.query_pipeline.components.loop import LoopComponent from llama_index.core.query_pipeline.components.router import ( RouterComponent, SelectorComponent, @@ -155,3 +157,53 @@ def bar2_fn(a: Any) -> str: selector_c = SelectorComponent(selector=selector) output = selector_c.run_component(query="hello", choices=["t1", "t2"]) assert output["output"][0] == SingleSelection(index=1, reason="foo") + + +def stateful_foo_fn(state: Dict[str, Any], a: int, b: int = 2) -> Dict[str, Any]: + """Foo function.""" + old = state.get("prev", 0) + new = old + a + b + state["prev"] = new + return new + + +def test_stateful_fn_pipeline() -> None: + """Test pipeline with function components.""" + p = QueryPipeline() + p.add_modules( + { + "m1": StatefulFnComponent(fn=stateful_foo_fn), + "m2": StatefulFnComponent(fn=stateful_foo_fn), + } + ) + p.add_link("m1", "m2", src_key="output", dest_key="a") + output = p.run(a=1, b=2) + assert output == 8 + p.reset_state() + output = p.run(a=1, b=2) + assert output == 8 + + # try one iteration + p.reset_state() + loop_component = LoopComponent( + pipeline=p, + should_exit_fn=lambda x: x["output"] > 10, + # add_output_to_input_fn=lambda cur_input, output: {"a": output}, + max_iterations=1, + ) + output = loop_component.run_component(a=1, b=2) + assert output["output"] == 8 + + # try two iterations + p.reset_state() + # loop 1: 0 + 1 + 2 = 3, 3 + 3 + 2 = 8 + # loop 2: 8 + 8 + 2 = 18, 18 + 18 + 2 = 38 + loop_component = LoopComponent( + pipeline=p, + should_exit_fn=lambda x: x["output"] > 10, + add_output_to_input_fn=lambda cur_input, output: {"a": output["output"]}, + max_iterations=5, + ) + assert loop_component.run_component(a=1, b=2)["output"] == 38 + + # test loop component