From 937f9c40021c9f477b0ba7e32fd0d1afae991afb Mon Sep 17 00:00:00 2001 From: Brendan Smith Date: Tue, 10 Dec 2024 11:22:28 +0000 Subject: [PATCH] Add exclude_flags parameter to case search view --- api/cases/managers.py | 15 ++++++-- api/cases/tests/test_case_search.py | 53 +++++++++++++++++++++++++++-- api/cases/views/search/views.py | 7 +++- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/api/cases/managers.py b/api/cases/managers.py index 406cd1b3e9..6d022ea223 100644 --- a/api/cases/managers.py +++ b/api/cases/managers.py @@ -123,7 +123,7 @@ def without_control_list_entries(self, control_list_entries): def without_regime_entries(self, regime_entries): return self.exclude(baseapplication__goods__regime_entries__id__in=regime_entries) - def with_flags(self, flags): + def _get_case_ids_with_any_flags(self, flags): case_flag_ids = self.filter(flags__id__in=flags).values_list("id", flat=True) org_flag_ids = self.filter(organisation__flags__id__in=flags).values_list("id", flat=True) goods_flag_ids = self.filter(baseapplication__goods__good__flags__id__in=flags).values_list("id", flat=True) @@ -131,9 +131,16 @@ def with_flags(self, flags): "id", flat=True ) - case_ids = set(list(case_flag_ids) + list(org_flag_ids) + list(goods_flag_ids) + list(parties_flag_ids)) + return set(list(case_flag_ids) + list(org_flag_ids) + list(goods_flag_ids) + list(parties_flag_ids)) + + def with_flags(self, flags): + case_ids = self._get_case_ids_with_any_flags(flags) return self.filter(id__in=case_ids) + def without_flags(self, flags): + case_ids = self._get_case_ids_with_any_flags(flags) + return self.exclude(id__in=case_ids) + def with_country(self, country_id): return self.filter(Q(baseapplication__parties__party__country_id=country_id)) @@ -267,6 +274,7 @@ def search( # noqa exclude_regime_entry=None, regime_entry=None, flags=None, + exclude_flags=None, country=None, countries=None, team_advice_type=None, @@ -402,6 +410,9 @@ def search( # noqa if flags: case_qs = case_qs.with_flags(flags) + if exclude_flags: + case_qs = case_qs.without_flags(exclude_flags) + if country: case_qs = case_qs.with_country(country) diff --git a/api/cases/tests/test_case_search.py b/api/cases/tests/test_case_search.py index 6b20508dc8..bf4bc5fbca 100644 --- a/api/cases/tests/test_case_search.py +++ b/api/cases/tests/test_case_search.py @@ -34,8 +34,6 @@ from api.users.enums import UserStatuses from api.users.models import GovUser from api.cases.tests import factories -from api.cases.enums import AdviceType -from api.staticdata.statuses.enums import CaseStatusEnum from api.teams.models import Team from api.cases.views.search.service import ( get_case_status_list, @@ -553,6 +551,57 @@ def test_filter_cases_by_flags(self, flag_id, flags_key): for case in response_data["cases"]: self.assertIn(flag_id, [item["id"] for item in case[flags_key]]) + def test_filter_cases_exclude_flags_case_level(self): + # set required flags + application = self.application_cases[0] + case = Case.objects.get(id=application.id) + flag_id = FlagsEnum.GOODS_NOT_LISTED + flag = Flag.objects.get(id=flag_id) + case.flags.add(flag) + + url = f"{self.url}?exclude_flags={flag_id}" + + response = self.client.get(url, **self.gov_headers) + response_data = response.json()["results"] + + self.assertEqual(response.status_code, status.HTTP_200_OK) + for search_result_case in response_data["cases"]: + self.assertNotEqual(search_result_case["id"], str(case.id)) + + def test_filter_cases_exclude_flags_good_level(self): + # set required flags + application = self.application_cases[0] + good = application.goods.all()[0].good + flag_id = FlagsEnum.WASSENAAR + flag = Flag.objects.get(id=flag_id) + good.flags.add(flag) + + url = f"{self.url}?exclude_flags={flag_id}" + + response = self.client.get(url, **self.gov_headers) + response_data = response.json()["results"] + + self.assertEqual(response.status_code, status.HTTP_200_OK) + for search_result_case in response_data["cases"]: + self.assertNotEqual(search_result_case["id"], str(application.id)) + + def test_filter_cases_exclude_flags_destination_level(self): + # set required flags + application = self.application_cases[0] + destination = application.parties.all()[0].party + flag_id = FlagsEnum.MOD_DI_COUNTRY_OF_INTEREST + flag = Flag.objects.get(id=flag_id) + destination.flags.add(flag) + + url = f"{self.url}?exclude_flags={flag_id}" + + response = self.client.get(url, **self.gov_headers) + response_data = response.json()["results"] + + self.assertEqual(response.status_code, status.HTTP_200_OK) + for search_result_case in response_data["cases"]: + self.assertNotEqual(search_result_case["id"], str(application.id)) + @parameterized.expand(["permanent", "temporary"]) def test_get_cases_filter_by_export_type(self, export_type): expected_id = str(self.application_cases[0].id) diff --git a/api/cases/views/search/views.py b/api/cases/views/search/views.py index a14a9486be..c57c984d0f 100644 --- a/api/cases/views/search/views.py +++ b/api/cases/views/search/views.py @@ -129,7 +129,11 @@ def get_case_queryset(self, user, queue_id, is_work_queue, include_hidden, filte ) def get_filters(self, request): - filters = {key: value for key, value in request.GET.items() if key not in ["hidden", "queue_id", "flags"]} + filters = { + key: value + for key, value in request.GET.items() + if key not in ["hidden", "queue_id", "flags", "exclude_flags"] + } search_tabs = ("my_cases", "open_queries") selected_tab = request.GET.get("selected_tab") @@ -143,6 +147,7 @@ def get_filters(self, request): del filters["max_total_value"] filters["flags"] = request.GET.getlist("flags", []) + filters["exclude_flags"] = request.GET.getlist("exclude_flags", []) filters["regime_entry"] = [regime for regime in request.GET.getlist("regime_entry", []) if regime] filters["exclude_regime_entry"] = request.GET.get("exclude_regime_entry", False) filters["control_list_entry"] = [cle for cle in request.GET.getlist("control_list_entry", []) if cle]