Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fallback and change format of user-questions routes #101

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backend/alembic/versions/65b186bae301_add_fallback_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Add fallback table

Revision ID: 65b186bae301
Revises: c1022b3f1de5
Create Date: 2024-09-27 01:37:55.705475

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "65b186bae301"
down_revision: Union[str, None] = "c1022b3f1de5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"fallback",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("point_id", sa.Integer(), nullable=False),
sa.Column("alt_approach", sa.String(), nullable=False),
sa.Column("general_argument", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["point_id"],
["point.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("fallback")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Add elaboration for third layer

Revision ID: 951ac4411b0d
Revises: feeb1c78c0a2
Create Date: 2024-09-26 21:39:24.431100

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "951ac4411b0d"
down_revision: Union[str, None] = "feeb1c78c0a2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("point", sa.Column("positive", sa.Boolean(), nullable=False))
op.add_column(
"point_analysis", sa.Column("elaboration", sa.String(), nullable=False)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("point_analysis", "elaboration")
op.drop_column("point", "positive")
# ### end Alembic commands ###
24 changes: 24 additions & 0 deletions backend/alembic/versions/c1022b3f1de5_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""empty message

Revision ID: c1022b3f1de5
Revises: 7f60f0211af1, 951ac4411b0d
Create Date: 2024-09-27 01:37:39.117941

"""

from typing import Sequence, Union


# revision identifiers, used by Alembic.
revision: str = "c1022b3f1de5"
down_revision: Union[str, None] = ("7f60f0211af1", "951ac4411b0d")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
25 changes: 24 additions & 1 deletion backend/src/user_questions/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sqlalchemy import ForeignKey, and_
from src.common.base import Base
from sqlalchemy.orm import Mapped, mapped_column, relationship, foreign
from src.events.models import Analysis
from src.notes.models import Note


Expand Down Expand Up @@ -39,8 +40,14 @@ class Point(Base):
primaryjoin=and_(id == foreign(Note.parent_id), Note.parent_type == "point"),
backref="point",
)
positive: Mapped[bool]

analysises = relationship("Analysis", secondary="point_analysis")
# analysises = relationship("Analysis", secondary="point_analysis")
point_analysises: Mapped[list["PointAnalysis"]] = relationship(
back_populates="point"
)

fallback: Mapped["Fallback"] = relationship(back_populates="point")


class PointAnalysis(Base):
Expand All @@ -50,3 +57,19 @@ class PointAnalysis(Base):
ForeignKey("analysis.id"), primary_key=True
)
point_id: Mapped[int] = mapped_column(ForeignKey("point.id"), primary_key=True)
elaboration: Mapped[str]

point: Mapped[Point] = relationship(back_populates="point_analysises")
analysis: Mapped[Analysis] = relationship(backref="point_analysises")


class Fallback(Base):
__tablename__ = "fallback"

id: Mapped[int] = mapped_column(primary_key=True)

point_id: Mapped[int] = mapped_column(ForeignKey("point.id"))
alt_approach: Mapped[str]
general_argument: Mapped[str]

point: Mapped[Point] = relationship(back_populates="fallback")
136 changes: 96 additions & 40 deletions backend/src/user_questions/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.orm import with_polymorphic, aliased, selectinload
from sqlalchemy.orm import with_polymorphic, selectinload
from src.auth.dependencies import get_current_user
from src.auth.models import User
from src.common.dependencies import get_session
from src.events.models import Analysis, Event
from src.likes.models import Like
from src.notes.models import Note
from src.user_questions.models import Answer, Point, UserQuestion
from src.user_questions.models import (
Answer,
Fallback,
Point,
PointAnalysis,
UserQuestion,
)
from src.user_questions.schemas import CreateUserQuestion, UserQuestionMiniDTO
from src.lm.generate_response import generate_response
from src.lm.generate_points import get_relevant_analyses
Expand All @@ -26,37 +31,29 @@ def get_user_questions(
session=Depends(get_session),
) -> list[UserQuestionMiniDTO]:
# Create an alias for the Point table to use for the Like condition
point_alias = aliased(Point)
user_questions = session.scalars(
select(UserQuestion)
.where(UserQuestion.user_id == user.id)
.join(UserQuestion.answer)
.join(Answer.points)
.join(point_alias.analysises)
.join(Analysis.event)
.join(Event.original_article)
.join(Analysis.category)
.join(Analysis.likes)
.where(Like.point_id == point_alias.id)
.options(
selectinload(
UserQuestion.answer,
Answer.points.of_type(point_alias),
point_alias.analysises,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.event,
Event.original_article,
),
selectinload(
UserQuestion.answer,
Answer.points.of_type(point_alias),
point_alias.analysises,
Analysis.category,
Answer.points,
Point.fallback,
),
selectinload(
UserQuestion.answer,
Answer.points.of_type(point_alias),
point_alias.analysises,
Analysis.likes,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.category,
),
)
)
Expand All @@ -83,13 +80,28 @@ def get_user_question(
select(UserQuestion)
.where(UserQuestion.id == id)
.where(UserQuestion.user_id == user.id)
.join(UserQuestion.answer)
.join(Answer.points)
.join(Point.analysises)
.join(Analysis.event)
.join(Event.original_article)
.join(Analysis.category)
.join(Analysis.likes)
.options(
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.event,
Event.original_article,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.fallback,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.category,
),
)
)
if not user_question:
raise HTTPException(HTTPStatus.NOT_FOUND)
Expand All @@ -108,16 +120,43 @@ def create_user_question(
answer = Answer()
user_question.answer = answer

results = get_relevant_analyses(data.question)
for row in results["for_points"] + results["against_points"]:
results = generate_response(data.question)

for row in results["for_points"]:
point = row["point"]
analyses = row["analyses"]
point = Point(title=point, body="")
analysis_id = [analysis["id"] for analysis in analyses]
point = Point(title=point, body="", positive=True)

for analysis in analyses:
point.point_analysises.append(
PointAnalysis(
elaboration=analysis["elaborations"],
analysis_id=analysis["id"],
)
)
if not analyses:
point.fallback = Fallback(
alt_approach=row["fall_back_response"]["alt_approach"],
general_argument=row["fall_back_response"]["general_argument"],
)
answer.points.append(point)

point.analysises = list(
session.scalars(select(Analysis).where(Analysis.id.in_(analysis_id)))
)
for row in results["against_points"]:
point = row["point"]
analyses = row["analyses"]
point = Point(title=point, body="", positive=False)
for analysis in analyses:
point.point_analysises.append(
PointAnalysis(
elaboration=analysis["elaborations"],
analysis_id=analysis["id"],
)
)
if not analyses:
point.fallback = Fallback(
alt_approach=row["fall_back_response"]["alt_approach"],
general_argument=row["fall_back_response"]["general_argument"],
)
answer.points.append(point)

session.add(user_question)
Expand All @@ -126,11 +165,28 @@ def create_user_question(
same_user_question = session.scalar(
select(UserQuestion)
.where(UserQuestion.id == user_question.id)
.join(UserQuestion.answer)
.join(Answer.points)
.join(Point.analysises)
.join(Analysis.event)
.join(Event.original_article)
.join(Analysis.category)
.options(
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.event,
Event.original_article,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.fallback,
),
selectinload(
UserQuestion.answer,
Answer.points,
Point.point_analysises,
PointAnalysis.analysis,
Analysis.category,
),
)
)

return same_user_question
32 changes: 23 additions & 9 deletions backend/src/user_questions/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,37 @@ class AnalysisDTO(BaseModel):
likes: list[LikeDTO]


class PointMiniDTO(BaseModel):
class PointAnalysisDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
title: str
body: str
analysises: list[AnalysisDTO]

analysis: AnalysisDTO
elaboration: str

@model_validator(mode="after")
def filter(self):
# i gave up on using the orm to filter the ones relevant to the point
for analysis in self.analysises:
analysis.likes = [
like for like in analysis.likes if like.point_id == self.id
]
self.analysis.likes = [
like for like in self.analysis.likes if like.point_id == self.id
]
return self


class FallbackDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

alt_approach: str
general_argument: str


class PointMiniDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
title: str
body: str
point_analysises: list[PointAnalysisDTO]
fallback: FallbackDTO | None = None


class AnswerDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
Expand Down
Loading