diff --git a/backend/alembic/versions/4f9ec96fc98e_add_unverified_tier.py b/backend/alembic/versions/4f9ec96fc98e_add_unverified_tier.py new file mode 100644 index 00000000..316e4a59 --- /dev/null +++ b/backend/alembic/versions/4f9ec96fc98e_add_unverified_tier.py @@ -0,0 +1,62 @@ +"""Add unverified tier + +Revision ID: 4f9ec96fc98e +Revises: 651ed2d244c5 +Create Date: 2024-10-31 14:30:56.099043 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlalchemy.orm as orm +from alembic_postgresql_enum import TableReference +from src.limits.models import Tier + +# revision identifiers, used by Alembic. +revision: str = "4f9ec96fc98e" +down_revision: Union[str, None] = "651ed2d244c5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + "public", + "tiernames", + ["FREE", "ADMIN", "PREMIUM", "UNVERIFIED"], + [ + TableReference( + table_schema="public", table_name="tier", column_name="tier_name" + ) + ], + enum_values_to_rename=[], + ) + session = orm.Session(bind=op.get_bind()) + session.add(Tier(tier_name="UNVERIFIED", label="Unverified", gp_question_limit=0)) + session.commit() + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + session = orm.Session(bind=op.get_bind()) + unverified = session.scalar(sa.select(Tier).where(Tier.label == "Unverified")) + session.delete(unverified) + session.commit() + + op.sync_enum_values( + "public", + "tiernames", + ["FREE", "ADMIN", "PREMIUM"], + [ + TableReference( + table_schema="public", table_name="tier", column_name="tier_name" + ) + ], + enum_values_to_rename=[], + ) + + # ### end Alembic commands ### diff --git a/backend/src/auth/models.py b/backend/src/auth/models.py index d2e6e7fd..9d831e14 100644 --- a/backend/src/auth/models.py +++ b/backend/src/auth/models.py @@ -27,6 +27,10 @@ class Role(str, Enum): ADMIN = "admin" +# TODO: it's probably safer to check with the db but it'll do for now +UNVERIFIED_TIER_ID = 4 + + class User(Base): __tablename__ = "user" diff --git a/backend/src/auth/router.py b/backend/src/auth/router.py index fd68714b..a4e24d35 100644 --- a/backend/src/auth/router.py +++ b/backend/src/auth/router.py @@ -37,7 +37,13 @@ get_password_hash, verify_password, ) -from .models import AccountType, EmailVerification, PasswordReset, User +from .models import ( + UNVERIFIED_TIER_ID, + AccountType, + EmailVerification, + PasswordReset, + User, +) router = APIRouter(prefix="/auth", tags=["auth"]) routerWithAuth = APIRouter( @@ -67,6 +73,7 @@ def sign_up( hashed_password=get_password_hash(data.password), account_type=AccountType.NORMAL, verified=False, + tier_id=UNVERIFIED_TIER_ID, ) session.add(new_user) session.commit() diff --git a/backend/src/limits/models.py b/backend/src/limits/models.py index fd99eb60..f9e2b8d2 100644 --- a/backend/src/limits/models.py +++ b/backend/src/limits/models.py @@ -8,6 +8,7 @@ class TierNames(str, Enum): FREE = "FREE" ADMIN = "ADMIN" PREMIUM = "PREMIUM" + UNVERIFIED = "UNVERIFIED" class Usage(Base):