From 7bb42984ea69c8ac8ce3f77d49a3fabc03c3068a Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Thu, 9 Nov 2023 05:16:45 +0300 Subject: [PATCH] Simplify related query to remove nesting and make more performant (#3307) * Add related test Signed-off-by: Olga Bulat * Simplify related query to remove nesting Signed-off-by: Olga Bulat * Use terms query for tags in related Signed-off-by: Olga Bulat * Update the unit test Signed-off-by: Olga Bulat * Add excluded providers cache Signed-off-by: Olga Bulat * Test number of related results in integration Signed-off-by: Olga Bulat * Test that the results are related in integration Signed-off-by: Olga Bulat --------- Signed-off-by: Olga Bulat --- api/api/controllers/elasticsearch/__init__.py | 0 api/api/controllers/elasticsearch/helpers.py | 107 +++++++++++ api/api/controllers/elasticsearch/related.py | 73 ++++++++ api/api/controllers/search_controller.py | 169 +----------------- api/api/views/media_views.py | 3 +- api/test/factory/es_http.py | 29 ++- api/test/media_integration.py | 20 ++- api/test/test_dead_link_filter.py | 2 +- api/test/test_image_integration.py | 2 +- .../controllers/elasticsearch/test_related.py | 106 +++++++++++ .../controllers/test_search_controller.py | 7 +- .../test_search_controller_search_query.py | 18 +- 12 files changed, 361 insertions(+), 175 deletions(-) create mode 100644 api/api/controllers/elasticsearch/__init__.py create mode 100644 api/api/controllers/elasticsearch/related.py create mode 100644 api/test/unit/controllers/elasticsearch/test_related.py diff --git a/api/api/controllers/elasticsearch/__init__.py b/api/api/controllers/elasticsearch/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/api/controllers/elasticsearch/helpers.py b/api/api/controllers/elasticsearch/helpers.py index c4bbdf9386e..808a7b81b29 100644 --- a/api/api/controllers/elasticsearch/helpers.py +++ b/api/api/controllers/elasticsearch/helpers.py @@ -4,10 +4,15 @@ import logging as log import pprint import time +from itertools import accumulate +from math import ceil from django.conf import settings from elasticsearch import BadRequestError, NotFoundError +from elasticsearch_dsl import Search + +from api.utils.dead_link_mask import get_query_hash, get_query_mask def log_timing_info(func): @@ -55,3 +60,105 @@ def get_es_response(s, *args, **kwargs): @log_timing_info def get_raw_es_response(index, body, *args, **kwargs): return settings.ES.search(index=index, body=body, *args, **kwargs) + + +ELASTICSEARCH_MAX_RESULT_WINDOW = 10000 +DEAD_LINK_RATIO = 1 / 2 +DEEP_PAGINATION_ERROR = "Deep pagination is not allowed." + + +def _unmasked_query_end(page_size, page): + """ + Calculate the upper index of results to retrieve from Elasticsearch. + + Used to retrieve the upper index of results to retrieve from Elasticsearch under the + following conditions: + 1. There is no query mask + 2. The lower index is beyond the scope of the existing query mask + 3. The lower index is within the scope of the existing query mask + but the upper index exceeds it + + In all these cases, the query mask is not used to calculate the upper index. + """ + return ceil(page_size * page / (1 - DEAD_LINK_RATIO)) + + +def _paginate_with_dead_link_mask( + s: Search, page_size: int, page: int +) -> tuple[int, int]: + """ + Return the start and end of the results slice, given the query, page and page size. + + In almost all cases the ``DEAD_LINK_RATIO`` will effectively double + the page size (given the current configuration of 0.5). + + The "branch X" labels are for cross-referencing with the tests. + + :param s: The elasticsearch Search object + :param page_size: How big the page should be. + :param page: The page number. + :return: Tuple of start and end. + """ + query_hash = get_query_hash(s) + query_mask = get_query_mask(query_hash) + if not query_mask: # branch 1 + start = 0 + end = _unmasked_query_end(page_size, page) + elif page_size * (page - 1) > sum(query_mask): # branch 2 + start = len(query_mask) + end = _unmasked_query_end(page_size, page) + else: # branch 3 + # query_mask is a list of 0 and 1 where 0 indicates the result position + # for the given query will be an invalid link. If we accumulate a query + # mask you end up, at each index, with the number of live results you + # will get back when you query that deeply. + # We then query for the start and end index _of the results_ in ES based + # on the number of results that we think will be valid based on the query mask. + # If we're requesting `page=2 page_size=3` and the mask is [0, 1, 0, 1, 0, 1], + # then we know that we have to _start_ with at least the sixth result of the + # overall query to skip the first page of 3 valid results. The "end" of the + # query will then follow the same pattern to reach the number of valid results + # required to fill the requested page. If the mask is not deep enough to + # account for the entire range, then we follow the typical assumption when + # a mask is not available that the end should be `page * page_size / 0.5` + # (i.e., double the page size) + accu_query_mask = list(accumulate(query_mask)) + start = 0 + if page > 1: + try: # branch 3_start_A + # find the index at which we can skip N valid results where N = all + # the results that would be skipped to arrive at the start of the + # requested page + # This will effectively be the index at which we have the number of + # previous valid results + 1 because we don't want to include the + # last valid result from the previous page + start = accu_query_mask.index(page_size * (page - 1) + 1) + except ValueError: # branch 3_start_B + # Cannot fail because of the check on branch 2 which verifies that + # the query mask already includes at least enough masked valid + # results to fulfill the requested page size + start = accu_query_mask.index(page_size * (page - 1)) + 1 + # else: branch 3_start_C + # Always start page=1 queries at 0 + + if page_size * page > sum(query_mask): # branch 3_end_A + end = _unmasked_query_end(page_size, page) + else: # branch 3_end_B + end = accu_query_mask.index(page_size * page) + 1 + return start, end + + +def get_query_slice( + s: Search, page_size: int, page: int, filter_dead: bool | None = False +) -> tuple[int, int]: + """Select the start and end of the search results for this query.""" + + if filter_dead: + start_slice, end_slice = _paginate_with_dead_link_mask(s, page_size, page) + else: + # Paginate search query. + start_slice = page_size * (page - 1) + end_slice = page_size * page + if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW: + raise ValueError(DEEP_PAGINATION_ERROR) + return start_slice, end_slice diff --git a/api/api/controllers/elasticsearch/related.py b/api/api/controllers/elasticsearch/related.py new file mode 100644 index 00000000000..52f542acb9a --- /dev/null +++ b/api/api/controllers/elasticsearch/related.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from elasticsearch_dsl import Search +from elasticsearch_dsl.query import Match, Q, Term +from elasticsearch_dsl.response import Hit + +from api.controllers.elasticsearch.helpers import get_es_response, get_query_slice +from api.controllers.search_controller import ( + _post_process_results, + get_excluded_providers_query, +) + + +def related_media(uuid: str, index: str, filter_dead: bool) -> list[Hit]: + """ + Given a UUID, finds 10 related search results based on title and tags. + + Uses Match query for title or SimpleQueryString for tags. + If the item has no title and no tags, returns items by the same creator. + If the item has no title, no tags or no creator, returns empty list. + + :param uuid: The UUID of the item to find related results for. + :param index: The Elasticsearch index to search (e.g. 'image') + :param filter_dead: Whether dead links should be removed. + :return: List of related results. + """ + + # Search the default index for the item itself as it might be sensitive. + item_search = Search(index=index) + item_hit = item_search.query(Term(identifier=uuid)).execute().hits[0] + + # Match related using title. + title = getattr(item_hit, "title", None) + tags = getattr(item_hit, "tags", None) + creator = getattr(item_hit, "creator", None) + + related_query = {"must_not": [], "must": [], "should": []} + + if not title and not tags: + if not creator: + return [] + else: + # Only use `creator` query if there are no `title` and `tags` + related_query["should"].append(Term(creator=creator)) + else: + if title: + related_query["should"].append(Match(title=title)) + + # Match related using tags, if the item has any. + # Only use the first 10 tags + if tags: + tags = [tag["name"] for tag in tags[:10]] + related_query["should"].append(Q("terms", tags__name=tags)) + + # Exclude the dynamically disabled sources. + if excluded_providers_query := get_excluded_providers_query(): + related_query["must_not"].append(excluded_providers_query) + # Exclude the current item and mature content. + related_query["must_not"].extend( + [Q("term", mature=True), Q("term", identifier=uuid)] + ) + + # Search the filtered index for related items. + s = Search(index=f"{index}-filtered") + s = s.query("bool", **related_query) + + page, page_size = 1, 10 + start, end = get_query_slice(s, page_size, page, filter_dead) + s = s[start:end] + + response = get_es_response(s, es_query="related_media") + results = _post_process_results(s, start, end, page_size, response, filter_dead) + return results or [] diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 8ce00c3ab90..414b8ebcc4e 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging as log -from itertools import accumulate from math import ceil from typing import Literal @@ -10,130 +9,35 @@ from elasticsearch.exceptions import NotFoundError from elasticsearch_dsl import Q, Search -from elasticsearch_dsl.query import EMPTY_QUERY, Match, SimpleQueryString, Term +from elasticsearch_dsl.query import EMPTY_QUERY from elasticsearch_dsl.response import Hit, Response import api.models as models from api.constants.media_types import OriginIndex from api.constants.sorting import INDEXED_ON -from api.controllers.elasticsearch.helpers import get_es_response, get_raw_es_response +from api.controllers.elasticsearch.helpers import ( + ELASTICSEARCH_MAX_RESULT_WINDOW, + get_es_response, + get_query_slice, + get_raw_es_response, +) from api.serializers import media_serializers from api.utils import tallies from api.utils.check_dead_links import check_dead_links -from api.utils.dead_link_mask import get_query_hash, get_query_mask +from api.utils.dead_link_mask import get_query_hash from api.utils.search_context import SearchContext -ELASTICSEARCH_MAX_RESULT_WINDOW = 10000 SOURCE_CACHE_TIMEOUT = 60 * 60 * 4 # 4 hours FILTER_CACHE_TIMEOUT = 30 -DEAD_LINK_RATIO = 1 / 2 THUMBNAIL = "thumbnail" URL = "url" PROVIDER = "provider" -DEEP_PAGINATION_ERROR = "Deep pagination is not allowed." QUERY_SPECIAL_CHARACTER_ERROR = "Unescaped special characters are not allowed." DEFAULT_BOOST = 10000 DEFAULT_SEARCH_FIELDS = ["title", "description", "tags.name"] -def _unmasked_query_end(page_size, page): - """ - Calculate the upper index of results to retrieve from Elasticsearch. - - Used to retrieve the upper index of results to retrieve from Elasticsearch under the - following conditions: - 1. There is no query mask - 2. The lower index is beyond the scope of the existing query mask - 3. The lower index is within the scope of the existing query mask - but the upper index exceeds it - - In all these cases, the query mask is not used to calculate the upper index. - """ - return ceil(page_size * page / (1 - DEAD_LINK_RATIO)) - - -def _paginate_with_dead_link_mask( - s: Search, page_size: int, page: int -) -> tuple[int, int]: - """ - Return the start and end of the results slice, given the query, page and page size. - - In almost all cases the ``DEAD_LINK_RATIO`` will effectively double - the page size (given the current configuration of 0.5). - - The "branch X" labels are for cross-referencing with the tests. - - :param s: The elasticsearch Search object - :param page_size: How big the page should be. - :param page: The page number. - :return: Tuple of start and end. - """ - query_hash = get_query_hash(s) - query_mask = get_query_mask(query_hash) - if not query_mask: # branch 1 - start = 0 - end = _unmasked_query_end(page_size, page) - elif page_size * (page - 1) > sum(query_mask): # branch 2 - start = len(query_mask) - end = _unmasked_query_end(page_size, page) - else: # branch 3 - # query_mask is a list of 0 and 1 where 0 indicates the result position - # for the given query will be an invalid link. If we accumulate a query - # mask you end up, at each index, with the number of live results you - # will get back when you query that deeply. - # We then query for the start and end index _of the results_ in ES based - # on the number of results that we think will be valid based on the query mask. - # If we're requesting `page=2 page_size=3` and the mask is [0, 1, 0, 1, 0, 1], - # then we know that we have to _start_ with at least the sixth result of the - # overall query to skip the first page of 3 valid results. The "end" of the - # query will then follow the same pattern to reach the number of valid results - # required to fill the requested page. If the mask is not deep enough to - # account for the entire range, then we follow the typical assumption when - # a mask is not available that the end should be `page * page_size / 0.5` - # (i.e., double the page size) - accu_query_mask = list(accumulate(query_mask)) - start = 0 - if page > 1: - try: # branch 3_start_A - # find the index at which we can skip N valid results where N = all - # the results that would be skipped to arrive at the start of the - # requested page - # This will effectively be the index at which we have the number of - # previous valid results + 1 because we don't want to include the - # last valid result from the previous page - start = accu_query_mask.index(page_size * (page - 1) + 1) - except ValueError: # branch 3_start_B - # Cannot fail because of the check on branch 2 which verifies that - # the query mask already includes at least enough masked valid - # results to fulfill the requested page size - start = accu_query_mask.index(page_size * (page - 1)) + 1 - # else: branch 3_start_C - # Always start page=1 queries at 0 - - if page_size * page > sum(query_mask): # branch 3_end_A - end = _unmasked_query_end(page_size, page) - else: # branch 3_end_B - end = accu_query_mask.index(page_size * page) + 1 - return start, end - - -def _get_query_slice( - s: Search, page_size: int, page: int, filter_dead: bool | None = False -) -> tuple[int, int]: - """Select the start and end of the search results for this query.""" - - if filter_dead: - start_slice, end_slice = _paginate_with_dead_link_mask(s, page_size, page) - else: - # Paginate search query. - start_slice = page_size * (page - 1) - end_slice = page_size * page - if start_slice + end_slice > ELASTICSEARCH_MAX_RESULT_WINDOW: - raise ValueError(DEEP_PAGINATION_ERROR) - return start_slice, end_slice - - def _quote_escape(query_string): """Ignore any unmatched quotes in the query supplied by the user.""" @@ -446,7 +350,7 @@ def search( s = s.sort({"created_on": {"order": search_params.validated_data["sort_dir"]}}) # Paginate - start, end = _get_query_slice(s, page_size, page, filter_dead) + start, end = get_query_slice(s, page_size, page, filter_dead) s = s[start:end] search_response = get_es_response(s, es_query="search") @@ -495,61 +399,6 @@ def search( return results, page_count, result_count, search_context.asdict() -def related_media(uuid: str, index: str, filter_dead: bool) -> list[Hit]: - """ - Given a UUID, finds 10 related search results based on title and tags. - - Uses Match query for title or SimpleQueryString for tags. - If the item has no title and no tags, returns items by the same creator. - If the item has no title, no tags or no creator, returns empty list. - - :param uuid: The UUID of the item to find related results for. - :param index: The Elasticsearch index to search (e.g. 'image') - :param filter_dead: Whether dead links should be removed. - :return: List of related results. - """ - - # Search the default index for the item itself as it might be sensitive. - item_search = Search(index=index) - item_hit = item_search.query(Term(identifier=uuid)).execute().hits[0] - - # Match related using title. - title = getattr(item_hit, "title", None) - tags = getattr(item_hit, "tags", None) - creator = getattr(item_hit, "creator", None) - - if not title and not tags: - if not creator: - return [] - related_query = Term(creator__keyword=creator) - else: - related_query = None if not title else Match(title=title) - - # Match related using tags, if the item has any. - if tags: - # Only use the first 10 tags - tags = " | ".join([tag.name for tag in tags[:10]]) - tags_query = SimpleQueryString(fields=["tags.name"], query=tags) - related_query = related_query | tags_query if related_query else tags_query - - # Search the filtered index for related items. - s = Search(index=f"{index}-filtered") - - # Exclude the current item and mature content. - s = s.query(related_query & ~Term(identifier=uuid) & ~Term(mature=True)) - # Exclude the dynamically disabled sources. - if excluded_providers_query := get_excluded_providers_query(): - s = s.exclude(excluded_providers_query) - - page, page_size = 1, 10 - start, end = _get_query_slice(s, page_size, page, filter_dead) - s = s[start:end] - - response = get_es_response(s, es_query="related_media") - results = _post_process_results(s, start, end, page_size, response, filter_dead) - return results or [] - - def get_sources(index): """ Given an index, find all available data sources and return their counts. diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index 96278b7a3b3..20aebf8ef44 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -7,6 +7,7 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from api.controllers import search_controller +from api.controllers.elasticsearch.related import related_media from api.models import ContentProvider from api.models.media import AbstractMedia from api.serializers.provider_serializers import ProviderSerializer @@ -158,7 +159,7 @@ def stats(self, *_, **__): @action(detail=True) def related(self, request, identifier=None, *_, **__): try: - results = search_controller.related_media( + results = related_media( uuid=identifier, index=self.default_index, filter_dead=True, diff --git a/api/test/factory/es_http.py b/api/test/factory/es_http.py index f80de2e0680..d784271db9a 100644 --- a/api/test/factory/es_http.py +++ b/api/test/factory/es_http.py @@ -5,7 +5,9 @@ MOCK_DEAD_RESULT_URL_PREFIX = "https://example.com/openverse-dead-image-result-url" -def create_mock_es_http_image_hit(_id: str, index: str, live: bool = True): +def create_mock_es_http_image_hit( + _id: str, index: str, live: bool = True, identifier: str | None = None +): return { "_index": index, "_type": "_doc", @@ -20,7 +22,7 @@ def create_mock_es_http_image_hit(_id: str, index: str, live: bool = True): "max_boost": 85, "min_boost": 1, "id": _id, - "identifier": str(uuid4()), + "identifier": identifier or str(uuid4()), "title": "Bird Nature Photo", "foreign_landing_url": "https://example.com/photo/LYTN21EBYO", "creator": "Nature's Beauty", @@ -79,3 +81,26 @@ def create_mock_es_http_image_search_response( "hits": base_hits + live_hits + dead_hits, }, } + + +def create_mock_es_http_image_response_with_identifier( + index: str, + identifier: str, +): + return { + "took": 3, + "timed_out": False, + "_shards": {"total": 18, "successful": 18, "skipped": 0, "failed": 0}, + "hits": { + "total": {"value": 1, "relation": "eq"}, + "max_score": 11.0007305, + "hits": [ + create_mock_es_http_image_hit( + _id="1", + index=index, + live=True, + identifier=identifier, + ) + ], + }, + } diff --git a/api/test/media_integration.py b/api/test/media_integration.py index 952353425d2..227d3bf31a2 100644 --- a/api/test/media_integration.py +++ b/api/test/media_integration.py @@ -150,9 +150,23 @@ def uuid_validation(media_type, identifier): def related(fixture): - related_url = fixture["results"][0]["related_url"] - response = requests.get(related_url) - assert response.status_code == 200 + item = fixture["results"][0] + + def get_terms_set(item): + return set([t["name"] for t in item["tags"]] + item["title"].split(" ")) + + terms_set = get_terms_set(item) + related_url = item["related_url"] + raw_response = requests.get(related_url) + response = raw_response.json() + + assert raw_response.status_code == 200 + assert response["result_count"] == len(response["results"]) == 10 + assert response["page_count"] == 1 + + # Make sure each result has at least one word in common with the original item + for result in response["results"]: + assert len(terms_set.intersection(get_terms_set(result))) > 0 def sensitive_search_and_detail(media_type): diff --git a/api/test/test_dead_link_filter.py b/api/test/test_dead_link_filter.py index 4c682dcc687..676f804faac 100644 --- a/api/test/test_dead_link_filter.py +++ b/api/test/test_dead_link_filter.py @@ -8,7 +8,7 @@ import requests from fakeredis import FakeRedis -from api.controllers.search_controller import DEAD_LINK_RATIO +from api.controllers.elasticsearch.helpers import DEAD_LINK_RATIO @pytest.fixture(autouse=True) diff --git a/api/test/test_image_integration.py b/api/test/test_image_integration.py index e8b7d479913..36fe7a33f67 100644 --- a/api/test/test_image_integration.py +++ b/api/test/test_image_integration.py @@ -149,5 +149,5 @@ def test_image_related(image_fixture): related(image_fixture) -def test_audio_sensitive_search_and_detail(): +def test_image_sensitive_search_and_detail(): sensitive_search_and_detail("images") diff --git a/api/test/unit/controllers/elasticsearch/test_related.py b/api/test/unit/controllers/elasticsearch/test_related.py new file mode 100644 index 00000000000..31f7eeab62f --- /dev/null +++ b/api/test/unit/controllers/elasticsearch/test_related.py @@ -0,0 +1,106 @@ +from test.factory.es_http import ( + MOCK_LIVE_RESULT_URL_PREFIX, + create_mock_es_http_image_response_with_identifier, + create_mock_es_http_image_search_response, +) +from test.factory.models import ImageFactory +from unittest import mock + +from django.core.cache import cache + +import pook +import pytest + +from api.controllers.elasticsearch import related + + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def excluded_providers_cache(): + cache_key = "filtered_providers" + excluded_provider = "excluded_provider" + cache_value = [{"provider_identifier": excluded_provider}] + cache.set(cache_key, cache_value, timeout=1) + + yield excluded_provider + + cache.delete(cache_key) + + +@mock.patch( + "api.controllers.elasticsearch.related.related_media", + wraps=related.related_media, +) +@pook.on +def test_related_media( + wrapped_related_results, + image_media_type_config, + settings, + excluded_providers_cache, +): + image = ImageFactory.create() + + # Mock the ES response for the item itself + es_original_index_endpoint = ( + f"{settings.ES_ENDPOINT}/{image_media_type_config.origin_index}/_search" + ) + mock_es_hit_response = create_mock_es_http_image_response_with_identifier( + index=image_media_type_config.origin_index, + identifier=image.identifier, + ) + pook.post(es_original_index_endpoint).times(1).reply(200).header( + "x-elastic-product", "Elasticsearch" + ).json(mock_es_hit_response) + + # Mock the post process ES requests + pook.head(pook.regex(rf"{MOCK_LIVE_RESULT_URL_PREFIX}/\d")).times(20).reply(200) + + # Related only queries the filtered index, so we mock that. + es_filtered_index_endpoint = ( + f"{settings.ES_ENDPOINT}/{image_media_type_config.filtered_index}/_search" + ) + mock_es_response = create_mock_es_http_image_search_response( + index=image_media_type_config.origin_index, + total_hits=20, + live_hit_count=20, + hit_count=10, + ) + + # Testing the ES query + es_related_query = { + "from": 0, + "query": { + "bool": { + "must_not": [ + {"terms": {"provider": [excluded_providers_cache]}}, + {"term": {"mature": True}}, + {"term": {"identifier": image.identifier}}, + ], + "should": [ + {"match": {"title": "Bird Nature Photo"}}, + {"terms": {"tags.name": ["bird"]}}, + ], + } + }, + "size": 20, + } + mock_related = ( + pook.post(es_filtered_index_endpoint) + .json(es_related_query) # Testing that ES query is correct + .times(1) + .reply(200) + .header("x-elastic-product", "Elasticsearch") + .json(mock_es_response) + .mock + ) + + results = related.related_media( + uuid=image.identifier, + index=image_media_type_config.origin_index, + filter_dead=True, + ) + assert len(results) == 10 + assert wrapped_related_results.call_count == 1 + assert mock_related.total_matches == 1 diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index e802be4920e..43318fa4812 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -16,6 +16,7 @@ from elasticsearch_dsl import Search from api.controllers import search_controller +from api.controllers.elasticsearch import helpers as es_helpers from api.utils import tallies from api.utils.dead_link_mask import get_query_hash, save_query_mask from api.utils.search_context import SearchContext @@ -150,7 +151,7 @@ def test_paginate_with_dead_link_mask_new_search( """ start = 0 - assert search_controller._paginate_with_dead_link_mask( + assert es_helpers._paginate_with_dead_link_mask( s=unique_search, page_size=page_size, page=page ) == (start, expected_end) @@ -260,7 +261,7 @@ def test_paginate_with_dead_link_mask_query_mask_is_not_large_enough( """ start = mask_size create_mask(s=unique_search, mask_size=mask_size, liveness_count=liveness_count) - assert search_controller._paginate_with_dead_link_mask( + assert es_helpers._paginate_with_dead_link_mask( s=unique_search, page_size=page_size, page=page ) == (start, expected_end) @@ -383,7 +384,7 @@ def test_paginate_with_dead_link_mask_query_mask_overlaps_query_window( create_mask_kwargs.update(mask=mask_or_mask_size) create_mask(**create_mask_kwargs) - actual_range = search_controller._paginate_with_dead_link_mask( + actual_range = es_helpers._paginate_with_dead_link_mask( s=unique_search, page_size=page_size, page=page ) assert ( diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 755e4c6ac76..005c1bbbe10 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -8,6 +8,18 @@ pytestmark = pytest.mark.django_db +@pytest.fixture +def excluded_providers_cache(): + cache_key = "filtered_providers" + excluded_provider = "excluded_provider" + cache_value = [{"provider_identifier": excluded_provider}] + cache.set(cache_key, cache_value, timeout=1) + + yield excluded_provider + + cache.delete(cache_key) + + def test_create_search_query_empty(media_type_config): serializer = media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) @@ -218,10 +230,8 @@ def test_create_search_query_q_search_license_license_type_creates_2_terms_filte def test_create_search_query_empty_with_dynamically_excluded_providers( image_media_type_config, + excluded_providers_cache, ): - excluded = {"provider_identifier": "flickr"} - cache.set(key="filtered_providers", timeout=1, value=[excluded]) - serializer = image_media_type_config.search_request_serializer(data={}) serializer.is_valid(raise_exception=True) @@ -231,7 +241,7 @@ def test_create_search_query_empty_with_dynamically_excluded_providers( assert actual_query_clauses == { "must_not": [ {"term": {"mature": True}}, - {"terms": {"provider": [excluded["provider_identifier"]]}}, + {"terms": {"provider": [excluded_providers_cache]}}, ], "must": [{"match_all": {}}], "should": [