diff --git a/api/conf/settings_test.py b/api/conf/settings_test.py index e28fcb845a..30b35d95be 100644 --- a/api/conf/settings_test.py +++ b/api/conf/settings_test.py @@ -8,3 +8,7 @@ SUPPRESS_TEST_OUTPUT = True AWS_ENDPOINT_URL = None + +INSTALLED_APPS += [ + "api.core.tests.apps.CoreTestsConfig", +] diff --git a/api/core/filters.py b/api/core/filters.py new file mode 100644 index 0000000000..16eaf164cb --- /dev/null +++ b/api/core/filters.py @@ -0,0 +1,17 @@ +from rest_framework import filters + +from django.core.exceptions import ImproperlyConfigured + + +class ParentFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + parent_filter_id_lookup_field = getattr(view, "parent_filter_id_lookup_field", None) + if not parent_filter_id_lookup_field: + raise ImproperlyConfigured( + f"Cannot use {self.__class__.__name__} on a view which does not have a parent_filter_id_lookup_field attribute" + ) + + lookup = { + parent_filter_id_lookup_field: view.kwargs["pk"], + } + return queryset.filter(**lookup) diff --git a/api/core/tests/apps.py b/api/core/tests/apps.py new file mode 100644 index 0000000000..b293bee58c --- /dev/null +++ b/api/core/tests/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CoreTestsConfig(AppConfig): + name = "api.core.tests" + label = "api_core_tests" diff --git a/api/core/tests/migrations/0001_initial.py b/api/core/tests/migrations/0001_initial.py new file mode 100644 index 0000000000..33d79b1f55 --- /dev/null +++ b/api/core/tests/migrations/0001_initial.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.9 on 2024-02-09 14:48 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="ParentModel", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=255)), + ], + ), + migrations.CreateModel( + name="ChildModel", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=255)), + ( + "parent", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="api_core_tests.parentmodel"), + ), + ], + ), + ] diff --git a/api/core/tests/migrations/__init__.py b/api/core/tests/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tests/models.py b/api/core/tests/models.py new file mode 100644 index 0000000000..899a3b7455 --- /dev/null +++ b/api/core/tests/models.py @@ -0,0 +1,10 @@ +from django.db import models + + +class ParentModel(models.Model): + name = models.CharField(max_length=255) + + +class ChildModel(models.Model): + name = models.CharField(max_length=255) + parent = models.ForeignKey(ParentModel, on_delete=models.CASCADE) diff --git a/api/core/tests/serializers.py b/api/core/tests/serializers.py new file mode 100644 index 0000000000..c029a50307 --- /dev/null +++ b/api/core/tests/serializers.py @@ -0,0 +1,12 @@ +from rest_framework import serializers + +from api.core.tests.models import ChildModel + + +class ChildModelSerializer(serializers.ModelSerializer): + class Meta: + model = ChildModel + fields = ( + "id", + "name", + ) diff --git a/api/core/tests/test_filters.py b/api/core/tests/test_filters.py new file mode 100644 index 0000000000..2ac8dcefad --- /dev/null +++ b/api/core/tests/test_filters.py @@ -0,0 +1,62 @@ +import uuid + +from django.core.exceptions import ImproperlyConfigured +from django.test import ( + override_settings, + SimpleTestCase, + TestCase, +) +from django.urls import reverse + +from api.core.tests.models import ( + ChildModel, + ParentModel, +) + + +@override_settings( + ROOT_URLCONF="api.core.tests.urls", +) +class TestMisconfiguredParentFilter(SimpleTestCase): + def test_misconfigured_parent_filter(self): + url = reverse( + "test-misconfigured-parent-filter", + kwargs={ + "pk": str(uuid.uuid4()), + "child_pk": str(uuid.uuid4()), + }, + ) + with self.assertRaises(ImproperlyConfigured): + self.client.get(url) + + +@override_settings( + ROOT_URLCONF="api.core.tests.urls", +) +class TestParentFilter(TestCase): + def test_parent_filter(self): + parent = ParentModel.objects.create(name="parent") + child = ChildModel.objects.create(parent=parent, name="child") + url = reverse( + "test-parent-filter", + kwargs={ + "pk": str(parent.pk), + "child_pk": str(child.pk), + }, + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_parent_other_parent_filter(self): + parent = ParentModel.objects.create(name="parent") + child = ChildModel.objects.create(parent=parent, name="child") + other_parent = ParentModel.objects.create(name="other_parent") + url = reverse( + "test-parent-filter", + kwargs={ + "pk": str(other_parent.pk), + "child_pk": str(child.pk), + }, + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 404) diff --git a/api/core/tests/urls.py b/api/core/tests/urls.py new file mode 100644 index 0000000000..90accb20f8 --- /dev/null +++ b/api/core/tests/urls.py @@ -0,0 +1,19 @@ +from django.urls import path + +from .views import ( + MisconfiguredParentFilterView, + ParentFilterView, +) + +urlpatterns = [ + path( + "misconfigured-parent//child//", + MisconfiguredParentFilterView.as_view(), + name="test-misconfigured-parent-filter", + ), + path( + "parent//child//", + ParentFilterView.as_view(), + name="test-parent-filter", + ), +] diff --git a/api/core/tests/views.py b/api/core/tests/views.py new file mode 100644 index 0000000000..75964d3f84 --- /dev/null +++ b/api/core/tests/views.py @@ -0,0 +1,17 @@ +from rest_framework.generics import RetrieveAPIView + +from api.core.filters import ParentFilter +from api.core.tests.models import ChildModel +from api.core.tests.serializers import ChildModelSerializer + + +class MisconfiguredParentFilterView(RetrieveAPIView): + filter_backends = (ParentFilter,) + queryset = ChildModel.objects.all() + + +class ParentFilterView(RetrieveAPIView): + filter_backends = (ParentFilter,) + parent_filter_id_lookup_field = "parent_id" + queryset = ChildModel.objects.all() + serializer_class = ChildModelSerializer diff --git a/api/goods/permissions.py b/api/goods/permissions.py new file mode 100644 index 0000000000..576a62ceab --- /dev/null +++ b/api/goods/permissions.py @@ -0,0 +1,14 @@ +from rest_framework import permissions + +from api.goods.enums import GoodStatus +from api.organisations.libraries.get_organisation import get_request_user_organisation_id + + +class IsDocumentInOrganisation(permissions.BasePermission): + def has_object_permission(self, request, view, obj): + return obj.good.organisation_id == get_request_user_organisation_id(request) + + +class IsGoodDraft(permissions.BasePermission): + def has_object_permission(self, request, view, obj): + return obj.good.status == GoodStatus.DRAFT diff --git a/api/goods/tests/test_goods_documents.py b/api/goods/tests/test_goods_documents.py index b9e70a7550..0f60582469 100644 --- a/api/goods/tests/test_goods_documents.py +++ b/api/goods/tests/test_goods_documents.py @@ -1,10 +1,16 @@ from django.urls import reverse from rest_framework import status +from moto import mock_aws +from parameterized import parameterized + +from django.http import FileResponse + from api.applications.models import GoodOnApplication from test_helpers.clients import DataTestClient from api.applications.tests.factories import StandardApplicationFactory +from api.goods.enums import GoodStatus from api.goods.tests.factories import GoodFactory from api.organisations.tests.factories import OrganisationFactory @@ -121,3 +127,110 @@ def test_edit_product_document_description(self): response = self.client.get(url, **self.exporter_headers) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.json()["document"]["description"], "Updated document description") + + +@mock_aws +class GoodDocumentStreamTests(DataTestClient): + def setUp(self): + super().setUp() + self.good = GoodFactory( + organisation=self.organisation, + status=GoodStatus.DRAFT, + ) + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") + + def test_get_good_document_stream(self): + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=self.organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(self.good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_get_good_document_stream_invalid_good_pk(self): + another_good = GoodFactory(organisation=self.organisation) + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=self.organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(another_good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_get_good_document_stream_other_organisation(self): + other_organisation = self.create_organisation_with_exporter_user()[0] + self.good.organisation = other_organisation + self.good.save() + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=other_organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(self.good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @parameterized.expand( + [ + GoodStatus.SUBMITTED, + GoodStatus.QUERY, + GoodStatus.VERIFIED, + ], + ) + def test_get_good_document_stream_good_not_draft(self, good_status): + self.good.status = good_status + self.good.save() + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=self.organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(self.good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/api/goods/urls.py b/api/goods/urls.py index 1cda5dc63f..d51dffa0c7 100644 --- a/api/goods/urls.py +++ b/api/goods/urls.py @@ -23,6 +23,11 @@ views.GoodDocumentDetail.as_view(), name="document", ), + path( + "/documents//stream/", + views.GoodDocumentStream.as_view(), + name="document_stream", + ), path( "document_internal_good_on_application//", views.DocumentGoodOnApplicationInternalView.as_view(), diff --git a/api/goods/views.py b/api/goods/views.py index 9bd33c6d69..7c55d26d7f 100644 --- a/api/goods/views.py +++ b/api/goods/views.py @@ -18,6 +18,8 @@ from api.core.authentication import ExporterAuthentication, SharedAuthentication, GovAuthentication from api.core.exceptions import BadRequestError from api.core.helpers import str_to_bool +from api.core.filters import ParentFilter +from api.core.views import DocumentStreamAPIView from api.documents.libraries.delete_documents_on_bad_request import delete_documents_on_bad_request from api.documents.models import Document from api.goods.enums import GoodStatus, GoodPvGraded, ItemCategory @@ -31,6 +33,10 @@ from api.goods.libraries.get_goods import get_good, get_good_document from api.goods.libraries.save_good import create_or_update_good from api.goods.models import Good, GoodDocument +from api.goods.permissions import ( + IsDocumentInOrganisation, + IsGoodDraft, +) from api.goods.serializers import ( GoodAttachingSerializer, GoodCreateSerializer, @@ -539,6 +545,21 @@ def delete(self, request, pk, doc_pk): return JsonResponse({"document": "deleted success"}) +class GoodDocumentStream(DocumentStreamAPIView): + authentication_classes = (ExporterAuthentication,) + filter_backends = (ParentFilter,) + parent_filter_id_lookup_field = "good_id" + lookup_url_kwarg = "doc_pk" + queryset = GoodDocument.objects.all() + permission_classes = ( + IsDocumentInOrganisation, + IsGoodDraft, + ) + + def get_document(self, instance): + return instance + + class DocumentGoodOnApplicationInternalView(APIView): authentication_classes = (GovAuthentication,) serializer_class = GoodOnApplicationInternalDocumentCreateSerializer diff --git a/api/organisations/filters.py b/api/organisations/filters.py deleted file mode 100644 index 58a86fab04..0000000000 --- a/api/organisations/filters.py +++ /dev/null @@ -1,6 +0,0 @@ -from rest_framework import filters - - -class OrganisationFilter(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(organisation_id=view.kwargs["pk"]) diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 27ab30ba49..c0ff504aa8 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -5,9 +5,9 @@ from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication +from api.core.filters import ParentFilter from api.core.views import DocumentStreamAPIView from api.organisations import ( - filters, models, permissions, serializers, @@ -16,7 +16,8 @@ class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) - filter_backends = (filters.OrganisationFilter,) + filter_backends = (ParentFilter,) + parent_filter_id_lookup_field = "organisation_id" lookup_url_kwarg = "document_on_application_pk" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all() @@ -86,7 +87,8 @@ def update(self, request, pk, document_on_application_pk): class DocumentOnOrganisationStreamView(DocumentStreamAPIView): authentication_classes = (SharedAuthentication,) - filter_backends = (filters.OrganisationFilter,) + filter_backends = (ParentFilter,) + parent_filter_id_lookup_field = "organisation_id" lookup_url_kwarg = "document_on_application_pk" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all()