Skip to content

Commit

Permalink
refactored project routers to encapsulate db interaction a ProjectSto…
Browse files Browse the repository at this point in the history
…re dependency
  • Loading branch information
KonradUdoHannes committed Jul 28, 2023
1 parent 039a92e commit 069219e
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 84 deletions.
127 changes: 67 additions & 60 deletions backend/src/api/routers/projects.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,81 @@
import uuid
from typing import Annotated, AsyncIterator
from typing import Annotated

import pydantic
from api.auth.users import current_active_user
from database.session import async_session_maker
from database.session import transactional_session
from fastapi import APIRouter, Depends, HTTPException, status
from models.project import Project
from pydantic import BaseModel, ConfigDict, TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession

router = APIRouter()


class ReadProject(pydantic.BaseModel):
class ReadProject(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: uuid.UUID
title: str
summary: str | None = None
status: str


class CreateProject(pydantic.BaseModel):
class CreateProject(BaseModel):
title: str
summary: str | None = None
status: str


class UpdateProject(pydantic.BaseModel):
class UpdateProject(BaseModel):
title: str | None = None
summary: str | None = None
status: str | None = None


async def get_project(project_id: uuid.UUID, session: AsyncSession) -> Project:
try:
project_entry = (
await session.execute(select(Project).filter_by(id=project_id))
).scalar_one()
except NoResultFound:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return project_entry


async def transactional_session() -> AsyncIterator[AsyncSession]:
async with async_session_maker.begin() as session:
yield session


SessionWithTransactionContext = Annotated[AsyncSession, Depends(transactional_session)]
class ProjectStore:
def __init__(
self, session: Annotated[AsyncSession, Depends(transactional_session)]
):
self.session = session

async def _get_orm_project(self, project_id: uuid.UUID) -> Project:
try:
project_entry = (
await self.session.execute(select(Project).filter_by(id=project_id))
).scalar_one()
except NoResultFound:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return project_entry

async def get(self, project_id: uuid.UUID) -> ReadProject:
project_entry = await self._get_orm_project(project_id)
return TypeAdapter(ReadProject).validate_python(project_entry)

async def get_all(self) -> list[ReadProject]:
stmt = select(Project)
projects = (await self.session.scalars(stmt)).all()
return TypeAdapter(list[ReadProject]).validate_python(projects)

async def create(self, project: CreateProject) -> ReadProject:
new_project = Project(id=uuid.uuid4(), **project.model_dump())
self.session.add(new_project)
return TypeAdapter(ReadProject).validate_python(new_project)

async def delete(self, project_id: uuid.UUID) -> None:
project_entry = await self._get_orm_project(project_id)
await self.session.delete(project_entry)
return None

async def update(
self, project_id: uuid.UUID, project: UpdateProject
) -> ReadProject:
project_entry = await self._get_orm_project(project_id)
for key, value in project.model_dump(exclude_defaults=True).items():
if getattr(project_entry, key) != value:
setattr(project_entry, key, value)
return TypeAdapter(ReadProject).validate_python(project_entry)


@router.get(
Expand All @@ -56,14 +84,10 @@ async def transactional_session() -> AsyncIterator[AsyncSession]:
dependencies=[Depends(current_active_user)],
)
async def list_projects(
session: SessionWithTransactionContext,
project_store: Annotated[ProjectStore, Depends()],
) -> list[ReadProject]:
stmt = select(Project)
projects = (await session.scalars(stmt)).all()
return [
ReadProject(id=p.id, title=p.title, summary=p.summary, status=p.status)
for p in projects
]
projects = await project_store.get_all()
return projects


@router.post(
Expand All @@ -73,16 +97,11 @@ async def list_projects(
dependencies=[Depends(current_active_user)],
)
async def create_project(
project: CreateProject, session: SessionWithTransactionContext
project: CreateProject,
project_store: Annotated[ProjectStore, Depends()],
) -> ReadProject:
new_project = Project(id=uuid.uuid4(), **project.dict())
session.add(new_project)
return ReadProject(
id=new_project.id,
title=new_project.title,
summary=new_project.summary,
status=new_project.status,
)
new_project = await project_store.create(project)
return new_project


@router.get(
Expand All @@ -91,15 +110,11 @@ async def create_project(
dependencies=[Depends(current_active_user)],
)
async def read_project(
project_id: uuid.UUID, session: SessionWithTransactionContext
project_id: uuid.UUID,
project_store: Annotated[ProjectStore, Depends()],
) -> ReadProject:
project_entry = await get_project(project_id, session)
return ReadProject(
id=project_entry.id,
title=project_entry.title,
summary=project_entry.summary,
status=project_entry.status,
)
project_entry = await project_store.get(project_id)
return project_entry


@router.delete(
Expand All @@ -109,10 +124,10 @@ async def read_project(
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_project(
project_id: uuid.UUID, session: SessionWithTransactionContext
project_id: uuid.UUID,
project_store: Annotated[ProjectStore, Depends()],
) -> None:
project_entry = await get_project(project_id, session)
await session.delete(project_entry)
await project_store.delete(project_id)
return None


Expand All @@ -124,15 +139,7 @@ async def delete_project(
async def update_project(
project_id: uuid.UUID,
project: UpdateProject,
session: SessionWithTransactionContext,
project_store: Annotated[ProjectStore, Depends()],
) -> ReadProject:
project_entry = await get_project(project_id, session)
for key, value in project.model_dump(exclude_defaults=True).items():
if getattr(project_entry, key) != value:
setattr(project_entry, key, value)
return ReadProject(
id=project_entry.id,
title=project_entry.title,
summary=project_entry.summary,
status=project_entry.status,
)
project_entry = await project_store.update(project_id, project)
return project_entry
7 changes: 6 additions & 1 deletion backend/src/database/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncGenerator
from typing import AsyncGenerator, AsyncIterator

from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase
Expand Down Expand Up @@ -26,3 +26,8 @@ async def get_user_service(session: AsyncSession = Depends(get_async_session)):

async def get_access_token_service(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyAccessTokenDatabase(session, AccessToken)


async def transactional_session() -> AsyncIterator[AsyncSession]:
async with async_session_maker.begin() as session:
yield session
1 change: 1 addition & 0 deletions backend/src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def dsn(self) -> str:

class TestConfiguration(BaseModel):
test_database: bool = False
use_latest_migration: bool = True


def is_src_root_dir(directory: Path):
Expand Down
1 change: 1 addition & 0 deletions backend/tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .client import * # noqa: F403, F401
from .database import * # noqa: F403, F401
from .overrides import * # noqa: F403, F401
64 changes: 41 additions & 23 deletions backend/tests/fixtures/database.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
import contextlib
import uuid
from pathlib import Path

import pytest
from alembic import command
from alembic.config import Config
from api.auth.users import get_user_manager
from api.routers.auth import UserCreate
from database.session import get_async_session, get_user_service
from models.project import Project
from models.user import User
from settings import settings
from sqlalchemy import delete

from ..base import database_test

get_async_session_context = contextlib.asynccontextmanager(get_async_session)
get_user_service_context = contextlib.asynccontextmanager(get_user_service)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)


@pytest.fixture(scope="session")
def alembic_upgrade():
if settings.tests.test_database and settings.tests.use_latest_migration:
backend_root = Path(__file__).parent.parent.parent
alembic_cfg = Config(backend_root / "src" / "alembic.ini")
migration_directory = alembic_cfg.get_main_option("script_location")
alembic_cfg.set_main_option(
"script_location", str(backend_root / "src" / migration_directory)
)
command.upgrade(alembic_cfg, "head")


@pytest.fixture(scope="session")
def admin_details():
return {
Expand All @@ -26,9 +40,8 @@ def admin_details():
}


@database_test
@pytest.fixture(scope="session", autouse=True)
async def admin_user(admin_details):
async def admin_user(admin_details, alembic_upgrade):
"""Creates a user with admin privileges for testing purposes.
The test does "tear down" up front because asyncio errors in tests
Expand All @@ -37,14 +50,17 @@ async def admin_user(admin_details):
in the database are idempotent and also succeed if the user does
not exist.
"""
async with get_async_session_context() as session:
delete_stmt = delete(User).where(User.email == admin_details.get("email"))
await session.execute(delete_stmt)
await session.commit()

async with get_user_service_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
await user_manager.create(UserCreate(**admin_details))
if settings.tests.test_database:
async with get_async_session_context() as session:
delete_stmt = delete(User).where(User.email == admin_details.get("email"))
await session.execute(delete_stmt)
await session.commit()

async with get_user_service_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
await user_manager.create(UserCreate(**admin_details))
yield
else:
yield


Expand All @@ -57,18 +73,20 @@ async def project_details():
}


@database_test
@pytest.fixture(scope="session", autouse=True)
async def example_project(project_details):
async with get_async_session_context() as session:
delete_stmt = delete(Project).where(
Project.title == project_details.get("title")
)
await session.execute(delete_stmt)
await session.commit()
async def example_project(project_details, alembic_upgrade):
if settings.tests.test_database:
async with get_async_session_context() as session:
delete_stmt = delete(Project).where(
Project.title == project_details.get("title")
)
await session.execute(delete_stmt)
await session.commit()

project = Project(**project_details, id=uuid.uuid4())
session.add(project)
await session.commit()
project = Project(**project_details, id=uuid.uuid4())
session.add(project)
await session.commit()

yield
else:
yield
52 changes: 52 additions & 0 deletions backend/tests/fixtures/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import uuid

import pytest
from api.app import app
from api.auth.users import current_active_user
from api.routers.projects import ProjectStore, ReadProject
from fastapi import HTTPException, status


class ProjectDictStore:
projects: dict[uuid.UUID, ReadProject] = {}

async def get(self, project_id):
try:
return self.projects[project_id]
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

async def get_all(self):
return self.projects.values()

async def create(self, project):
project_id = uuid.uuid4()
project_entry = ReadProject(**project.model_dump(), id=project_id)
self.projects[project_id] = project_entry
return project_entry

async def update(self, project_id, project):
project_entry = await self.get(project_id)
project_entry.title = project.title
project_entry.summary = project.summary
project_entry.status = project.status
return project_entry

async def delete(self, project_id):
await self.get(project_id)
del self.projects[project_id]
return None


@pytest.fixture
def override_project_store():
app.dependency_overrides[ProjectStore] = ProjectDictStore
yield
del app.dependency_overrides[ProjectStore]


@pytest.fixture
def override_active_user():
app.dependency_overrides[current_active_user] = lambda: {}
yield
del app.dependency_overrides[current_active_user]
Loading

0 comments on commit 069219e

Please sign in to comment.