generated from langchain-ai/react-agent
-
Notifications
You must be signed in to change notification settings - Fork 23
/
graph.py
229 lines (197 loc) · 9.34 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Define a data enrichment agent.
Works with a chat model with tool calling support.
"""
import json
from typing import Any, Dict, List, Literal, Optional, cast
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel, Field
from enrichment_agent import prompts
from enrichment_agent.configuration import Configuration
from enrichment_agent.state import InputState, OutputState, State
from enrichment_agent.tools import scrape_website, search
from enrichment_agent.utils import init_model
async def call_agent_model(
state: State, *, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Call the primary Language Model (LLM) to decide on the next research action.
This asynchronous function performs the following steps:
1. Initializes configuration and sets up the 'Info' tool, which is the user-defined extraction schema.
2. Prepares the prompt and message history for the LLM.
3. Initializes and configures the LLM with available tools.
4. Invokes the LLM and processes its response.
5. Handles the LLM's decision to either continue research or submit final info.
"""
# Load configuration from the provided RunnableConfig
configuration = Configuration.from_runnable_config(config)
# Define the 'Info' tool, which is the user-defined extraction schema
info_tool = {
"name": "Info",
"description": "Call this when you have gathered all the relevant info",
"parameters": state.extraction_schema,
}
# Format the prompt defined in prompts.py with the extraction schema and topic
p = configuration.prompt.format(
info=json.dumps(state.extraction_schema, indent=2), topic=state.topic
)
# Create the messages list with the formatted prompt and the previous messages
messages = [HumanMessage(content=p)] + state.messages
# Initialize the raw model with the provided configuration and bind the tools
raw_model = init_model(config)
model = raw_model.bind_tools([scrape_website, search, info_tool], tool_choice="any")
response = cast(AIMessage, await model.ainvoke(messages))
# Initialize info to None
info = None
# Check if the response has tool calls
if response.tool_calls:
for tool_call in response.tool_calls:
if tool_call["name"] == "Info":
info = tool_call["args"]
break
if info is not None:
# The agent is submitting their answer;
# ensure it isn't erroneously attempting to simultaneously perform research
response.tool_calls = [
next(tc for tc in response.tool_calls if tc["name"] == "Info")
]
response_messages: List[BaseMessage] = [response]
if not response.tool_calls: # If LLM didn't respect the tool_choice
response_messages.append(
HumanMessage(content="Please respond by calling one of the provided tools.")
)
return {
"messages": response_messages,
"info": info,
# Add 1 to the step count
"loop_step": 1,
}
class InfoIsSatisfactory(BaseModel):
"""Validate whether the current extracted info is satisfactory and complete."""
reason: List[str] = Field(
description="First, provide reasoning for why this is either good or bad as a final result. Must include at least 3 reasons."
)
is_satisfactory: bool = Field(
description="After providing your reasoning, provide a value indicating whether the result is satisfactory. If not, you will continue researching."
)
improvement_instructions: Optional[str] = Field(
description="If the result is not satisfactory, provide clear and specific instructions on what needs to be improved or added to make the information satisfactory."
" This should include details on missing information, areas that need more depth, or specific aspects to focus on in further research.",
default=None,
)
async def reflect(
state: State, *, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Validate the quality of the data enrichment agent's output.
This asynchronous function performs the following steps:
1. Prepares the initial prompt using the main prompt template.
2. Constructs a message history for the model.
3. Prepares a checker prompt to evaluate the presumed info.
4. Initializes and configures a language model with structured output.
5. Invokes the model to assess the quality of the gathered information.
6. Processes the model's response and determines if the info is satisfactory.
"""
p = prompts.MAIN_PROMPT.format(
info=json.dumps(state.extraction_schema, indent=2), topic=state.topic
)
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"{reflect.__name__} expects the last message in the state to be an AI message with tool calls."
f" Got: {type(last_message)}"
)
messages = [HumanMessage(content=p)] + state.messages[:-1]
presumed_info = state.info
checker_prompt = """I am thinking of calling the info tool with the info below. \
Is this good? Give your reasoning as well. \
You can encourage the Assistant to look at specific URLs if that seems relevant, or do more searches.
If you don't think it is good, you should be very specific about what could be improved.
{presumed_info}"""
p1 = checker_prompt.format(presumed_info=json.dumps(presumed_info or {}, indent=2))
messages.append(HumanMessage(content=p1))
raw_model = init_model(config)
bound_model = raw_model.with_structured_output(InfoIsSatisfactory)
response = cast(InfoIsSatisfactory, await bound_model.ainvoke(messages))
if response.is_satisfactory and presumed_info:
return {
"info": presumed_info,
"messages": [
ToolMessage(
tool_call_id=last_message.tool_calls[0]["id"],
content="\n".join(response.reason),
name="Info",
additional_kwargs={"artifact": response.model_dump()},
status="success",
)
],
}
else:
return {
"messages": [
ToolMessage(
tool_call_id=last_message.tool_calls[0]["id"],
content=f"Unsatisfactory response:\n{response.improvement_instructions}",
name="Info",
additional_kwargs={"artifact": response.model_dump()},
status="error",
)
]
}
def route_after_agent(
state: State,
) -> Literal["reflect", "tools", "call_agent_model", "__end__"]:
"""Schedule the next node after the agent's action.
This function determines the next step in the research process based on the
last message in the state. It handles three main scenarios:
1. Error recovery: If the last message is unexpectedly not an AIMessage.
2. Info submission: If the agent has called the "Info" tool to submit findings.
3. Continued research: If the agent has called any other tool.
"""
last_message = state.messages[-1]
# "If for some reason the last message is not an AIMessage (due to a bug or unexpected behavior elsewhere in the code),
# it ensures the system doesn't crash but instead tries to recover by calling the agent model again.
if not isinstance(last_message, AIMessage):
return "call_agent_model"
# If the "Into" tool was called, then the model provided its extraction output. Reflect on the result
if last_message.tool_calls and last_message.tool_calls[0]["name"] == "Info":
return "reflect"
# The last message is a tool call that is not "Info" (extraction output)
else:
return "tools"
def route_after_checker(
state: State, config: RunnableConfig
) -> Literal["__end__", "call_agent_model"]:
"""Schedule the next node after the checker's evaluation.
This function determines whether to continue the research process or end it
based on the checker's evaluation and the current state of the research.
"""
configurable = Configuration.from_runnable_config(config)
last_message = state.messages[-1]
if state.loop_step < configurable.max_loops:
if not state.info:
return "call_agent_model"
if not isinstance(last_message, ToolMessage):
raise ValueError(
f"{route_after_checker.__name__} expected a tool messages. Received: {type(last_message)}."
)
if last_message.status == "error":
# Research deemed unsatisfactory
return "call_agent_model"
# It's great!
return "__end__"
else:
return "__end__"
# Create the graph
workflow = StateGraph(
State, input=InputState, output=OutputState, config_schema=Configuration
)
workflow.add_node(call_agent_model)
workflow.add_node(reflect)
workflow.add_node("tools", ToolNode([search, scrape_website]))
workflow.add_edge("__start__", "call_agent_model")
workflow.add_conditional_edges("call_agent_model", route_after_agent)
workflow.add_edge("tools", "call_agent_model")
workflow.add_conditional_edges("reflect", route_after_checker)
graph = workflow.compile()
graph.name = "ResearchTopic"