Skip to content

Commit

Permalink
fix: style checks
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus-ny committed Sep 22, 2024
1 parent 6e4b380 commit 1882bdd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
28 changes: 13 additions & 15 deletions backend/src/events/generate_events.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from typing import List
from pydantic import BaseModel
from src.scrapers.guardian.scrape import query_page

from src.common.constants import LANGCHAIN_API_KEY
from src.common.constants import LANGCHAIN_TRACING_V2
from src.common.constants import OPENAI_API_KEY
Expand All @@ -12,15 +14,14 @@
os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser

lm_model = ChatOpenAI(model="gpt-4o-mini")


class CategoryAnalysis(BaseModel):
category: str
analysis: str


class EventPublic(BaseModel):
id: int
title: str
Expand All @@ -45,6 +46,7 @@ class Example(BaseModel):
class EventDetails(BaseModel):
examples: list[Example]


SYSPROMPT = """
You are a helpful assistant helping students find examples for their A Level general paper to substantiate their arguments in their essays.
Given an article, you should provide examples that can be used to support or refute arguments in a General Paper essay.
Expand Down Expand Up @@ -125,6 +127,7 @@ class EventDetails(BaseModel):
“But the team reacted to what we had to do playing at home with 10 men. We didn’t want to be so deep defending like this, but we read the game and we played the game that we had to play and we should have got rewarded.”"""


def generate_events() -> List[EventPublic]:
articles = query_page(1)
articles = articles[:50]
Expand All @@ -136,7 +139,7 @@ def generate_events() -> List[EventPublic]:
for example in event_details.get("examples"):
res.append(form_event_json(example))

return res;
return res


def form_event_json(event_details) -> dict:
Expand All @@ -149,23 +152,18 @@ def form_event_json(event_details) -> dict:
date="",
is_singapore=event_details.get("in_singapore", False),
categories=event_details.get("category", []),
original_article_id=0
original_article_id=0,
)


def generate_events_from_article(article: str) -> dict:
messages = [
SystemMessage(content=SYSPROMPT),
HumanMessage(content=article)
]
messages = [SystemMessage(content=SYSPROMPT), HumanMessage(content=article)]

result = lm_model.invoke(messages);
result = lm_model.invoke(messages)
parser = JsonOutputParser(pydantic_object=EventDetails)
events = parser.invoke(result)
return events


# if __name__ == "__main__":
# print(generate_events())



24 changes: 11 additions & 13 deletions backend/src/events/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from uuid import uuid4
from langchain_openai import OpenAIEmbeddings

from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone, ServerlessSpec
from langchain_core.documents import Document
from src.events.generate_events import generate_events

from src.common.constants import LANGCHAIN_API_KEY
from src.common.constants import LANGCHAIN_TRACING_V2
from src.common.constants import OPENAI_API_KEY
from src.common.constants import PINECONE_API_KEY

import os
import time

os.environ["LANGCHAIN_TRACING_V2"] = LANGCHAIN_TRACING_V2
os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY

from pinecone import Pinecone, ServerlessSpec

from langchain_core.documents import Document

from src.events.generate_events import generate_events

pc = Pinecone(api_key=PINECONE_API_KEY)

import time

index_name = "langchain-test-index" # change if desired

Expand All @@ -41,10 +41,9 @@

embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")

from langchain_pinecone import PineconeVectorStore

vector_store = PineconeVectorStore(index=index, embedding=embeddings)


def store_documents():
events = generate_events()
# print(events)
Expand All @@ -59,21 +58,20 @@ def store_documents():
"categories": str(event.categories),
"is_singapore": event.is_singapore,
},
id=id
id=id,
)
documents.append(document)
id += 1
uuids = [str(uuid4()) for _ in range(len(documents))]
vector_store.add_documents(documents=documents, ids=uuids)
print("Job done")



if __name__ == "__main__":
# store_documents()
query = "No, the use of performance enhancing drugs undermines the integrity of sport: The use of performance-enhancing drugs violates the principle of fair play. It creates an uneven playing field where success depends more on pharmaceutical intervention than talent or hard work, eroding the values that sports should represent."
docs = vector_store.similarity_search_with_relevance_scores(query, k = 3)
if (len(docs) == 0):
docs = vector_store.similarity_search_with_relevance_scores(query, k=3)
if len(docs) == 0:
print("No documents found")
for doc in docs:
print(doc)
print(doc)

0 comments on commit 1882bdd

Please sign in to comment.