Skip to content

Commit

Permalink
Add support beanie backend for Mongo
Browse files Browse the repository at this point in the history
  • Loading branch information
shepilov-vladislav committed Jun 29, 2024
1 parent 6a04bfb commit 648f24b
Show file tree
Hide file tree
Showing 10 changed files with 982 additions and 5 deletions.
6 changes: 4 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Add querystring filters to your api endpoints and show them in the swagger UI.

The supported backends are [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy) and
[MongoEngine](https://github.com/MongoEngine/mongoengine).
The supported backends are [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy),
[MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie).

## Example

Expand All @@ -20,6 +20,8 @@ as well as the type of operator, then tie your filter to a specific model.

[MongoEngine](https://github.com/arthurio/fastapi-filter/blob/main/examples/fastapi_filter_mongoengine.py)

[beanie](https://github.com/arthurio/fastapi-filter/blob/main/examples/fastapi_filter_beanie.py)

### Operators

By default, **fastapi_filter** supports the following operators:
Expand Down
141 changes: 141 additions & 0 deletions examples/fastapi_filter_beanie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import asyncio
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any, Optional

import click
import uvicorn
from beanie import Document, Link, PydanticObjectId, init_beanie
from beanie.odm.fields import WriteRules
from faker import Faker
from fastapi import FastAPI, Query
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseModel, ConfigDict, EmailStr, Field

from fastapi_filter import FilterDepends, with_prefix
from fastapi_filter.contrib.beanie import Filter

fake = Faker()

logger = logging.getLogger("uvicorn")


class Address(Document):
street: str
city: str
country: str


class User(Document):
name: str
email: EmailStr
age: int
address: Link[Address]


class AddressOut(BaseModel):
id: PydanticObjectId = Field(alias="_id", description="MongoDB document ObjectID")
street: str
city: str
country: str

class Config:
orm_mode = True


class UserIn(BaseModel):
name: str
email: EmailStr
age: int


class UserOut(UserIn):
model_config = ConfigDict(from_attributes=True)

id: PydanticObjectId = Field(alias="_id", description="MongoDB document ObjectID")
name: str
email: EmailStr
age: int
address: Optional[AddressOut] = None


class AddressFilter(Filter):
street: Optional[str] = None
country: Optional[str] = None
city: Optional[str] = None
city__in: Optional[list[str]] = None
custom_order_by: Optional[list[str]] = None
custom_search: Optional[str] = None

class Constants(Filter.Constants):
model = Address
ordering_field_name = "custom_order_by"
search_field_name = "custom_search"
search_model_fields = ["street", "country", "city"]


class UserFilter(Filter):
name: Optional[str] = None
address: Optional[AddressFilter] = FilterDepends(with_prefix("address", AddressFilter))
age__lt: Optional[int] = None
age__gte: int = Field(Query(description="this is a nice description"))
"""Required field with a custom description.
See: https://github.com/tiangolo/fastapi/issues/4700 for why we need to wrap `Query` in `Field`.
"""
order_by: list[str] = ["age"]
search: Optional[str] = None

class Constants(Filter.Constants):
model = User
search_model_fields = ["name"]


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
message = "Open http://127.0.0.1:8000/docs to start exploring 🎒 🧭 🗺️"
color_message = "Open " + click.style("http://127.0.0.1:8000/docs", bold=True) + " to start exploring 🎒 🧭 🗺️"
logger.info(message, extra={"color_message": color_message})

client: AsyncIOMotorClient = AsyncIOMotorClient("mongodb://localhost:27017/fastapi_filter")
# https://github.com/tiangolo/fastapi/issues/3855#issuecomment-1013148113
client.get_io_loop = asyncio.get_event_loop # type: ignore[method-assign]
db = client.fastapi_filter
await init_beanie(database=db, document_models=[Address, User])

for _ in range(100):
address = Address(street=fake.street_address(), city=fake.city(), country=fake.country())
await address.save()
user = User(name=fake.name(), email=fake.email(), age=fake.random_int(min=5, max=120), address=address)
await user.save(link_rule=WriteRules.WRITE)

yield

Address.find_all().delete()
User.find_all().delete()
client.close()


app = FastAPI(lifespan=lifespan)


@app.get("/users", response_model=list[UserOut])
async def get_users(user_filter: UserFilter = FilterDepends(UserFilter)) -> Any:
query = user_filter.filter(User.find({}))
query = user_filter.sort(query)
query = query.find(fetch_links=True)
return await query.project(UserOut).to_list()


@app.get("/addresses", response_model=list[AddressOut])
async def get_addresses(
address_filter: AddressFilter = FilterDepends(with_prefix("my_custom_prefix", AddressFilter), by_alias=True),
) -> Any:
query = address_filter.filter(Address.find({}))
query = address_filter.sort(query)
return await query.project(AddressOut).to_list()


if __name__ == "__main__":
uvicorn.run("fastapi_filter_beanie:app", reload=True)
3 changes: 3 additions & 0 deletions fastapi_filter/contrib/beanie/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .filter import Filter

__all__ = ("Filter",)
115 changes: 115 additions & 0 deletions fastapi_filter/contrib/beanie/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from collections.abc import Callable, Mapping
from typing import Any

from beanie.odm.interfaces.find import FindType
from beanie.odm.queries.find import FindMany
from pydantic import ValidationInfo, field_validator

from fastapi_filter.base.filter import BaseFilterModel

_odm_operator_transformer: dict[str, Callable[[str | None], dict[str, Any] | None]] = {
"neq": lambda value: {"$ne": value},
"gt": lambda value: {"$gt": value},
"gte": lambda value: {"$gte": value},
"in": lambda value: {"$in": value},
"isnull": lambda value: None if value else {"$ne": None},
"lt": lambda value: {"$lt": value},
"lte": lambda value: {"$lte": value},
"not": lambda value: {"$ne": value},
"ne": lambda value: {"$ne": value},
"not_in": lambda value: {"$nin": value},
"nin": lambda value: {"$nin": value},
"like": lambda value: {"$regex": f".*{value}.*"},
"ilike": lambda value: {"$regex": f".*{value}.*", "$options": "i"},
}


class Filter(BaseFilterModel):
"""Base filter for beanie related filters.
Example:
```python
class MyModel:
id: PrimaryKey()
name: StringField(null=True)
count: IntField()
created_at: DatetimeField()
class MyModelFilter(Filter):
id: Optional[int]
id__in: Optional[str]
count: Optional[int]
count__lte: Optional[int]
created_at__gt: Optional[datetime]
name__ne: Optional[str]
name__nin: Optional[list[str]]
name__isnull: Optional[bool]
```
"""

def sort(self, query: FindMany[FindType]) -> FindMany[FindType]:
if not self.ordering_values:
return query
return query.sort(*self.ordering_values)

@field_validator("*", mode="before")
@classmethod
def split_str(cls: type["BaseFilterModel"], value: str | None, field: ValidationInfo) -> list[str] | str | None:
if (
field.field_name is not None
and (
field.field_name == cls.Constants.ordering_field_name
or field.field_name.endswith("__in")
or field.field_name.endswith("__nin")
)
and isinstance(value, str)
):
if not value:
# Empty string should return [] not ['']
return []
return list(value.split(","))
return value

def _get_filter_conditions(self, nesting_depth: int = 1) -> list[tuple[Mapping[str, Any], Mapping[str, Any]]]:
filter_conditions: list[tuple[Mapping[str, Any], Mapping[str, Any]]] = []
for field_name, value in self.filtering_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, Filter):
if not field_value.model_dump(exclude_none=True, exclude_unset=True):
continue

filter_conditions.append(
(
{field_name: _odm_operator_transformer["neq"](None)},
{"fetch_links": True, "nesting_depth": nesting_depth},
)
)
for part, part_options in field_value._get_filter_conditions(nesting_depth=nesting_depth + 1): # noqa: SLF001
for sub_field_name, sub_value in part.items():
filter_conditions.append(
(
{f"{field_name}.{sub_field_name}": sub_value},
{"fetch_links": True, "nesting_depth": nesting_depth, **part_options},
)
)

elif "__" in field_name:
stripped_field_name, operator = field_name.split("__")
search_criteria = _odm_operator_transformer[operator](value)
filter_conditions.append(({stripped_field_name: search_criteria}, {}))
elif field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"):
search_conditions = [
{search_field: _odm_operator_transformer["ilike"](value)}
for search_field in self.Constants.search_model_fields
]
filter_conditions.append(({"$or": search_conditions}, {}))
else:
filter_conditions.append(({field_name: value}, {}))

return filter_conditions

def filter(self, query: FindMany[FindType]) -> FindMany[FindType]:
data = self._get_filter_conditions()
for filter_condition, filter_kwargs in data:
query = query.find(filter_condition, **filter_kwargs)
return query.find(fetch_links=True)
78 changes: 76 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 648f24b

Please sign in to comment.