diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index 34879de6..0ab7f76e 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- from enum import Enum -from typing import Union +from typing import Any, List, Tuple, Union from warnings import warn from pydantic import ValidationInfo, field_validator from sqlalchemy import or_ -from sqlalchemy.orm import Query +from sqlalchemy.orm import Query, RelationshipProperty, class_mapper from sqlalchemy.sql.selectable import Select from ...base.filter import BaseFilterModel @@ -105,6 +105,43 @@ def split_str(cls, value, field: ValidationInfo): return value def filter(self, query: Union[Query, Select]): + relationships = _get_relationships(self.Constants.model) + query = self._join_relationships(query, relationships) + query = self._apply_filters(query) + return query + + def _join_relationships(self, query: Union[Query, Select], relationships: List[Tuple[Any, RelationshipProperty]]): + """Joins the specified relationships in the query. + + Args: + query (Union[Query, Select]): The SQLAlchemy query object. + relationships (List[Tuple[Any, RelationshipProperty]]): The list of relationships to join. + + Returns: + Union[Query, Select]: The modified query object with the joined relationships. + """ + for rel_class, rel in relationships: + related_filter = next( + (value for key, value in self.filtering_fields if key == rel.key), + None, + ) + if related_filter is not None and _any_field_not_none(related_filter): + if rel.secondary is not None: + query = query.outerjoin(rel.secondary, rel.primaryjoin) + query = query.outerjoin(rel_class, rel.secondaryjoin) + else: + query = query.outerjoin(rel_class, rel.primaryjoin) + return query + + def _apply_filters(self, query: Union[Query, Select]): + """Apply the filtering fields to the given query. + + Args: + query (Union[Query, Select]): The query to apply the filters to. + + Returns: + Union[Query, Select]: The modified query with the filters applied. + """ for field_name, value in self.filtering_fields: field_value = getattr(self, field_name) if isinstance(field_value, Filter): @@ -143,3 +180,38 @@ def sort(self, query: Union[Query, Select]): query = query.order_by(getattr(order_by_field, direction)()) return query + + +def _get_relationships(model: Any) -> List[Tuple[Any, RelationshipProperty]]: + """Get the related classes and relationship attributes of a SQLAlchemy model. + + Args: + model (Any): The SQLAlchemy model. + + Returns: + List[Tuple[Any, RelationshipProperty]]: A list of tuples, where each tuple contains a SQLAlchemy ORM class + related to the model and the corresponding relationship attribute. + """ + mapper = class_mapper(model) + relationships = [(rel.mapper.class_, rel) for rel in mapper.relationships] + return relationships + + +def _any_field_not_none(model: dict) -> bool: + """Check if any field in a Pydantic model or any of its nested models is not None. + + Args: + model (BaseModel): The dict representation of a model. + + Returns: + bool: True if any field is not None, False otherwise. + """ + for _field, value in model.items(): + if value is not None: + if isinstance(value, dict): + # If the value is a nested model, check if any field in the nested model is not None + if _any_field_not_none(value): + return True + else: + return True + return False diff --git a/tests/test_sqlalchemy/conftest.py b/tests/test_sqlalchemy/conftest.py index f250b9f7..d84e0ae9 100644 --- a/tests/test_sqlalchemy/conftest.py +++ b/tests/test_sqlalchemy/conftest.py @@ -344,7 +344,7 @@ async def get_users( user_filter: UserFilter = FilterDepends(UserFilter), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.filter(select(User)) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() @@ -353,7 +353,7 @@ async def get_users_by_alias( user_filter: UserFilter = FilterDepends(UserFilterByAlias, by_alias=True), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.filter(select(User)) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() @@ -362,7 +362,7 @@ async def get_users_with_order_by( user_filter: UserFilterOrderBy = FilterDepends(UserFilterOrderBy), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.sort(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.sort(select(User)) # type: ignore[attr-defined] query = user_filter.filter(query) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() diff --git a/tests/test_sqlalchemy/test_filter.py b/tests/test_sqlalchemy/test_filter.py index 3217a60b..16373ac3 100644 --- a/tests/test_sqlalchemy/test_filter.py +++ b/tests/test_sqlalchemy/test_filter.py @@ -40,7 +40,7 @@ @pytest.mark.usefixtures("users") @pytest.mark.asyncio async def test_filter(session, Address, User, UserFilter, filter_, expected_count): - query = select(User).outerjoin(Address) + query = select(User) query = UserFilter(**filter_).filter(query) result = await session.execute(query) assert len(result.scalars().unique().all()) == expected_count @@ -56,7 +56,7 @@ async def test_filter(session, Address, User, UserFilter, filter_, expected_coun @pytest.mark.usefixtures("users") @pytest.mark.asyncio async def test_filter_deprecation_like_and_ilike(session, Address, User, UserFilter, filter_, expected_count): - query = select(User).outerjoin(Address) + query = select(User) with pytest.warns(DeprecationWarning, match="like and ilike operators."): query = UserFilter(**filter_).filter(query) result = await session.execute(query)