Skip to content

Commit

Permalink
wip(crud): added exist & exist_n methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jd-solanki committed Dec 17, 2024
1 parent 16dc2fd commit a6f75fa
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
37 changes: 37 additions & 0 deletions examples/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,43 @@ async def get_one_user(
) from e


@app.get("/users/exist")
async def user_exist(
db: Annotated[AsyncSession, Depends(get_db)],
user_id: PositiveInt | None = None,
first_name: str = "",
first_name__contains: str = "",
):
select_statement = select(User)
if user_id:
select_statement = select_statement.where(User.id == user_id)
if first_name:
select_statement = select_statement.where(User.first_name == first_name)
if first_name__contains:
select_statement = select_statement.where(User.first_name.contains(first_name__contains))

return await user_crud.exist(db, select_statement=lambda _: select_statement)


@app.get("/users/exist_n")
async def user_exist_n(
db: Annotated[AsyncSession, Depends(get_db)],
n: int,
user_id: PositiveInt | None = None,
first_name: str = "",
first_name__contains: str = "",
):
select_statement = select(User)
if user_id:
select_statement = select_statement.where(User.id == user_id)
if first_name:
select_statement = select_statement.where(User.first_name == first_name)
if first_name__contains:
select_statement = select_statement.where(User.first_name.contains(first_name__contains))

return await user_crud.exist_n(db, select_statement=lambda _: select_statement, n=n)


@app.get("/users/{user_id}")
async def get_user(user_id: PositiveInt, db: Annotated[AsyncSession, Depends(get_db)]):
return await user_crud.get_or_404(db, user_id)
Expand Down
76 changes: 75 additions & 1 deletion src/fastapi_batteries/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import status
from pydantic import BaseModel, RootModel
from sqlalchemy import ScalarResult, Select, delete, func, insert, select
from sqlalchemy import ScalarResult, Select, delete, exists, func, insert, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -202,6 +202,7 @@ async def get_multi(
return records, total
return records

# TODO: Instead of all columns, fetch specific columns
async def get_one_or_none(
self,
db: AsyncSession,
Expand Down Expand Up @@ -299,12 +300,85 @@ async def delete(self, db: AsyncSession, item_id: int, *, commit: bool = True) -

# TODO: Use callable for select_statement like other methods
async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelType]] | None = None) -> int:
"""Count the number of records for given select statement.
TIP: If you just want to know if n records exist, use `exist_n` method.
Using `count` method is not recommended for checking existence.
Args:
db: SQLAlchemy AsyncSession
select_statement: Select statement to count the records
Returns:
Number of records
"""
count_select_from = select_statement.subquery() if select_statement is not None else self.model
count_statement = select(func.count()).select_from(count_select_from)

result = await db.scalars(count_statement)
return result.first() or 0

async def exist(
self,
db: AsyncSession,
*,
select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s,
):
base_statement = select_statement(select(1))

# Perf: Replace columns with `SELECT 1` to optimize the query
base_statement = base_statement.with_only_columns(1)

exist_statement = select(exists(base_statement))

result = await db.scalar(exist_statement)

# NOTE: We added `or False` to ensure it don't return `None` value from `.scalar()`
return result or False

async def exist_n(
self,
db: AsyncSession,
*,
select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]],
n: int,
) -> bool:
"""Check if exactly n records exist for given select statement.
Args:
db: SQLAlchemy AsyncSession
select_statement: Function to modify select statement (e.g. add where clause)
n: Number of records to check for exact match
Returns:
bool: True if exactly n records exist, False otherwise
Raises:
ValueError: If n is less than 0
"""
if n < 0:
msg = "n must be greater than or equal to 0"
raise ValueError(msg)

# Start with basic SELECT 1 for performance
base_statement = select_statement(select(1))

# Replace columns with SELECT 1 to optimize
base_statement = base_statement.with_only_columns(1)

# Add LIMIT n+1 to optimize by not fetching all records
# We fetch n+1 to check if more than n records exist
base_statement = base_statement.limit(n + 1)

# Get all records up to n+1
result = await db.scalars(base_statement)
records = result.all()

# Compare length to check exact match
return len(records) == n

async def upsert(
self,
db: AsyncSession,
Expand Down

0 comments on commit a6f75fa

Please sign in to comment.