-
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support beanie backend for Mongo
- Loading branch information
1 parent
6a04bfb
commit 648f24b
Showing
10 changed files
with
982 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import asyncio | ||
import logging | ||
from collections.abc import AsyncGenerator | ||
from contextlib import asynccontextmanager | ||
from typing import Any, Optional | ||
|
||
import click | ||
import uvicorn | ||
from beanie import Document, Link, PydanticObjectId, init_beanie | ||
from beanie.odm.fields import WriteRules | ||
from faker import Faker | ||
from fastapi import FastAPI, Query | ||
from motor.motor_asyncio import AsyncIOMotorClient | ||
from pydantic import BaseModel, ConfigDict, EmailStr, Field | ||
|
||
from fastapi_filter import FilterDepends, with_prefix | ||
from fastapi_filter.contrib.beanie import Filter | ||
|
||
fake = Faker() | ||
|
||
logger = logging.getLogger("uvicorn") | ||
|
||
|
||
class Address(Document): | ||
street: str | ||
city: str | ||
country: str | ||
|
||
|
||
class User(Document): | ||
name: str | ||
email: EmailStr | ||
age: int | ||
address: Link[Address] | ||
|
||
|
||
class AddressOut(BaseModel): | ||
id: PydanticObjectId = Field(alias="_id", description="MongoDB document ObjectID") | ||
street: str | ||
city: str | ||
country: str | ||
|
||
class Config: | ||
orm_mode = True | ||
|
||
|
||
class UserIn(BaseModel): | ||
name: str | ||
email: EmailStr | ||
age: int | ||
|
||
|
||
class UserOut(UserIn): | ||
model_config = ConfigDict(from_attributes=True) | ||
|
||
id: PydanticObjectId = Field(alias="_id", description="MongoDB document ObjectID") | ||
name: str | ||
email: EmailStr | ||
age: int | ||
address: Optional[AddressOut] = None | ||
|
||
|
||
class AddressFilter(Filter): | ||
street: Optional[str] = None | ||
country: Optional[str] = None | ||
city: Optional[str] = None | ||
city__in: Optional[list[str]] = None | ||
custom_order_by: Optional[list[str]] = None | ||
custom_search: Optional[str] = None | ||
|
||
class Constants(Filter.Constants): | ||
model = Address | ||
ordering_field_name = "custom_order_by" | ||
search_field_name = "custom_search" | ||
search_model_fields = ["street", "country", "city"] | ||
|
||
|
||
class UserFilter(Filter): | ||
name: Optional[str] = None | ||
address: Optional[AddressFilter] = FilterDepends(with_prefix("address", AddressFilter)) | ||
age__lt: Optional[int] = None | ||
age__gte: int = Field(Query(description="this is a nice description")) | ||
"""Required field with a custom description. | ||
See: https://github.com/tiangolo/fastapi/issues/4700 for why we need to wrap `Query` in `Field`. | ||
""" | ||
order_by: list[str] = ["age"] | ||
search: Optional[str] = None | ||
|
||
class Constants(Filter.Constants): | ||
model = User | ||
search_model_fields = ["name"] | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | ||
message = "Open http://127.0.0.1:8000/docs to start exploring 🎒 🧭 🗺️" | ||
color_message = "Open " + click.style("http://127.0.0.1:8000/docs", bold=True) + " to start exploring 🎒 🧭 🗺️" | ||
logger.info(message, extra={"color_message": color_message}) | ||
|
||
client: AsyncIOMotorClient = AsyncIOMotorClient("mongodb://localhost:27017/fastapi_filter") | ||
# https://github.com/tiangolo/fastapi/issues/3855#issuecomment-1013148113 | ||
client.get_io_loop = asyncio.get_event_loop # type: ignore[method-assign] | ||
db = client.fastapi_filter | ||
await init_beanie(database=db, document_models=[Address, User]) | ||
|
||
for _ in range(100): | ||
address = Address(street=fake.street_address(), city=fake.city(), country=fake.country()) | ||
await address.save() | ||
user = User(name=fake.name(), email=fake.email(), age=fake.random_int(min=5, max=120), address=address) | ||
await user.save(link_rule=WriteRules.WRITE) | ||
|
||
yield | ||
|
||
Address.find_all().delete() | ||
User.find_all().delete() | ||
client.close() | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
|
||
@app.get("/users", response_model=list[UserOut]) | ||
async def get_users(user_filter: UserFilter = FilterDepends(UserFilter)) -> Any: | ||
query = user_filter.filter(User.find({})) | ||
query = user_filter.sort(query) | ||
query = query.find(fetch_links=True) | ||
return await query.project(UserOut).to_list() | ||
|
||
|
||
@app.get("/addresses", response_model=list[AddressOut]) | ||
async def get_addresses( | ||
address_filter: AddressFilter = FilterDepends(with_prefix("my_custom_prefix", AddressFilter), by_alias=True), | ||
) -> Any: | ||
query = address_filter.filter(Address.find({})) | ||
query = address_filter.sort(query) | ||
return await query.project(AddressOut).to_list() | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run("fastapi_filter_beanie:app", reload=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .filter import Filter | ||
|
||
__all__ = ("Filter",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from collections.abc import Callable, Mapping | ||
from typing import Any | ||
|
||
from beanie.odm.interfaces.find import FindType | ||
from beanie.odm.queries.find import FindMany | ||
from pydantic import ValidationInfo, field_validator | ||
|
||
from fastapi_filter.base.filter import BaseFilterModel | ||
|
||
_odm_operator_transformer: dict[str, Callable[[str | None], dict[str, Any] | None]] = { | ||
"neq": lambda value: {"$ne": value}, | ||
"gt": lambda value: {"$gt": value}, | ||
"gte": lambda value: {"$gte": value}, | ||
"in": lambda value: {"$in": value}, | ||
"isnull": lambda value: None if value else {"$ne": None}, | ||
"lt": lambda value: {"$lt": value}, | ||
"lte": lambda value: {"$lte": value}, | ||
"not": lambda value: {"$ne": value}, | ||
"ne": lambda value: {"$ne": value}, | ||
"not_in": lambda value: {"$nin": value}, | ||
"nin": lambda value: {"$nin": value}, | ||
"like": lambda value: {"$regex": f".*{value}.*"}, | ||
"ilike": lambda value: {"$regex": f".*{value}.*", "$options": "i"}, | ||
} | ||
|
||
|
||
class Filter(BaseFilterModel): | ||
"""Base filter for beanie related filters. | ||
Example: | ||
```python | ||
class MyModel: | ||
id: PrimaryKey() | ||
name: StringField(null=True) | ||
count: IntField() | ||
created_at: DatetimeField() | ||
class MyModelFilter(Filter): | ||
id: Optional[int] | ||
id__in: Optional[str] | ||
count: Optional[int] | ||
count__lte: Optional[int] | ||
created_at__gt: Optional[datetime] | ||
name__ne: Optional[str] | ||
name__nin: Optional[list[str]] | ||
name__isnull: Optional[bool] | ||
``` | ||
""" | ||
|
||
def sort(self, query: FindMany[FindType]) -> FindMany[FindType]: | ||
if not self.ordering_values: | ||
return query | ||
return query.sort(*self.ordering_values) | ||
|
||
@field_validator("*", mode="before") | ||
@classmethod | ||
def split_str(cls: type["BaseFilterModel"], value: str | None, field: ValidationInfo) -> list[str] | str | None: | ||
if ( | ||
field.field_name is not None | ||
and ( | ||
field.field_name == cls.Constants.ordering_field_name | ||
or field.field_name.endswith("__in") | ||
or field.field_name.endswith("__nin") | ||
) | ||
and isinstance(value, str) | ||
): | ||
if not value: | ||
# Empty string should return [] not [''] | ||
return [] | ||
return list(value.split(",")) | ||
return value | ||
|
||
def _get_filter_conditions(self, nesting_depth: int = 1) -> list[tuple[Mapping[str, Any], Mapping[str, Any]]]: | ||
filter_conditions: list[tuple[Mapping[str, Any], Mapping[str, Any]]] = [] | ||
for field_name, value in self.filtering_fields: | ||
field_value = getattr(self, field_name) | ||
if isinstance(field_value, Filter): | ||
if not field_value.model_dump(exclude_none=True, exclude_unset=True): | ||
continue | ||
|
||
filter_conditions.append( | ||
( | ||
{field_name: _odm_operator_transformer["neq"](None)}, | ||
{"fetch_links": True, "nesting_depth": nesting_depth}, | ||
) | ||
) | ||
for part, part_options in field_value._get_filter_conditions(nesting_depth=nesting_depth + 1): # noqa: SLF001 | ||
for sub_field_name, sub_value in part.items(): | ||
filter_conditions.append( | ||
( | ||
{f"{field_name}.{sub_field_name}": sub_value}, | ||
{"fetch_links": True, "nesting_depth": nesting_depth, **part_options}, | ||
) | ||
) | ||
|
||
elif "__" in field_name: | ||
stripped_field_name, operator = field_name.split("__") | ||
search_criteria = _odm_operator_transformer[operator](value) | ||
filter_conditions.append(({stripped_field_name: search_criteria}, {})) | ||
elif field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"): | ||
search_conditions = [ | ||
{search_field: _odm_operator_transformer["ilike"](value)} | ||
for search_field in self.Constants.search_model_fields | ||
] | ||
filter_conditions.append(({"$or": search_conditions}, {})) | ||
else: | ||
filter_conditions.append(({field_name: value}, {})) | ||
|
||
return filter_conditions | ||
|
||
def filter(self, query: FindMany[FindType]) -> FindMany[FindType]: | ||
data = self._get_filter_conditions() | ||
for filter_condition, filter_kwargs in data: | ||
query = query.find(filter_condition, **filter_kwargs) | ||
return query.find(fetch_links=True) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.