-
Notifications
You must be signed in to change notification settings - Fork 10
/
build_graph.py
103 lines (81 loc) · 2.85 KB
/
build_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import json
from pydantic import BaseModel
import chromadb
# How many matches an individual can have at maximum
N_RESULTS = 10
# Larger = more likely to be matched
DISTANCE_THRESHOLD = 1.1
# Which prompt stored in sqlite to process
# We define this in generate_embeddings.py
COLLECTION_TO_PROCESS = "time_prompt_embeddings"
class NodeData(BaseModel):
name: str
major: str
response: str
topMatch: str
class Node(BaseModel):
id: str
data: NodeData
class Link(BaseModel):
source: str
target: str
def process_collection(collection: chromadb.Collection, nodes: list[Node], links: list[Link]):
results = collection.get(
ids=None, include=["metadatas", "documents", "embeddings"])
if results["embeddings"] is None or results["metadatas"] is None:
raise ValueError("No embeddings found in the collection")
for i, embedding in enumerate(results["embeddings"]):
self_id = results["ids"][i]
self_name = results["metadatas"][i]["name"]
self_major = results['metadatas'][i]['program']
self_response = results['documents'][i]
query = collection.query(
n_results=N_RESULTS,
query_embeddings=[embedding],
include=["metadatas", "distances", "documents"],
where={
"name": {
"$ne": self_name,
}
}
)
nearest_ids = query["ids"][0]
if not query["distances"]:
raise ValueError("No distances found in the query")
distances = query["distances"][0]
name = str(results["metadatas"][i]["name"])
print(
f"{i+1}/{len(results['embeddings'])}: Processing {name} ({self_major})")
# Always give a match
top_match = nearest_ids[0]
new_node = Node(
id=self_id,
data=NodeData(
name=name,
response=self_response,
major=self_major,
topMatch=top_match
)
)
nodes.append(new_node)
# Always add the links
for i, distance in enumerate(distances):
if distance < DISTANCE_THRESHOLD:
links.append(
Link(
source=self_id,
target=nearest_ids[i]
)
)
def main():
chroma_client = chromadb.PersistentClient(path="chromadb")
nodes: list[Node] = []
links: list[Link] = []
collection = chroma_client.get_collection(COLLECTION_TO_PROCESS)
process_collection(collection, nodes, links)
print(f"Graph constructed from collection: {COLLECTION_TO_PROCESS}")
with open("graphData.json", "w") as f:
json.dump({"nodes": [n.model_dump() for n in nodes],
"links": [l.model_dump() for l in links]}, f)
if __name__ == '__main__':
main()