diff --git a/backend/src/events/models.py b/backend/src/events/models.py index 0e00ac6b..07fed17b 100644 --- a/backend/src/events/models.py +++ b/backend/src/events/models.py @@ -21,7 +21,7 @@ class Article(Base): source: Mapped[ArticleSource] date: Mapped[datetime] - event: Mapped["Event"] = relationship(back_populates="original_article") + events: Mapped[list["Event"]] = relationship(back_populates="original_article") class Event(Base): @@ -40,7 +40,7 @@ class Event(Base): back_populates="events", secondary="event_category" ) - original_article: Mapped[Article] = relationship(back_populates="event") + original_article: Mapped[Article] = relationship(back_populates="events") class Category(Base): diff --git a/backend/src/scripts/seed.py b/backend/src/scripts/seed.py new file mode 100644 index 00000000..834e1b78 --- /dev/null +++ b/backend/src/scripts/seed.py @@ -0,0 +1,84 @@ +from datetime import datetime +from sqlalchemy import select +from src.events.models import Article, ArticleSource, Category, Event +from sqlalchemy.orm import Session +from src.common.database import engine + + +def add_categories(): + CATEGORIES = [ + "Arts & Humanities", + "Science & Tech", + "Politics", + "Media", + "Environment", + "Education", + "Sports", + "Gender & Equality", + "Religion", + "Society & Culture", + "Economics", + ] + categories = [Category(name=category_name) for category_name in CATEGORIES] + with Session(engine) as session: + # If categories are already added, return + if session.scalars(select(Category)).first() is not None: + return + session.add_all(categories) + session.commit() + + +add_categories() + + +def test_associations(): + with Session(engine) as session: + article = Article( + title="test article", + summary="test summary", + url="https://whatever.com", + source=ArticleSource.CNA, + body="test body", + date="2024-02-05", + ) + event = Event( + title="test event 1", + description="x", + analysis="x", + duplicate=False, + date=datetime.now(), + is_singapore=False, + ) + article.events.append(event) + session.add(article) + session.commit() + + session.refresh(article) + session.refresh(event) + print(article) + print(event) + event_id = event.id + + with Session(engine) as session: + event_again = session.scalar(select(Event).where(Event.id == event_id)) + categories = session.scalars( + select(Category).where(Category.name.in_(["Environment", "Media"])) + ) + event_again.categories.extend(categories) + session.add(event_again) + session.commit() + + with Session(engine) as session: + event_again = session.scalar(select(Event).where(Event.id == event_id)) + print(event_again) + print(event_again.original_article) + print(event_again.categories) + event_again.categories.clear() + original_article = event_again.original_article + session.add(event_again) + session.commit() + session.delete(event_again) + session.delete(original_article) + + +# test_associations()