Skip to content

Commit

Permalink
Add exclude_flags parameter to case search view
Browse files Browse the repository at this point in the history
  • Loading branch information
currycoder committed Dec 12, 2024
1 parent ab83f5e commit 937f9c4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
15 changes: 13 additions & 2 deletions api/cases/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,24 @@ def without_control_list_entries(self, control_list_entries):
def without_regime_entries(self, regime_entries):
return self.exclude(baseapplication__goods__regime_entries__id__in=regime_entries)

def with_flags(self, flags):
def _get_case_ids_with_any_flags(self, flags):
case_flag_ids = self.filter(flags__id__in=flags).values_list("id", flat=True)
org_flag_ids = self.filter(organisation__flags__id__in=flags).values_list("id", flat=True)
goods_flag_ids = self.filter(baseapplication__goods__good__flags__id__in=flags).values_list("id", flat=True)
parties_flag_ids = self.filter(baseapplication__parties__party__flags__id__in=flags).values_list(
"id", flat=True
)

case_ids = set(list(case_flag_ids) + list(org_flag_ids) + list(goods_flag_ids) + list(parties_flag_ids))
return set(list(case_flag_ids) + list(org_flag_ids) + list(goods_flag_ids) + list(parties_flag_ids))

def with_flags(self, flags):
case_ids = self._get_case_ids_with_any_flags(flags)
return self.filter(id__in=case_ids)

def without_flags(self, flags):
case_ids = self._get_case_ids_with_any_flags(flags)
return self.exclude(id__in=case_ids)

def with_country(self, country_id):
return self.filter(Q(baseapplication__parties__party__country_id=country_id))

Expand Down Expand Up @@ -267,6 +274,7 @@ def search( # noqa
exclude_regime_entry=None,
regime_entry=None,
flags=None,
exclude_flags=None,
country=None,
countries=None,
team_advice_type=None,
Expand Down Expand Up @@ -402,6 +410,9 @@ def search( # noqa
if flags:
case_qs = case_qs.with_flags(flags)

if exclude_flags:
case_qs = case_qs.without_flags(exclude_flags)

if country:
case_qs = case_qs.with_country(country)

Expand Down
53 changes: 51 additions & 2 deletions api/cases/tests/test_case_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from api.users.enums import UserStatuses
from api.users.models import GovUser
from api.cases.tests import factories
from api.cases.enums import AdviceType
from api.staticdata.statuses.enums import CaseStatusEnum
from api.teams.models import Team
from api.cases.views.search.service import (
get_case_status_list,
Expand Down Expand Up @@ -553,6 +551,57 @@ def test_filter_cases_by_flags(self, flag_id, flags_key):
for case in response_data["cases"]:
self.assertIn(flag_id, [item["id"] for item in case[flags_key]])

def test_filter_cases_exclude_flags_case_level(self):
# set required flags
application = self.application_cases[0]
case = Case.objects.get(id=application.id)
flag_id = FlagsEnum.GOODS_NOT_LISTED
flag = Flag.objects.get(id=flag_id)
case.flags.add(flag)

url = f"{self.url}?exclude_flags={flag_id}"

response = self.client.get(url, **self.gov_headers)
response_data = response.json()["results"]

self.assertEqual(response.status_code, status.HTTP_200_OK)
for search_result_case in response_data["cases"]:
self.assertNotEqual(search_result_case["id"], str(case.id))

def test_filter_cases_exclude_flags_good_level(self):
# set required flags
application = self.application_cases[0]
good = application.goods.all()[0].good
flag_id = FlagsEnum.WASSENAAR
flag = Flag.objects.get(id=flag_id)
good.flags.add(flag)

url = f"{self.url}?exclude_flags={flag_id}"

response = self.client.get(url, **self.gov_headers)
response_data = response.json()["results"]

self.assertEqual(response.status_code, status.HTTP_200_OK)
for search_result_case in response_data["cases"]:
self.assertNotEqual(search_result_case["id"], str(application.id))

def test_filter_cases_exclude_flags_destination_level(self):
# set required flags
application = self.application_cases[0]
destination = application.parties.all()[0].party
flag_id = FlagsEnum.MOD_DI_COUNTRY_OF_INTEREST
flag = Flag.objects.get(id=flag_id)
destination.flags.add(flag)

url = f"{self.url}?exclude_flags={flag_id}"

response = self.client.get(url, **self.gov_headers)
response_data = response.json()["results"]

self.assertEqual(response.status_code, status.HTTP_200_OK)
for search_result_case in response_data["cases"]:
self.assertNotEqual(search_result_case["id"], str(application.id))

@parameterized.expand(["permanent", "temporary"])
def test_get_cases_filter_by_export_type(self, export_type):
expected_id = str(self.application_cases[0].id)
Expand Down
7 changes: 6 additions & 1 deletion api/cases/views/search/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ def get_case_queryset(self, user, queue_id, is_work_queue, include_hidden, filte
)

def get_filters(self, request):
filters = {key: value for key, value in request.GET.items() if key not in ["hidden", "queue_id", "flags"]}
filters = {
key: value
for key, value in request.GET.items()
if key not in ["hidden", "queue_id", "flags", "exclude_flags"]
}

search_tabs = ("my_cases", "open_queries")
selected_tab = request.GET.get("selected_tab")
Expand All @@ -143,6 +147,7 @@ def get_filters(self, request):
del filters["max_total_value"]

filters["flags"] = request.GET.getlist("flags", [])
filters["exclude_flags"] = request.GET.getlist("exclude_flags", [])
filters["regime_entry"] = [regime for regime in request.GET.getlist("regime_entry", []) if regime]
filters["exclude_regime_entry"] = request.GET.get("exclude_regime_entry", False)
filters["control_list_entry"] = [cle for cle in request.GET.getlist("control_list_entry", []) if cle]
Expand Down

0 comments on commit 937f9c4

Please sign in to comment.