diff --git a/examples/crud.py b/examples/crud.py index 4bc45f2..811db02 100644 --- a/examples/crud.py +++ b/examples/crud.py @@ -171,6 +171,43 @@ async def get_one_user( ) from e +@app.get("/users/exist") +async def user_exist( + db: Annotated[AsyncSession, Depends(get_db)], + user_id: PositiveInt | None = None, + first_name: str = "", + first_name__contains: str = "", +): + select_statement = select(User) + if user_id: + select_statement = select_statement.where(User.id == user_id) + if first_name: + select_statement = select_statement.where(User.first_name == first_name) + if first_name__contains: + select_statement = select_statement.where(User.first_name.contains(first_name__contains)) + + return await user_crud.exist(db, select_statement=lambda _: select_statement) + + +@app.get("/users/exist_n") +async def user_exist_n( + db: Annotated[AsyncSession, Depends(get_db)], + n: int, + user_id: PositiveInt | None = None, + first_name: str = "", + first_name__contains: str = "", +): + select_statement = select(User) + if user_id: + select_statement = select_statement.where(User.id == user_id) + if first_name: + select_statement = select_statement.where(User.first_name == first_name) + if first_name__contains: + select_statement = select_statement.where(User.first_name.contains(first_name__contains)) + + return await user_crud.exist_n(db, select_statement=lambda _: select_statement, n=n) + + @app.get("/users/{user_id}") async def get_user(user_id: PositiveInt, db: Annotated[AsyncSession, Depends(get_db)]): return await user_crud.get_or_404(db, user_id) diff --git a/src/fastapi_batteries/crud/__init__.py b/src/fastapi_batteries/crud/__init__.py index 5463f6e..8202182 100644 --- a/src/fastapi_batteries/crud/__init__.py +++ b/src/fastapi_batteries/crud/__init__.py @@ -5,7 +5,7 @@ from fastapi import status from pydantic import BaseModel, RootModel -from sqlalchemy import ScalarResult, Select, delete, func, insert, select +from sqlalchemy import ScalarResult, Select, delete, exists, func, insert, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.asyncio import AsyncSession @@ -202,6 +202,7 @@ async def get_multi( return records, total return records + # TODO: Instead of all columns, fetch specific columns async def get_one_or_none( self, db: AsyncSession, @@ -299,12 +300,85 @@ async def delete(self, db: AsyncSession, item_id: int, *, commit: bool = True) - # TODO: Use callable for select_statement like other methods async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelType]] | None = None) -> int: + """Count the number of records for given select statement. + + TIP: If you just want to know if n records exist, use `exist_n` method. + Using `count` method is not recommended for checking existence. + + Args: + db: SQLAlchemy AsyncSession + select_statement: Select statement to count the records + + Returns: + Number of records + + """ count_select_from = select_statement.subquery() if select_statement is not None else self.model count_statement = select(func.count()).select_from(count_select_from) result = await db.scalars(count_statement) return result.first() or 0 + async def exist( + self, + db: AsyncSession, + *, + select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s, + ): + base_statement = select_statement(select(1)) + + # Perf: Replace columns with `SELECT 1` to optimize the query + base_statement = base_statement.with_only_columns(1) + + exist_statement = select(exists(base_statement)) + + result = await db.scalar(exist_statement) + + # NOTE: We added `or False` to ensure it don't return `None` value from `.scalar()` + return result or False + + async def exist_n( + self, + db: AsyncSession, + *, + select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]], + n: int, + ) -> bool: + """Check if exactly n records exist for given select statement. + + Args: + db: SQLAlchemy AsyncSession + select_statement: Function to modify select statement (e.g. add where clause) + n: Number of records to check for exact match + + Returns: + bool: True if exactly n records exist, False otherwise + + Raises: + ValueError: If n is less than 0 + + """ + if n < 0: + msg = "n must be greater than or equal to 0" + raise ValueError(msg) + + # Start with basic SELECT 1 for performance + base_statement = select_statement(select(1)) + + # Replace columns with SELECT 1 to optimize + base_statement = base_statement.with_only_columns(1) + + # Add LIMIT n+1 to optimize by not fetching all records + # We fetch n+1 to check if more than n records exist + base_statement = base_statement.limit(n + 1) + + # Get all records up to n+1 + result = await db.scalars(base_statement) + records = result.all() + + # Compare length to check exact match + return len(records) == n + async def upsert( self, db: AsyncSession,