From 59dbb12c4dce9bdd8cc4da63df34c1ae91897be8 Mon Sep 17 00:00:00 2001 From: KonradUdoHannes Date: Fri, 28 Jul 2023 19:54:21 +0200 Subject: [PATCH] refactored project routers to encapsulate db interaction a ProjectStore dependency --- backend/src/api/routers/projects.py | 130 ++++++++++++++++++---------- backend/src/settings.py | 1 + backend/tests/fixtures/__init__.py | 1 + backend/tests/fixtures/database.py | 64 +++++++++----- backend/tests/fixtures/overrides.py | 52 +++++++++++ backend/tests/test_projects.py | 37 ++++++++ 6 files changed, 214 insertions(+), 71 deletions(-) create mode 100644 backend/tests/fixtures/overrides.py diff --git a/backend/src/api/routers/projects.py b/backend/src/api/routers/projects.py index 4284399..39d5ae5 100644 --- a/backend/src/api/routers/projects.py +++ b/backend/src/api/routers/projects.py @@ -32,16 +32,6 @@ class UpdateProject(pydantic.BaseModel): status: str | None = None -async def get_project(project_id: uuid.UUID, session: AsyncSession) -> Project: - try: - project_entry = ( - await session.execute(select(Project).filter_by(id=project_id)) - ).scalar_one() - except NoResultFound: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - return project_entry - - async def transactional_session() -> AsyncIterator[AsyncSession]: async with async_session_maker.begin() as session: yield session @@ -50,20 +40,81 @@ async def transactional_session() -> AsyncIterator[AsyncSession]: SessionWithTransactionContext = Annotated[AsyncSession, Depends(transactional_session)] +class ProjectStore: + def __init__(self, session: SessionWithTransactionContext): + self.session = session + + async def _get_orm_project(self, project_id: uuid.UUID) -> Project: + try: + project_entry = ( + await self.session.execute(select(Project).filter_by(id=project_id)) + ).scalar_one() + except NoResultFound: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + return project_entry + + async def get(self, project_id: uuid.UUID) -> ReadProject: + project_entry = await self._get_orm_project(project_id) + return ReadProject( + id=project_entry.id, + title=project_entry.title, + summary=project_entry.summary, + status=project_entry.status, + ) + + async def get_all(self) -> list[ReadProject]: + stmt = select(Project) + projects = (await self.session.scalars(stmt)).all() + return [ + ReadProject( + id=p.id, + title=p.title, + summary=p.summary, + status=p.status, + ) + for p in projects + ] + + async def create(self, project: CreateProject) -> ReadProject: + new_project = Project(id=uuid.uuid4(), **project.dict()) + self.session.add(new_project) + return ReadProject( + id=new_project.id, + title=new_project.title, + summary=new_project.summary, + status=new_project.status, + ) + + async def delete(self, project_id: uuid.UUID) -> None: + project_entry = await self._get_orm_project(project_id) + await self.session.delete(project_entry) + return None + + async def update( + self, project_id: uuid.UUID, project: UpdateProject + ) -> ReadProject: + project_entry = await self._get_orm_project(project_id) + for key, value in project.model_dump(exclude_defaults=True).items(): + if getattr(project_entry, key) != value: + setattr(project_entry, key, value) + return ReadProject( + id=project_entry.id, + title=project_entry.title, + summary=project_entry.summary, + status=project_entry.status, + ) + + @router.get( "/", response_description="List all projects", dependencies=[Depends(current_active_user)], ) async def list_projects( - session: SessionWithTransactionContext, + project_store: Annotated[ProjectStore, Depends()], ) -> list[ReadProject]: - stmt = select(Project) - projects = (await session.scalars(stmt)).all() - return [ - ReadProject(id=p.id, title=p.title, summary=p.summary, status=p.status) - for p in projects - ] + projects = await project_store.get_all() + return projects @router.post( @@ -73,16 +124,11 @@ async def list_projects( dependencies=[Depends(current_active_user)], ) async def create_project( - project: CreateProject, session: SessionWithTransactionContext + project: CreateProject, + project_store: Annotated[ProjectStore, Depends()], ) -> ReadProject: - new_project = Project(id=uuid.uuid4(), **project.dict()) - session.add(new_project) - return ReadProject( - id=new_project.id, - title=new_project.title, - summary=new_project.summary, - status=new_project.status, - ) + new_project = await project_store.create(project) + return new_project @router.get( @@ -91,15 +137,11 @@ async def create_project( dependencies=[Depends(current_active_user)], ) async def read_project( - project_id: uuid.UUID, session: SessionWithTransactionContext + project_id: uuid.UUID, + project_store: Annotated[ProjectStore, Depends()], ) -> ReadProject: - project_entry = await get_project(project_id, session) - return ReadProject( - id=project_entry.id, - title=project_entry.title, - summary=project_entry.summary, - status=project_entry.status, - ) + project_entry = await project_store.get(project_id) + return project_entry @router.delete( @@ -109,10 +151,10 @@ async def read_project( status_code=status.HTTP_204_NO_CONTENT, ) async def delete_project( - project_id: uuid.UUID, session: SessionWithTransactionContext + project_id: uuid.UUID, + project_store: Annotated[ProjectStore, Depends()], ) -> None: - project_entry = await get_project(project_id, session) - await session.delete(project_entry) + await project_store.delete(project_id) return None @@ -124,15 +166,7 @@ async def delete_project( async def update_project( project_id: uuid.UUID, project: UpdateProject, - session: SessionWithTransactionContext, + project_store: Annotated[ProjectStore, Depends()], ) -> ReadProject: - project_entry = await get_project(project_id, session) - for key, value in project.model_dump(exclude_defaults=True).items(): - if getattr(project_entry, key) != value: - setattr(project_entry, key, value) - return ReadProject( - id=project_entry.id, - title=project_entry.title, - summary=project_entry.summary, - status=project_entry.status, - ) + project_entry = await project_store.update(project_id, project) + return project_entry diff --git a/backend/src/settings.py b/backend/src/settings.py index d97fc3d..b9ff52f 100644 --- a/backend/src/settings.py +++ b/backend/src/settings.py @@ -29,6 +29,7 @@ def dsn(self) -> str: class TestConfiguration(BaseModel): test_database: bool = False + use_latest_migration: bool = True def is_src_root_dir(directory: Path): diff --git a/backend/tests/fixtures/__init__.py b/backend/tests/fixtures/__init__.py index 88697a3..7f141c2 100644 --- a/backend/tests/fixtures/__init__.py +++ b/backend/tests/fixtures/__init__.py @@ -1,2 +1,3 @@ from .client import * # noqa: F403, F401 from .database import * # noqa: F403, F401 +from .overrides import * # noqa: F403, F401 diff --git a/backend/tests/fixtures/database.py b/backend/tests/fixtures/database.py index 305197f..2787212 100644 --- a/backend/tests/fixtures/database.py +++ b/backend/tests/fixtures/database.py @@ -1,21 +1,35 @@ import contextlib import uuid +from pathlib import Path import pytest +from alembic import command +from alembic.config import Config from api.auth.users import get_user_manager from api.routers.auth import UserCreate from database.session import get_async_session, get_user_service from models.project import Project from models.user import User +from settings import settings from sqlalchemy import delete -from ..base import database_test - get_async_session_context = contextlib.asynccontextmanager(get_async_session) get_user_service_context = contextlib.asynccontextmanager(get_user_service) get_user_manager_context = contextlib.asynccontextmanager(get_user_manager) +@pytest.fixture(scope="session") +def alembic_upgrade(): + if settings.tests.test_database and settings.tests.use_latest_migration: + backend_root = Path(__file__).parent.parent.parent + alembic_cfg = Config(backend_root / "src" / "alembic.ini") + migration_directory = alembic_cfg.get_main_option("script_location") + alembic_cfg.set_main_option( + "script_location", str(backend_root / "src" / migration_directory) + ) + command.upgrade(alembic_cfg, "head") + + @pytest.fixture(scope="session") def admin_details(): return { @@ -26,9 +40,8 @@ def admin_details(): } -@database_test @pytest.fixture(scope="session", autouse=True) -async def admin_user(admin_details): +async def admin_user(admin_details, alembic_upgrade): """Creates a user with admin privileges for testing purposes. The test does "tear down" up front because asyncio errors in tests @@ -37,14 +50,17 @@ async def admin_user(admin_details): in the database are idempotent and also succeed if the user does not exist. """ - async with get_async_session_context() as session: - delete_stmt = delete(User).where(User.email == admin_details.get("email")) - await session.execute(delete_stmt) - await session.commit() - - async with get_user_service_context(session) as user_db: - async with get_user_manager_context(user_db) as user_manager: - await user_manager.create(UserCreate(**admin_details)) + if settings.tests.test_database: + async with get_async_session_context() as session: + delete_stmt = delete(User).where(User.email == admin_details.get("email")) + await session.execute(delete_stmt) + await session.commit() + + async with get_user_service_context(session) as user_db: + async with get_user_manager_context(user_db) as user_manager: + await user_manager.create(UserCreate(**admin_details)) + yield + else: yield @@ -57,18 +73,20 @@ async def project_details(): } -@database_test @pytest.fixture(scope="session", autouse=True) -async def example_project(project_details): - async with get_async_session_context() as session: - delete_stmt = delete(Project).where( - Project.title == project_details.get("title") - ) - await session.execute(delete_stmt) - await session.commit() +async def example_project(project_details, alembic_upgrade): + if settings.tests.test_database: + async with get_async_session_context() as session: + delete_stmt = delete(Project).where( + Project.title == project_details.get("title") + ) + await session.execute(delete_stmt) + await session.commit() - project = Project(**project_details, id=uuid.uuid4()) - session.add(project) - await session.commit() + project = Project(**project_details, id=uuid.uuid4()) + session.add(project) + await session.commit() + yield + else: yield diff --git a/backend/tests/fixtures/overrides.py b/backend/tests/fixtures/overrides.py new file mode 100644 index 0000000..f87aedc --- /dev/null +++ b/backend/tests/fixtures/overrides.py @@ -0,0 +1,52 @@ +import uuid + +import pytest +from api.app import app +from api.auth.users import current_active_user +from api.routers.projects import ProjectStore, ReadProject +from fastapi import HTTPException, status + + +class ProjectDictStore: + projects: dict[uuid.UUID, ReadProject] = {} + + async def get(self, project_id): + try: + return self.projects[project_id] + except KeyError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + + async def get_all(self): + return self.projects.values() + + async def create(self, project): + project_id = uuid.uuid4() + project_entry = ReadProject(**project.dict(), id=project_id) + self.projects[project_id] = project_entry + return project_entry + + async def update(self, project_id, project): + project_entry = await self.get(project_id) + project_entry.title = project.title + project_entry.summary = project.summary + project_entry.status = project.status + return project_entry + + async def delete(self, project_id): + await self.get(project_id) + del self.projects[project_id] + return None + + +@pytest.fixture +def override_project_store(): + app.dependency_overrides[ProjectStore] = ProjectDictStore + yield + del app.dependency_overrides[ProjectStore] + + +@pytest.fixture +def override_active_user(): + app.dependency_overrides[current_active_user] = lambda: {} + yield + del app.dependency_overrides[current_active_user] diff --git a/backend/tests/test_projects.py b/backend/tests/test_projects.py index 6987290..92696af 100644 --- a/backend/tests/test_projects.py +++ b/backend/tests/test_projects.py @@ -62,3 +62,40 @@ async def test_project_crud(async_client, admin_details): response = await ac.get(f"/projects/{project_id}", headers=auth_cookie_header) assert response.status_code == 404 + + +async def test_project_crud_no_db( + async_client, admin_details, override_project_store, override_active_user +): + project = { + "title": "Inserted Project", + "status": "running", + } + + async with async_client as ac: + response = await ac.post("/projects/", json=project) + assert response.status_code == 201 + + project_id = response.json()["id"] + + response = await ac.get(f"/projects/{project_id}") + assert response.status_code == 200 + assert response.json()["title"] == project["title"] + assert response.json()["status"] == project["status"] + assert response.json()["summary"] is None + + project["title"] = "Updated Project" + project["summary"] = "Updated Summary" + + response = await ac.patch(f"/projects/{project_id}", json=project) + response = await ac.get(f"/projects/{project_id}") + assert response.status_code == 200 + assert response.json()["title"] == project["title"] + assert response.json()["status"] == project["status"] + assert response.json()["summary"] == project["summary"] + + response = await ac.delete(f"/projects/{project_id}") + assert response.status_code == 204 + + response = await ac.get(f"/projects/{project_id}") + assert response.status_code == 404