From 06b81fb879a02eba21012bd3ed0d98c9412ec39a Mon Sep 17 00:00:00 2001 From: seeleng Date: Thu, 31 Oct 2024 15:19:03 +0800 Subject: [PATCH] fix: rate limit essays --- .../63af7264fba3_add_essay_rate_limit.py | 48 +++++++++++++++++++ backend/src/essays/dependencies.py | 29 +++++++++++ backend/src/essays/router.py | 2 + backend/src/limits/models.py | 4 +- 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 backend/alembic/versions/63af7264fba3_add_essay_rate_limit.py create mode 100644 backend/src/essays/dependencies.py diff --git a/backend/alembic/versions/63af7264fba3_add_essay_rate_limit.py b/backend/alembic/versions/63af7264fba3_add_essay_rate_limit.py new file mode 100644 index 00000000..10a7577c --- /dev/null +++ b/backend/alembic/versions/63af7264fba3_add_essay_rate_limit.py @@ -0,0 +1,48 @@ +"""Add essay rate limit + +Revision ID: 63af7264fba3 +Revises: 4f9ec96fc98e +Create Date: 2024-10-31 14:54:32.307467 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlalchemy.orm as orm +from src.limits.models import Tier + + +# revision identifiers, used by Alembic. +revision: str = "63af7264fba3" +down_revision: Union[str, None] = "4f9ec96fc98e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +ESSAY_LIMITS = {"Free": 3, "Premium": 10, "Unverified": 0, "Admin": 1000} + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "tier", + sa.Column("essay_limit", sa.Integer(), server_default="0", nullable=False), + ) + op.add_column( + "usage", sa.Column("essays", sa.Integer(), server_default="0", nullable=False) + ) + session = orm.Session(bind=op.get_bind()) + for tier_type, essay_limit in ESSAY_LIMITS.items(): + tier = session.scalar(sa.select(Tier).where(Tier.label == tier_type)) + tier.essay_limit = essay_limit + session.add(tier) + session.commit() + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("usage", "essays") + op.drop_column("tier", "essay_limit") + # ### end Alembic commands ### diff --git a/backend/src/essays/dependencies.py b/backend/src/essays/dependencies.py new file mode 100644 index 00000000..c6fb1e52 --- /dev/null +++ b/backend/src/essays/dependencies.py @@ -0,0 +1,29 @@ +from http import HTTPStatus +from typing import Annotated + +from fastapi import Depends, HTTPException +from src.auth.dependencies import get_current_user +from src.auth.models import User +from src.common.dependencies import get_session +from src.limits.models import Usage + + +def has_essay_tries_left( + user: Annotated[User, Depends(get_current_user)], + session=Depends(get_session), +): + usage = session.get(Usage, user.id) + if not usage: + usage = Usage(user_id=user.id) + # This is inefficient, refactor in the future. + session.add(usage) + session.commit() + + user_tier_limit = user.tier.essay_limit + user_essay_usage = usage.essays + if user_tier_limit - user_essay_usage <= 0: + raise HTTPException(HTTPStatus.TOO_MANY_REQUESTS) + + usage.essays += 1 + session.add(usage) + session.commit() diff --git a/backend/src/essays/router.py b/backend/src/essays/router.py index 07810f0d..0f392027 100644 --- a/backend/src/essays/router.py +++ b/backend/src/essays/router.py @@ -6,6 +6,7 @@ from src.auth.dependencies import get_current_user from src.auth.models import User from src.common.dependencies import get_session +from src.essays.dependencies import has_essay_tries_left from src.essays.models import ( Comment, CommentAnalysis, @@ -31,6 +32,7 @@ def create_essay( data: EssayCreate, user: Annotated[User, Depends(get_current_user)], session: Annotated[Session, Depends(get_session)], + _=Depends(has_essay_tries_left), ) -> EssayCreateDTO: essay = Essay(question=data.question, user_id=user.id) diff --git a/backend/src/limits/models.py b/backend/src/limits/models.py index 3f9df7df..e9be11a4 100644 --- a/backend/src/limits/models.py +++ b/backend/src/limits/models.py @@ -15,8 +15,8 @@ class Usage(Base): __tablename__ = "usage" user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), primary_key=True) - gp_question_asked: Mapped[int] = mapped_column(server_default="0") - essays: Mapped[int] = mapped_column(server_default="0") + gp_question_asked: Mapped[int] = mapped_column(default=0, server_default="0") + essays: Mapped[int] = mapped_column(default=0, server_default="0") class Tier(Base):