From da7541c57af464b4375a6a73351fa6e9d3b17c3b Mon Sep 17 00:00:00 2001 From: seeleng Date: Fri, 27 Sep 2024 02:18:00 +0800 Subject: [PATCH] feat: add fallback to alembic --- .../65b186bae301_add_fallback_table.py | 42 ++++++ ...4411b0d_add_elaboration_for_third_layer.py | 35 +++++ backend/alembic/versions/c1022b3f1de5_.py | 24 ++++ backend/src/user_questions/models.py | 25 +++- backend/src/user_questions/router.py | 136 ++++++++++++------ backend/src/user_questions/schemas.py | 32 +++-- 6 files changed, 244 insertions(+), 50 deletions(-) create mode 100644 backend/alembic/versions/65b186bae301_add_fallback_table.py create mode 100644 backend/alembic/versions/951ac4411b0d_add_elaboration_for_third_layer.py create mode 100644 backend/alembic/versions/c1022b3f1de5_.py diff --git a/backend/alembic/versions/65b186bae301_add_fallback_table.py b/backend/alembic/versions/65b186bae301_add_fallback_table.py new file mode 100644 index 00000000..aa607624 --- /dev/null +++ b/backend/alembic/versions/65b186bae301_add_fallback_table.py @@ -0,0 +1,42 @@ +"""Add fallback table + +Revision ID: 65b186bae301 +Revises: c1022b3f1de5 +Create Date: 2024-09-27 01:37:55.705475 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "65b186bae301" +down_revision: Union[str, None] = "c1022b3f1de5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "fallback", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("point_id", sa.Integer(), nullable=False), + sa.Column("alt_approach", sa.String(), nullable=False), + sa.Column("general_argument", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["point_id"], + ["point.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("fallback") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/951ac4411b0d_add_elaboration_for_third_layer.py b/backend/alembic/versions/951ac4411b0d_add_elaboration_for_third_layer.py new file mode 100644 index 00000000..aab57a4c --- /dev/null +++ b/backend/alembic/versions/951ac4411b0d_add_elaboration_for_third_layer.py @@ -0,0 +1,35 @@ +"""Add elaboration for third layer + +Revision ID: 951ac4411b0d +Revises: feeb1c78c0a2 +Create Date: 2024-09-26 21:39:24.431100 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "951ac4411b0d" +down_revision: Union[str, None] = "feeb1c78c0a2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("point", sa.Column("positive", sa.Boolean(), nullable=False)) + op.add_column( + "point_analysis", sa.Column("elaboration", sa.String(), nullable=False) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("point_analysis", "elaboration") + op.drop_column("point", "positive") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/c1022b3f1de5_.py b/backend/alembic/versions/c1022b3f1de5_.py new file mode 100644 index 00000000..5ed445d1 --- /dev/null +++ b/backend/alembic/versions/c1022b3f1de5_.py @@ -0,0 +1,24 @@ +"""empty message + +Revision ID: c1022b3f1de5 +Revises: 7f60f0211af1, 951ac4411b0d +Create Date: 2024-09-27 01:37:39.117941 + +""" + +from typing import Sequence, Union + + +# revision identifiers, used by Alembic. +revision: str = "c1022b3f1de5" +down_revision: Union[str, None] = ("7f60f0211af1", "951ac4411b0d") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/backend/src/user_questions/models.py b/backend/src/user_questions/models.py index 3cab702d..7b620ceb 100644 --- a/backend/src/user_questions/models.py +++ b/backend/src/user_questions/models.py @@ -1,6 +1,7 @@ from sqlalchemy import ForeignKey, and_ from src.common.base import Base from sqlalchemy.orm import Mapped, mapped_column, relationship, foreign +from src.events.models import Analysis from src.notes.models import Note @@ -39,8 +40,14 @@ class Point(Base): primaryjoin=and_(id == foreign(Note.parent_id), Note.parent_type == "point"), backref="point", ) + positive: Mapped[bool] - analysises = relationship("Analysis", secondary="point_analysis") + # analysises = relationship("Analysis", secondary="point_analysis") + point_analysises: Mapped[list["PointAnalysis"]] = relationship( + back_populates="point" + ) + + fallback: Mapped["Fallback"] = relationship(back_populates="point") class PointAnalysis(Base): @@ -50,3 +57,19 @@ class PointAnalysis(Base): ForeignKey("analysis.id"), primary_key=True ) point_id: Mapped[int] = mapped_column(ForeignKey("point.id"), primary_key=True) + elaboration: Mapped[str] + + point: Mapped[Point] = relationship(back_populates="point_analysises") + analysis: Mapped[Analysis] = relationship(backref="point_analysises") + + +class Fallback(Base): + __tablename__ = "fallback" + + id: Mapped[int] = mapped_column(primary_key=True) + + point_id: Mapped[int] = mapped_column(ForeignKey("point.id")) + alt_approach: Mapped[str] + general_argument: Mapped[str] + + point: Mapped[Point] = relationship(back_populates="fallback") diff --git a/backend/src/user_questions/router.py b/backend/src/user_questions/router.py index 33140b66..40fd8e3b 100644 --- a/backend/src/user_questions/router.py +++ b/backend/src/user_questions/router.py @@ -2,14 +2,19 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select -from sqlalchemy.orm import with_polymorphic, aliased, selectinload +from sqlalchemy.orm import with_polymorphic, selectinload from src.auth.dependencies import get_current_user from src.auth.models import User from src.common.dependencies import get_session from src.events.models import Analysis, Event -from src.likes.models import Like from src.notes.models import Note -from src.user_questions.models import Answer, Point, UserQuestion +from src.user_questions.models import ( + Answer, + Fallback, + Point, + PointAnalysis, + UserQuestion, +) from src.user_questions.schemas import CreateUserQuestion, UserQuestionMiniDTO from src.lm.generate_response import generate_response from src.lm.generate_points import get_relevant_analyses @@ -26,37 +31,29 @@ def get_user_questions( session=Depends(get_session), ) -> list[UserQuestionMiniDTO]: # Create an alias for the Point table to use for the Like condition - point_alias = aliased(Point) user_questions = session.scalars( select(UserQuestion) .where(UserQuestion.user_id == user.id) - .join(UserQuestion.answer) - .join(Answer.points) - .join(point_alias.analysises) - .join(Analysis.event) - .join(Event.original_article) - .join(Analysis.category) - .join(Analysis.likes) - .where(Like.point_id == point_alias.id) .options( selectinload( UserQuestion.answer, - Answer.points.of_type(point_alias), - point_alias.analysises, + Answer.points, + Point.point_analysises, + PointAnalysis.analysis, Analysis.event, Event.original_article, ), selectinload( UserQuestion.answer, - Answer.points.of_type(point_alias), - point_alias.analysises, - Analysis.category, + Answer.points, + Point.fallback, ), selectinload( UserQuestion.answer, - Answer.points.of_type(point_alias), - point_alias.analysises, - Analysis.likes, + Answer.points, + Point.point_analysises, + PointAnalysis.analysis, + Analysis.category, ), ) ) @@ -83,13 +80,28 @@ def get_user_question( select(UserQuestion) .where(UserQuestion.id == id) .where(UserQuestion.user_id == user.id) - .join(UserQuestion.answer) - .join(Answer.points) - .join(Point.analysises) - .join(Analysis.event) - .join(Event.original_article) - .join(Analysis.category) - .join(Analysis.likes) + .options( + selectinload( + UserQuestion.answer, + Answer.points, + Point.point_analysises, + PointAnalysis.analysis, + Analysis.event, + Event.original_article, + ), + selectinload( + UserQuestion.answer, + Answer.points, + Point.fallback, + ), + selectinload( + UserQuestion.answer, + Answer.points, + Point.point_analysises, + PointAnalysis.analysis, + Analysis.category, + ), + ) ) if not user_question: raise HTTPException(HTTPStatus.NOT_FOUND) @@ -108,16 +120,43 @@ def create_user_question( answer = Answer() user_question.answer = answer - results = get_relevant_analyses(data.question) - for row in results["for_points"] + results["against_points"]: + results = generate_response(data.question) + + for row in results["for_points"]: point = row["point"] analyses = row["analyses"] - point = Point(title=point, body="") - analysis_id = [analysis["id"] for analysis in analyses] + point = Point(title=point, body="", positive=True) + + for analysis in analyses: + point.point_analysises.append( + PointAnalysis( + elaboration=analysis["elaborations"], + analysis_id=analysis["id"], + ) + ) + if not analyses: + point.fallback = Fallback( + alt_approach=row["fall_back_response"]["alt_approach"], + general_argument=row["fall_back_response"]["general_argument"], + ) + answer.points.append(point) - point.analysises = list( - session.scalars(select(Analysis).where(Analysis.id.in_(analysis_id))) - ) + for row in results["against_points"]: + point = row["point"] + analyses = row["analyses"] + point = Point(title=point, body="", positive=False) + for analysis in analyses: + point.point_analysises.append( + PointAnalysis( + elaboration=analysis["elaborations"], + analysis_id=analysis["id"], + ) + ) + if not analyses: + point.fallback = Fallback( + alt_approach=row["fall_back_response"]["alt_approach"], + general_argument=row["fall_back_response"]["general_argument"], + ) answer.points.append(point) session.add(user_question) @@ -126,11 +165,28 @@ def create_user_question( same_user_question = session.scalar( select(UserQuestion) .where(UserQuestion.id == user_question.id) - .join(UserQuestion.answer) - .join(Answer.points) - .join(Point.analysises) - .join(Analysis.event) - .join(Event.original_article) - .join(Analysis.category) + .options( + selectinload( + UserQuestion.answer, + Answer.points, + Point.point_analysises, + PointAnalysis.analysis, + Analysis.event, + Event.original_article, + ), + selectinload( + UserQuestion.answer, + Answer.points, + Point.fallback, + ), + selectinload( + UserQuestion.answer, + Answer.points, + Point.point_analysises, + PointAnalysis.analysis, + Analysis.category, + ), + ) ) + return same_user_question diff --git a/backend/src/user_questions/schemas.py b/backend/src/user_questions/schemas.py index daac58a7..010836e3 100644 --- a/backend/src/user_questions/schemas.py +++ b/backend/src/user_questions/schemas.py @@ -11,23 +11,37 @@ class AnalysisDTO(BaseModel): likes: list[LikeDTO] -class PointMiniDTO(BaseModel): +class PointAnalysisDTO(BaseModel): model_config = ConfigDict(from_attributes=True) - id: int - title: str - body: str - analysises: list[AnalysisDTO] + + analysis: AnalysisDTO + elaboration: str @model_validator(mode="after") def filter(self): # i gave up on using the orm to filter the ones relevant to the point - for analysis in self.analysises: - analysis.likes = [ - like for like in analysis.likes if like.point_id == self.id - ] + self.analysis.likes = [ + like for like in self.analysis.likes if like.point_id == self.id + ] return self +class FallbackDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + alt_approach: str + general_argument: str + + +class PointMiniDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: int + title: str + body: str + point_analysises: list[PointAnalysisDTO] + fallback: FallbackDTO | None = None + + class AnswerDTO(BaseModel): model_config = ConfigDict(from_attributes=True) id: int