Skip to content

Commit

Permalink
Add optional sorting option to RedisVL queries (#148)
Browse files Browse the repository at this point in the history
This PR adds the ability to sort results by a hash field if the field
name is passed as a parameter to filter queries.
  • Loading branch information
justin-cechmanek authored May 15, 2024
1 parent 5e845f2 commit 458be76
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 3 deletions.
27 changes: 24 additions & 3 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
return_fields: Optional[List[str]] = None,
num_results: int = 10,
dialect: int = 2,
sort_by: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
):
"""A query for a running a filtered search with a filter expression.
Expand All @@ -146,6 +147,8 @@ def __init__(
return_fields (Optional[List[str]], optional): The fields to return.
num_results (Optional[int], optional): The number of results to
return. Defaults to 10.
sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by vector distance.
params (Optional[Dict[str, Any]], optional): The parameters for the
query. Defaults to None.
Expand All @@ -164,6 +167,7 @@ def __init__(
"""
super().__init__(return_fields, num_results, dialect)
self.set_filter(filter_expression)
self._sort_by = sort_by
self._params = params or {}

@property
Expand All @@ -180,6 +184,8 @@ def query(self) -> Query:
.paging(self._first, self._limit)
.dialect(self._dialect)
)
if self._sort_by:
query = query.sort_by(self._sort_by)
return query


Expand All @@ -201,12 +207,14 @@ def __init__(
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
):
super().__init__(return_fields, num_results, dialect)
self.set_filter(filter_expression)
self._vector = vector
self._field = vector_field_name
self._dtype = dtype.lower()
self._sort_by = sort_by

if return_score:
self._return_fields.append(self.DISTANCE_ID)
Expand All @@ -223,6 +231,7 @@ def __init__(
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
):
"""A query for running a vector search along with an optional filter
expression.
Expand All @@ -243,6 +252,8 @@ def __init__(
distance. Defaults to True.
dialect (int, optional): The RediSearch query dialect.
Defaults to 2.
sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by vector distance.
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
Expand All @@ -259,6 +270,7 @@ def __init__(
num_results,
return_score,
dialect,
sort_by,
)

@property
Expand All @@ -272,10 +284,13 @@ def query(self) -> Query:
query = (
Query(base_query)
.return_fields(*self._return_fields)
.sort_by(self.DISTANCE_ID)
.paging(self._first, self._limit)
.dialect(self._dialect)
)
if self._sort_by:
query = query.sort_by(self._sort_by)
else:
query = query.sort_by(self.DISTANCE_ID)
return query

@property
Expand Down Expand Up @@ -307,6 +322,7 @@ def __init__(
num_results: int = 10,
return_score: bool = True,
dialect: int = 2,
sort_by: Optional[str] = None,
):
"""A query for running a filtered vector search based on semantic
distance threshold.
Expand All @@ -330,7 +346,8 @@ def __init__(
distance. Defaults to True.
dialect (int, optional): The RediSearch query dialect.
Defaults to 2.
sort_by (Optional[str]): The field to order the results by. Defaults
to None. Results will be ordered by vector distance.
Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
Expand All @@ -347,6 +364,7 @@ def __init__(
num_results,
return_score,
dialect,
sort_by,
)
self.set_distance_threshold(distance_threshold)

Expand Down Expand Up @@ -390,10 +408,13 @@ def query(self) -> Query:
query = (
Query(base_query)
.return_fields(*self._return_fields)
.sort_by(self.DISTANCE_ID)
.paging(self._first, self._limit)
.dialect(self._dialect)
)
if self._sort_by:
query = query.sort_by(self._sort_by)
else:
query = query.sort_by(self.DISTANCE_ID)
return query

@property
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def vector_query():
)


@pytest.fixture
def sorted_vector_query():
return VectorQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
sort_by="age",
)


@pytest.fixture
def filter_query():
return FilterQuery(
Expand All @@ -26,6 +36,15 @@ def filter_query():
)


@pytest.fixture
def sorted_filter_query():
return FilterQuery(
return_fields=["user", "credit_score", "age", "job", "location"],
filter_expression=Tag("credit_score") == "high",
sort_by="age",
)


@pytest.fixture
def range_query():
return RangeQuery(
Expand All @@ -36,6 +55,17 @@ def range_query():
)


@pytest.fixture
def sorted_range_query():
return RangeQuery(
vector=[0.1, 0.1, 0.5],
vector_field_name="user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
distance_threshold=0.2,
sort_by="age",
)


@pytest.fixture
def index(sample_data, redis_url):
# construct a search index from the schema
Expand Down Expand Up @@ -160,6 +190,7 @@ def search(
age_range=None,
location=None,
distance_threshold=0.2,
sort=False,
):
"""Utility function to test filters."""

Expand Down Expand Up @@ -199,6 +230,21 @@ def search(
else:
assert len(results.docs) == expected_count

# check results are in sorted order
if sort:
if isinstance(query, RangeQuery):
assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100]
else:
assert [int(doc.age) for doc in results.docs] == [
12,
14,
15,
18,
35,
94,
100,
]


@pytest.fixture(
params=["vector_query", "filter_query", "range_query"],
Expand Down Expand Up @@ -339,3 +385,18 @@ def test_paginate_range_query(index, range_query):
assert len(all_results) == expected_count
assert i == expected_iterations
assert all(float(item["vector_distance"]) <= 0.2 for item in all_results)


def test_sort_filter_query(index, sorted_filter_query):
t = Text("job") % ""
search(sorted_filter_query, index, t, 7, sort=True)


def test_sort_vector_query(index, sorted_vector_query):
t = Text("job") % ""
search(sorted_vector_query, index, t, 7, sort=True)


def test_sort_range_query(index, sorted_range_query):
t = Text("job") % ""
search(sorted_range_query, index, t, 7, sort=True)
31 changes: 31 additions & 0 deletions tests/unit/test_query_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_filter_query():
assert isinstance(filter_query.params, dict)
assert filter_query.params == {}
assert filter_query._dialect == 2
assert filter_query._sort_by == None

# Test set_filter functionality
new_filter_expression = Tag("category") == "Sportswear"
Expand All @@ -57,6 +58,12 @@ def test_filter_query():
assert filter_query._limit == 7
assert filter_query._num_results == 10

# Test sort_by functionality
filter_query = FilterQuery(
filter_expression, return_fields, num_results=10, sort_by="price"
)
assert filter_query._sort_by == "price"


def test_vector_query():
# Create a vector query
Expand All @@ -73,6 +80,7 @@ def test_vector_query():
assert isinstance(vector_query.params, dict)
assert vector_query.params != {}
assert vector_query._dialect == 3
assert vector_query._sort_by == None

# Test set_filter functionality
new_filter_expression = Tag("category") == "Sportswear"
Expand All @@ -85,6 +93,17 @@ def test_vector_query():
assert vector_query._limit == 7
assert vector_query._num_results == 10

# Test sort_by functionality
vector_query = VectorQuery(
sample_vector,
"vector_field",
["field1", "field2"],
dialect=3,
num_results=10,
sort_by="field2",
)
assert vector_query._sort_by == "field2"


def test_range_query():
# Create a filter expression
Expand All @@ -104,6 +123,7 @@ def test_range_query():
assert isinstance(range_query.query, Query)
assert isinstance(range_query.params, dict)
assert range_query.params != {}
assert range_query._sort_by == None

# Test set_filter functionality
new_filter_expression = Tag("category") == "Outdoor"
Expand All @@ -115,3 +135,14 @@ def test_range_query():
assert range_query._first == 5
assert range_query._limit == 7
assert range_query._num_results == 10

# Test sort_by functionality
range_query = RangeQuery(
sample_vector,
"vector_field",
["field1"],
filter_expression,
num_results=10,
sort_by="field1",
)
assert range_query._sort_by == "field1"

0 comments on commit 458be76

Please sign in to comment.