-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from cs3216-a3-group-4/seeleng/add-method-to-a…
…dd-event-to-db feat: add method to add to db after llm batch
- Loading branch information
Showing
3 changed files
with
110 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from pydantic import BaseModel | ||
from sqlalchemy import select | ||
from src.events.generate_events import CategoryAnalysis | ||
from src.events.models import Analysis, Article, Category, Event | ||
from src.common.database import engine | ||
from sqlalchemy.orm import Session | ||
from functools import cache | ||
|
||
|
||
class EventLLM(BaseModel): | ||
"""This is EventPublic from generate_events.py without the id field minus date plus rating""" | ||
|
||
title: str | ||
description: str | ||
analysis_list: list[CategoryAnalysis] | ||
duplicate: bool | ||
is_singapore: bool | ||
original_article_id: int | ||
# categories: list[str] --> derived from analysis list, not needed | ||
rating: int | ||
|
||
|
||
@cache | ||
def get_categories(): | ||
with Session(engine) as session: | ||
categories = session.scalars(select(Category)) | ||
if not categories: | ||
raise ValueError( | ||
"You don't have any categories in your db. Type `uv run src/scripts/seed.py`" | ||
) | ||
return {category.name: category.id for category in categories} | ||
|
||
|
||
def add_event_to_db(event: EventLLM) -> bool: | ||
"""Returns whether adding the event was successful. | ||
Can fail if category does not exist/article is invalid.""" | ||
categories = get_categories() | ||
|
||
with Session(engine) as session: | ||
try: | ||
article = session.get(Article, event.original_article_id) | ||
if not article: | ||
print(f"article {event.original_article_id} does not exist") | ||
raise ValueError() | ||
|
||
eventORM = Event( | ||
title=event.title, | ||
description=event.description, | ||
duplicate=False, | ||
date=article.date, | ||
is_singapore=event.is_singapore, | ||
original_article_id=event.original_article_id, | ||
rating=event.rating, | ||
) | ||
|
||
for analysis in event.analysis_list: | ||
category = analysis.category | ||
content = analysis.analysis | ||
|
||
# noticed a mismatch between seed and llm prompt | ||
if category == "Economic": | ||
category = "Economics" | ||
|
||
analysisORM = Analysis( | ||
category_id=categories[category], content=content | ||
) | ||
|
||
eventORM.analysises.append(analysisORM) | ||
|
||
session.add(eventORM) | ||
session.commit() | ||
|
||
except Exception as e: # noqa: E722 | ||
print("something went wrong:", event) | ||
print(e) | ||
return False | ||
|
||
|
||
if __name__ == "__main__": | ||
# example usage | ||
add_event_to_db( | ||
EventLLM( | ||
title="test", | ||
description="test", | ||
analysis_list=[ | ||
CategoryAnalysis(analysis="a", category="Science & Tech"), | ||
CategoryAnalysis(analysis="b", category="Media"), | ||
], | ||
duplicate=False, | ||
rating=5, | ||
original_article_id=14, | ||
is_singapore=False, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,4 +97,4 @@ def test_associations(): | |
# session.delete(original_article) | ||
|
||
|
||
test_associations() | ||
# test_associations() |