-
Notifications
You must be signed in to change notification settings - Fork 0
/
dbObject.py
375 lines (311 loc) · 14.7 KB
/
dbObject.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import os
from datetime import datetime, timezone
from dotenv import load_dotenv
import weaviate
from weaviate.exceptions import WeaviateBaseError
load_dotenv()
# This class is the main interface to the Weaviate database. It handles all interactions with the database.
class dbObject:
def __init__(self):
load_dotenv()
self.client = weaviate.Client(
# url=os.getenv("WCS_URL"),
url="https://sadie-testing-jzq2se3h.weaviate.network",
# auth_client_secret=weaviate.auth.AuthApiKey(os.getenv("WCS_API_KEY")),
auth_client_secret=weaviate.auth.AuthApiKey("Bp3Y97srB6oaueYODD8rze29o923WxTfdLXw"),
additional_headers={
"X-OpenAI-Api-Key": os.getenv("OPENAI_API_KEY")
}
)
assert self.client.is_ready() # check if client is ready
self._ensure_base_classes_exist()
def _ensure_base_classes_exist(self):
network_class = {
"class": "Networks",
"properties": [
{"name": "networkID", "dataType": ["string"]},
{"name": "name", "dataType": ["string"]},
{"name": "description", "dataType": ["text"]}
]
}
agent_class = {
"class": "Agents",
"properties": [
{"name": "agentID", "dataType": ["string"]},
{"name": "network", "dataType": ["Network"]},
{"name": "createdAt", "dataType": ["date"]},
{"name": "instructions", "dataType": ["text"]},
{"name": "toxicitySettings", "dataType": ["text"]}
]
}
# if the database hasn't been setup it will create those classes
try:
existing_classes = self.client.schema.get()['classes']
existing_class_names = [cls['class'] for cls in existing_classes]
if "Networks" not in existing_class_names:
self.client.schema.create_class(network_class)
print("Networks class created successfully.")
else:
print("Networks class already exists.")
if "Agents" not in existing_class_names:
self.client.schema.create_class(agent_class)
print("Agents class created successfully.")
else:
print("Agents class already exists.")
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
# *** Network methods ***
# ? Do we need to autogenerate unique network IDs? max: yeah we should figure something out, maybe throw error if they use a take network id
def create_network(self, network_id, name, description):
network = {
"networkID": network_id,
"name": name,
"description": description
}
try:
self.client.data_object.create(network, "Networks")
print(f"Network {network_id} created successfully.")
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
def get_network(self, network_name) -> str:
try:
response = self.client.query.get("Networks", ["networkID", "name", "description"]) \
.with_where({
"path": ["name"],
"operator": "Equal",
"valueString": network_name
}) \
.do()
if 'data' in response and 'Get' in response['data'] and 'Networks' in response['data']['Get']:
networks = response['data']['Get']['Networks']
if len(networks) > 0:
return networks[0].get("networkID")
else:
print("No networks found with the provided agentID.")
return None
else:
print("Unexpected response structure.")
return None
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
return None
# *** Agent methods ***
# create an agent
def create_agent(self, network_id, instructions=None, toxicitySettings=None) -> str:
agent_id = self._get_next_agent_id()
print(f"Creating agent with ID: {agent_id}")
default_instructions = "You are a Gen Z texter. You can use abbreviations and common slang. You can ask questions, provide answers, or just chat. You should not say anything offensive, toxic, ignorant, or malicious."
default_toxicitySettings = "You are moderate and not overly sensitive, yet do not tolerate any form of hate speech, racism, or discrimination. You are open to learning and growing."
agent_object = {
"agentID": str(agent_id),
"network": {"beacon": f"weaviate://localhost/Network/{network_id}"},
"createdAt": datetime.now(timezone.utc).isoformat(),
"instructions": instructions if instructions is not None else default_instructions,
"toxicitySettings": toxicitySettings if toxicitySettings is not None else default_toxicitySettings
}
try:
# create the agent
self.client.data_object.create(agent_object, "Agents")
# create memory class for the agent
self._create_agent_class(str(agent_id))
print(f"Agent '{agent_id}' created successfully.")
return str(agent_id)
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
return None
# add data to an agent
def add_agent_data(self, agent_id: str, data_content, toxicity_flag=False, interest_metric=0.0):
agent_data_class = f"AgentData_{agent_id}"
agent_data_object = {
"dataContent": data_content,
"createdAt": datetime.now(timezone.utc).isoformat(),
"toxicityFlag": toxicity_flag,
"interestMetric": interest_metric
}
print(f"Inserting data into {agent_data_class}: {agent_data_object}")
try:
self.client.data_object.create(agent_data_object, agent_data_class)
print(f"Data for agent '{agent_id}' added successfully.")
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
# memory retrieval for the agent
def get_agent_memory(self, agent_id, query_string=None):
agent_data_class = f"AgentData_{agent_id}"
try:
if query_string:
# query based on dataContent relevancy to the query string
response = self.client.query.get(agent_data_class, ["dataContent", "toxicityFlag"]).do() # \
# .with_bm25(query=query_string) \
# .with_additional("score") \
# .do()
# # .with_limit(10) \
else:
# returns all memories for the agent if no query string is provided
response = self.client.query.get(agent_data_class, ["dataContent", "toxicityFlag"]).do()
# print(response)
# * parse memories and only keep one's over our relevancy threshold
filtered_response = []
if 'data' in response and 'Get' in response['data'] and agent_data_class in response['data']['Get']:
for item in response['data']['Get'][agent_data_class]:
# if query_string:
# if '_additional' in item and 'score' in item['_additional']:
# score = float(item['_additional']['score'])
# if score > 0.00: # TODO: set threshold here (0.01 is just for testing, i'm thinking we'll use like 0.5 or 0.6)
# # if item['toxicityFlag'] == False: # strict. we could set it to be loose or even avoid it in conversation
# filtered_response.append(item)
# else:
filtered_response.append(item)
else:
print(f"No data found for agent '{agent_id}'.")
# print(filtered_response)
# print(f"Memory retrieval for agent '{agent_id}' successful. {filtered_response}")
return filtered_response
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
return None
############################ INSTRUCTIONS #######################################
# get the instructions prompt for an agent
def get_instructions(self, agent_id):
try:
response = self.client.query.get("Agents", ["instructions"]) \
.with_where({
"path": ["agentID"],
"operator": "Equal",
"valueString": agent_id
}) \
.do()
if 'data' in response and 'Get' in response['data'] and 'Agents' in response['data']['Get']:
agents = response['data']['Get']['Agents']
if len(agents) > 0:
return agents[0]['instructions']
else:
print("No agents found with the provided agentID.")
return None
else:
print("Unexpected response structure.")
return None
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
return None
# update the context prompt for an agent
def update_instructions(self, agent_id, new_instructions):
try:
if new_instructions is None:
raise ValueError("New instructions is None, cannot update instructions prompt.")
# get uuid of agent to update
response = self.client.query.get("Agents", ["_additional { id }"]) \
.with_where({
"path": ["agentID"],
"operator": "Equal",
"valueString": agent_id
}) \
.do()
if not response['data']['Get']['Agents']:
print(f"No agent found with agentID '{agent_id}'")
return
uuid = response['data']['Get']['Agents'][0]['_additional']['id']
# update the instructions prompt
self.client.data_object.update({
"instructions": new_instructions
}, class_name="Agents", uuid=uuid)
print(f"Instructions prompt for agent '{agent_id}' updated successfully.")
# return the instructions prompt
return new_instructions
except ValueError as e:
print(f"Validation error: {e}")
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
############################ TOXICITY SETTINGS #######################################
# get the instructions prompt for an agent
def get_toxicty_settings(self, agent_id):
try:
response = self.client.query.get("Agents", ["toxicitySettings"]) \
.with_where({
"path": ["agentID"],
"operator": "Equal",
"valueString": agent_id
}) \
.do()
if 'data' in response and 'Get' in response['data'] and 'Agents' in response['data']['Get']:
agents = response['data']['Get']['Agents']
if len(agents) > 0:
return agents[0]['toxicitySettings']
else:
print("No agents found with the provided agentID.")
return None
else:
print("Unexpected response structure.")
return None
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
return None
# update the context prompt for an agent
def update_toxicity_settings(self, agent_id, new_toxicity_settings):
try:
if new_toxicity_settings is None:
raise ValueError("New toxicity settings is None, cannot update toxicity settings.")
# get uuid of agent to update
response = self.client.query.get("Agents", ["_additional { id }"]) \
.with_where({
"path": ["agentID"],
"operator": "Equal",
"valueString": agent_id
}) \
.do()
if not response['data']['Get']['Agents']:
print(f"No agent found with agentID '{agent_id}'")
return
uuid = response['data']['Get']['Agents'][0]['_additional']['id']
# update the instructions prompt
self.client.data_object.update({
"toxicitySettings": new_toxicity_settings
}, class_name="Agents", uuid=uuid)
print(f"Toxicity settings for agent '{agent_id}' updated successfully.")
return new_toxicity_settings
except ValueError as e:
print(f"Validation error: {e}")
except WeaviateBaseError as e:
print(f"An error occurred: {e}")
# *** Private methods ***
def _create_agent_class(self, agent_id):
agent_data_class = {
"class": f"AgentData_{agent_id}",
"properties": [
{"name": "dataContent", "dataType": ["text"]},
{"name": "createdAt", "dataType": ["date"]},
{"name": "toxicityFlag", "dataType": ["boolean"]}
],
"vectorizer": "text2vec-openai",
"moduleConfig": {
"text2vec-openai": {
"vectorizeClassName": True,
"vectorizeProperties": ["dataContent"]
}
}
}
try:
# execute creating class in weaviate
self.client.schema.create_class(agent_data_class)
print(f"Class (memory) for agent '{agent_id}' created successfully.")
except WeaviateBaseError as e:
print(f"An error occurred while creating class for agent '{agent_id}': {e}")
def _get_next_agent_id(self):
try:
# graphql query to get the count of agents
query = """
{
Aggregate {
Agents {
meta {
count
}
}
}
}
"""
response = self.client.query.raw(query)
count = response["data"]["Aggregate"]["Agents"][0]["meta"]["count"]
return count + 1 # returns the next available agentID
except WeaviateBaseError as e:
print(f"An error occurred while getting next agent ID: {e}")
return None