diff --git a/examples/crud.py b/examples/crud.py index b3ab97a..baa5013 100644 --- a/examples/crud.py +++ b/examples/crud.py @@ -10,8 +10,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column from fastapi_batteries.crud import CRUD -from fastapi_batteries.fastapi.exceptions import get_api_exception_handler -from fastapi_batteries.fastapi.exceptions.api_exception import APIException, get_api_exception_handler +from fastapi_batteries.fastapi.exceptions import APIException, get_api_exception_handler from fastapi_batteries.fastapi.middlewares import QueryCountMiddleware from fastapi_batteries.pydantic.schemas import Paginated, PaginationOffsetLimit from fastapi_batteries.sa.mixins import MixinId @@ -63,7 +62,6 @@ class UserPatch(UserBasePartial): ... class UserRead(UserBase): id: PositiveInt - is_active: bool # --- FastAPI @@ -191,7 +189,11 @@ async def get_one_user( select_statement = select_statement.where(User.first_name.contains(first_name__contains)) try: - return await user_crud.get_one_or_none(db, select_statement=lambda _: select_statement) + return await user_crud.get_one_or_404( + db, + select_statement=lambda _: select_statement, + msg_multiple_results_exc="Multiple users found", + ) except MultipleResultsFound as e: raise APIException( title="Multiple results found", @@ -199,6 +201,25 @@ async def get_one_user( ) from e +@app.get("/users/one/with-first-name-and-is-active", response_model=UserRead) +async def get_user_with_cols( + db: Annotated[AsyncSession, Depends(get_db)], + first_name: str = "", + first_name__contains: str = "", +): + select_statement = select(User.first_name, User.id) + if first_name: + select_statement = select_statement.where(User.first_name == first_name) + if first_name__contains: + select_statement = select_statement.where(User.first_name.contains(first_name__contains)) + + return await user_crud.get_one_for_cols( + db, + select_statement=select_statement, + as_mappings=True, + ) + + @app.get("/users/exist") async def user_exist( db: Annotated[AsyncSession, Depends(get_db)], diff --git a/src/fastapi_batteries/crud/__init__.py b/src/fastapi_batteries/crud/__init__.py index 8cb8b94..7de2052 100644 --- a/src/fastapi_batteries/crud/__init__.py +++ b/src/fastapi_batteries/crud/__init__.py @@ -1,11 +1,11 @@ -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from contextlib import suppress from logging import Logger from typing import Any, Literal, overload from fastapi import status from pydantic import BaseModel, RootModel -from sqlalchemy import RowMapping, ScalarResult, Select, delete, exists, func, insert, select +from sqlalchemy import MappingResult, Row, RowMapping, ScalarResult, Select, delete, exists, func, insert, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.asyncio import AsyncSession @@ -329,8 +329,7 @@ async def get_multi_for_cols[*T]( return records, total return records - # TODO: Instead of all columns, fetch specific columns - async def get_one_or_none( + async def get_one( self, db: AsyncSession, *, @@ -359,6 +358,72 @@ async def get_one_or_none( if not suppress_multiple_result_exc: raise + async def get_one_or_404( + self, + db: AsyncSession, + *, + select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s, + msg_404: str | None = None, + msg_multiple_results_exc: str, + ) -> ModelType: + try: + if result := await self.get_one(db, select_statement=select_statement): + return result + except MultipleResultsFound as e: + raise APIException( + status=status.HTTP_400_BAD_REQUEST, + title=msg_multiple_results_exc, + ) from e + + raise APIException( + status=status.HTTP_404_NOT_FOUND, + title=msg_404 or self.err_messages[404], + ) + + """ + - `as_mappings` is False + """ + + @overload + async def get_one_for_cols[*T]( + self, + db: AsyncSession, + *, + select_statement: Select[tuple[*T]], + suppress_multiple_result_exc: bool = False, + as_mappings: Literal[False] = False, + ) -> tuple[*T] | None: ... + + """ + - `as_mappings` is True + """ + + @overload + async def get_one_for_cols[*T]( + self, + db: AsyncSession, + *, + select_statement: Select[tuple[*T]], + suppress_multiple_result_exc: bool = False, + as_mappings: Literal[True], + ) -> RowMapping | None: ... + + async def get_one_for_cols[*T]( + self, + db: AsyncSession, + *, + select_statement: Select[tuple[*T]], + suppress_multiple_result_exc: bool = False, + as_mappings: bool = False, + ) -> tuple[*T] | RowMapping | None: + result = await db.execute(select_statement) + + try: + return result.mappings().one_or_none() if as_mappings else result.tuples().one_or_none() + except MultipleResultsFound: + if not suppress_multiple_result_exc: + raise + # TODO: Can we fetch TypedDict from SchemaPatch? Using `dict[str, Any]` is not good. async def patch( self,