Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically join if filtering by relationship field #550

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions fastapi_filter/contrib/sqlalchemy/filter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
from enum import Enum
from typing import Union
from typing import Any, List, Tuple, Union
from warnings import warn

from pydantic import ValidationInfo, field_validator
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, RelationshipProperty, class_mapper
from sqlalchemy.sql.selectable import Select

from ...base.filter import BaseFilterModel
Expand Down Expand Up @@ -105,6 +105,43 @@
return value

def filter(self, query: Union[Query, Select]):
relationships = _get_relationships(self.Constants.model)
query = self._join_relationships(query, relationships)
query = self._apply_filters(query)
return query

def _join_relationships(self, query: Union[Query, Select], relationships: List[Tuple[Any, RelationshipProperty]]):
"""Joins the specified relationships in the query.

Args:
query (Union[Query, Select]): The SQLAlchemy query object.
relationships (List[Tuple[Any, RelationshipProperty]]): The list of relationships to join.

Returns:
Union[Query, Select]: The modified query object with the joined relationships.
"""
for rel_class, rel in relationships:
related_filter = next(
(value for key, value in self.filtering_fields if key == rel.key),
None,
)
if related_filter is not None and _any_field_not_none(related_filter):
if rel.secondary is not None:
query = query.outerjoin(rel.secondary, rel.primaryjoin)
query = query.outerjoin(rel_class, rel.secondaryjoin)

Check warning on line 131 in fastapi_filter/contrib/sqlalchemy/filter.py

View check run for this annotation

Codecov / codecov/patch

fastapi_filter/contrib/sqlalchemy/filter.py#L130-L131

Added lines #L130 - L131 were not covered by tests
else:
query = query.outerjoin(rel_class, rel.primaryjoin)
return query

def _apply_filters(self, query: Union[Query, Select]):
"""Apply the filtering fields to the given query.

Args:
query (Union[Query, Select]): The query to apply the filters to.

Returns:
Union[Query, Select]: The modified query with the filters applied.
"""
for field_name, value in self.filtering_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, Filter):
Expand Down Expand Up @@ -143,3 +180,38 @@
query = query.order_by(getattr(order_by_field, direction)())

return query


def _get_relationships(model: Any) -> List[Tuple[Any, RelationshipProperty]]:
"""Get the related classes and relationship attributes of a SQLAlchemy model.

Args:
model (Any): The SQLAlchemy model.

Returns:
List[Tuple[Any, RelationshipProperty]]: A list of tuples, where each tuple contains a SQLAlchemy ORM class
related to the model and the corresponding relationship attribute.
"""
mapper = class_mapper(model)
relationships = [(rel.mapper.class_, rel) for rel in mapper.relationships]
return relationships


def _any_field_not_none(model: dict) -> bool:
"""Check if any field in a Pydantic model or any of its nested models is not None.

Args:
model (BaseModel): The dict representation of a model.

Returns:
bool: True if any field is not None, False otherwise.
"""
for _field, value in model.items():
if value is not None:
if isinstance(value, dict):
# If the value is a nested model, check if any field in the nested model is not None
if _any_field_not_none(value):
return True

Check warning on line 214 in fastapi_filter/contrib/sqlalchemy/filter.py

View check run for this annotation

Codecov / codecov/patch

fastapi_filter/contrib/sqlalchemy/filter.py#L213-L214

Added lines #L213 - L214 were not covered by tests
else:
return True
return False
6 changes: 3 additions & 3 deletions tests/test_sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ async def get_users(
user_filter: UserFilter = FilterDepends(UserFilter), # type: ignore[valid-type]
db: AsyncSession = Depends(get_db),
):
query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined]
query = user_filter.filter(select(User)) # type: ignore[attr-defined]
result = await db.execute(query)
return result.scalars().unique().all()

Expand All @@ -353,7 +353,7 @@ async def get_users_by_alias(
user_filter: UserFilter = FilterDepends(UserFilterByAlias, by_alias=True), # type: ignore[valid-type]
db: AsyncSession = Depends(get_db),
):
query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined]
query = user_filter.filter(select(User)) # type: ignore[attr-defined]
result = await db.execute(query)
return result.scalars().unique().all()

Expand All @@ -362,7 +362,7 @@ async def get_users_with_order_by(
user_filter: UserFilterOrderBy = FilterDepends(UserFilterOrderBy), # type: ignore[valid-type]
db: AsyncSession = Depends(get_db),
):
query = user_filter.sort(select(User).outerjoin(Address)) # type: ignore[attr-defined]
query = user_filter.sort(select(User)) # type: ignore[attr-defined]
query = user_filter.filter(query) # type: ignore[attr-defined]
result = await db.execute(query)
return result.scalars().unique().all()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sqlalchemy/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
@pytest.mark.usefixtures("users")
@pytest.mark.asyncio
async def test_filter(session, Address, User, UserFilter, filter_, expected_count):
query = select(User).outerjoin(Address)
query = select(User)
query = UserFilter(**filter_).filter(query)
result = await session.execute(query)
assert len(result.scalars().unique().all()) == expected_count
Expand All @@ -56,7 +56,7 @@ async def test_filter(session, Address, User, UserFilter, filter_, expected_coun
@pytest.mark.usefixtures("users")
@pytest.mark.asyncio
async def test_filter_deprecation_like_and_ilike(session, Address, User, UserFilter, filter_, expected_count):
query = select(User).outerjoin(Address)
query = select(User)
with pytest.warns(DeprecationWarning, match="like and ilike operators."):
query = UserFilter(**filter_).filter(query)
result = await session.execute(query)
Expand Down
Loading