Skip to content

Commit

Permalink
feat: add method to add to db after llm batch
Browse files Browse the repository at this point in the history
  • Loading branch information
seelengxd committed Sep 23, 2024
1 parent 61a9f88 commit ccea71b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 13 deletions.
94 changes: 94 additions & 0 deletions backend/src/events/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pydantic import BaseModel
from sqlalchemy import select
from src.events.generate_events import CategoryAnalysis
from src.events.models import Analysis, Article, Category, Event
from src.common.database import engine
from sqlalchemy.orm import Session
from functools import cache


class EventLLM(BaseModel):
"""This is EventPublic from generate_events.py without the id field minus date plus rating"""

title: str
description: str
analysis_list: list[CategoryAnalysis]
duplicate: bool
is_singapore: bool
original_article_id: int
# categories: list[str] --> derived from analysis list, not needed
rating: int


@cache
def get_categories():
with Session(engine) as session:
categories = session.scalars(select(Category))
if not categories:
raise ValueError(
"You don't have any categories in your db. Type `uv run src/scripts/seed.py`"
)
return {category.name: category.id for category in categories}


def add_event_to_db(event: EventLLM) -> bool:
"""Returns whether adding the event was successful.
Can fail if category does not exist/article is invalid."""
categories = get_categories()

with Session(engine) as session:
try:
article = session.get(Article, event.original_article_id)
if not article:
print(f"article {event.original_article_id} does not exist")
raise ValueError()

eventORM = Event(
title=event.title,
description=event.description,
duplicate=False,
date=article.date,
is_singapore=event.is_singapore,
original_article_id=event.original_article_id,
rating=event.rating,
)

for analysis in event.analysis_list:
category = analysis.category
content = analysis.analysis

# noticed a mismatch between seed and llm prompt
if category == "Economic":
category = "Economics"

analysisORM = Analysis(
category_id=categories[category], content=content
)

eventORM.analysises.append(analysisORM)

session.add(eventORM)
session.commit()

except Exception as e: # noqa: E722
print("something went wrong:", event)
print(e)
return False


if __name__ == "__main__":
# example usage
add_event_to_db(
EventLLM(
title="test",
description="test",
analysis_list=[
CategoryAnalysis(analysis="a", category="Science & Tech"),
CategoryAnalysis(analysis="b", category="Media"),
],
duplicate=False,
rating=5,
original_article_id=14,
is_singapore=False,
)
)
27 changes: 15 additions & 12 deletions backend/src/scrapers/guardian/scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ def query_page(page: int):
return data["results"]


parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output", help="output file path")
parser.add_argument("-s", "--start", type=int, help="start index of page", default=0)
parser.add_argument("-n", "--number", type=int, help="number of pages", default=50)
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output", help="output file path")
parser.add_argument(
"-s", "--start", type=int, help="start index of page", default=0
)
parser.add_argument("-n", "--number", type=int, help="number of pages", default=50)
args = parser.parse_args()

result = []
for i in range(args.start, args.start + args.number):
result += query_page(i)
print("scraped:", i)
time.sleep(1)
result = []
for i in range(args.start, args.start + args.number):
result += query_page(i)
print("scraped:", i)
time.sleep(1)

with open(args.output, "w") as f:
json.dump(result, f)
with open(args.output, "w") as f:
json.dump(result, f)
2 changes: 1 addition & 1 deletion backend/src/scripts/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ def test_associations():
# session.delete(original_article)


test_associations()
# test_associations()

0 comments on commit ccea71b

Please sign in to comment.