Skip to content

Commit

Permalink
Merge pull request #12 from merll/fix-sort-on-nullable-relationship
Browse files Browse the repository at this point in the history
Fix sorts across nullable relationships
  • Loading branch information
merll committed Dec 20, 2022
2 parents 6e3fcd2 + 7d51d62 commit 063cab4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 10 deletions.
11 changes: 8 additions & 3 deletions sqlalchemy_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
from sqlalchemy import and_, or_, not_, func

from .exceptions import BadFilterFormat
from .models import Field, auto_join, get_model_from_spec, get_relationship_models, \
should_outer_join_relationship
from .models import (
Field,
auto_join,
get_model_from_spec,
get_relationship_models,
should_filter_outer_join_relationship,
)

BooleanFunction = namedtuple(
'BooleanFunction', ('key', 'sqlalchemy_fn', 'only_one_arg')
Expand Down Expand Up @@ -98,7 +103,7 @@ def get_named_models(self, model):
operator = self.filter_spec['op'] if 'op' in self.filter_spec else None

models = get_relationship_models(model, field)
return (list(), models) if should_outer_join_relationship(operator) else (models, list())
return (list(), models) if should_filter_outer_join_relationship(operator) else (models, list())

def format_for_sqlalchemy(self, query, default_model):
filter_spec = self.filter_spec
Expand Down
11 changes: 10 additions & 1 deletion sqlalchemy_filters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,19 @@ def get_relationship_models(model, field):
return list()


def should_outer_join_relationship(operator):
def should_filter_outer_join_relationship(operator):
return operator == 'is_null'


def should_sort_outer_join_relationship(models):
for rel_model in models:
if rel_model.prop.direction == symbol('ONETOMANY'):
return True
elif any(column.nullable for column in rel_model.prop.local_columns):
return True
return False


def find_nested_relationship_model(mapper, field):
parts = field if isinstance(field, list) else field.split(".")

Expand Down
13 changes: 8 additions & 5 deletions sqlalchemy_filters/sorting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# -*- coding: utf-8 -*-

from .exceptions import BadSortFormat
from .models import Field, auto_join, get_model_from_spec, get_default_model, get_relationship_models, \
should_outer_join_relationship
from .models import (
Field,
auto_join,
get_model_from_spec,
get_relationship_models,
should_sort_outer_join_relationship,
)

SORT_ASCENDING = 'asc'
SORT_DESCENDING = 'desc'
Expand Down Expand Up @@ -35,11 +40,9 @@ def __init__(self, sort_spec):

def get_named_models(self, model):
field = self.sort_spec['field']
operator = self.sort_spec['op'] if 'op' in self.sort_spec else None

models = get_relationship_models(model, field)

return (list(), models) if should_outer_join_relationship(operator) else (models, list())
return (list(), models) if should_sort_outer_join_relationship(models) else (models, list())

def format_for_sqlalchemy(self, query, default_model):
sort_spec = self.sort_spec
Expand Down
47 changes: 46 additions & 1 deletion test/interface/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from test import error_value
from test.models import Foo, Bar, Qux


NULLSFIRST_NOT_SUPPORTED = (
"'nullsfirst' only supported by PostgreSQL in the current tests"
)
Expand Down Expand Up @@ -245,6 +244,52 @@ def test_multiple_models(self, session):
assert results[2].id == 4
assert results[3].id == 2

def test_nullable_relationships(self, session):
bar_1 = Bar(id=1, name='name_1', count=5)
bar_2 = Bar(id=2, name='name_2', count=20)
bar_3 = Bar(id=3, name='name_1', count=None)
bar_4 = Bar(id=4, name='name_4', count=10)
foo_1 = Foo(id=1, bar_id=1, name='name_1', count=1)
foo_2 = Foo(id=2, bar_id=2, name='name_2', count=1)
foo_3 = Foo(id=3, bar_id=3, name='name_1', count=1)
foo_4 = Foo(id=4, bar_id=4, name='name_4', count=1)
foo_5 = Foo(id=5, bar_id=None, name='name_1', count=2)
foo_6 = Foo(id=6, bar_id=None, name='name_4', count=2)
foo_7 = Foo(id=7, bar_id=None, name='name_2', count=2)
foo_8 = Foo(id=8, bar_id=None, name='name_5', count=2)
session.add_all([
bar_1, bar_2, bar_3, bar_4,
foo_1, foo_2, foo_3, foo_4, foo_5, foo_6, foo_7, foo_8,
])
session.commit()

query = session.query(Foo)
sort_spec = [
{'field': 'bar.count', 'direction': 'desc'},
{'field': 'name', 'direction': 'asc'},
]
sorted_query = apply_sort(Foo, query, sort_spec)
results = sorted_query.all()
assert len(results) == 8
results_with_bar = results[:4]
results_without_bar = results[4:]
assert [
(result.id, result.bar.count, result.name) for result in results_with_bar
] == [
(2, 20, 'name_2'),
(4, 10, 'name_4'),
(1, 5, 'name_1'),
(3, None, 'name_1'),
]
assert [
(result.id, result.bar, result.name) for result in results_without_bar
] == [
(5, None, 'name_1'),
(7, None, 'name_2'),
(6, None, 'name_4'),
(8, None, 'name_5'),
]

@pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted')
def test_a_single_dict_can_be_supplied_as_sort_spec(self, session):
query = session.query(Bar)
Expand Down

0 comments on commit 063cab4

Please sign in to comment.