-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_embeddings.py
95 lines (76 loc) · 2.52 KB
/
generate_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import asyncio
import logging
import os
import asyncpg
from pypika import PostgreSQLQuery as Query
from core.context_loader import load_context
from core.embed import embed_text
from core.tables import tweet_embeds_table
from core.tables import tweets_table
# load context
context = load_context()
# setup logging
log_file_path = os.path.join(os.getcwd(), "logs", "embeddings.log")
logging.basicConfig(
filename=log_file_path,
filemode="w",
format="%(asctime)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
level=logging.INFO,
)
async def generate_embeddings():
"""
Generate embeddings for all tweets in the database.
"""
# connect to postgres
conn = await asyncpg.connect(context["postgres_url"])
# set search path (schema)
await conn.execute(f"SET search_path TO {context['postgres_schema']}")
# load all tweets from DB
query = (
Query.from_(tweets_table)
.select("tweet_id", "tweet_text")
.where(
tweets_table.field("tweet_id").notin(
tweet_embeds_table.select("tweet_id"),
),
)
)
query = query.get_sql()
tweets = await conn.fetch(query)
logging.info(f"Fetched {len(tweets)} tweets for embedding generation")
# insert query
insert_query = Query.into(tweet_embeds_table).columns(
"tweet_id",
"embedding",
)
# generate embeddings
insert_values = []
for tweet in tweets:
tweet_text = tweet["tweet_text"]
tweet_id = tweet["tweet_id"]
try:
embedding = await embed_text(tweet_text)
except Exception as exc:
logging.error(f"Error generating embedding {tweet_id}: {exc}")
continue
if len(embedding.data) > 1:
logging.warning(f"Multiple embeddings for tweet {tweet_id}")
continue
vector = embedding.data[0].embedding
insert_values.append((tweet_id, vector))
if len(insert_values) >= 10:
query = insert_query.insert(*insert_values)
query = query.get_sql()
response = await conn.execute(query)
logging.info(f"Inserted {response} tweet embeddings")
insert_values = []
# insert remaining values
if insert_values:
query = insert_query.insert(*insert_values)
query = query.get_sql()
response = await conn.execute(query)
logging.info(f"Inserted {response} tweet embeddings")
await conn.close()
if __name__ == "__main__":
asyncio.run(generate_embeddings())