Skip to content

Commit

Permalink
refactored project routers to encapsulate db interaction a ProjectSto…
Browse files Browse the repository at this point in the history
…re dependency
  • Loading branch information
KonradUdoHannes committed Jul 28, 2023
1 parent 40a6385 commit 59dbb12
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 71 deletions.
130 changes: 82 additions & 48 deletions backend/src/api/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@ class UpdateProject(pydantic.BaseModel):
status: str | None = None


async def get_project(project_id: uuid.UUID, session: AsyncSession) -> Project:
try:
project_entry = (
await session.execute(select(Project).filter_by(id=project_id))
).scalar_one()
except NoResultFound:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return project_entry


async def transactional_session() -> AsyncIterator[AsyncSession]:
async with async_session_maker.begin() as session:
yield session
Expand All @@ -50,20 +40,81 @@ async def transactional_session() -> AsyncIterator[AsyncSession]:
SessionWithTransactionContext = Annotated[AsyncSession, Depends(transactional_session)]


class ProjectStore:
def __init__(self, session: SessionWithTransactionContext):
self.session = session

async def _get_orm_project(self, project_id: uuid.UUID) -> Project:
try:
project_entry = (
await self.session.execute(select(Project).filter_by(id=project_id))
).scalar_one()
except NoResultFound:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return project_entry

async def get(self, project_id: uuid.UUID) -> ReadProject:
project_entry = await self._get_orm_project(project_id)
return ReadProject(
id=project_entry.id,
title=project_entry.title,
summary=project_entry.summary,
status=project_entry.status,
)

async def get_all(self) -> list[ReadProject]:
stmt = select(Project)
projects = (await self.session.scalars(stmt)).all()
return [
ReadProject(
id=p.id,
title=p.title,
summary=p.summary,
status=p.status,
)
for p in projects
]

async def create(self, project: CreateProject) -> ReadProject:
new_project = Project(id=uuid.uuid4(), **project.dict())
self.session.add(new_project)
return ReadProject(
id=new_project.id,
title=new_project.title,
summary=new_project.summary,
status=new_project.status,
)

async def delete(self, project_id: uuid.UUID) -> None:
project_entry = await self._get_orm_project(project_id)
await self.session.delete(project_entry)
return None

async def update(
self, project_id: uuid.UUID, project: UpdateProject
) -> ReadProject:
project_entry = await self._get_orm_project(project_id)
for key, value in project.model_dump(exclude_defaults=True).items():
if getattr(project_entry, key) != value:
setattr(project_entry, key, value)
return ReadProject(
id=project_entry.id,
title=project_entry.title,
summary=project_entry.summary,
status=project_entry.status,
)


@router.get(
"/",
response_description="List all projects",
dependencies=[Depends(current_active_user)],
)
async def list_projects(
session: SessionWithTransactionContext,
project_store: Annotated[ProjectStore, Depends()],
) -> list[ReadProject]:
stmt = select(Project)
projects = (await session.scalars(stmt)).all()
return [
ReadProject(id=p.id, title=p.title, summary=p.summary, status=p.status)
for p in projects
]
projects = await project_store.get_all()
return projects


@router.post(
Expand All @@ -73,16 +124,11 @@ async def list_projects(
dependencies=[Depends(current_active_user)],
)
async def create_project(
project: CreateProject, session: SessionWithTransactionContext
project: CreateProject,
project_store: Annotated[ProjectStore, Depends()],
) -> ReadProject:
new_project = Project(id=uuid.uuid4(), **project.dict())
session.add(new_project)
return ReadProject(
id=new_project.id,
title=new_project.title,
summary=new_project.summary,
status=new_project.status,
)
new_project = await project_store.create(project)
return new_project


@router.get(
Expand All @@ -91,15 +137,11 @@ async def create_project(
dependencies=[Depends(current_active_user)],
)
async def read_project(
project_id: uuid.UUID, session: SessionWithTransactionContext
project_id: uuid.UUID,
project_store: Annotated[ProjectStore, Depends()],
) -> ReadProject:
project_entry = await get_project(project_id, session)
return ReadProject(
id=project_entry.id,
title=project_entry.title,
summary=project_entry.summary,
status=project_entry.status,
)
project_entry = await project_store.get(project_id)
return project_entry


@router.delete(
Expand All @@ -109,10 +151,10 @@ async def read_project(
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_project(
project_id: uuid.UUID, session: SessionWithTransactionContext
project_id: uuid.UUID,
project_store: Annotated[ProjectStore, Depends()],
) -> None:
project_entry = await get_project(project_id, session)
await session.delete(project_entry)
await project_store.delete(project_id)
return None


Expand All @@ -124,15 +166,7 @@ async def delete_project(
async def update_project(
project_id: uuid.UUID,
project: UpdateProject,
session: SessionWithTransactionContext,
project_store: Annotated[ProjectStore, Depends()],
) -> ReadProject:
project_entry = await get_project(project_id, session)
for key, value in project.model_dump(exclude_defaults=True).items():
if getattr(project_entry, key) != value:
setattr(project_entry, key, value)
return ReadProject(
id=project_entry.id,
title=project_entry.title,
summary=project_entry.summary,
status=project_entry.status,
)
project_entry = await project_store.update(project_id, project)
return project_entry
1 change: 1 addition & 0 deletions backend/src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def dsn(self) -> str:

class TestConfiguration(BaseModel):
test_database: bool = False
use_latest_migration: bool = True


def is_src_root_dir(directory: Path):
Expand Down
1 change: 1 addition & 0 deletions backend/tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .client import * # noqa: F403, F401
from .database import * # noqa: F403, F401
from .overrides import * # noqa: F403, F401
64 changes: 41 additions & 23 deletions backend/tests/fixtures/database.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
import contextlib
import uuid
from pathlib import Path

import pytest
from alembic import command
from alembic.config import Config
from api.auth.users import get_user_manager
from api.routers.auth import UserCreate
from database.session import get_async_session, get_user_service
from models.project import Project
from models.user import User
from settings import settings
from sqlalchemy import delete

from ..base import database_test

get_async_session_context = contextlib.asynccontextmanager(get_async_session)
get_user_service_context = contextlib.asynccontextmanager(get_user_service)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)


@pytest.fixture(scope="session")
def alembic_upgrade():
if settings.tests.test_database and settings.tests.use_latest_migration:
backend_root = Path(__file__).parent.parent.parent
alembic_cfg = Config(backend_root / "src" / "alembic.ini")
migration_directory = alembic_cfg.get_main_option("script_location")
alembic_cfg.set_main_option(
"script_location", str(backend_root / "src" / migration_directory)
)
command.upgrade(alembic_cfg, "head")


@pytest.fixture(scope="session")
def admin_details():
return {
Expand All @@ -26,9 +40,8 @@ def admin_details():
}


@database_test
@pytest.fixture(scope="session", autouse=True)
async def admin_user(admin_details):
async def admin_user(admin_details, alembic_upgrade):
"""Creates a user with admin privileges for testing purposes.
The test does "tear down" up front because asyncio errors in tests
Expand All @@ -37,14 +50,17 @@ async def admin_user(admin_details):
in the database are idempotent and also succeed if the user does
not exist.
"""
async with get_async_session_context() as session:
delete_stmt = delete(User).where(User.email == admin_details.get("email"))
await session.execute(delete_stmt)
await session.commit()

async with get_user_service_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
await user_manager.create(UserCreate(**admin_details))
if settings.tests.test_database:
async with get_async_session_context() as session:
delete_stmt = delete(User).where(User.email == admin_details.get("email"))
await session.execute(delete_stmt)
await session.commit()

async with get_user_service_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
await user_manager.create(UserCreate(**admin_details))
yield
else:
yield


Expand All @@ -57,18 +73,20 @@ async def project_details():
}


@database_test
@pytest.fixture(scope="session", autouse=True)
async def example_project(project_details):
async with get_async_session_context() as session:
delete_stmt = delete(Project).where(
Project.title == project_details.get("title")
)
await session.execute(delete_stmt)
await session.commit()
async def example_project(project_details, alembic_upgrade):
if settings.tests.test_database:
async with get_async_session_context() as session:
delete_stmt = delete(Project).where(
Project.title == project_details.get("title")
)
await session.execute(delete_stmt)
await session.commit()

project = Project(**project_details, id=uuid.uuid4())
session.add(project)
await session.commit()
project = Project(**project_details, id=uuid.uuid4())
session.add(project)
await session.commit()

yield
else:
yield
52 changes: 52 additions & 0 deletions backend/tests/fixtures/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import uuid

import pytest
from api.app import app
from api.auth.users import current_active_user
from api.routers.projects import ProjectStore, ReadProject
from fastapi import HTTPException, status


class ProjectDictStore:
projects: dict[uuid.UUID, ReadProject] = {}

async def get(self, project_id):
try:
return self.projects[project_id]
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

async def get_all(self):
return self.projects.values()

async def create(self, project):
project_id = uuid.uuid4()
project_entry = ReadProject(**project.dict(), id=project_id)
self.projects[project_id] = project_entry
return project_entry

async def update(self, project_id, project):
project_entry = await self.get(project_id)
project_entry.title = project.title
project_entry.summary = project.summary
project_entry.status = project.status
return project_entry

async def delete(self, project_id):
await self.get(project_id)
del self.projects[project_id]
return None


@pytest.fixture
def override_project_store():
app.dependency_overrides[ProjectStore] = ProjectDictStore
yield
del app.dependency_overrides[ProjectStore]


@pytest.fixture
def override_active_user():
app.dependency_overrides[current_active_user] = lambda: {}
yield
del app.dependency_overrides[current_active_user]
37 changes: 37 additions & 0 deletions backend/tests/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,40 @@ async def test_project_crud(async_client, admin_details):

response = await ac.get(f"/projects/{project_id}", headers=auth_cookie_header)
assert response.status_code == 404


async def test_project_crud_no_db(
async_client, admin_details, override_project_store, override_active_user
):
project = {
"title": "Inserted Project",
"status": "running",
}

async with async_client as ac:
response = await ac.post("/projects/", json=project)
assert response.status_code == 201

project_id = response.json()["id"]

response = await ac.get(f"/projects/{project_id}")
assert response.status_code == 200
assert response.json()["title"] == project["title"]
assert response.json()["status"] == project["status"]
assert response.json()["summary"] is None

project["title"] = "Updated Project"
project["summary"] = "Updated Summary"

response = await ac.patch(f"/projects/{project_id}", json=project)
response = await ac.get(f"/projects/{project_id}")
assert response.status_code == 200
assert response.json()["title"] == project["title"]
assert response.json()["status"] == project["status"]
assert response.json()["summary"] == project["summary"]

response = await ac.delete(f"/projects/{project_id}")
assert response.status_code == 204

response = await ac.get(f"/projects/{project_id}")
assert response.status_code == 404

0 comments on commit 59dbb12

Please sign in to comment.