-
Notifications
You must be signed in to change notification settings - Fork 0
/
agentObject.py
369 lines (291 loc) · 18.4 KB
/
agentObject.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import os
import dspy
from datetime import datetime, timezone
from dotenv import load_dotenv
from dbObject import dbObject
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
from dspy import Suggest
class Agent():
def __init__(self, network_id: str, agent_id=None, db: dbObject=None):
# ensure all required stuff is passed in
if db is None: raise ValueError("Database object cannot be None.")
self.db = db
if network_id is None: raise ValueError("Network ID cannot be None.")
self.network_id = network_id
if agent_id is None: raise ValueError("Agent ID cannot be None.")
self.agent_id = agent_id
# config dspy model to use, eventually we should just pass this in
self.turbo = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=2000, model_type="chat", temperature=0.8)
dspy.settings.configure(lm=self.turbo, trace=[])
# self.user_chat_module = UserChatModule()
# load user's settings from db. It's faster to do it here, but they can't be updated mid conversation (for example in a user interaction, aka problem for later)
self.instructions = self.db.get_instructions(self.agent_id)
self.toxicity_settings = self.db.get_toxicty_settings(self.agent_id)
self.chat_history = []
# helper modules
self.interaction_expectation = InteractionLengthExpectation()
self.chat_summarizer = ChatHistorySummarizer()
# metrics
self.toxicity_metric = ToxicityMetric(self.db)
self.toxicity_threshold = 6 # 1-very toxic, 10-not toxic, flags memory as toxic if below this
self.interest_metric = ConversationMetric()
self.interest_threshold = 6 # 1-not interested, 10-very interested, gets added to metadata of memory
# agent chat module (and activate suggestions)
self.agent_chat_module = AgentChatModule(self.db, self.agent_id, relevance_threshold=7) # set the relevance threshold, it retries if it's below this
self.agent_chat_module = assert_transform_module(self.agent_chat_module, backtrack_handler)
######################## USER CHAT ########################
########### need to completely rework this for
# just responds to a message
def handle_user_chat(self, message):
user_response = self.user_chat_module(prompt=message, instructions=self.instructions)
self.chat_history.append({"user": message, "agent": user_response["answer"]})
return user_response["answer"]
######################### AGENTS CHAT ########################
def handle_agent_interaction(self, away_agent_id, init_prompt="Hey, talk to me", environment="In line at a coffee shop"):
# logic could be added for the initial prompt to be generated by what's in the memory retrieval, or maybe what's most recent
# init DSPy chat modules
# we should init models here and pass them into these modules (use "with dspy.context" inside modules to select model)
home_agent = self.agent_chat_module
away_agent = assert_transform_module(AgentChatModule(self.db, away_agent_id), backtrack_handler)
# predict number of interactions
expected_interactions = self.interaction_expectation(environment)
# run interactions (recursive extension logic built in)
interest_metric_output = self._run_interactions(home_agent, away_agent, init_prompt, expected_interactions)
formatted_chat_history = home_agent.format_chat_history() # make it a str
# summarize chat history
interaction_summary = self.chat_summarizer(self.agent_id, home_agent.chat_history)
# Check toxicity of the chat history
toxicity_metric_output = self.toxicity_metric(formatted_chat_history, self.agent_id)
# Store summary in memory
self.db.add_agent_data(self.agent_id, interaction_summary, toxicity_metric_output, interest_metric=interest_metric_output)
# Testing
print(f"Chat history: \n{formatted_chat_history}")
print(f"Chat history summary: \n{interaction_summary}")
return interaction_summary, interest_metric_output, toxicity_metric_output # i just changed this, any where we're calling it should be able to parse 3 outputs instead of 1 like originally
def _run_interactions(self, home_agent, away_agent, init_prompt, expected_interactions, extension=0, max_extensions=3):
for i in range(expected_interactions):
if i == 0: # for initializing conversation
# home responds to init prompt
home_response = home_agent.forward(prompt=init_prompt, retrieved_memories=str(self.db.get_agent_memory(self.agent_id, init_prompt)))
# away responds to home's response
away_response = away_agent.forward(prompt=home_response['answer'], retrieved_memories=str(self.db.get_agent_memory(away_agent.agent_id, home_response['answer'])))
else:
# home responds to away's response
home_response = home_agent.forward(prompt=away_response['answer'], retrieved_memories=str(self.db.get_agent_memory(self.agent_id, away_response['answer'])))
# away responds to home's response
away_response = away_agent.forward(prompt=home_response['answer'], retrieved_memories=str(self.db.get_agent_memory(away_agent.agent_id, home_response['answer'])))
# format chat history
away_agent_chat_history_str = away_agent.format_chat_history() # make it a str
# eval conversation importance
interest_metric_output = self.interest_metric(
away_agent_chat_history_str,
str(self.db.get_agent_memory(
self.agent_id,
away_agent_chat_history_str)
)
)
# recursive
if extension < max_extensions and interest_metric_output > self.interest_threshold:
extension += 1 # update extension count
self._run_interactions(home_agent, away_agent, "I'm interested in learning more about you.", expected_interactions, interest_threshold=self.interest_threshold, extension=extension, max_extensions=max_extensions)
return interest_metric_output
########################## DSPY CLASSES ##########################
##################### AGENT CHAT MODULES #####################
class AgentChatModule(dspy.Module):
class AgentChatSignature(dspy.Signature):
"""
Exchange information with another agent, following the instructions provided. Do not make up any information or experiences.
Find commonalities and relevant things in your memory retrieval based on what the other agent asks you. The conversation takes place in the environment provided.
"""
guidelines = dspy.InputField(desc="User-set guidelines for you follow during social interactions.")
prompt = dspy.InputField(desc="A message from the other agent.")
memory_retrieval = dspy.OutputField(desc="Retrieved memory based on the prompt.")
chat_history = dspy.InputField(desc="Chat history to avoid repeating.") # is this too inefficient to pass in every time?
answer = dspy.OutputField(desc="A response to the other agent.")
def __init__(self, db, agent_id, relevance_threshold=7):
self.db = db
self.agent_id = agent_id
self.chat_history = [] # stored in the class so we can call it outside of the module
self.instructions = self.db.get_instructions(self.agent_id)
# main chain of thought
self.respond = dspy.ChainOfThought(self.AgentChatSignature)
# metrics
self.relevance_metric = RelevanceMetric()
self.relevance_threshold = relevance_threshold # 1-not relevant, 10-highly relevance, retries if below this
def forward(self, prompt, retrieved_memories):
# add some module here for interpeting the memories based on how we store their summaries. Picking the ones relevant to the conversation to compose context info string to be added to the prompt.
agent_chat_response = self.respond(
guidelines=self.instructions,
prompt=prompt,
memory_retrieval=retrieved_memories,
chat_history=str(self.chat_history),
)
# check relevance of the response
relevance_metric_output = self.relevance_metric(
memory_context=retrieved_memories,
previous_chat_msg=agent_chat_response['answer']
) # outputs a float
# retry logic for relevance check
Suggest(relevance_metric_output > self.relevance_threshold,
"Your response should be more relevant to the memories or previous chat message.",
target_module="AgentChatSignature"
)
# Add the latest response in the chat history
self.append_chat_history(prompt, agent_chat_response['answer'])
return agent_chat_response #, relevance_metric_output
# chat history stuff is currently setup to only use this agents chat history, but we can zip it together with it's partner agent later
def append_chat_history(self, prompt, response):
timestamp = datetime.now(timezone.utc).isoformat()
chat_entry = {
"timestamp": timestamp,
"agent_id": self.agent_id,
"prompt": prompt,
"response": response
}
self.chat_history.append(chat_entry)
# for better interpretation during summary
def format_chat_history(self):
formatted_history = []
for entry in self.chat_history:
formatted_entry = f"Timestamp: {entry['timestamp']}\n"
formatted_entry += f"Agent {entry['agent_id']}\nPrompt: {entry['prompt']}\n"
formatted_entry += f"Response: {entry['response']}\n"
formatted_entry += "-" * 50
formatted_history.append(formatted_entry + "\n")
return "\n".join(formatted_history)
##################### EVALUATIONS/CHECKERS #####################
class RelevanceMetric(dspy.Module):
class RelevanceMetricSignature(dspy.Signature):
"""
You are a relevance checker for a conversation, to make sure it stays on track and grounded in factual memories.
Your job is to evaluate how relevant a given string is to a set of memories.
Output a float score of 1.0-10.0 (1-less relevant, 10-more relevant)
"""
memory_context = dspy.InputField(desc="The grounded factual to make sure the conversation sticks to.")
previous_chat_msg = dspy.InputField(desc="Text string to check for relevance to the memories.")
is_relevant = dspy.OutputField(desc="1.0-10.0 relevance score")
class ParseRelevanceMetricSignature(dspy.Signature):
"""
Your job to to extract the float value (a metric we want to isolate) from a given string.
"""
given_str = dspy.InputField(desc="The string containing our desired float value")
metric = dspy.OutputField(desc="The isolated float value")
def __init__(self):
self.check_relevance = dspy.ChainOfThought(self.RelevanceMetricSignature)
self.parse = dspy.Predict(self.ParseRelevanceMetricSignature)
def forward(self, memory_context, previous_chat_msg):
result = self.check_relevance(
memory_context=memory_context,
previous_chat_msg=previous_chat_msg,
).is_relevant
result = float(self.parse(given_str=result).metric)
return result
class ToxicityMetric(dspy.Module):
class ToxicityMetricSignature(dspy.Signature):
"""
Your job is to evaluate if a conversation violated any of the given toxicity settings.
Return a float metric 1.0-10.0 (1-the conversation was extremely toxic, 10-the conversation was completely safe and fair).
"""
toxicity_settings = dspy.InputField(desc="User settings for what is considered toxic.")
chat_history = dspy.InputField(desc="Chat history to check for toxicity")
answer = dspy.OutputField(desc="1.0-10.0")
class ParseToxicityMetricSignature(dspy.Signature):
"""
Your job to to extract the float value (a metric we want to isolate) from a given string.
"""
given_str = dspy.InputField(desc="The string containing our desired float value")
metric = dspy.OutputField(desc="The isolated float value")
def __init__(self, db: dbObject):
self.db = db
self.toxicity_metric = dspy.ChainOfThought(self.ToxicityMetricSignature)
self.parse = dspy.Predict(self.ParseToxicityMetricSignature)
def forward(self, chat_history, home_agent_id):
# call toxicity settings from db
toxicity_settings = self.db.get_toxicty_settings(home_agent_id)
result = self.toxicity_metric(
toxicity_settings=toxicity_settings,
chat_history=chat_history,
).answer
result = self.parse(given_str=result).metric
return result
# I'm thinking we should create another tag like the toxicity flag, but for positive things. Like if it finds anything highly relevant gets a good tag.
# module to determine number of interactions for the conversation
class InteractionLengthExpectation(dspy.Module):
class EnvironmentToExpectedLengthSignature(dspy.Signature):
"""
Based on a given environment and agent settings determine an expected number of interactions that would happen in a conversation, in this setting.
"""
convo_setting = dspy.InputField(desc="The environment for the conversation among agents to take place in.")
num_interactions = dspy.OutputField(desc="The number of interactions expected (back and forth is counted as 1), typically 2-7.")
def forward(self, environment):
# logic to determine number of interactions
result = dspy.ChainOfThought(self.EnvironmentToExpectedLengthSignature)(
convo_setting=environment,
).num_interactions
return int(result)
# module to determine if a conversation should end or be extended
class ConversationMetric(dspy.Module):
class ConversationImportanceSignature(dspy.Signature):
"""
Determine how interested a party is expected to be in a conversation (based on the chat history), given a list of things that are important to the party.
Evaluate the interest in the conversation on a scale of 1.0 to 10.0.
"""
chat_history = dspy.InputField(desc="The chat history to be considered for ending the conversation between agents.")
important_topics = dspy.InputField(desc="A list of information and topics the user finds interesting.")
metric = dspy.OutputField(desc="The party's expected interest in the conversation, scored from 1-10.")
class ParseMetricSignature(dspy.Signature):
"""
Parse the float value (a metric) from a string.
"""
input_str = dspy.InputField(desc="The string to parse.")
metric = dspy.OutputField(desc="The parsed metric.")
def forward(self, chat_history, agent_memories):
metric_str = dspy.ChainOfThought(self.ConversationImportanceSignature)(
chat_history=chat_history,
important_topics=agent_memories,
).metric
metric = dspy.Predict(self.ParseMetricSignature)(input_str=metric_str).metric
return float(metric)
##################### USER MODULES #####################
# handles casual chat with the user
class UserChatModule(dspy.Module):
class UserChatSignature(dspy.Signature):
"""Your task is to be a casual texting buddy of the user, texting with abbreviations and common slang. You can ask questions, provide answers, or just chat. You must follow the settings/instructions given to you."""
guidelines = dspy.InputField(desc="User-set guidelines for the agent to follow during social interactions.")
prompt = dspy.InputField()
answer = dspy.OutputField(desc="A response to the user.")
update_command = dspy.OutputField(desc="Command to update settings, if detected. Optional")
def forward(self, prompt, instructions):
# default casual chat response
response = dspy.ChainOfThought(self.UserChatSignature)(guidelines=instructions, prompt=prompt).answer
return {"answer": response}
# for now, we'll streamline setting instructions, and toxicity_settings.
# they can be set via their endpoints/dbObject methods
# later, we can make them also updateable via the chat, by adding modules/checks here. For now, keep it simple stupid (KISS)
# self.db.set_instructions(self.agent_id, instructions)
# self.db.set_toxicity_settings(self.agent_id, toxicity_settings)
##################### CHAT HISTORY SUMMARIZER #####################
class ChatHistorySummarizer(dspy.Module):
class ChatHistorySummarySignature(dspy.Signature):
"""
Summarize the chat history with very descriptive information about what was shared. You are writing a one paragraph briefing for someone when someone asks you about the same topic.
"""
chat_history = dspy.InputField(desc="The chat history to be summarized.")
summary = dspy.OutputField(desc="A descriptive summary of the chat history.")
# I really think we should include some examples of the format we want them to be stored in. They should include things like the date, away_agent_id, etc. This will be decided by how we interpret it on the recall side.
# the module that turns the memories into conversation specific context should be able to interpret the format of the summary correctly.
def forward(self, chat_history_list):
away_agent_id = chat_history_list[0]['agent_id'] if chat_history_list[0]['agent_id'] != self.agent_id else chat_history_list[1]['agent_id']
rationale_type = dspy.OutputField(
prefix="Reasoning: Let's think step by step in order to",
desc=f"accurately summarize the chat history with very descriptive information about what was shared. I talked to {away_agent_id} about...",
)
chat_history_str = ""
for item in chat_history_list:
label = 'Me' if item['agent_id'] == self.agent_id else f'Agent {item["agent_id"]}'
prompt_label = f'{label}: {item["prompt"]}' if label != 'Me' else label
chat_history_str += f"{prompt_label}\n{label}: {item['response']}\n\n"
response = dspy.ChainOfThought(self.ChatHistorySummarySignature, rationale_type=rationale_type)(
chat_history=chat_history_str
).summary
return response