Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agent notebook fixes (v1) #10610

Merged
merged 10 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 116 additions & 43 deletions docs/examples/agent/react_agent.ipynb
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "5cfa0417",
"metadata": {},
Expand All @@ -22,7 +21,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "81d6fba5",
"metadata": {},
Expand Down Expand Up @@ -52,10 +50,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "e8ac1778-0585-43c9-9dad-014d13d7460d",
"metadata": {},
"outputs": [],
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /Users/jerryliu/Programmi\n",
"[nltk_data] ng/gpt_index/.venv/lib/python3.10/site-\n",
"[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n",
"[nltk_data] Downloading package punkt to /Users/jerryliu/Programming/g\n",
"[nltk_data] pt_index/.venv/lib/python3.10/site-\n",
"[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
}
],
"source": [
"from llama_index.core.agent import ReActAgent\n",
"from llama_index.llms.openai import OpenAI\n",
Expand All @@ -75,9 +90,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "26472aaf-1a12-49f9-9fe6-cbf41dd15f88",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def multiply(a: int, b: int) -> int:\n",
Expand All @@ -90,9 +107,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "df78ae85-bcf7-44c1-87ee-f301e646db20",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def add(a: int, b: int) -> int:\n",
Expand All @@ -115,9 +134,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "7ab300f1-b054-46d9-b1c8-dbcd0d538e5a",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(model=\"gpt-3.5-turbo-instruct\")\n",
Expand All @@ -126,21 +147,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "69bb1aa9-1ea3-4c88-a4f3-239b76392aa5",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1;3;38;5;200mThought: I need to use a tool to help me answer the question.\n",
"assistant: Action: multiply\n",
"assistant: Action Input: {\"a\": 2, \"b\": 4}\n",
"Action: multiply\n",
"Action Input: {\"a\": 2, \"b\": 4}\n",
"Observation: 8\n",
"assistant: Thought: I need to use a tool to help me answer the question.\n",
"assistant: Action: add\n",
"assistant: Action Input: {\"a\": 20, \"b\": 8}\n",
"Thought: I need to use a tool to help me answer the question.\n",
"Action: add\n",
"Action Input: {\"a\": 20, \"b\": 8}\n",
"Observation: 28\n",
"Thought: I can answer without using any more tools.\n",
"Answer: 28\n",
Expand All @@ -154,15 +177,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "76112bb6-a291-4235-ad00-d4f6ff20adfe",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 28"
"28"
]
}
],
Expand All @@ -181,9 +206,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "1bb7d49b-404c-4a46-9a84-1f7bb8792991",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(model=\"gpt-4\")\n",
Expand All @@ -192,19 +219,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "277f21d3-4f62-430e-b0fc-7561c64084d0",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1;3;38;5;200mThought: I need to use the tools to help me answer the question. According to the order of operations in mathematics (BIDMAS/BODMAS), multiplication should be done before addition. So, I will first multiply 2 and 4, then add the result to 2.\n",
"\u001b[1;3;38;5;200mThought: I need to use the tools to help me answer the question. According to the order of operations in mathematics (BIDMAS/BODMAS), multiplication should be done before addition. So, I will first multiply 2 and 4, and then add the result to 2.\n",
"Action: multiply\n",
"Action Input: {'a': 2, 'b': 4}\n",
"\u001b[0m\u001b[1;3;34mObservation: 8\n",
"\u001b[0m\u001b[1;3;38;5;200mThought: Now that I have the result of the multiplication, I need to add this to 2.\n",
"\u001b[0m\u001b[1;3;38;5;200mThought: Now that I have the result of the multiplication, I need to add this result to 2.\n",
"Action: add\n",
"Action Input: {'a': 2, 'b': 8}\n",
"\u001b[0m\u001b[1;3;34mObservation: 10\n",
Expand Down Expand Up @@ -233,9 +262,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "990597d5-f86c-44cb-ad91-24715a49a278",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(model=\"gpt-4\")\n",
Expand All @@ -244,9 +275,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "ad7964a5-a953-4a53-9865-6a0795cd2772",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -324,9 +357,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "0766d978-b011-40a6-bdce-b0ea566d2475",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from llama_index.core import PromptTemplate\n",
Expand Down Expand Up @@ -391,19 +426,46 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "c32037ea-6c30-4059-bd32-4b9ba64912ab",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'agent_worker:system_prompt': PromptTemplate(metadata={'prompt_type': <PromptType.CUSTOM: 'custom'>}, template_vars=['tool_desc', 'tool_names'], kwargs={}, output_parser=None, template_var_mappings=None, function_mappings=None, template='\\nYou are designed to help with a variety of tasks, from answering questions to providing summaries to other types of analyses.\\n\\n## Tools\\nYou have access to a wide variety of tools. You are responsible for using\\nthe tools in any sequence you deem appropriate to complete the task at hand.\\nThis may require breaking the task into subtasks and using different tools\\nto complete each subtask.\\n\\nYou have access to the following tools:\\n{tool_desc}\\n\\n## Output Format\\nTo answer the question, please use the following format.\\n\\n```\\nThought: I need to use a tool to help me answer the question.\\nAction: tool name (one of {tool_names}) if using a tool.\\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{\"input\": \"hello world\", \"num_beams\": 5}})\\n```\\n\\nPlease ALWAYS start with a Thought.\\n\\nPlease use a valid JSON format for the Action Input. Do NOT do this {{\\'input\\': \\'hello world\\', \\'num_beams\\': 5}}.\\n\\nIf this format is used, the user will respond in the following format:\\n\\n```\\nObservation: tool response\\n```\\n\\nYou should keep repeating the above format until you have enough information\\nto answer the question without using any more tools. At that point, you MUST respond\\nin the one of the following two formats:\\n\\n```\\nThought: I can answer without using any more tools.\\nAnswer: [your answer here]\\n```\\n\\n```\\nThought: I cannot answer the question with the provided tools.\\nAnswer: Sorry, I cannot answer your query.\\n```\\n\\n## Current Conversation\\nBelow is the current conversation consisting of interleaving human and assistant messages.\\n\\n')}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.get_prompts()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "67581223-f625-4b28-90aa-c8e5a232879e",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"agent.update_prompts({\"agent_worker:system_prompt\": react_system_prompt})"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "9e1b2ce0-8f46-4d1d-8504-5b4b5d3f8478",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
Expand All @@ -413,17 +475,19 @@
"Action: add\n",
"Action Input: {'a': 5, 'b': 3}\n",
"\u001b[0m\u001b[1;3;34mObservation: 8\n",
"\u001b[0m\u001b[1;3;38;5;200mThought: Now I need to add the result from the previous operation with 2.\n",
"\u001b[0m\u001b[1;3;38;5;200mThought: Now I need to add the result from the previous operation to 2.\n",
"Action: add\n",
"Action Input: {'a': 8, 'b': 2}\n",
"\u001b[0m\u001b[1;3;34mObservation: 10\n",
"\u001b[0m\u001b[1;3;38;5;200mThought: I can answer without using any more tools.\n",
"Answer: The result of 5+3+2 is 10.\n",
"- First, I added 5 and 3 using the add tool, which resulted in 8.\n",
"- Then, I added the result (8) to 2 using the add tool, which resulted in 10.\n",
"\n",
"- I first added 5 and 3 using the add tool, which resulted in 8.\n",
"- Then I added the result (8) to 2 using the add tool again, which resulted in 10.\n",
"\u001b[0mThe result of 5+3+2 is 10.\n",
"- First, I added 5 and 3 using the add tool, which resulted in 8.\n",
"- Then, I added the result (8) to 2 using the add tool, which resulted in 10.\n"
"\n",
"- I first added 5 and 3 using the add tool, which resulted in 8.\n",
"- Then I added the result (8) to 2 using the add tool again, which resulted in 10.\n"
]
}
],
Expand All @@ -432,6 +496,14 @@
"response = agent.chat(\"What is 5+3+2\")\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b60d6e4-06c5-424b-a5b5-1616340374b5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -449,7 +521,8 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
5 changes: 5 additions & 0 deletions llama-index-core/llama_index/core/agent/react/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from llama_index.core.objects.base import ObjectRetriever
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
from llama_index.core.prompts.mixin import PromptMixinType


class ReActAgent(AgentRunner):
Expand Down Expand Up @@ -126,3 +127,7 @@ def from_tools(
verbose=verbose,
context=context,
)

def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt modules."""
return {"agent_worker": self.agent_worker}
Loading