From d3afc515af50606a5b29635f7458e4e0ee952e6d Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Fri, 9 Feb 2024 12:50:17 +0000 Subject: [PATCH] Extract out common filter class for parent object filtering --- api/conf/settings_test.py | 4 ++ api/core/filters.py | 17 +++++++ api/core/tests/apps.py | 6 +++ api/core/tests/migrations/0001_initial.py | 32 ++++++++++++ api/core/tests/migrations/__init__.py | 0 api/core/tests/models.py | 10 ++++ api/core/tests/serializers.py | 12 +++++ api/core/tests/test_filters.py | 62 +++++++++++++++++++++++ api/core/tests/urls.py | 19 +++++++ api/core/tests/views.py | 17 +++++++ api/goods/filters.py | 6 --- api/goods/views.py | 5 +- api/organisations/filters.py | 6 --- api/organisations/views/documents.py | 8 +-- 14 files changed, 187 insertions(+), 17 deletions(-) create mode 100644 api/core/filters.py create mode 100644 api/core/tests/apps.py create mode 100644 api/core/tests/migrations/0001_initial.py create mode 100644 api/core/tests/migrations/__init__.py create mode 100644 api/core/tests/models.py create mode 100644 api/core/tests/serializers.py create mode 100644 api/core/tests/test_filters.py create mode 100644 api/core/tests/urls.py create mode 100644 api/core/tests/views.py delete mode 100644 api/goods/filters.py delete mode 100644 api/organisations/filters.py diff --git a/api/conf/settings_test.py b/api/conf/settings_test.py index e28fcb845a..30b35d95be 100644 --- a/api/conf/settings_test.py +++ b/api/conf/settings_test.py @@ -8,3 +8,7 @@ SUPPRESS_TEST_OUTPUT = True AWS_ENDPOINT_URL = None + +INSTALLED_APPS += [ + "api.core.tests.apps.CoreTestsConfig", +] diff --git a/api/core/filters.py b/api/core/filters.py new file mode 100644 index 0000000000..de1948fecc --- /dev/null +++ b/api/core/filters.py @@ -0,0 +1,17 @@ +from rest_framework import filters + +from django.core.exceptions import ImproperlyConfigured + + +class ParentFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + parent_id_lookup_field = getattr(view, "parent_id_lookup_field", None) + if not parent_id_lookup_field: + raise ImproperlyConfigured( + f"Cannot use {self.__class__.__name__} on a view which does not have a parent_id_lookup_field attribute" + ) + + lookup = { + parent_id_lookup_field: view.kwargs["pk"], + } + return queryset.filter(**lookup) diff --git a/api/core/tests/apps.py b/api/core/tests/apps.py new file mode 100644 index 0000000000..b293bee58c --- /dev/null +++ b/api/core/tests/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CoreTestsConfig(AppConfig): + name = "api.core.tests" + label = "api_core_tests" diff --git a/api/core/tests/migrations/0001_initial.py b/api/core/tests/migrations/0001_initial.py new file mode 100644 index 0000000000..33d79b1f55 --- /dev/null +++ b/api/core/tests/migrations/0001_initial.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.9 on 2024-02-09 14:48 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="ParentModel", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=255)), + ], + ), + migrations.CreateModel( + name="ChildModel", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=255)), + ( + "parent", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="api_core_tests.parentmodel"), + ), + ], + ), + ] diff --git a/api/core/tests/migrations/__init__.py b/api/core/tests/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tests/models.py b/api/core/tests/models.py new file mode 100644 index 0000000000..899a3b7455 --- /dev/null +++ b/api/core/tests/models.py @@ -0,0 +1,10 @@ +from django.db import models + + +class ParentModel(models.Model): + name = models.CharField(max_length=255) + + +class ChildModel(models.Model): + name = models.CharField(max_length=255) + parent = models.ForeignKey(ParentModel, on_delete=models.CASCADE) diff --git a/api/core/tests/serializers.py b/api/core/tests/serializers.py new file mode 100644 index 0000000000..c029a50307 --- /dev/null +++ b/api/core/tests/serializers.py @@ -0,0 +1,12 @@ +from rest_framework import serializers + +from api.core.tests.models import ChildModel + + +class ChildModelSerializer(serializers.ModelSerializer): + class Meta: + model = ChildModel + fields = ( + "id", + "name", + ) diff --git a/api/core/tests/test_filters.py b/api/core/tests/test_filters.py new file mode 100644 index 0000000000..2ac8dcefad --- /dev/null +++ b/api/core/tests/test_filters.py @@ -0,0 +1,62 @@ +import uuid + +from django.core.exceptions import ImproperlyConfigured +from django.test import ( + override_settings, + SimpleTestCase, + TestCase, +) +from django.urls import reverse + +from api.core.tests.models import ( + ChildModel, + ParentModel, +) + + +@override_settings( + ROOT_URLCONF="api.core.tests.urls", +) +class TestMisconfiguredParentFilter(SimpleTestCase): + def test_misconfigured_parent_filter(self): + url = reverse( + "test-misconfigured-parent-filter", + kwargs={ + "pk": str(uuid.uuid4()), + "child_pk": str(uuid.uuid4()), + }, + ) + with self.assertRaises(ImproperlyConfigured): + self.client.get(url) + + +@override_settings( + ROOT_URLCONF="api.core.tests.urls", +) +class TestParentFilter(TestCase): + def test_parent_filter(self): + parent = ParentModel.objects.create(name="parent") + child = ChildModel.objects.create(parent=parent, name="child") + url = reverse( + "test-parent-filter", + kwargs={ + "pk": str(parent.pk), + "child_pk": str(child.pk), + }, + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_parent_other_parent_filter(self): + parent = ParentModel.objects.create(name="parent") + child = ChildModel.objects.create(parent=parent, name="child") + other_parent = ParentModel.objects.create(name="other_parent") + url = reverse( + "test-parent-filter", + kwargs={ + "pk": str(other_parent.pk), + "child_pk": str(child.pk), + }, + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 404) diff --git a/api/core/tests/urls.py b/api/core/tests/urls.py new file mode 100644 index 0000000000..90accb20f8 --- /dev/null +++ b/api/core/tests/urls.py @@ -0,0 +1,19 @@ +from django.urls import path + +from .views import ( + MisconfiguredParentFilterView, + ParentFilterView, +) + +urlpatterns = [ + path( + "misconfigured-parent//child//", + MisconfiguredParentFilterView.as_view(), + name="test-misconfigured-parent-filter", + ), + path( + "parent//child//", + ParentFilterView.as_view(), + name="test-parent-filter", + ), +] diff --git a/api/core/tests/views.py b/api/core/tests/views.py new file mode 100644 index 0000000000..e0da0c00fd --- /dev/null +++ b/api/core/tests/views.py @@ -0,0 +1,17 @@ +from rest_framework.generics import RetrieveAPIView + +from api.core.filters import ParentFilter +from api.core.tests.models import ChildModel +from api.core.tests.serializers import ChildModelSerializer + + +class MisconfiguredParentFilterView(RetrieveAPIView): + filter_backends = (ParentFilter,) + queryset = ChildModel.objects.all() + + +class ParentFilterView(RetrieveAPIView): + filter_backends = (ParentFilter,) + parent_id_lookup_field = "parent_id" + queryset = ChildModel.objects.all() + serializer_class = ChildModelSerializer diff --git a/api/goods/filters.py b/api/goods/filters.py deleted file mode 100644 index 45cefe7253..0000000000 --- a/api/goods/filters.py +++ /dev/null @@ -1,6 +0,0 @@ -from rest_framework import filters - - -class GoodFilter(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(good_id=view.kwargs["pk"]) diff --git a/api/goods/views.py b/api/goods/views.py index 05546ad342..8445dcdb98 100644 --- a/api/goods/views.py +++ b/api/goods/views.py @@ -18,11 +18,11 @@ from api.core.authentication import ExporterAuthentication, SharedAuthentication, GovAuthentication from api.core.exceptions import BadRequestError from api.core.helpers import str_to_bool +from api.core.filters import ParentFilter from api.core.views import DocumentStreamAPIView from api.documents.libraries.delete_documents_on_bad_request import delete_documents_on_bad_request from api.documents.models import Document from api.goods.enums import GoodStatus, GoodPvGraded, ItemCategory -from api.goods.filters import GoodFilter from api.goods.goods_paginator import GoodListPaginator from api.goods.helpers import ( FIREARMS_CORE_TYPES, @@ -547,7 +547,8 @@ def delete(self, request, pk, doc_pk): class GoodDocumentStream(DocumentStreamAPIView): authentication_classes = (ExporterAuthentication,) - filter_backends = (GoodFilter,) + filter_backends = (ParentFilter,) + parent_id_lookup_field = "good_id" lookup_url_kwarg = "doc_pk" queryset = GoodDocument.objects.all() permission_classes = ( diff --git a/api/organisations/filters.py b/api/organisations/filters.py deleted file mode 100644 index 58a86fab04..0000000000 --- a/api/organisations/filters.py +++ /dev/null @@ -1,6 +0,0 @@ -from rest_framework import filters - - -class OrganisationFilter(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(organisation_id=view.kwargs["pk"]) diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 27ab30ba49..090bddb785 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -5,9 +5,9 @@ from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication +from api.core.filters import ParentFilter from api.core.views import DocumentStreamAPIView from api.organisations import ( - filters, models, permissions, serializers, @@ -16,8 +16,9 @@ class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) - filter_backends = (filters.OrganisationFilter,) + filter_backends = (ParentFilter,) lookup_url_kwarg = "document_on_application_pk" + parent_id_lookup_field = "organisation_id" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all() serializer_class = serializers.DocumentOnOrganisationSerializer @@ -86,8 +87,9 @@ def update(self, request, pk, document_on_application_pk): class DocumentOnOrganisationStreamView(DocumentStreamAPIView): authentication_classes = (SharedAuthentication,) - filter_backends = (filters.OrganisationFilter,) + filter_backends = (ParentFilter,) lookup_url_kwarg = "document_on_application_pk" + parent_id_lookup_field = "organisation_id" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all()