Skip to content

Commit

Permalink
Merge pull request #101 from cs3216-a3-group-4/seeleng/modify-user-qu…
Browse files Browse the repository at this point in the history
…estions-backend

feat: add fallback and change format of user-questions routes
  • Loading branch information
seelengxd authored Sep 26, 2024
2 parents 704cc3c + da7541c commit 873a1b0
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 50 deletions.
42 changes: 42 additions & 0 deletions backend/alembic/versions/65b186bae301_add_fallback_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Add fallback table
Revision ID: 65b186bae301
Revises: c1022b3f1de5
Create Date: 2024-09-27 01:37:55.705475
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "65b186bae301"
down_revision: Union[str, None] = "c1022b3f1de5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"fallback",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("point_id", sa.Integer(), nullable=False),
sa.Column("alt_approach", sa.String(), nullable=False),
sa.Column("general_argument", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["point_id"],
["point.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("fallback")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Add elaboration for third layer
Revision ID: 951ac4411b0d
Revises: feeb1c78c0a2
Create Date: 2024-09-26 21:39:24.431100
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "951ac4411b0d"
down_revision: Union[str, None] = "feeb1c78c0a2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("point", sa.Column("positive", sa.Boolean(), nullable=False))
op.add_column(
"point_analysis", sa.Column("elaboration", sa.String(), nullable=False)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("point_analysis", "elaboration")
op.drop_column("point", "positive")
# ### end Alembic commands ###
24 changes: 24 additions & 0 deletions backend/alembic/versions/c1022b3f1de5_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""empty message
Revision ID: c1022b3f1de5
Revises: 7f60f0211af1, 951ac4411b0d
Create Date: 2024-09-27 01:37:39.117941
"""

from typing import Sequence, Union


# revision identifiers, used by Alembic.
revision: str = "c1022b3f1de5"
down_revision: Union[str, None] = ("7f60f0211af1", "951ac4411b0d")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
25 changes: 24 additions & 1 deletion backend/src/user_questions/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sqlalchemy import ForeignKey, and_
from src.common.base import Base
from sqlalchemy.orm import Mapped, mapped_column, relationship, foreign
from src.events.models import Analysis
from src.notes.models import Note


Expand Down Expand Up @@ -39,8 +40,14 @@ class Point(Base):
primaryjoin=and_(id == foreign(Note.parent_id), Note.parent_type == "point"),
backref="point",
)
positive: Mapped[bool]

analysises = relationship("Analysis", secondary="point_analysis")
# analysises = relationship("Analysis", secondary="point_analysis")
point_analysises: Mapped[list["PointAnalysis"]] = relationship(
back_populates="point"
)

fallback: Mapped["Fallback"] = relationship(back_populates="point")


class PointAnalysis(Base):
Expand All @@ -50,3 +57,19 @@ class PointAnalysis(Base):
ForeignKey("analysis.id"), primary_key=True
)
point_id: Mapped[int] = mapped_column(ForeignKey("point.id"), primary_key=True)
elaboration: Mapped[str]

point: Mapped[Point] = relationship(back_populates="point_analysises")
analysis: Mapped[Analysis] = relationship(backref="point_analysises")


class Fallback(Base):
__tablename__ = "fallback"

id: Mapped[int] = mapped_column(primary_key=True)

point_id: Mapped[int] = mapped_column(ForeignKey("point.id"))
alt_approach: Mapped[str]
general_argument: Mapped[str]

point: Mapped[Point] = relationship(back_populates="fallback")
136 changes: 96 additions & 40 deletions backend/src/user_questions/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.orm import with_polymorphic, aliased, selectinload
from sqlalchemy.orm import with_polymorphic, selectinload
from src.auth.dependencies import get_current_user
from src.auth.models import User
from src.common.dependencies import get_session
from src.events.models import Analysis, Event
from src.likes.models import Like
from src.notes.models import Note
from src.user_questions.models import Answer, Point, UserQuestion
from src.user_questions.models import (
Answer,
Fallback,
Point,
PointAnalysis,
UserQuestion,
)
from src.user_questions.schemas import CreateUserQuestion, UserQuestionMiniDTO
from src.lm.generate_response import generate_response
from src.lm.generate_points import get_relevant_analyses
Expand All @@ -26,37 +31,29 @@ def get_user_questions(
session=Depends(get_session),
) -> list[UserQuestionMiniDTO]:
# Create an alias for the Point table to use for the Like condition
point_alias = aliased(Point)
user_questions = session.scalars(
select(UserQuestion)
.where(UserQuestion.user_id == user.id)
.join(UserQuestion.answer)
.join(Answer.points)
.join(point_alias.analysises)
.join(Analysis.event)
.join(Event.original_article)
.join(Analysis.category)
.join(Analysis.likes)
.where(Like.point_id == point_alias.id)
.options(
selectinload(
UserQuestion.answer,
Answer.points.of_type(point_alias),
point_alias.analysises,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.event,
Event.original_article,
),
selectinload(
UserQuestion.answer,
Answer.points.of_type(point_alias),
point_alias.analysises,
Analysis.category,
Answer.points,
Point.fallback,
),
selectinload(
UserQuestion.answer,
Answer.points.of_type(point_alias),
point_alias.analysises,
Analysis.likes,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.category,
),
)
)
Expand All @@ -83,13 +80,28 @@ def get_user_question(
select(UserQuestion)
.where(UserQuestion.id == id)
.where(UserQuestion.user_id == user.id)
.join(UserQuestion.answer)
.join(Answer.points)
.join(Point.analysises)
.join(Analysis.event)
.join(Event.original_article)
.join(Analysis.category)
.join(Analysis.likes)
.options(
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.event,
Event.original_article,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.fallback,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.category,
),
)
)
if not user_question:
raise HTTPException(HTTPStatus.NOT_FOUND)
Expand All @@ -108,16 +120,43 @@ def create_user_question(
answer = Answer()
user_question.answer = answer

results = get_relevant_analyses(data.question)
for row in results["for_points"] + results["against_points"]:
results = generate_response(data.question)

for row in results["for_points"]:
point = row["point"]
analyses = row["analyses"]
point = Point(title=point, body="")
analysis_id = [analysis["id"] for analysis in analyses]
point = Point(title=point, body="", positive=True)

for analysis in analyses:
point.point_analysises.append(
PointAnalysis(
elaboration=analysis["elaborations"],
analysis_id=analysis["id"],
)
)
if not analyses:
point.fallback = Fallback(
alt_approach=row["fall_back_response"]["alt_approach"],
general_argument=row["fall_back_response"]["general_argument"],
)
answer.points.append(point)

point.analysises = list(
session.scalars(select(Analysis).where(Analysis.id.in_(analysis_id)))
)
for row in results["against_points"]:
point = row["point"]
analyses = row["analyses"]
point = Point(title=point, body="", positive=False)
for analysis in analyses:
point.point_analysises.append(
PointAnalysis(
elaboration=analysis["elaborations"],
analysis_id=analysis["id"],
)
)
if not analyses:
point.fallback = Fallback(
alt_approach=row["fall_back_response"]["alt_approach"],
general_argument=row["fall_back_response"]["general_argument"],
)
answer.points.append(point)

session.add(user_question)
Expand All @@ -126,11 +165,28 @@ def create_user_question(
same_user_question = session.scalar(
select(UserQuestion)
.where(UserQuestion.id == user_question.id)
.join(UserQuestion.answer)
.join(Answer.points)
.join(Point.analysises)
.join(Analysis.event)
.join(Event.original_article)
.join(Analysis.category)
.options(
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.event,
Event.original_article,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.fallback,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.category,
),
)
)

return same_user_question
32 changes: 23 additions & 9 deletions backend/src/user_questions/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,37 @@ class AnalysisDTO(BaseModel):
likes: list[LikeDTO]


class PointMiniDTO(BaseModel):
class PointAnalysisDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
title: str
body: str
analysises: list[AnalysisDTO]

analysis: AnalysisDTO
elaboration: str

@model_validator(mode="after")
def filter(self):
# i gave up on using the orm to filter the ones relevant to the point
for analysis in self.analysises:
analysis.likes = [
like for like in analysis.likes if like.point_id == self.id
]
self.analysis.likes = [
like for like in self.analysis.likes if like.point_id == self.id
]
return self


class FallbackDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

alt_approach: str
general_argument: str


class PointMiniDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
title: str
body: str
point_analysises: list[PointAnalysisDTO]
fallback: FallbackDTO | None = None


class AnswerDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
Expand Down

0 comments on commit 873a1b0

Please sign in to comment.