Skip to content

Commit

Permalink
Merge pull request #383 from cs3216-a3-group-4/seeleng/email-validati…
Browse files Browse the repository at this point in the history
…on-v2

feat: email verification
  • Loading branch information
haoyangw authored Oct 31, 2024
2 parents d902844 + 1839e9d commit ca59cba
Show file tree
Hide file tree
Showing 28 changed files with 1,068 additions and 177 deletions.
62 changes: 62 additions & 0 deletions backend/alembic/versions/4f9ec96fc98e_add_unverified_tier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Add unverified tier
Revision ID: 4f9ec96fc98e
Revises: 651ed2d244c5
Create Date: 2024-10-31 14:30:56.099043
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlalchemy.orm as orm
from alembic_postgresql_enum import TableReference
from src.limits.models import Tier

# revision identifiers, used by Alembic.
revision: str = "4f9ec96fc98e"
down_revision: Union[str, None] = "651ed2d244c5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
"public",
"tiernames",
["FREE", "ADMIN", "PREMIUM", "UNVERIFIED"],
[
TableReference(
table_schema="public", table_name="tier", column_name="tier_name"
)
],
enum_values_to_rename=[],
)
session = orm.Session(bind=op.get_bind())
session.add(Tier(tier_name="UNVERIFIED", label="Unverified", gp_question_limit=0))
session.commit()
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
session = orm.Session(bind=op.get_bind())
unverified = session.scalar(sa.select(Tier).where(Tier.label == "Unverified"))
session.delete(unverified)
session.commit()

op.sync_enum_values(
"public",
"tiernames",
["FREE", "ADMIN", "PREMIUM"],
[
TableReference(
table_schema="public", table_name="tier", column_name="tier_name"
)
],
enum_values_to_rename=[],
)

# ### end Alembic commands ###
48 changes: 48 additions & 0 deletions backend/alembic/versions/63af7264fba3_add_essay_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Add essay rate limit
Revision ID: 63af7264fba3
Revises: 4f9ec96fc98e
Create Date: 2024-10-31 14:54:32.307467
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlalchemy.orm as orm
from src.limits.models import Tier


# revision identifiers, used by Alembic.
revision: str = "63af7264fba3"
down_revision: Union[str, None] = "4f9ec96fc98e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

ESSAY_LIMITS = {"Free": 3, "Premium": 10, "Unverified": 0, "Admin": 1000}


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"tier",
sa.Column("essay_limit", sa.Integer(), server_default="0", nullable=False),
)
op.add_column(
"usage", sa.Column("essays", sa.Integer(), server_default="0", nullable=False)
)
session = orm.Session(bind=op.get_bind())
for tier_type, essay_limit in ESSAY_LIMITS.items():
tier = session.scalar(sa.select(Tier).where(Tier.label == tier_type))
tier.essay_limit = essay_limit
session.add(tier)
session.commit()
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("usage", "essays")
op.drop_column("tier", "essay_limit")
# ### end Alembic commands ###
54 changes: 54 additions & 0 deletions backend/alembic/versions/651ed2d244c5_add_email_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Add email verification
Revision ID: 651ed2d244c5
Revises: 59cef91d2fa1
Create Date: 2024-10-31 13:46:07.360330
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "651ed2d244c5"
down_revision: Union[str, None] = "59cef91d2fa1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"email_verification",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("code", sa.String(), nullable=False),
sa.Column("used", sa.Boolean(), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False
),
sa.Column(
"updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False
),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column(
"user",
sa.Column("verified", sa.Boolean(), server_default="true", nullable=False),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user", "verified")
op.drop_table("email_verification")
# ### end Alembic commands ###
24 changes: 18 additions & 6 deletions backend/src/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class Role(str, Enum):
ADMIN = "admin"


# TODO: it's probably safer to check with the db but it'll do for now
UNVERIFIED_TIER_ID = 4


class User(Base):
__tablename__ = "user"

Expand All @@ -35,18 +39,17 @@ class User(Base):
hashed_password: Mapped[str]
account_type: Mapped[AccountType]
last_accessed: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
top_events_period: Mapped[int] = mapped_column(Integer, default=7)
tier_id: Mapped[int] = mapped_column(
ForeignKey("tier.id"), default=1, server_default="1"
)
verified: Mapped[bool] = mapped_column(server_default="true")

role: Mapped[Role] = mapped_column(server_default="NORMAL")

categories: Mapped[list[Category]] = relationship(secondary=user_category_table)
notes: Mapped[list[Note]] = relationship("Note", backref="user")
top_events_period: Mapped[int] = mapped_column(Integer, default=7)

bookmarks: Mapped[list[Bookmark]] = relationship(backref="user")

tier_id: Mapped[int] = mapped_column(
ForeignKey("tier.id"), default=1, server_default="1"
)
subscription: Mapped[Subscription] = relationship(
"Subscription", backref="user", lazy="selectin", uselist=False
)
Expand All @@ -62,3 +65,12 @@ class PasswordReset(Base):
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
code: Mapped[str]
used: Mapped[bool]


class EmailVerification(Base):
__tablename__ = "email_verification"

id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
code: Mapped[str]
used: Mapped[bool] = mapped_column(default=False)
110 changes: 102 additions & 8 deletions backend/src/auth/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
import httpx
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from src.auth.utils import create_token, send_reset_password_email
from src.auth.utils import (
create_token,
send_reset_password_email,
send_verification_email,
)
from src.common.constants import (
FRONTEND_URL,
GOOGLE_CLIENT_ID,
GOOGLE_CLIENT_SECRET,
GOOGLE_REDIRECT_URI,
Expand All @@ -32,9 +37,18 @@
get_password_hash,
verify_password,
)
from .models import AccountType, PasswordReset, User
from .models import (
UNVERIFIED_TIER_ID,
AccountType,
EmailVerification,
PasswordReset,
User,
)

router = APIRouter(prefix="/auth", tags=["auth"])
routerWithAuth = APIRouter(
prefix="/auth", tags=["auth"], dependencies=[Depends(add_current_user)]
)

#######################
# username & password #
Expand All @@ -43,7 +57,10 @@

@router.post("/signup")
def sign_up(
data: SignUpData, response: Response, session=Depends(get_session)
data: SignUpData,
response: Response,
background_task: BackgroundTasks,
session=Depends(get_session),
) -> Token:
existing_user = session.scalars(
select(User).where(User.email == data.email)
Expand All @@ -55,6 +72,8 @@ def sign_up(
email=data.email,
hashed_password=get_password_hash(data.password),
account_type=AccountType.NORMAL,
verified=False,
tier_id=UNVERIFIED_TIER_ID,
)
session.add(new_user)
session.commit()
Expand All @@ -70,6 +89,13 @@ def sign_up(
)
)

code = str(uuid4())
email_validation = EmailVerification(user_id=new_user.id, code=code, used=False)
session.add(email_validation)
session.commit()
verification_link = f"{FRONTEND_URL}/verify-email?code={code}"
background_task.add_task(send_verification_email, data.email, verification_link)

return create_token(new_user, response)


Expand All @@ -85,6 +111,79 @@ def log_in(
return create_token(user, response)


@routerWithAuth.put("/email-verification")
def complete_email_verification(
user: Annotated[User, Depends(get_current_user)],
code: str,
response: Response,
session=Depends(get_session),
) -> Token:
email_verification = session.scalar(
select(EmailVerification)
.where(EmailVerification.code == code)
.where(EmailVerification.user_id == user.id) # noqa: E712
)
if not email_verification:
raise HTTPException(HTTPStatus.NOT_FOUND)
elif email_verification.used:
print(
f"""ERROR: Attempt to reuse an old email verification code {code} for user with ID {email_verification.user_id}"""
)
raise HTTPException(HTTPStatus.BAD_REQUEST)

user = session.scalar(
select(User)
.where(User.id == email_verification.user_id)
.options(
selectinload(User.categories),
selectinload(User.tier),
selectinload(User.usage),
)
)

if user.verified and user.tier_id != UNVERIFIED_TIER_ID:
print(
f"""ERROR: Attempt to verify email of user with ID {user.id} who is already verified"""
)
raise HTTPException(HTTPStatus.CONFLICT)

user.verified = True
user.tier_id = 1
email_verification.used = True
session.add(user)
session.add(email_verification)
session.commit()
session.refresh(user)

token = create_token(user, response)

return token


@routerWithAuth.post("/email-verification")
def resend_verification_email(
user: Annotated[User, Depends(get_current_user)],
background_task: BackgroundTasks,
session=Depends(get_session),
):
existing_email_verifications = session.scalars(
select(EmailVerification).where(EmailVerification.user_id == user.id)
)
for email_verification in existing_email_verifications:
email_verification.used = True
session.add(email_verification)
session.commit()

code = str(uuid4())
email_validation = EmailVerification(user_id=user.id, code=code, used=False)
session.add(email_validation)
session.commit()
verification_link = f"{FRONTEND_URL}/verify-email?code={code}"
background_task.add_task(send_verification_email, user.email, verification_link)

return


#######################
# google auth #
#######################
Expand Down Expand Up @@ -144,11 +243,6 @@ def auth_google(
return token


routerWithAuth = APIRouter(
prefix="/auth", tags=["auth"], dependencies=[Depends(add_current_user)]
)


#######################
# Reset password #
#######################
Expand Down
1 change: 1 addition & 0 deletions backend/src/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class UserPublic(BaseModel):

usage: UsageDTO | None = None
tier: TierDTO
verified: bool


class Token(BaseModel):
Expand Down
Loading

0 comments on commit ca59cba

Please sign in to comment.