Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: optimizable LLM function node #35

Merged
merged 54 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ab59bc8
initial commit
rizar Sep 27, 2024
6d06979
2 tests pass!!
rizar Sep 27, 2024
db4dccf
cot test
rizar Sep 28, 2024
0f7c137
function node, function template, agent using both of them
rizar Sep 29, 2024
3ffb037
rag + cot agent
rizar Sep 30, 2024
3007ad5
add rag demos as test data
rizar Oct 1, 2024
7f32d5c
rag prompt reproduced!
rizar Oct 1, 2024
dca71ef
evaluation on one example works
rizar Oct 1, 2024
91f25e1
Merge branch 'main' into llm_function
rizar Oct 1, 2024
569ff8c
move tests to separate file
rizar Oct 1, 2024
6cdf261
separate examples and tests
rizar Oct 1, 2024
5a79111
less verbose logging
rizar Oct 2, 2024
4ed3a08
minor tweaks, mostly less verbose logging
rizar Oct 2, 2024
5a39968
run rag agent, evaluate and browser tapes
rizar Oct 2, 2024
2511a3c
fix writing too much
rizar Oct 2, 2024
fdbb11d
add classmethod to get step kind
rizar Oct 2, 2024
ffea118
our own FunctionCall, ToolCall, ToolCalls; less strict signatures
rizar Oct 2, 2024
8d1ea52
fix environment to work with parsed arguments
rizar Oct 2, 2024
cae6d91
Merge branch 'main' into llm_function
rizar Oct 2, 2024
171a4cd
first successful pass of agentic rag
rizar Oct 2, 2024
680a307
things start looking good!
rizar Oct 3, 2024
16c7281
refactor and fix retrieval accuracy calc
rizar Oct 3, 2024
e37d154
agentic rag evaluation
rizar Oct 3, 2024
8154511
finish up optimization code
rizar Oct 3, 2024
90dbd35
typo
rizar Oct 3, 2024
f679030
better experiment management
rizar Oct 3, 2024
9adbbd2
add a test for the query prompt - it works
rizar Oct 3, 2024
a478494
change context rendering
rizar Oct 3, 2024
f09b491
add rag examples loading
rizar Oct 4, 2024
1606101
fix tests
rizar Oct 4, 2024
9d29d22
make it all work
rizar Oct 4, 2024
22f868c
make TapeMetadata.result a dict by default
rizar Oct 4, 2024
c9ebfa7
update the readme
rizar Oct 4, 2024
4a6c603
fix some of the issues with the notebook
rizar Oct 4, 2024
cb81121
Merge branch 'main' into llm_function
rizar Oct 4, 2024
045b46d
fix tape browser
rizar Oct 4, 2024
ad55d1c
break up the code
rizar Oct 4, 2024
5af9582
fix launch commands
rizar Oct 4, 2024
692434c
notebook update
rizar Oct 4, 2024
1bf0b35
remove debug stuff rom the intro
rizar Oct 4, 2024
0b8446e
move out some functionality to tapeagents.optimize
rizar Oct 4, 2024
e340a45
remove comment
rizar Oct 4, 2024
e628de3
typo
rizar Oct 4, 2024
ec65938
rename a whole bunch of things
rizar Oct 4, 2024
447269c
renaming, type fixes
rizar Oct 4, 2024
5bee52a
add a test for optimize.py
rizar Oct 4, 2024
fdb1a73
cute references
rizar Oct 4, 2024
c39aa51
add dspy as a dev dependency
rizar Oct 4, 2024
a218e71
force creation of sqlite db
rizar Oct 5, 2024
645384f
typo
rizar Oct 5, 2024
8276d03
go back after running test in tmp dir
rizar Oct 5, 2024
29f7625
small readme fixes
rizar Oct 7, 2024
3e27217
rename parse into get_node_runs
rizar Oct 8, 2024
1c2ce75
regenerate outputs
rizar Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions conf/tapeagent/hotpot_qa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
hydra:
job:
chdir: True
run:
dir: outputs/${exp_name}_load=${load_demos}_optimize=${optimize.do}_seed=${seed}
target: evaluate
exp_name: ${agent}
seed: 1
agent: ???
llm_cache: true
load_demos: false
optimize:
do: false
n_paragraphs: 2
max_n_demos: 4
rag:
partial_demos: true
demos: true
agentic_rag:
max_hops: 2
dataset:
dev_size: 50
question: How many storeys are in the castle that David Gregory inherited?
2 changes: 1 addition & 1 deletion examples/data_science/data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tapeagents.llms import LLM, LiteLLM
from tapeagents.rendering import BasicRenderer, PrettyRenderer
from tapeagents.runtime import main_loop
from tapeagents.utils import run_in_tmp_dir_to_make_test_data
from tapeagents.test_utils import run_in_tmp_dir_to_make_test_data
from tapeagents.view import Call, Respond

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
Expand Down
50 changes: 17 additions & 33 deletions examples/intro_clean.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -501,7 +501,7 @@
" messages=[system_message] + tape_to_messages(tape) + [guidance_message], tools=env.get_tool_schema_dicts()\n",
" )\n",
"\n",
" def generate_steps(self, agent, tape, llm_stream: LLMStream):\n",
" def generate_steps(self, agent, tape , llm_stream: LLMStream):\n",
" if content := llm_stream.get_message().content:\n",
" yield AssistantThought(content=content)\n",
" else:\n",
Expand All @@ -522,7 +522,7 @@
" yield AssistantStep(content=m.content)\n",
" yield SetNextNode(next_node=0)\n",
" elif m.tool_calls:\n",
" yield ToolCalls(tool_calls=m.tool_calls)\n",
" yield ToolCalls.from_llm_output(m)\n",
" yield SetNextNode(next_node=1)\n",
" else:\n",
" raise ValueError()\n",
Expand Down Expand Up @@ -615,9 +615,8 @@
"outputs": [],
"source": [
"import json\n",
"from tapeagents.dialog_tape import FunctionCall, ToolCall\n",
"from tapeagents.llms import TrainableLLM\n",
"from litellm.utils import ChatCompletionMessageToolCall\n",
"from litellm.utils import Function\n",
"\n",
"from tapeagents.prompting import step_to_message\n",
"\n",
Expand Down Expand Up @@ -684,8 +683,8 @@
" if data.get(\"kind\") == \"response\":\n",
" response = data[\"content\"]\n",
" elif data.get(\"kind\") == \"tool_call\":\n",
" tool_call = ChatCompletionMessageToolCall(\n",
" function=Function(name=data[\"tool_name\"], arguments=json.dumps(data[\"parameters\"])),\n",
" tool_call = ToolCall(\n",
" function=FunctionCall(name=data[\"tool_name\"], arguments=json.dumps(data[\"parameters\"])),\n",
" # tool call must be a unique string, it helps to make it something deterministic\n",
" id=f\"tool_call_{len(tool_calls)}_node_starts_at_{len(tape)}\",\n",
" )\n",
Expand Down Expand Up @@ -1030,7 +1029,7 @@
" yield Respond(content=m.content)\n",
" elif m.tool_calls:\n",
" # while the LLM suggests tool calls, yield them as action steps\n",
" yield ToolCalls(tool_calls=m.tool_calls)\n",
" yield ToolCalls.from_llm_output(m)\n",
" yield SetNextNode(next_node=0)\n",
" else:\n",
" raise ValueError()\n",
Expand Down Expand Up @@ -1119,20 +1118,19 @@
" elif m.tool_calls:\n",
" yield SetNextNode(next_node=1)\n",
" # only keep the tool calls before the call to another agent\n",
" tool_calls = []\n",
" for tc in m.tool_calls:\n",
" agent_call = None\n",
" for i, tc in enumerate(m.tool_calls):\n",
" if tc.function.name == \"call_search_agent\":\n",
" agent_call = tc\n",
" m.tool_calls = m.tool_calls[:i]\n",
" break\n",
" else:\n",
" tool_calls.append(tc)\n",
" # either produce the ToolCalls action OR call another agent\n",
" if tool_calls:\n",
" yield ToolCalls(tool_calls=tool_calls)\n",
" if m.tool_calls: \n",
" yield ToolCalls.from_llm_output(m)\n",
" else:\n",
" tc = m.tool_calls[0]\n",
" assert tc.function.name == \"call_search_agent\"\n",
" yield Call(agent_name=\"search_agent\", content=json.loads(m.tool_calls[0].function.arguments)[\"query\"])\n",
"\n",
" assert agent_call and agent_call.function.name == \"call_search_agent\"\n",
" yield Call(agent_name=\"search_agent\", content=json.loads(agent_call.function.arguments)[\"query\"])\n",
" \n",
" else:\n",
" raise ValueError()\n",
"\n",
Expand Down Expand Up @@ -1168,22 +1166,8 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "tapeagents",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"name": "python"
}
},
"nbformat": 4,
Expand Down
82 changes: 82 additions & 0 deletions examples/optimize/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Agent optimization in TapeAgents

This example demostrates how one can optimize the agent's prompt templates in TapeAgents.

There are two common ways to optimize prompts while keeping the overall structure of the agent the same:
- add demonstrations to a prompt
- change the instruction part of the prompt

In TapeAgents we have a structured prompt template called [LLMFunctionTemplate](tapeagents/llm_function.py) that enables both of these prompt change approaches. If you are familiar with [DSPy](https://github.com/stanfordnlp/dspy), you will recognize in this DSPy's signature `Signature` (pun intended). The equivalent of DSPy's modules are [LLMFunctionNode](tapeagents/llm_function.py) nodes that apply the respective function template to the tape in order to make the prompt and to generate the next steps.

In our [agent optimization](examples/optimize) example we show how one can build a 2-hop Retrieval-Augmented Generation agent (a.k.a. agentic RAG) and optimize its query generation prompts using weak supervision. This example is a reimplementation of [DSPy intro](https://github.com/stanfordnlp/dspy/blob/main/intro.ipynb) in TapeAgents. It uses questions from the [HotPotQA dataset](https://hotpotqa.github.io/) and the generously provided by DSPy Wikipedia paragraph retrieval service.

# How to run the example

## Setup

First, install extra depedencies:

```bash
pip install -r examples/optimize/requirements.txt
```

## Explore the setting

To better understand the setup, you can launch a pre-optimized agent in TapeAgents Studio and run it by pressing `Run Loop` button.

```bash
python examples/optimize/optimize.py agent=agentic_rag target=studio load_demos=true
```

Check out the prompts: they contain support demonstrations of how to use the search engine for complex queries, like this one:

> Context:
N/A
Question: Which of these publications was most recently published, Who Put the Bomp or Self?
Reasoning: Let's think step by step in order to produce the query. We know that publication dates are typically included in metadata for articles or books. By searching for the publication date of each article, we can determine which one was most recently published.
Query: "publication date of Who Put the Bomp" OR "publication date of Self"

That is what we will be learning below.

## Optimize and benchmark different agents

Let's benchmark a basic RAG agent. In the basic RAG the user's question is used as the query.

```bash
$ python -m examples.optimize.optimize agent=rag target=evaluate
Retrieval accuracy: 0.26
Answer accuracy: 0.54
```

The retrieval accuracy is not that high. Let's try 2-hop Agentic RAG. In our Agentic RAG example the agent makes two retrieval queries, and the second query is based on the paragraphs that were trieved for the first one.

```bash
$ python -m examples.optimize.optimize agent=agentic_rag target=evaluate
Retrieval accuracy: 0.50
Answer accuracy: 0.62
```

The retrieval accuracy is higher, but we can do better. Let's optimize the agent's prompts using weak supervision.

```bash
$ python -m examples.optimize.optimize agent=agentic_rag optimize.do=true target=evaluate
Retrieval accuracy: 0.56
Answer accuracy: 0.52
```

And this way we get a higher retrieval accuracy, though answer accuracy went down.

Note:
- we found the quantitative results of this experiment to be very unstable due to the LLM non-determinism and the small training and dev set sizes. In our future work we will add validation of the selected examples and evaluate on a larget dev set.
- by default the LLM cache is on, so if you rerun an experiment, you will get the exact same results. You can run another experiment by changing passing `exp_name=<another_name>` to Hydra.

## Explore resulting tapes

Change `target` to `browse` to launch the TapeBrowser app.

```bash
$ python examples/optimize/optimize.py agent=agentic_rag optimize.do=true target=browse
```

You can now explore the agent tapes on the dev set, as well as the "good" and the "bad" training tapes. The good tapes that are the ones we used to mine demonstrations for the function templates. The bad tapes are the ones that we filtered out by various criteria (see `result` field in metadata in the tape browser for the reason for filtering).

41 changes: 41 additions & 0 deletions examples/optimize/func_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from tapeagents.dialog_tape import ToolResult
from tapeagents.llm_function import Input, LLMFunctionTemplate, AssistantOutput, RationaleOutput, ToolCallOutput


def render_contexts(contexts: list[str]) -> str:
if not contexts:
return "N/A"
return "\n".join(f"[{i + 1}] «{t}»" for i, t in enumerate(contexts))


class ContextInput(Input):
def render(self, step: ToolResult):
return render_contexts(step.content)


def make_answer_template() -> LLMFunctionTemplate:
return LLMFunctionTemplate(
desc="Answer questions with short factoid answers.",
inputs=[
ContextInput(name="context", desc="may contain relevant facts", separator="\n"),
Input(name="question"),
],
outputs=[
RationaleOutput.for_output("answer"),
AssistantOutput(name="answer", desc="often between 1 and 5 words")
]
)


def make_query_template() -> LLMFunctionTemplate:
return LLMFunctionTemplate(
desc="Write a simple search query that will help answer a complex question.",
inputs=[
ContextInput(name="context", desc="may contain relevant facts", separator="\n"),
Input(name="question"),
],
outputs=[
RationaleOutput.for_output("query"),
ToolCallOutput(name="query", tool_name="retrieve", arg_name="query")
]
)
64 changes: 64 additions & 0 deletions examples/optimize/load_demos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
import pathlib

from tapeagents.dialog_tape import AssistantStep, AssistantThought, FunctionCall, ToolCall, ToolCalls, ToolResult, UserStep


res_dir = pathlib.Path(__file__).parent.parent.resolve() / "res"


def load_rag_demos() -> tuple[list, list]:
with open(res_dir / "llm_function_rag_demos.json") as f:
demos_json = json.load(f)
partial_demos = []
demos = []
for demo in demos_json:
if demo.get("augmented"):
demo = {
"question": UserStep(content=demo["question"]),
"context": ToolResult(content=demo["context"], tool_call_id=""),
"rationale": AssistantThought(content=demo["rationale"]),
"answer": AssistantStep(content=demo["answer"]),
}
demos.append(demo)
else:
demo = {
"question": UserStep(content=demo["question"]),
"answer": AssistantStep(content=demo["answer"]),
}
partial_demos.append(demo)
return partial_demos, demos


def load_agentic_rag_demos() -> dict[str, tuple[list, list]]:
"""Loads full demos only"""
with open(res_dir / "agentic_rag_demos.json") as f:
demos_json = json.load(f)
result = {}
for predictor, predictor_demos in demos_json.items():
predictor_demos = [d for d in predictor_demos if d.get("augmented")]
demos = []
if "query" in predictor:
for demo in predictor_demos:
tc = ToolCall(function=FunctionCall(name='retrieve', arguments={'query': demo["query"]}))
demo = {
"question": UserStep(content=demo["question"]),
"context": ToolResult(content=demo["context"]),
"rationale": AssistantThought(content=demo["rationale"]),
"query": ToolCalls(tool_calls=[tc]),
}
demos.append(demo)
result[f"query{predictor[-2]}"] = ([], demos)
elif predictor == "generate_answer":
for demo in predictor_demos:
demo = {
"question": UserStep(content=demo["question"]),
"context": ToolResult(content=demo["context"], tool_call_id=""),
"rationale": AssistantThought(content=demo["rationale"]),
"answer": AssistantStep(content=demo["answer"]),
}
demos.append(demo)
result["answer"] = ([], demos)
else:
raise ValueError(f"Unknown predictor {predictor}")
return result
Loading