Skip to content

Commit

Permalink
Extract out common filter class for parent object filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
kevincarrogan committed Feb 9, 2024
1 parent 6ad06ab commit d3afc51
Show file tree
Hide file tree
Showing 14 changed files with 187 additions and 17 deletions.
4 changes: 4 additions & 0 deletions api/conf/settings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@
SUPPRESS_TEST_OUTPUT = True

AWS_ENDPOINT_URL = None

INSTALLED_APPS += [
"api.core.tests.apps.CoreTestsConfig",
]
17 changes: 17 additions & 0 deletions api/core/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rest_framework import filters

from django.core.exceptions import ImproperlyConfigured


class ParentFilter(filters.BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
parent_id_lookup_field = getattr(view, "parent_id_lookup_field", None)
if not parent_id_lookup_field:
raise ImproperlyConfigured(
f"Cannot use {self.__class__.__name__} on a view which does not have a parent_id_lookup_field attribute"
)

lookup = {
parent_id_lookup_field: view.kwargs["pk"],
}
return queryset.filter(**lookup)
6 changes: 6 additions & 0 deletions api/core/tests/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class CoreTestsConfig(AppConfig):
name = "api.core.tests"
label = "api_core_tests"
32 changes: 32 additions & 0 deletions api/core/tests/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Generated by Django 4.2.9 on 2024-02-09 14:48

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

initial = True

dependencies = []

operations = [
migrations.CreateModel(
name="ParentModel",
fields=[
("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(max_length=255)),
],
),
migrations.CreateModel(
name="ChildModel",
fields=[
("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(max_length=255)),
(
"parent",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="api_core_tests.parentmodel"),
),
],
),
]
Empty file.
10 changes: 10 additions & 0 deletions api/core/tests/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.db import models


class ParentModel(models.Model):
name = models.CharField(max_length=255)


class ChildModel(models.Model):
name = models.CharField(max_length=255)
parent = models.ForeignKey(ParentModel, on_delete=models.CASCADE)
12 changes: 12 additions & 0 deletions api/core/tests/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from rest_framework import serializers

from api.core.tests.models import ChildModel


class ChildModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildModel
fields = (
"id",
"name",
)
62 changes: 62 additions & 0 deletions api/core/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import uuid

from django.core.exceptions import ImproperlyConfigured
from django.test import (
override_settings,
SimpleTestCase,
TestCase,
)
from django.urls import reverse

from api.core.tests.models import (
ChildModel,
ParentModel,
)


@override_settings(
ROOT_URLCONF="api.core.tests.urls",
)
class TestMisconfiguredParentFilter(SimpleTestCase):
def test_misconfigured_parent_filter(self):
url = reverse(
"test-misconfigured-parent-filter",
kwargs={
"pk": str(uuid.uuid4()),
"child_pk": str(uuid.uuid4()),
},
)
with self.assertRaises(ImproperlyConfigured):
self.client.get(url)


@override_settings(
ROOT_URLCONF="api.core.tests.urls",
)
class TestParentFilter(TestCase):
def test_parent_filter(self):
parent = ParentModel.objects.create(name="parent")
child = ChildModel.objects.create(parent=parent, name="child")
url = reverse(
"test-parent-filter",
kwargs={
"pk": str(parent.pk),
"child_pk": str(child.pk),
},
)
response = self.client.get(url)
self.assertEqual(response.status_code, 200)

def test_parent_other_parent_filter(self):
parent = ParentModel.objects.create(name="parent")
child = ChildModel.objects.create(parent=parent, name="child")
other_parent = ParentModel.objects.create(name="other_parent")
url = reverse(
"test-parent-filter",
kwargs={
"pk": str(other_parent.pk),
"child_pk": str(child.pk),
},
)
response = self.client.get(url)
self.assertEqual(response.status_code, 404)
19 changes: 19 additions & 0 deletions api/core/tests/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from django.urls import path

from .views import (
MisconfiguredParentFilterView,
ParentFilterView,
)

urlpatterns = [
path(
"misconfigured-parent/<str:pk>/child/<str:child_pk>/",
MisconfiguredParentFilterView.as_view(),
name="test-misconfigured-parent-filter",
),
path(
"parent/<str:pk>/child/<str:child_pk>/",
ParentFilterView.as_view(),
name="test-parent-filter",
),
]
17 changes: 17 additions & 0 deletions api/core/tests/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rest_framework.generics import RetrieveAPIView

from api.core.filters import ParentFilter
from api.core.tests.models import ChildModel
from api.core.tests.serializers import ChildModelSerializer


class MisconfiguredParentFilterView(RetrieveAPIView):
filter_backends = (ParentFilter,)
queryset = ChildModel.objects.all()


class ParentFilterView(RetrieveAPIView):
filter_backends = (ParentFilter,)
parent_id_lookup_field = "parent_id"
queryset = ChildModel.objects.all()
serializer_class = ChildModelSerializer
6 changes: 0 additions & 6 deletions api/goods/filters.py

This file was deleted.

5 changes: 3 additions & 2 deletions api/goods/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from api.core.authentication import ExporterAuthentication, SharedAuthentication, GovAuthentication
from api.core.exceptions import BadRequestError
from api.core.helpers import str_to_bool
from api.core.filters import ParentFilter
from api.core.views import DocumentStreamAPIView
from api.documents.libraries.delete_documents_on_bad_request import delete_documents_on_bad_request
from api.documents.models import Document
from api.goods.enums import GoodStatus, GoodPvGraded, ItemCategory
from api.goods.filters import GoodFilter
from api.goods.goods_paginator import GoodListPaginator
from api.goods.helpers import (
FIREARMS_CORE_TYPES,
Expand Down Expand Up @@ -547,7 +547,8 @@ def delete(self, request, pk, doc_pk):

class GoodDocumentStream(DocumentStreamAPIView):
authentication_classes = (ExporterAuthentication,)
filter_backends = (GoodFilter,)
filter_backends = (ParentFilter,)
parent_id_lookup_field = "good_id"
lookup_url_kwarg = "doc_pk"
queryset = GoodDocument.objects.all()
permission_classes = (
Expand Down
6 changes: 0 additions & 6 deletions api/organisations/filters.py

This file was deleted.

8 changes: 5 additions & 3 deletions api/organisations/views/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from api.audit_trail import service as audit_trail_service
from api.audit_trail.enums import AuditType
from api.core.authentication import SharedAuthentication
from api.core.filters import ParentFilter
from api.core.views import DocumentStreamAPIView
from api.organisations import (
filters,
models,
permissions,
serializers,
Expand All @@ -16,8 +16,9 @@

class DocumentOnOrganisationView(viewsets.ModelViewSet):
authentication_classes = (SharedAuthentication,)
filter_backends = (filters.OrganisationFilter,)
filter_backends = (ParentFilter,)
lookup_url_kwarg = "document_on_application_pk"
parent_id_lookup_field = "organisation_id"
permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,)
queryset = models.DocumentOnOrganisation.objects.all()
serializer_class = serializers.DocumentOnOrganisationSerializer
Expand Down Expand Up @@ -86,8 +87,9 @@ def update(self, request, pk, document_on_application_pk):

class DocumentOnOrganisationStreamView(DocumentStreamAPIView):
authentication_classes = (SharedAuthentication,)
filter_backends = (filters.OrganisationFilter,)
filter_backends = (ParentFilter,)
lookup_url_kwarg = "document_on_application_pk"
parent_id_lookup_field = "organisation_id"
permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,)
queryset = models.DocumentOnOrganisation.objects.all()

Expand Down

0 comments on commit d3afc51

Please sign in to comment.