Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add some routes #21

Merged
merged 11 commits into from
Sep 22, 2024
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: local
hooks:
- id: ruff
name: ruff
language: system
entry: bash -c 'cd backend && uvx ruff check --fix; uvx ruff format'
43 changes: 43 additions & 0 deletions backend/alembic/versions/021f1bdc162b_add_categories_to_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Add categories to users

Revision ID: 021f1bdc162b
Revises: f3e847c3ee9d
Create Date: 2024-09-22 11:48:13.802703

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "021f1bdc162b"
down_revision: Union[str, None] = "f3e847c3ee9d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"user_category",
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("category_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["category_id"],
["category.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user_category")
# ### end Alembic commands ###
30 changes: 30 additions & 0 deletions backend/alembic/versions/a73902039c96_make_email_unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Make email unique

Revision ID: a73902039c96
Revises: 021f1bdc162b
Create Date: 2024-09-22 11:56:00.470507

"""

from typing import Sequence, Union

from alembic import op


# revision identifiers, used by Alembic.
revision: str = "a73902039c96"
down_revision: Union[str, None] = "021f1bdc162b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint(None, "user", ["email"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "user", type_="unique")
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ build-backend = "hatchling.build"
dev-dependencies = [
"alembic>=1.13.2",
"alembic-postgresql-enum>=1.3.0",
"pre-commit>=3.8.0",
]
6 changes: 6 additions & 0 deletions backend/src/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.auth.router import router as auth_router
from src.categories.router import router as category_router
from src.profile.router import router as profile_router
from src.events.router import router as events_router
from contextlib import asynccontextmanager

import logging
Expand Down Expand Up @@ -30,3 +33,6 @@ async def lifespan(app: FastAPI):
)

server.include_router(auth_router)
server.include_router(category_router)
server.include_router(profile_router)
server.include_router(events_router)
16 changes: 11 additions & 5 deletions backend/src/auth/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta, timezone
from fastapi import Cookie, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, selectinload
from src.common.database import engine
from .models import User
import jwt
Expand Down Expand Up @@ -46,7 +46,11 @@ def get_password_hash(password: str):

def authenticate_user(email: str, password: str):
with Session(engine) as session:
user = session.scalars(select(User).where(User.email == email)).first()
user = session.scalars(
select(User)
.where(User.email == email)
.options(selectinload(User.categories))
).first()
if not user:
return False
if not verify_password(password, user.hashed_password):
Expand Down Expand Up @@ -90,11 +94,13 @@ async def get_current_user(
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
id = payload.get("sub")
with Session(engine) as session:
staff = session.get(User, id)
if not staff:
user = session.scalar(
select(User).where(User.id == id).options(selectinload(User.categories))
)
if not user:
raise InvalidTokenError()

return staff
return user

except InvalidTokenError:
raise credentials_exception
16 changes: 14 additions & 2 deletions backend/src/auth/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
from enum import Enum
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import Column, ForeignKey, Table
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.common.base import Base
from src.events.models import Category


class AccountType(str, Enum):
NORMAL = "normal"
GOOGLE = "google"


user_category_table = Table(
"user_category",
Base.metadata,
Column("user_id", ForeignKey("user.id")),
Column("category_id", ForeignKey("category.id")),
)


class User(Base):
__tablename__ = "user"

id: Mapped[int] = mapped_column(primary_key=True)
email: Mapped[str]
email: Mapped[str] = mapped_column(unique=True)
hashed_password: Mapped[str]
account_type: Mapped[AccountType]

categories: Mapped[list[Category]] = relationship(secondary=user_category_table)
7 changes: 5 additions & 2 deletions backend/src/auth/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi.security import OAuth2PasswordRequestForm
import httpx
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from src.auth.utils import create_token
from src.common.constants import (
GOOGLE_CLIENT_ID,
Expand All @@ -23,7 +24,7 @@
)
from .models import AccountType, User

router = APIRouter(prefix="/auth")
router = APIRouter(prefix="/auth", tags=["auth"])

#######################
# username & password #
Expand Down Expand Up @@ -97,7 +98,9 @@ def auth_google(
).json()
# 2. Check for existing user.
email = user_info["email"]
user = session.scalars(select(User).where(User.email == email)).first()
user = session.scalars(
select(User).where(User.email == email).options(selectinload(User.categories))
).first()
if user:
if user.account_type == AccountType.NORMAL:
raise HTTPException(
Expand Down
3 changes: 3 additions & 0 deletions backend/src/auth/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel, ConfigDict, EmailStr, Field
from src.categories.schemas import CategoryDTO


class UserPublic(BaseModel):
Expand All @@ -7,6 +8,8 @@ class UserPublic(BaseModel):
id: int
email: EmailStr

categories: list[CategoryDTO]


class Token(BaseModel):
access_token: str
Expand Down
15 changes: 15 additions & 0 deletions backend/src/categories/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from fastapi import APIRouter, Depends
from sqlalchemy import select
from src.events.models import Category
from src.categories.schemas import CategoryDTO
from src.common.dependencies import get_session


router = APIRouter(prefix="/categories", tags=["categories"])


@router.get("/")
def get_categories(session=Depends(get_session)) -> list[CategoryDTO]:
categories = session.scalars(select(Category))
category_dtos = [CategoryDTO.model_validate(category) for category in categories]
return category_dtos
8 changes: 8 additions & 0 deletions backend/src/categories/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel, ConfigDict


class CategoryDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int
name: str
2 changes: 1 addition & 1 deletion backend/src/common/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from src.common.constants import DATABASE_URL


engine = create_engine(DATABASE_URL)
engine = create_engine(DATABASE_URL, echo=True)
53 changes: 53 additions & 0 deletions backend/src/events/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from src.auth.dependencies import get_current_user
from src.auth.models import User
from src.events.models import Category, Event
from src.common.dependencies import get_session
from src.events.schemas import EventDTO, EventIndexResponse


router = APIRouter(prefix="/events", tags=["events"])


@router.get("/")
def get_events(
_: Annotated[User, Depends(get_current_user)],
session=Depends(get_session),
category_ids: Annotated[list[int] | None, Query()] = None,
limit: int | None = None,
offset: int | None = None,
) -> EventIndexResponse:
query = select(Event.id).distinct()
if category_ids:
query = query.join(Event.categories.and_(Category.id.in_(category_ids)))
relevant_ids = [id for id in session.scalars(query)]

total_count = len(relevant_ids)
event_query = (
select(Event)
.options(selectinload(Event.categories))
.where(Event.id.in_(relevant_ids))
)
if limit is not None:
event_query = event_query.limit(limit)
if offset is not None:
event_query = event_query.offset(offset)

events = list(session.scalars(event_query))
return EventIndexResponse(total_count=total_count, count=len(events), data=events)


@router.get("/:id")
def get_event(
id: int,
_: Annotated[User, Depends(get_current_user)],
session=Depends(get_session),
) -> EventDTO:
event = session.scalar(
select(Event).where(Event.id == id).options(selectinload(Event.categories))
)
# TODO: link to more models, give more data
return event
21 changes: 21 additions & 0 deletions backend/src/events/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from src.categories.schemas import CategoryDTO


class EventDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
title: str
description: str
analysis: str
is_singapore: bool
date: datetime

categories: list[CategoryDTO]


class EventIndexResponse(BaseModel):
total_count: int
count: int
data: list[EventDTO]
32 changes: 32 additions & 0 deletions backend/src/profile/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy import select
from src.auth.dependencies import get_current_user
from src.auth.models import User
from src.auth.schemas import UserPublic
from src.events.models import Category
from src.common.dependencies import get_session
from src.profile.schemas import ProfileUpdate


router = APIRouter(prefix="/profile", tags=["profile"])


@router.put("/")
def update_profile(
data: ProfileUpdate,
user: Annotated[User, Depends(get_current_user)],
session=Depends(get_session),
) -> UserPublic:
categories = session.scalars(
select(Category).where(Category.id.in_(data.category_ids))
).all()

user = session.get(User, user.id)
user.categories = categories

session.add(user)
session.commit()
session.refresh(user)

return user
5 changes: 5 additions & 0 deletions backend/src/profile/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel


class ProfileUpdate(BaseModel):
category_ids: list[int]
Loading
Loading