Skip to content

Commit

Permalink
feat: set up workflow
Browse files Browse the repository at this point in the history
for populating vector_db
  • Loading branch information
marcus-ny committed Sep 24, 2024
1 parent be9f93c commit 3257f82
Show file tree
Hide file tree
Showing 7 changed files with 5,700 additions and 159 deletions.
5,760 changes: 5,641 additions & 119 deletions backend/lm_events_output.json

Large diffs are not rendered by default.

65 changes: 35 additions & 30 deletions backend/src/embeddings/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from src.common.constants import OPENAI_API_KEY
from src.common.constants import PINECONE_API_KEY

from src.scrapers.guardian.get_analyses import get_analyses

import os
import time

Expand All @@ -23,7 +25,7 @@


def create_vector_store():
index_name = "langchain-test-index-3" # change to create a new index
index_name = "langchain-test-index-4" # change to create a new index

existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]

Expand Down Expand Up @@ -51,37 +53,40 @@ def create_vector_store():
vector_store = create_vector_store()


def store_documents(events: list[dict]):
# print(events)
id = 0
def store_documents():
documents = []
for event in events:
analysis_list = event.get("analysis_list")
for analysis in analysis_list:
document = Document(
page_content=analysis.get("analysis"),
metadata={
"title": event.get("title"),
"description": event.get("description"),
"categories": str(event.get("categories")),
"is_singapore": event.get("is_singapore"),
"questions": event.get("questions"),
},
id=id,
)
documents.append(document)
id += 1

analysis_list = get_analyses()
for analysis in analysis_list:
document = Document(
page_content=analysis.get("content"),
metadata={
"id": analysis.get("id"),
"event_id": analysis.get("event_id"),
"category_id": analysis.get("category_id"),
},
)
documents.append(document)

uuids = [str(uuid4()) for _ in range(len(documents))]
vector_store.add_documents(documents=documents, ids=uuids)
print("Job done")


# if __name__ == "__main__":
# # store_documents()
# query = "There is a need for more stringent regulations on social media platforms."
# docs = vector_store.similarity_search_with_relevance_scores(query, k=3)
# if len(docs) == 0:
# print("No documents found")
# for doc in docs:
# print(doc)
print(f"Stored {len(documents)} documents")


def get_similar_results(query: str, top_k: int = 5):
documents = vector_store.similarity_search_with_relevance_scores(
query=query, k=top_k
)
results = []
for document, score in documents:
results.append(
{
"id": document.metadata["id"],
"event_id": document.metadata["event_id"],
"category_id": document.metadata["category_id"],
"content": document.page_content,
"score": score,
}
)
return results
8 changes: 8 additions & 0 deletions backend/src/events/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.events.schemas import EventDTO, EventIndexResponse
from src.notes.models import Note, NoteType
from src.notes.schemas import NoteDTO
from src.embeddings.vector_store import get_similar_results


router = APIRouter(prefix="/events", tags=["events"])
Expand Down Expand Up @@ -108,3 +109,10 @@ def read_event(
session.add(read_event)
session.commit()
return


@router.get("/search")
def search_whatever(query: str):
# call your function and return the result
results = get_similar_results(query)
return results
3 changes: 3 additions & 0 deletions backend/src/lm/generate_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,13 @@ class EventDetails(BaseModel):

def generate_events(articles: list[dict]) -> List[EventPublic]:
res = []
count = 1
for article in articles:
event_details = generate_events_from_article(article)
for example in event_details.get("examples"):
res.append(form_event_json(example, article))
print(f"Generated {count} events")
count += 1

with open(file_path, "w") as json_file:
json.dump(res, json_file, indent=4)
Expand Down
12 changes: 6 additions & 6 deletions backend/src/scrapers/guardian/get_analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
def get_analyses():
with Session(engine) as session:
# Select the first 5 articles
result = session.scalars(select(Analysis).limit(5))
result = session.scalars(select(Analysis))

analyses = []
# Iterate over the result and print each article
for article in result:
for analysis in result:
data_dict = {
"id": article.id,
"event_id": article.event_id,
"category_id": article.category_id,
"content": article.content,
"id": analysis.id,
"event_id": analysis.event_id,
"category_id": analysis.category_id,
"content": analysis.content,
}
analyses.append(data_dict)

Expand Down
2 changes: 1 addition & 1 deletion backend/src/scrapers/guardian/get_articles.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def get_articles():
with Session(engine) as session:
# Select the first 5 articles
result = session.scalars(select(Article).limit(5))
result = session.scalars(select(Article).limit(100))

articles = []
# Iterate over the result and print each article
Expand Down
9 changes: 6 additions & 3 deletions backend/src/scripts/populate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
from src.events.process import add_event_to_db
from src.embeddings.vector_store import store_documents
from src.events.process import EventLLM
from src.embeddings.vector_store import store_documents


# Populate the db with events from lm_events_output.json
def populate():
with open("backend/lm_events_output.json", "r") as f:
with open("lm_events_output.json", "r") as f:
events = json.load(f)
for event in events:
event_obj = EventLLM(
Expand All @@ -21,5 +21,8 @@ def populate():
add_event_to_db(event_obj)


if __name__ == "__main__":
def set_up():
# add events + analyses to db
populate()
# store analyses in vector store
store_documents()

0 comments on commit 3257f82

Please sign in to comment.