diff --git a/backend/src/events/generate_events.py b/backend/src/events/generate_events.py index 1ca7f5b2..209c2366 100644 --- a/backend/src/events/generate_events.py +++ b/backend/src/events/generate_events.py @@ -1,7 +1,9 @@ +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.output_parsers import JsonOutputParser from typing import List from pydantic import BaseModel from src.scrapers.guardian.scrape import query_page - from src.common.constants import LANGCHAIN_API_KEY from src.common.constants import LANGCHAIN_TRACING_V2 from src.common.constants import OPENAI_API_KEY @@ -12,15 +14,14 @@ os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2 os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY -from langchain_openai import ChatOpenAI -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.output_parsers import JsonOutputParser - lm_model = ChatOpenAI(model="gpt-4o-mini") + + class CategoryAnalysis(BaseModel): category: str analysis: str + class EventPublic(BaseModel): id: int title: str @@ -45,6 +46,7 @@ class Example(BaseModel): class EventDetails(BaseModel): examples: list[Example] + SYSPROMPT = """ You are a helpful assistant helping students find examples for their A Level general paper to substantiate their arguments in their essays. Given an article, you should provide examples that can be used to support or refute arguments in a General Paper essay. @@ -125,6 +127,7 @@ class EventDetails(BaseModel): “But the team reacted to what we had to do playing at home with 10 men. We didn’t want to be so deep defending like this, but we read the game and we played the game that we had to play and we should have got rewarded.”""" + def generate_events() -> List[EventPublic]: articles = query_page(1) articles = articles[:50] @@ -136,7 +139,7 @@ def generate_events() -> List[EventPublic]: for example in event_details.get("examples"): res.append(form_event_json(example)) - return res; + return res def form_event_json(event_details) -> dict: @@ -149,23 +152,18 @@ def form_event_json(event_details) -> dict: date="", is_singapore=event_details.get("in_singapore", False), categories=event_details.get("category", []), - original_article_id=0 + original_article_id=0, ) def generate_events_from_article(article: str) -> dict: - messages = [ - SystemMessage(content=SYSPROMPT), - HumanMessage(content=article) - ] + messages = [SystemMessage(content=SYSPROMPT), HumanMessage(content=article)] - result = lm_model.invoke(messages); + result = lm_model.invoke(messages) parser = JsonOutputParser(pydantic_object=EventDetails) events = parser.invoke(result) return events + # if __name__ == "__main__": # print(generate_events()) - - - diff --git a/backend/src/events/vector_store.py b/backend/src/events/vector_store.py index 09f2078a..4e8056bd 100644 --- a/backend/src/events/vector_store.py +++ b/backend/src/events/vector_store.py @@ -1,27 +1,27 @@ from uuid import uuid4 from langchain_openai import OpenAIEmbeddings +from langchain_pinecone import PineconeVectorStore +from pinecone import Pinecone, ServerlessSpec +from langchain_core.documents import Document +from src.events.generate_events import generate_events + from src.common.constants import LANGCHAIN_API_KEY from src.common.constants import LANGCHAIN_TRACING_V2 from src.common.constants import OPENAI_API_KEY from src.common.constants import PINECONE_API_KEY import os +import time os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2 os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY -from pinecone import Pinecone, ServerlessSpec - -from langchain_core.documents import Document - -from src.events.generate_events import generate_events pc = Pinecone(api_key=PINECONE_API_KEY) -import time index_name = "langchain-test-index" # change if desired @@ -41,10 +41,9 @@ embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") -from langchain_pinecone import PineconeVectorStore - vector_store = PineconeVectorStore(index=index, embedding=embeddings) + def store_documents(): events = generate_events() # print(events) @@ -59,21 +58,20 @@ def store_documents(): "categories": str(event.categories), "is_singapore": event.is_singapore, }, - id=id + id=id, ) documents.append(document) id += 1 uuids = [str(uuid4()) for _ in range(len(documents))] vector_store.add_documents(documents=documents, ids=uuids) print("Job done") - if __name__ == "__main__": # store_documents() query = "No, the use of performance enhancing drugs undermines the integrity of sport: The use of performance-enhancing drugs violates the principle of fair play. It creates an uneven playing field where success depends more on pharmaceutical intervention than talent or hard work, eroding the values that sports should represent." - docs = vector_store.similarity_search_with_relevance_scores(query, k = 3) - if (len(docs) == 0): + docs = vector_store.similarity_search_with_relevance_scores(query, k=3) + if len(docs) == 0: print("No documents found") for doc in docs: - print(doc) \ No newline at end of file + print(doc)