From 445360d6ed071aa7dad560a02c45bfcfb28ba9e6 Mon Sep 17 00:00:00 2001 From: Matthias Erll Date: Tue, 13 Dec 2022 11:31:32 +0100 Subject: [PATCH] Added test case and fix for handling sorts across joined tables --- sqlalchemy_filters/filters.py | 11 +++++--- sqlalchemy_filters/models.py | 10 +++++++- sqlalchemy_filters/sorting.py | 13 ++++++---- test/interface/test_sorting.py | 47 +++++++++++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 10 deletions(-) diff --git a/sqlalchemy_filters/filters.py b/sqlalchemy_filters/filters.py index 329e21a..54c83b6 100644 --- a/sqlalchemy_filters/filters.py +++ b/sqlalchemy_filters/filters.py @@ -19,8 +19,13 @@ from sqlalchemy import and_, or_, not_, func from .exceptions import BadFilterFormat -from .models import Field, auto_join, get_model_from_spec, get_relationship_models, \ - should_outer_join_relationship +from .models import ( + Field, + auto_join, + get_model_from_spec, + get_relationship_models, + should_filter_outer_join_relationship, +) BooleanFunction = namedtuple( 'BooleanFunction', ('key', 'sqlalchemy_fn', 'only_one_arg') @@ -98,7 +103,7 @@ def get_named_models(self, model): operator = self.filter_spec['op'] if 'op' in self.filter_spec else None models = get_relationship_models(model, field) - return (list(), models) if should_outer_join_relationship(operator) else (models, list()) + return (list(), models) if should_filter_outer_join_relationship(operator) else (models, list()) def format_for_sqlalchemy(self, query, default_model): filter_spec = self.filter_spec diff --git a/sqlalchemy_filters/models.py b/sqlalchemy_filters/models.py index 50a4c0d..49a737f 100644 --- a/sqlalchemy_filters/models.py +++ b/sqlalchemy_filters/models.py @@ -56,10 +56,18 @@ def get_relationship_models(model, field): return list() -def should_outer_join_relationship(operator): +def should_filter_outer_join_relationship(operator): return operator == 'is_null' +def should_sort_outer_join_relationship(models): + return any( + column.nullable + for rel_model in models + for column in rel_model.prop.local_columns + ) + + def find_nested_relationship_model(mapper, field): parts = field if isinstance(field, list) else field.split(".") diff --git a/sqlalchemy_filters/sorting.py b/sqlalchemy_filters/sorting.py index 4fff198..680d46e 100644 --- a/sqlalchemy_filters/sorting.py +++ b/sqlalchemy_filters/sorting.py @@ -1,8 +1,13 @@ # -*- coding: utf-8 -*- from .exceptions import BadSortFormat -from .models import Field, auto_join, get_model_from_spec, get_default_model, get_relationship_models, \ - should_outer_join_relationship +from .models import ( + Field, + auto_join, + get_model_from_spec, + get_relationship_models, + should_sort_outer_join_relationship, +) SORT_ASCENDING = 'asc' SORT_DESCENDING = 'desc' @@ -35,11 +40,9 @@ def __init__(self, sort_spec): def get_named_models(self, model): field = self.sort_spec['field'] - operator = self.sort_spec['op'] if 'op' in self.sort_spec else None - models = get_relationship_models(model, field) - return (list(), models) if should_outer_join_relationship(operator) else (models, list()) + return (list(), models) if should_sort_outer_join_relationship(models) else (models, list()) def format_for_sqlalchemy(self, query, default_model): sort_spec = self.sort_spec diff --git a/test/interface/test_sorting.py b/test/interface/test_sorting.py index 33b19f9..694b67c 100644 --- a/test/interface/test_sorting.py +++ b/test/interface/test_sorting.py @@ -10,7 +10,6 @@ from test import error_value from test.models import Foo, Bar, Qux - NULLSFIRST_NOT_SUPPORTED = ( "'nullsfirst' only supported by PostgreSQL in the current tests" ) @@ -245,6 +244,52 @@ def test_multiple_models(self, session): assert results[2].id == 4 assert results[3].id == 2 + def test_nullable_relationships(self, session): + bar_1 = Bar(id=1, name='name_1', count=5) + bar_2 = Bar(id=2, name='name_2', count=20) + bar_3 = Bar(id=3, name='name_1', count=None) + bar_4 = Bar(id=4, name='name_4', count=10) + foo_1 = Foo(id=1, bar_id=1, name='name_1', count=1) + foo_2 = Foo(id=2, bar_id=2, name='name_2', count=1) + foo_3 = Foo(id=3, bar_id=3, name='name_1', count=1) + foo_4 = Foo(id=4, bar_id=4, name='name_4', count=1) + foo_5 = Foo(id=5, bar_id=None, name='name_1', count=2) + foo_6 = Foo(id=6, bar_id=None, name='name_4', count=2) + foo_7 = Foo(id=7, bar_id=None, name='name_2', count=2) + foo_8 = Foo(id=8, bar_id=None, name='name_5', count=2) + session.add_all([ + bar_1, bar_2, bar_3, bar_4, + foo_1, foo_2, foo_3, foo_4, foo_5, foo_6, foo_7, foo_8, + ]) + session.commit() + + query = session.query(Foo) + sort_spec = [ + {'field': 'bar.count', 'direction': 'desc'}, + {'field': 'name', 'direction': 'asc'}, + ] + sorted_query = apply_sort(Foo, query, sort_spec) + results = sorted_query.all() + assert len(results) == 8 + results_with_bar = results[:4] + results_without_bar = results[4:] + assert [ + (result.id, result.bar.count, result.name) for result in results_with_bar + ] == [ + (2, 20, 'name_2'), + (4, 10, 'name_4'), + (1, 5, 'name_1'), + (3, None, 'name_1'), + ] + assert [ + (result.id, result.bar, result.name) for result in results_without_bar + ] == [ + (5, None, 'name_1'), + (7, None, 'name_2'), + (6, None, 'name_4'), + (8, None, 'name_5'), + ] + @pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted') def test_a_single_dict_can_be_supplied_as_sort_spec(self, session): query = session.query(Bar)