diff --git a/backend/src/events/process.py b/backend/src/events/process.py new file mode 100644 index 00000000..3c957b03 --- /dev/null +++ b/backend/src/events/process.py @@ -0,0 +1,94 @@ +from pydantic import BaseModel +from sqlalchemy import select +from src.events.generate_events import CategoryAnalysis +from src.events.models import Analysis, Article, Category, Event +from src.common.database import engine +from sqlalchemy.orm import Session +from functools import cache + + +class EventLLM(BaseModel): + """This is EventPublic from generate_events.py without the id field minus date plus rating""" + + title: str + description: str + analysis_list: list[CategoryAnalysis] + duplicate: bool + is_singapore: bool + original_article_id: int + # categories: list[str] --> derived from analysis list, not needed + rating: int + + +@cache +def get_categories(): + with Session(engine) as session: + categories = session.scalars(select(Category)) + if not categories: + raise ValueError( + "You don't have any categories in your db. Type `uv run src/scripts/seed.py`" + ) + return {category.name: category.id for category in categories} + + +def add_event_to_db(event: EventLLM) -> bool: + """Returns whether adding the event was successful. + Can fail if category does not exist/article is invalid.""" + categories = get_categories() + + with Session(engine) as session: + try: + article = session.get(Article, event.original_article_id) + if not article: + print(f"article {event.original_article_id} does not exist") + raise ValueError() + + eventORM = Event( + title=event.title, + description=event.description, + duplicate=False, + date=article.date, + is_singapore=event.is_singapore, + original_article_id=event.original_article_id, + rating=event.rating, + ) + + for analysis in event.analysis_list: + category = analysis.category + content = analysis.analysis + + # noticed a mismatch between seed and llm prompt + if category == "Economic": + category = "Economics" + + analysisORM = Analysis( + category_id=categories[category], content=content + ) + + eventORM.analysises.append(analysisORM) + + session.add(eventORM) + session.commit() + + except Exception as e: # noqa: E722 + print("something went wrong:", event) + print(e) + return False + + +if __name__ == "__main__": + # example usage + add_event_to_db( + EventLLM( + title="test", + description="test", + analysis_list=[ + CategoryAnalysis(analysis="a", category="Science & Tech"), + CategoryAnalysis(analysis="b", category="Media"), + ], + duplicate=False, + rating=5, + original_article_id=14, + is_singapore=False, + ) + ) diff --git a/backend/src/scrapers/guardian/scrape.py b/backend/src/scrapers/guardian/scrape.py index 4b735820..007be6ba 100644 --- a/backend/src/scrapers/guardian/scrape.py +++ b/backend/src/scrapers/guardian/scrape.py @@ -25,17 +25,20 @@ def query_page(page: int): return data["results"] -parser = argparse.ArgumentParser() -parser.add_argument("-o", "--output", help="output file path") -parser.add_argument("-s", "--start", type=int, help="start index of page", default=0) -parser.add_argument("-n", "--number", type=int, help="number of pages", default=50) -args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--output", help="output file path") + parser.add_argument( + "-s", "--start", type=int, help="start index of page", default=0 + ) + parser.add_argument("-n", "--number", type=int, help="number of pages", default=50) + args = parser.parse_args() -result = [] -for i in range(args.start, args.start + args.number): - result += query_page(i) - print("scraped:", i) - time.sleep(1) + result = [] + for i in range(args.start, args.start + args.number): + result += query_page(i) + print("scraped:", i) + time.sleep(1) -with open(args.output, "w") as f: - json.dump(result, f) + with open(args.output, "w") as f: + json.dump(result, f) diff --git a/backend/src/scripts/seed.py b/backend/src/scripts/seed.py index 387b6da2..ee76df83 100644 --- a/backend/src/scripts/seed.py +++ b/backend/src/scripts/seed.py @@ -97,4 +97,4 @@ def test_associations(): # session.delete(original_article) -test_associations() +# test_associations()