From fb74aa2806100f4026e290dab0cb7164262ff142 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:00:11 +0300 Subject: [PATCH] Added support for ADDSCORES modifier (#3329) * Added support for ADDSCORES modifier * Fixed codestyle issues * More codestyle fixes * Updated test cases and testing image to represent latest * Codestyle issues * Added handling for dict responses --- .github/workflows/integration.yaml | 2 +- redis/commands/search/aggregation.py | 11 +++++++++++ tests/test_asyncio/test_search.py | 26 ++++++++++++++++++++++++++ tests/test_search.py | 26 ++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 94fe8f35b6..5342238dd3 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -28,7 +28,7 @@ env: # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon REDIS_IMAGE: redis:7.4-rc2 - REDIS_STACK_IMAGE: redis/redis-stack-server:7.4.0-rc2 + REDIS_STACK_IMAGE: redis/redis-stack-server:latest jobs: dependency-audit: diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 50d18f476a..42c3547b0b 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None: self._verbatim = False self._cursor = [] self._dialect = None + self._add_scores = False def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest": self._with_schema = True return self + def add_scores(self) -> "AggregateRequest": + """ + If set, includes the score as an ordinary field of the row. + """ + self._add_scores = True + return self + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self @@ -315,6 +323,9 @@ def build_args(self) -> List[str]: if self._verbatim: ret.append("VERBATIM") + if self._add_scores: + ret.append("ADDSCORES") + if self._cursor: ret += self._cursor diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 68560d1f2a..0e6fe22131 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1530,6 +1530,32 @@ async def test_withsuffixtrie(decoded_r: redis.Redis): assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_add_scores(decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + NumericField("age", sortable=True), + ) + ) + + assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"}) + assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"}) + + req = aggregations.AggregateRequest("*").add_scores() + res = await decoded_r.ft().aggregate(req) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} + assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} + else: + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] + + @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index e84f03c0e4..dde59f0f87 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1440,6 +1440,32 @@ def test_aggregations_filter(client): assert res["results"][1]["extra_attributes"] == {"age": "25"} +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +def test_aggregations_add_scores(client): + client.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + NumericField("age", sortable=True), + ) + ) + + client.hset("doc1", mapping={"name": "bar", "age": "25"}) + client.hset("doc2", mapping={"name": "foo", "age": "19"}) + + req = aggregations.AggregateRequest("*").add_scores() + res = client.ft().aggregate(req) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} + assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} + else: + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] + + @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client):