From b5bd1c403eafe50d661cae8bb148e2ae21c34c76 Mon Sep 17 00:00:00 2001 From: "K. Allagbe" Date: Mon, 4 Nov 2024 13:33:38 -0500 Subject: [PATCH] issue #178: connection manager no longer needed --- app/config.py | 5 +- app/connection_manager.py | 110 ---------------------------- app/controllers/inspections.py | 20 ++--- app/controllers/users.py | 16 ++-- app/dependencies.py | 13 ++-- app/main.py | 24 +++--- tests/test_app.py | 13 ++-- tests/test_connection_manager.py | 122 ------------------------------- tests/test_inspections.py | 82 ++++++++++----------- tests/test_users_controller.py | 57 ++++++++------- 10 files changed, 112 insertions(+), 350 deletions(-) delete mode 100644 app/connection_manager.py delete mode 100644 tests/test_connection_manager.py diff --git a/app/config.py b/app/config.py index ea3ba14..d082e62 100644 --- a/app/config.py +++ b/app/config.py @@ -5,8 +5,6 @@ from pydantic import Field, PostgresDsn from pydantic_settings import BaseSettings -from app.connection_manager import ConnectionManager - load_dotenv() @@ -34,8 +32,7 @@ def configure(app: FastAPI, settings: Settings): conninfo=settings.fertiscan_db_url.unicode_string(), kwargs={"options": f"-c search_path={settings.fertiscan_schema},public"}, ) - connection_manager = ConnectionManager(pool, settings.testing.lower() == "true") - app.connection_manager = connection_manager + app.pool = pool # Initialize OCR ocr = OCR(api_endpoint=settings.api_endpoint, api_key=settings.api_key) diff --git a/app/connection_manager.py b/app/connection_manager.py deleted file mode 100644 index ad59b5a..0000000 --- a/app/connection_manager.py +++ /dev/null @@ -1,110 +0,0 @@ -from psycopg_pool import ConnectionPool - - -class ConnectionManager: - """ - Manages database connections using a connection pool for PostgreSQL. - - This class is used for managing database connections, transactions, and pooling - efficiently. It can be utilized as a context manager, making sure that resources - are properly managed (connections are committed, rolled back, and released). - - In testing mode, it avoids committing transactions and releasing connections - back to the pool, enabling more control over connections. - - Attributes: - testing (bool): Indicates if the application is in testing mode. - pool (ConnectionPool): The connection pool for managing database connections. - connection (Connection): The active database connection, if any. - """ - - def __init__(self, pool: ConnectionPool, testing: bool = False): - """ - Initializes the ConnectionManager with a connection pool. - - Args: - pool (ConnectionPool): The connection pool used to manage database connections. - """ - - if not isinstance(pool, ConnectionPool): - raise ValueError("A connection pool is required.") - - self.testing = testing - self.pool = pool - self.connection = None - - def get(self): - """ - Retrieves an active connection from the connection pool. - - If no connection is active, it acquires a new one from the pool. - - Returns: - Connection: An active database connection. - """ - if self.connection is None: - self.connection = self.pool.getconn() - return self.connection - - def get_cursor(self): - """ - Retrieves a cursor from the active connection. - - If no connection is active, it obtains a new one from the pool - and creates a cursor from it. - - Returns: - Cursor: A cursor for executing SQL queries. - """ - return self.get().cursor() - - def put(self): - """ - Releases the active connection back to the pool if not in testing mode. - - If in testing mode, the connection is not released to the pool. - """ - if not self.testing and self.connection is not None: - self.pool.putconn(self.connection) - self.connection = None - - def commit(self): - """ - Commits the current transaction if not in testing mode. - """ - if not self.testing and self.connection is not None: - self.connection.commit() - - def rollback(self): - """ - Rolls back the current transaction, if an active connection exists. - """ - if self.connection is not None: - self.connection.rollback() - - def __enter__(self): - """ - Enters the context and returns the ConnectionManager instance. - - Returns: - ConnectionManager: The current instance of ConnectionManager. - """ - return self - - def __exit__(self, exc_type, exc_value, tb): - """ - Exits the context, managing transactions based on exceptions. - - Commits the transaction if no exception occurred, otherwise rolls it back. - Finally, releases the connection back to the pool. - - Args: - exc_type (type): The type of the exception (if any). - exc_value (Exception): The exception instance (if any). - tb (traceback): The traceback object (if any). - """ - if exc_type is None: - self.commit() - else: - self.rollback() - self.put() diff --git a/app/controllers/inspections.py b/app/controllers/inspections.py index 1bd8358..6b4da8a 100644 --- a/app/controllers/inspections.py +++ b/app/controllers/inspections.py @@ -10,20 +10,20 @@ from fertiscan.db.queries.inspection import ( InspectionNotFoundError as DBInspectionNotFoundError, ) +from psycopg_pool import ConnectionPool -from app.connection_manager import ConnectionManager from app.exceptions import InspectionNotFoundError, MissingUserAttributeError, log_error from app.models.inspections import Inspection, InspectionData from app.models.label_data import LabelData from app.models.users import User -async def read_all(cm: ConnectionManager, user: User): +async def read_all(cp: ConnectionPool, user: User): """ Retrieves all inspections associated with a user, both verified and unverified. Args: - cm (ConnectionManager): An instance managing the database connection. + cp (ConnectionPool): The connection pool to manage database connections. user (User): User instance containing user details, including the user ID. Returns: @@ -37,7 +37,7 @@ async def read_all(cm: ConnectionManager, user: User): if not user.id: raise MissingUserAttributeError("User ID is required for fetching inspections.") - with cm, cm.get_cursor() as cursor: + with cp.connection() as conn, conn.cursor() as cursor: inspections = await asyncio.gather( get_user_analysis_by_verified(cursor, user.id, True), get_user_analysis_by_verified(cursor, user.id, False), @@ -64,12 +64,12 @@ async def read_all(cm: ConnectionManager, user: User): return inspections -async def read(cm: ConnectionManager, user: User, id: UUID | str): +async def read(cp: ConnectionPool, user: User, id: UUID | str): """ Retrieves a specific inspection associated with a user by inspection ID. Args: - cm (ConnectionManager): An instance managing the database connection. + cp (ConnectionPool): The connection pool to manage database connections. user (User): User instance containing user details, including the user ID. id (UUID | str): Unique identifier of the inspection, as a UUID or a string convertible to UUID. @@ -90,7 +90,7 @@ async def read(cm: ConnectionManager, user: User, id: UUID | str): if not isinstance(id, UUID): id = UUID(id) - with cm, cm.get_cursor() as cursor: + with cp.connection() as conn, conn.cursor() as cursor: try: inspection = await get_full_inspection_json(cursor, id, user.id) except DBInspectionNotFoundError as e: @@ -100,7 +100,7 @@ async def read(cm: ConnectionManager, user: User, id: UUID | str): async def create( - cm: ConnectionManager, + cp: ConnectionPool, user: User, label_data: LabelData, label_images: list[bytes], @@ -110,7 +110,7 @@ async def create( Creates a new inspection record associated with a user. Args: - cm (ConnectionManager): An instance managing the database connection. + cp (ConnectionPool): The connection pool to manage database connections. user (User): User instance containing user details, including the user ID. label_data (LabelData): Data model containing label information required for the inspection. label_images (list[bytes]): List of images (in byte format) to be associated with the inspection. @@ -125,7 +125,7 @@ async def create( if not user.id: raise MissingUserAttributeError("User ID is required for creating inspections.") - with cm, cm.get_cursor() as cursor: + with cp.connection() as conn, conn.cursor() as cursor: container_client = ContainerClient.from_connection_string( connection_string, container_name=f"user-{user.id}" ) diff --git a/app/controllers/users.py b/app/controllers/users.py index 4df0079..2ccb5a9 100644 --- a/app/controllers/users.py +++ b/app/controllers/users.py @@ -2,8 +2,8 @@ from datastore import get_user, new_user from datastore.db.queries.user import UserNotFoundError as DBUserNotFoundError from fastapi.logger import logger +from psycopg_pool import ConnectionPool -from app.connection_manager import ConnectionManager from app.exceptions import ( MissingUserAttributeError, UserConflictError, @@ -13,14 +13,14 @@ from app.models.users import User -async def sign_up(cm: ConnectionManager, user: User, connection_string: str) -> User: +async def sign_up(cp: ConnectionPool, user: User, connection_string: str) -> User: """ Registers a new user in the system. Args: - cm (ConnectionManager): An instance managing the database connection. + cp (ConnectionPool): The connection pool to manage database connections. user (User): The User instance containing the user's details. - connection_string (str): Connection string for database setup. + connection_string (str): The database connection string for setup. Raises: MissingUserAttributeError: Raised if the username is not provided. @@ -33,7 +33,7 @@ async def sign_up(cm: ConnectionManager, user: User, connection_string: str) -> raise MissingUserAttributeError("Username is required for sign-up.") try: - with cm, cm.get_cursor() as cursor: + with cp.connection() as conn, conn.cursor() as cursor: logger.debug(f"Creating user: {user.username}") user_db = await new_user(cursor, user.username, connection_string) except DBUserAlreadyExistsError as e: @@ -43,12 +43,12 @@ async def sign_up(cm: ConnectionManager, user: User, connection_string: str) -> return user.model_copy(update={"id": user_db.id}) -async def sign_in(cm: ConnectionManager, user: User) -> User: +async def sign_in(cp: ConnectionPool, user: User) -> User: """ Authenticates an existing user in the system. Args: - cm (ConnectionManager): An instance managing the database connection. + cp (ConnectionPool): The connection pool to manage database connections. user (User): The User instance containing the user's details. Raises: @@ -62,7 +62,7 @@ async def sign_in(cm: ConnectionManager, user: User) -> User: raise MissingUserAttributeError("Username is required for sign-in.") try: - with cm, cm.get_cursor() as cursor: + with cp.connection() as conn, conn.cursor() as cursor: logger.debug(f"Fetching user ID for username: {user.username}") user_db = await get_user(cursor, user.username) except DBUserNotFoundError as e: diff --git a/app/dependencies.py b/app/dependencies.py index 9e2e1b2..f82e0da 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -4,9 +4,10 @@ from fastapi import Depends, File, HTTPException, Request, UploadFile from fastapi.security import HTTPBasic, HTTPBasicCredentials from pipeline import GPT, OCR +from psycopg_pool import ConnectionPool from app.config import Settings -from app.connection_manager import ConnectionManager + from app.controllers.users import sign_in from app.exceptions import UserNotFoundError from app.models.users import User @@ -19,11 +20,11 @@ def get_settings(): return Settings() -def get_connection_manager(request: Request) -> ConnectionManager: +def get_connection_pool(request: Request) -> ConnectionPool: """ - Returns the app's ConnectionManager instance. + Returns the app's connection pool. """ - return request.app.connection_manager + return request.app.pool def get_ocr(request: Request) -> OCR: @@ -54,13 +55,13 @@ def authenticate_user(credentials: HTTPBasicCredentials = Depends(auth)): async def fetch_user( auth_user: User = Depends(authenticate_user), - cm: ConnectionManager = Depends(get_connection_manager), + cp: ConnectionPool = Depends(get_connection_pool), ) -> User: """ Fetches the authenticated user's info from db. """ try: - return await sign_in(cm, auth_user) + return await sign_in(cp, auth_user) except UserNotFoundError: raise HTTPException( status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid username or password" diff --git a/app/main.py b/app/main.py index 22f0a8a..6b06eb6 100644 --- a/app/main.py +++ b/app/main.py @@ -5,17 +5,17 @@ from fastapi.concurrency import asynccontextmanager from fastapi.responses import JSONResponse from pipeline import GPT, OCR +from psycopg_pool import ConnectionPool from pydantic import UUID4 from app.config import Settings, configure -from app.connection_manager import ConnectionManager from app.controllers.data_extraction import extract_data from app.controllers.inspections import create, read, read_all from app.controllers.users import sign_up from app.dependencies import ( authenticate_user, fetch_user, - get_connection_manager, + get_connection_pool, get_gpt, get_ocr, get_settings, @@ -32,9 +32,9 @@ @asynccontextmanager async def lifespan(app: FastAPI): app = configure(app, get_settings()) - app.connection_manager.pool.open() + app.pool.open() yield - app.connection_manager.pool.close() + app.pool.close() app = FastAPI(lifespan=lifespan) @@ -64,12 +64,12 @@ async def analyze_document( @app.post("/signup", tags=["Users"], status_code=201, response_model=User) async def signup( - cm: Annotated[ConnectionManager, Depends(get_connection_manager)], + cp: Annotated[ConnectionPool, Depends(get_connection_pool)], user: Annotated[User, Depends(authenticate_user)], settings: Annotated[Settings, Depends(get_settings)], ): try: - return await sign_up(cm, user, settings.fertiscan_storage_url) + return await sign_up(cp, user, settings.fertiscan_storage_url) except UserConflictError: raise HTTPException(status_code=HTTPStatus.CONFLICT, detail="User exists!") @@ -81,20 +81,20 @@ async def login(user: User = Depends(fetch_user)): @app.get("/inspections", tags=["Inspections"], response_model=list[InspectionData]) async def get_inspections( - cm: Annotated[ConnectionManager, Depends(get_connection_manager)], + cp: Annotated[ConnectionPool, Depends(get_connection_pool)], user: User = Depends(fetch_user), ): - return await read_all(cm, user) + return await read_all(cp, user) @app.get("/inspections/{id}", tags=["Inspections"], response_model=Inspection) async def get_inspection( - cm: Annotated[ConnectionManager, Depends(get_connection_manager)], + cp: Annotated[ConnectionPool, Depends(get_connection_pool)], user: Annotated[User, Depends(fetch_user)], id: UUID4, ): try: - return await read(cm, user, id) + return await read(cp, user, id) except InspectionNotFoundError: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="Inspection not found" @@ -103,7 +103,7 @@ async def get_inspection( @app.post("/inspections", tags=["Inspections"], response_model=Inspection) async def create_inspection( - cm: Annotated[ConnectionManager, Depends(get_connection_manager)], + cp: Annotated[ConnectionPool, Depends(get_connection_pool)], user: Annotated[User, Depends(fetch_user)], settings: Annotated[Settings, Depends(get_settings)], label_data: Annotated[LabelData, Form(...)], @@ -112,4 +112,4 @@ async def create_inspection( # Note: later on, we might handle label images as their own domain label_images = [await f.read() for f in files] conn_string = settings.fertiscan_storage_url - return await create(cm, user, label_data, label_images, conn_string) + return await create(cp, user, label_data, label_images, conn_string) diff --git a/tests/test_app.py b/tests/test_app.py index d70d50b..7a1aca6 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -12,7 +12,7 @@ from app.dependencies import ( authenticate_user, fetch_user, - get_connection_manager, + get_connection_pool, get_gpt, get_ocr, get_settings, @@ -127,7 +127,7 @@ def override_dep(): self.test_user = User(username="test_user", id=uuid.uuid4()) app.dependency_overrides.clear() - app.dependency_overrides[get_connection_manager] = override_dep + app.dependency_overrides[get_connection_pool] = override_dep app.dependency_overrides[authenticate_user] = override_dep app.dependency_overrides[get_settings] = override_dep app.dependency_overrides[fetch_user] = lambda: self.test_user @@ -156,7 +156,7 @@ def test_signup_bad_authentication(self, _): response = self.client.post( "/signup", headers={ - "Authorization": f"Basic {self.credentials(empty_username, "password")}", + "Authorization": f'Basic {self.credentials(empty_username, "password")}', }, ) self.assertEqual(response.status_code, 400) @@ -168,7 +168,7 @@ def test_signup_authentication_success(self, mock_sign_up): response = self.client.post( "/signup", headers={ - "Authorization": f"Basic {self.credentials("test_user", "password")}", + "Authorization": f'Basic {self.credentials("test_user", "password")}', }, ) self.assertEqual(response.status_code, 201) @@ -222,13 +222,10 @@ class TestAPIInspections(unittest.TestCase): def setUp(self) -> None: self.client = TestClient(app) - def override_dep(): - return Mock() - self.test_user = User(username="test_user", id=uuid.uuid4()) app.dependency_overrides.clear() - app.dependency_overrides[get_connection_manager] = override_dep + app.dependency_overrides[get_connection_pool] = lambda: Mock() app.dependency_overrides[fetch_user] = lambda: self.test_user self.mock_inspection_data = [ diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py deleted file mode 100644 index c4425c5..0000000 --- a/tests/test_connection_manager.py +++ /dev/null @@ -1,122 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from psycopg import Connection -from psycopg_pool import ConnectionPool - -from app.connection_manager import ConnectionManager - - -class TestConnectionManager(unittest.TestCase): - def setUp(self): - self.pool = MagicMock(spec=ConnectionPool) - self.connection = MagicMock(spec=Connection) - self.pool.getconn.return_value = self.connection - self.conn_manager = ConnectionManager(self.pool) - - def tearDown(self): - self.conn_manager = None - - def test_init_with_none(self): - """Test that initializing ConnectionManager with None raises ValueError.""" - with self.assertRaises(ValueError) as context: - ConnectionManager(None) - self.assertEqual(str(context.exception), "A connection pool is required.") - - def test_init_with_invalid_type(self): - """Test that initializing ConnectionManager with invalid type raises ValueError.""" - with self.assertRaises(ValueError) as context: - ConnectionManager("invalid_type") - self.assertEqual(str(context.exception), "A connection pool is required.") - - def test_get_connection(self): - """Test that the 'get' method retrieves a connection from the pool.""" - conn = self.conn_manager.get() - # Assert that the connection returned is the one from the pool - self.assertEqual(conn, self.connection) - # Ensure that getconn() was called once to retrieve the connection - self.pool.getconn.assert_called_once() - - # Call get() again to ensure it does not call getconn() again - conn_again = self.conn_manager.get() - # Assert that the same connection is returned - self.assertEqual(conn_again, self.connection) - # Ensure getconn() is not called again, confirming reuse of the existing connection - self.pool.getconn.assert_called_once() - - def test_get_cursor(self): - """Test that the 'get_cursor' method retrieves a cursor from the active connection.""" - cursor = MagicMock() - self.connection.cursor.return_value = cursor - result_cursor = self.conn_manager.get_cursor() - # Verify that the cursor returned by get_cursor is the same as the mock cursor - self.assertEqual(result_cursor, cursor) - # Ensure that the connection's cursor method was called exactly once - self.connection.cursor.assert_called_once() - # Verify that the connection was obtained from the pool exactly once - self.pool.getconn.assert_called_once() - - def test_put_connection_not_testing(self): - """Test that the 'put' method releases the connection when not in testing mode.""" - self.conn_manager.get() - self.conn_manager.testing = False - self.conn_manager.put() - # Assert that the connection was released back to the pool - self.pool.putconn.assert_called_once_with(self.connection) - # Assert that the connection is set to None after releasing - self.assertIsNone(self.conn_manager.connection) - - def test_put_connection_in_testing(self): - """Test that the 'put' method does not release the connection when in testing mode.""" - self.conn_manager.get() - self.conn_manager.testing = True - self.conn_manager.put() - # Assert that the connection was not released back to the pool - self.pool.putconn.assert_not_called() - # Assert that the connection is still active - self.assertIsNotNone(self.conn_manager.connection) - - def test_commit_not_testing(self): - """Test that the 'commit' method commits the transaction when not in testing mode.""" - self.conn_manager.get() - self.conn_manager.testing = False - self.conn_manager.commit() - # Assert that the commit was called on the connection - self.connection.commit.assert_called_once() - - def test_commit_in_testing(self): - """Test that the 'commit' method does not commit the transaction when in testing mode.""" - self.conn_manager.get() - self.conn_manager.testing = True - self.conn_manager.commit() - # Assert that the commit was not called on the connection - self.connection.commit.assert_not_called() - - def test_rollback(self): - """Test that the 'rollback' method rolls back the current transaction.""" - self.conn_manager.get() # Activate a connection - self.conn_manager.rollback() - self.connection.rollback.assert_called_once() - - @patch.object(ConnectionManager, "commit") - @patch.object(ConnectionManager, "put") - def test_context_manager_exit_success(self, mock_put, mock_commit): - """Test that the context manager commits the transaction and calls put() on successful exit.""" - with self.conn_manager as _: - pass - # After exiting, commit() should be called - mock_commit.assert_called_once() - # Ensure put() was called - mock_put.assert_called_once() - - @patch.object(ConnectionManager, "rollback") - @patch.object(ConnectionManager, "put") - def test_context_manager_exit_failure(self, mock_put, mock_rollback): - """Test that the context manager rolls back the transaction and calls put() on exception.""" - with self.assertRaises(Exception): - with self.conn_manager as _: - raise Exception("Test Exception") - # After exiting, rollback() should be called - mock_rollback.assert_called_once() - # Ensure put() was called - mock_put.assert_called_once() diff --git a/tests/test_inspections.py b/tests/test_inspections.py index 5c86f5f..6dd8d6b 100644 --- a/tests/test_inspections.py +++ b/tests/test_inspections.py @@ -2,13 +2,12 @@ import unittest import uuid from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch from fertiscan.db.queries.inspection import ( InspectionNotFoundError as DBInspectionNotFoundError, ) -from app.connection_manager import ConnectionManager from app.controllers.inspections import create, read, read_all from app.exceptions import InspectionNotFoundError, MissingUserAttributeError from app.models.inspections import Inspection @@ -18,17 +17,16 @@ class TestReadAll(unittest.IsolatedAsyncioTestCase): async def test_missing_user_id_raises_error(self): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() user = User(id=None) with self.assertRaises(MissingUserAttributeError): - await read_all(cm, user) + await read_all(cp, user) @patch("app.controllers.inspections.get_user_analysis_by_verified") async def test_calls_analysis_twice_and_combines_results( self, mock_get_user_analysis_by_verified ): - # Set up mock return values for verified and unverified data mock_get_user_analysis_by_verified.side_effect = [ [ ( @@ -60,20 +58,19 @@ async def test_calls_analysis_twice_and_combines_results( ], ] - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() + conn_mock = MagicMock() cursor_mock = MagicMock() - cm.get_cursor.return_value.__enter__.return_value = cursor_mock + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock user = User(id=uuid.uuid4()) - # Execute the function - inspections = await read_all(cm, user) + inspections = await read_all(cp, user) - # Check that get_user_analysis_by_verified was called twice self.assertEqual(mock_get_user_analysis_by_verified.call_count, 2) mock_get_user_analysis_by_verified.assert_any_call(cursor_mock, user.id, True) mock_get_user_analysis_by_verified.assert_any_call(cursor_mock, user.id, False) - # Verify that the result is a list of InspectionData objects with correct length and data self.assertEqual(len(inspections), 2) self.assertEqual(inspections[0].product_name, "Product A") self.assertEqual(inspections[1].product_name, "Product B") @@ -82,53 +79,52 @@ async def test_calls_analysis_twice_and_combines_results( async def test_no_inspections_for_verified_and_unverified( self, mock_get_user_analysis_by_verified ): - # Mock empty results for both verified and unverified inspections mock_get_user_analysis_by_verified.side_effect = [[], []] - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() + conn_mock = MagicMock() cursor_mock = MagicMock() - cm.get_cursor.return_value.__enter__.return_value = cursor_mock + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock user = User(id=uuid.uuid4()) - # Execute the function - inspections = await read_all(cm, user) - - # Verify that an empty list is returned when there are no inspections + inspections = await read_all(cp, user) self.assertEqual(len(inspections), 0) class TestRead(unittest.IsolatedAsyncioTestCase): async def test_missing_user_id_raises_error(self): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() user = User(id=None) inspection_id = uuid.uuid4() with self.assertRaises(MissingUserAttributeError): - await read(cm, user, inspection_id) + await read(cp, user, inspection_id) async def test_missing_inspection_id_raises_error(self): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() user = User(id=uuid.uuid4()) with self.assertRaises(ValueError): - await read(cm, user, None) + await read(cp, user, None) - @patch("app.controllers.inspections.get_full_inspection_json") - async def test_invalid_inspection_id_format(self, mock_get_full_inspection_json): - cm = AsyncMock(spec=ConnectionManager) + async def test_invalid_inspection_id_format(self): + cp = MagicMock() user = User(id=uuid.uuid4()) invalid_id = "not-a-uuid" with self.assertRaises(ValueError): - await read(cm, user, invalid_id) + await read(cp, user, invalid_id) @patch("app.controllers.inspections.get_full_inspection_json") async def test_valid_inspection_id_calls_get_full_inspection_json( self, mock_get_full_inspection_json ): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() + conn_mock = MagicMock() cursor_mock = MagicMock() - cm.get_cursor.return_value.__enter__.return_value = cursor_mock + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock user = User(id=uuid.uuid4()) inspection_id = uuid.uuid4() @@ -167,55 +163,54 @@ async def test_valid_inspection_id_calls_get_full_inspection_json( mock_get_full_inspection_json.return_value = json.dumps(sample_inspection) - # Execute the function - inspection = await read(cm, user, inspection_id) + inspection = await read(cp, user, inspection_id) - # Verify that get_full_inspection_json was called with the correct arguments mock_get_full_inspection_json.assert_called_once_with( cursor_mock, inspection_id, user.id ) - - # Check that the result is an Inspection object self.assertIsInstance(inspection, Inspection) @patch("app.controllers.inspections.get_full_inspection_json") async def test_inspection_not_found_raises_error( self, mock_get_full_inspection_json ): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() + conn_mock = MagicMock() cursor_mock = MagicMock() - cm.get_cursor.return_value.__enter__.return_value = cursor_mock + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock user = User(id=uuid.uuid4()) inspection_id = uuid.uuid4() - # Set up the mock to raise the InspectionNotFoundError mock_get_full_inspection_json.side_effect = DBInspectionNotFoundError( "Not found" ) with self.assertRaises(InspectionNotFoundError): - await read(cm, user, inspection_id) + await read(cp, user, inspection_id) class TestCreateFunction(unittest.IsolatedAsyncioTestCase): async def test_missing_user_id_raises_error(self): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() label_data = LabelData() label_images = [b"image_data"] user = User(id=None) with self.assertRaises(MissingUserAttributeError): - await create(cm, user, label_data, label_images, "fake_conn_str") + await create(cp, user, label_data, label_images, "fake_conn_str") @patch("app.controllers.inspections.register_analysis") @patch("app.controllers.inspections.ContainerClient") async def test_create_inspection_success( self, mock_container_client, mock_register_analysis ): - cm = AsyncMock(spec=ConnectionManager) + cp = MagicMock() + conn_mock = MagicMock() cursor_mock = MagicMock() - cm.get_cursor.return_value.__enter__.return_value = cursor_mock + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock user = User(id=uuid.uuid4()) label_data = LabelData() label_images = [b"image_data"] @@ -254,15 +249,12 @@ async def test_create_inspection_success( } mock_register_analysis.return_value = mock_inspection_data - # Instantiate the mock container client container_client_instance = ( mock_container_client.from_connection_string.return_value ) - # Call create function - inspection = await create(cm, user, label_data, label_images, fake_conn_str) + inspection = await create(cp, user, label_data, label_images, fake_conn_str) - # Assertions mock_container_client.from_connection_string.assert_called_once_with( fake_conn_str, container_name=f"user-{user.id}" ) diff --git a/tests/test_users_controller.py b/tests/test_users_controller.py index f7c6a5c..b75a913 100644 --- a/tests/test_users_controller.py +++ b/tests/test_users_controller.py @@ -4,7 +4,6 @@ from datastore import UserAlreadyExistsError as DBUserAlreadyExistsError from datastore.db.queries.user import UserNotFoundError as DBUserNotFoundError -from app.connection_manager import ConnectionManager from app.controllers.users import sign_in, sign_up from app.exceptions import ( MissingUserAttributeError, @@ -16,70 +15,78 @@ class TestSignUpSuccess(unittest.IsolatedAsyncioTestCase): async def test_successful_user_sign_up(self): - mock_cm = MagicMock(spec=ConnectionManager) - mock_cursor = MagicMock() - mock_cm.get_cursor.return_value.__enter__.return_value = mock_cursor + cp = MagicMock() + conn_mock = MagicMock() + cursor_mock = MagicMock() + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock mock_user = User(username="test_user") mock_new_user = AsyncMock(return_value=MagicMock(id=1)) mock_storage_url = "mocked_storage_url" with patch("app.controllers.users.new_user", mock_new_user): - result = await sign_up(mock_cm, mock_user, mock_storage_url) + result = await sign_up(cp, mock_user, mock_storage_url) - mock_cm.get_cursor.assert_called_once() + cp.connection.assert_called_once() mock_new_user.assert_awaited_once_with( - mock_cursor, "test_user", mock_storage_url + cursor_mock, "test_user", mock_storage_url ) self.assertEqual(result.id, 1) self.assertEqual(result.username, "test_user") async def test_sign_up_missing_username(self): - mock_cm = MagicMock(spec=ConnectionManager) + cp = MagicMock() mock_user = User(username="") with self.assertRaises(MissingUserAttributeError): - await sign_up(mock_cm, mock_user, "mock_storage_url") + await sign_up(cp, mock_user, "mock_storage_url") async def test_sign_up_user_already_exists(self): - mock_cm = MagicMock(spec=ConnectionManager) - mock_cursor = MagicMock() - mock_cm.get_cursor.return_value.__enter__.return_value = mock_cursor + cp = MagicMock() + conn_mock = MagicMock() + cursor_mock = MagicMock() + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock mock_user = User(username="existing_user") mock_new_user = AsyncMock(side_effect=DBUserAlreadyExistsError) with patch("app.controllers.users.new_user", mock_new_user): with self.assertRaises(UserConflictError): - await sign_up(mock_cm, mock_user, "mock_storage_url") + await sign_up(cp, mock_user, "mock_storage_url") async def test_successful_user_sign_in(self): - mock_cm = MagicMock(spec=ConnectionManager) - mock_cursor = MagicMock() - mock_cm.get_cursor.return_value.__enter__.return_value = mock_cursor + cp = MagicMock() + conn_mock = MagicMock() + cursor_mock = MagicMock() + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock mock_user = User(username="test_user") mock_get_user = AsyncMock(return_value=MagicMock(id=1)) with patch("app.controllers.users.get_user", mock_get_user): - result = await sign_in(mock_cm, mock_user) + result = await sign_in(cp, mock_user) - mock_cm.get_cursor.assert_called_once() - mock_get_user.assert_awaited_once_with(mock_cursor, "test_user") + cp.connection.assert_called_once() + mock_get_user.assert_awaited_once_with(cursor_mock, "test_user") self.assertEqual(result.id, 1) self.assertEqual(result.username, "test_user") async def test_sign_in_missing_username(self): - mock_cm = MagicMock(spec=ConnectionManager) + cp = MagicMock() mock_user = User(username="") with self.assertRaises(MissingUserAttributeError): - await sign_in(mock_cm, mock_user) + await sign_in(cp, mock_user) async def test_sign_in_user_not_found(self): - mock_cm = MagicMock(spec=ConnectionManager) - mock_cursor = MagicMock() - mock_cm.get_cursor.return_value.__enter__.return_value = mock_cursor + cp = MagicMock() + conn_mock = MagicMock() + cursor_mock = MagicMock() + conn_mock.cursor.return_value.__enter__.return_value = cursor_mock + cp.connection.return_value.__enter__.return_value = conn_mock mock_user = User(username="non_existent_user") mock_get_user = AsyncMock(side_effect=DBUserNotFoundError) with patch("app.controllers.users.get_user", mock_get_user): with self.assertRaises(UserNotFoundError): - await sign_in(mock_cm, mock_user) + await sign_in(cp, mock_user)