From 9ddfba8c7378dbd51457cd1896e3858ebc3cf302 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 6 Feb 2024 14:51:12 +0000 Subject: [PATCH 1/5] Add endpoint to stream document on organisation --- api/documents/tests/test_document_stream.py | 30 +------ api/organisations/tests/factories.py | 1 + api/organisations/tests/test_documents.py | 89 +++++++++++++++++++++ api/organisations/urls.py | 5 ++ api/organisations/views/documents.py | 18 +++++ test_helpers/clients.py | 18 +++++ 6 files changed, 133 insertions(+), 28 deletions(-) diff --git a/api/documents/tests/test_document_stream.py b/api/documents/tests/test_document_stream.py index e613770c85..d9d408a8e8 100644 --- a/api/documents/tests/test_document_stream.py +++ b/api/documents/tests/test_document_stream.py @@ -1,5 +1,3 @@ -import boto3 - from moto import mock_aws from django.http import StreamingHttpResponse @@ -7,37 +5,13 @@ from test_helpers.clients import DataTestClient -from api.conf.settings import ( - AWS_ACCESS_KEY_ID, - AWS_SECRET_ACCESS_KEY, - AWS_REGION, - AWS_STORAGE_BUCKET_NAME, -) -from api.documents.libraries.s3_operations import init_s3_client - @mock_aws class DocumentStream(DataTestClient): def setUp(self): super().setUp() - init_s3_client() - s3 = boto3.client( - "s3", - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION, - ) - s3.create_bucket( - Bucket=AWS_STORAGE_BUCKET_NAME, - CreateBucketConfiguration={ - "LocationConstraint": AWS_REGION, - }, - ) - s3.put_object( - Bucket=AWS_STORAGE_BUCKET_NAME, - Key="thisisakey", - Body=b"test", - ) + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") def test_document_stream_as_caseworker(self): # given there is a case document diff --git a/api/organisations/tests/factories.py b/api/organisations/tests/factories.py index 5f9a6cf408..90dba01190 100644 --- a/api/organisations/tests/factories.py +++ b/api/organisations/tests/factories.py @@ -58,6 +58,7 @@ def site_records_located_at(self, create, extracted, **kwargs): class DocumentOnOrganisationFactory(factory.django.DjangoModelFactory): document = factory.SubFactory(DocumentFactory) + expiry_date = factory.Faker("future_date") class Meta: model = models.DocumentOnOrganisation diff --git a/api/organisations/tests/test_documents.py b/api/organisations/tests/test_documents.py index 0973986946..46b5765581 100644 --- a/api/organisations/tests/test_documents.py +++ b/api/organisations/tests/test_documents.py @@ -1,10 +1,15 @@ import datetime + from unittest import mock +from moto import mock_aws + +from django.http import FileResponse from django.urls import reverse from api.organisations.enums import OrganisationDocumentType from api.organisations.models import DocumentOnOrganisation +from api.organisations.tests.factories import DocumentOnOrganisationFactory from test_helpers.clients import DataTestClient @@ -162,3 +167,87 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations document_on_application.reference_code, "567", ) + + +@mock_aws +class DocumentOnOrganisationStreamViewTests(DataTestClient): + def setUp(self): + super().setUp() + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") + + def test_document_stream_as_caseworker(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.gov_headers) + + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_document_stream_as_exporter(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_unsafe_document_stream_as_caseworker(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=False, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.gov_headers) + + self.assertEqual(response.status_code, 404) + + def test_unsafe_document_stream_as_exporter(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=False, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 404) diff --git a/api/organisations/urls.py b/api/organisations/urls.py index 9ec3fd668e..e25df2112a 100644 --- a/api/organisations/urls.py +++ b/api/organisations/urls.py @@ -39,4 +39,9 @@ documents.DocumentOnOrganisationView.as_view({"get": "retrieve", "delete": "delete", "put": "update"}), name="documents", ), + path( + "/document//stream/", + documents.DocumentOnOrganisationStreamView.as_view(), + name="document_stream", + ), ] diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 46150023fa..1b30c5be61 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -1,4 +1,5 @@ from rest_framework import viewsets +from rest_framework.generics import RetrieveAPIView from django.http import JsonResponse from django.shortcuts import get_object_or_404 @@ -6,6 +7,7 @@ from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication +from api.documents.libraries.s3_operations import document_download_stream from api.organisations import models, serializers @@ -75,3 +77,19 @@ def update(self, request, pk, document_on_application_pk): }, ) return JsonResponse({"document": serializer.data}, status=200) + + +class DocumentOnOrganisationStreamView(RetrieveAPIView): + authentication_classes = (SharedAuthentication,) + lookup_url_kwarg = "document_on_application_pk" + + def get_queryset(self): + return models.DocumentOnOrganisation.objects.filter( + organisation_id=self.kwargs["pk"], + document__safe=True, + ) + + def retrieve(self, request, *args, **kwargs): + document = self.get_object() + document = document.document + return document_download_stream(document) diff --git a/test_helpers/clients.py b/test_helpers/clients.py index c684f2d6ea..73428b9d8f 100644 --- a/test_helpers/clients.py +++ b/test_helpers/clients.py @@ -55,6 +55,7 @@ from django.conf import settings from api.core.constants import Roles from api.conf.urls import urlpatterns +from api.documents.libraries.s3_operations import init_s3_client from api.flags.enums import SystemFlags, FlagStatuses, FlagLevels from api.flags.models import Flag, FlaggingRule from api.flags.tests.factories import FlagFactory @@ -1078,6 +1079,23 @@ def add_users(self, count=3): out.append(user) return out + def create_default_bucket(self): + s3 = init_s3_client() + s3.create_bucket( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + CreateBucketConfiguration={ + "LocationConstraint": settings.AWS_REGION, + }, + ) + + def put_object_in_default_bucket(self, key, body): + s3 = init_s3_client() + s3.put_object( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + Key=key, + Body=body, + ) + @pytest.mark.performance # we need to set debug to true otherwise we can't see the amount of queries From af00f5491b1ed00c3f8335074f8c4b15a86c2f9d Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 6 Feb 2024 16:07:01 +0000 Subject: [PATCH 2/5] Update permissions to check organisation against document on organisation --- api/organisations/permissions.py | 19 +++ api/organisations/tests/test_documents.py | 198 ++++++++++++++++++---- api/organisations/views/documents.py | 16 +- 3 files changed, 198 insertions(+), 35 deletions(-) create mode 100644 api/organisations/permissions.py diff --git a/api/organisations/permissions.py b/api/organisations/permissions.py new file mode 100644 index 0000000000..8a5804c6fd --- /dev/null +++ b/api/organisations/permissions.py @@ -0,0 +1,19 @@ +from rest_framework import permissions + +from api.organisations.libraries.get_organisation import get_request_user_organisation_id + + +class IsCaseworkerOrInDocumentOrganisation(permissions.BasePermission): + def has_permission(self, request, view): + if hasattr(request.user, "govuser"): + return True + elif hasattr(request.user, "exporteruser"): + return view.kwargs["pk"] == get_request_user_organisation_id(request) + raise NotImplementedError() + + def has_object_permission(self, request, view, obj): + if hasattr(request.user, "govuser"): + return True + elif hasattr(request.user, "exporteruser"): + return obj.organisation_id == get_request_user_organisation_id(request) + raise NotImplementedError() diff --git a/api/organisations/tests/test_documents.py b/api/organisations/tests/test_documents.py index 46b5765581..75ff98146b 100644 --- a/api/organisations/tests/test_documents.py +++ b/api/organisations/tests/test_documents.py @@ -22,22 +22,20 @@ def setUp(self): "size": 123456, } - def create_document_on_organisation(self, name): + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_create_organisation_document(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + url = reverse("organisations:documents", kwargs={"pk": self.organisation.pk}) data = { - "document": {"name": name, "s3_key": name, "size": 476}, + "document": {"name": "some-document", "s3_key": "some-document", "size": 476}, "expiry_date": "2026-01-01", "reference_code": "123", "document_type": OrganisationDocumentType.FIREARM_SECTION_FIVE, } - return self.client.post(url, data, **self.exporter_headers) - - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_create_organisation_document(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document") + response = self.client.post(url, data, **self.exporter_headers) self.assertEqual(response.status_code, 201, msg=response.content) self.assertEqual(self.organisation.document_on_organisations.count(), 1) @@ -52,14 +50,47 @@ def test_create_organisation_document(self, mock_virus_scan, mock_s3_operations_ self.assertEqual(instance.document_type, OrganisationDocumentType.FIREARM_SECTION_FIVE) self.assertEqual(instance.organisation, self.organisation) + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_create_organisation_document_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + other_organisation, _ = self.create_organisation_with_exporter_user() + url = reverse("organisations:documents", kwargs={"pk": other_organisation.pk}) + + data = { + "document": {"name": "some-document", "s3_key": "some-document", "size": 476}, + "expiry_date": "2026-01-01", + "reference_code": "123", + "document_type": OrganisationDocumentType.FIREARM_SECTION_FIVE, + } + response = self.client.post(url, data, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) + @mock.patch("api.documents.libraries.s3_operations.get_object") @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): mock_s3_operations_get_object.return_value = self.document_data mock_virus_scan.return_value = False - self.assertEqual(self.create_document_on_organisation("some-document-one").status_code, 201) - self.assertEqual(self.create_document_on_organisation("some-document-two").status_code, 201) - self.assertEqual(self.create_document_on_organisation("some-document-three").status_code, 201) + DocumentOnOrganisationFactory.create( + document__name="some-document-one", + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + DocumentOnOrganisationFactory.create( + document__name="some-document-two", + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + DocumentOnOrganisationFactory.create( + document__name="some-document-three", + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) url = reverse("organisations:documents", kwargs={"pk": self.organisation.pk}) @@ -70,17 +101,36 @@ def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_g @mock.patch("api.documents.libraries.s3_operations.get_object") @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): + def test_list_organisation_documents_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): mock_s3_operations_get_object.return_value = self.document_data mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document-one") - self.assertEqual(response.status_code, 201) - document_on_application_pk = response.json()["document"]["id"] + other_organisation, _ = self.create_organisation_with_exporter_user() + url = reverse("organisations:documents", kwargs={"pk": other_organisation.pk}) + + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) + + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + document_on_application = DocumentOnOrganisationFactory.create( + organisation=self.organisation, + expiry_date=datetime.date(2026, 1, 1), + document_type=OrganisationDocumentType.FIREARM_SECTION_FIVE, + reference_code="123", + document__name="some-document-one", + document__s3_key="some-document-one", + document__size=476, + document__safe=True, + ) url = reverse( "organisations:documents", - kwargs={"pk": self.organisation.pk, "document_on_application_pk": document_on_application_pk}, + kwargs={"pk": self.organisation.pk, "document_on_application_pk": document_on_application.pk}, ) response = self.client.get(url, **self.exporter_headers) @@ -90,7 +140,7 @@ def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operatio self.assertEqual( response.json(), { - "id": document_on_application_pk, + "id": str(document_on_application.pk), "expiry_date": "01 January 2026", "document_type": "section-five-certificate", "organisation": str(self.organisation.id), @@ -107,44 +157,86 @@ def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operatio }, ) + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_retrieve_organisation_documents_invalid_organisation(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create( + organisation=other_organisation, + expiry_date=datetime.date(2026, 1, 1), + document_type=OrganisationDocumentType.FIREARM_SECTION_FIVE, + reference_code="123", + document__name="some-document-one", + document__s3_key="some-document-one", + document__size=476, + document__safe=True, + ) + + url = reverse( + "organisations:documents", + kwargs={"pk": other_organisation.pk, "document_on_application_pk": document_on_application.pk}, + ) + + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) + @mock.patch("api.documents.libraries.s3_operations.get_object") @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") def test_delete_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): mock_s3_operations_get_object.return_value = self.document_data mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document-one") - self.assertEqual(response.status_code, 201) - document_on_application_pk = response.json()["document"]["id"] + document_on_application = DocumentOnOrganisationFactory.create(organisation=self.organisation) url = reverse( "organisations:documents", kwargs={ "pk": self.organisation.pk, - "document_on_application_pk": document_on_application_pk, + "document_on_application_pk": document_on_application.pk, }, ) response = self.client.delete(url, **self.exporter_headers) self.assertEqual(response.status_code, 204) with self.assertRaises(DocumentOnOrganisation.DoesNotExist): - DocumentOnOrganisation.objects.get(pk=document_on_application_pk) + DocumentOnOrganisation.objects.get(pk=document_on_application.pk) @mock.patch("api.documents.libraries.s3_operations.get_object") @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): + def test_delete_organisation_document_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): mock_s3_operations_get_object.return_value = self.document_data mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document-one") - self.assertEqual(response.status_code, 201) - document_on_application_pk = response.json()["document"]["id"] + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create(organisation=other_organisation) + + url = reverse( + "organisations:documents", + kwargs={ + "pk": other_organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + + response = self.client.delete(url, **self.exporter_headers) + self.assertEqual(response.status_code, 403) + self.assertTrue(DocumentOnOrganisation.objects.filter(pk=document_on_application.pk).exists()) + + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + document_on_application = DocumentOnOrganisationFactory.create(organisation=self.organisation) url = reverse( "organisations:documents", kwargs={ "pk": self.organisation.pk, - "document_on_application_pk": document_on_application_pk, + "document_on_application_pk": document_on_application.pk, }, ) @@ -158,7 +250,7 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations ) self.assertEqual(response.status_code, 200) - document_on_application = DocumentOnOrganisation.objects.get(pk=document_on_application_pk) + document_on_application.refresh_from_db() self.assertEqual( document_on_application.expiry_date, datetime.date(2030, 12, 12), @@ -168,6 +260,33 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations "567", ) + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_update_organisation_documents_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create(organisation=other_organisation) + + url = reverse( + "organisations:documents", + kwargs={ + "pk": other_organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + + response = self.client.put( + url, + data={ + "expiry_date": "2030-12-12", + "reference_code": "567", + }, + **self.exporter_headers, + ) + self.assertEqual(response.status_code, 403) + @mock_aws class DocumentOnOrganisationStreamViewTests(DataTestClient): @@ -251,3 +370,22 @@ def test_unsafe_document_stream_as_exporter(self): response = self.client.get(url, **self.exporter_headers) self.assertEqual(response.status_code, 404) + + def test_document_stream_as_exporter_on_other_organisation(self): + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=other_organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": other_organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 1b30c5be61..bd80e04138 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -2,17 +2,22 @@ from rest_framework.generics import RetrieveAPIView from django.http import JsonResponse -from django.shortcuts import get_object_or_404 from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication from api.documents.libraries.s3_operations import document_download_stream -from api.organisations import models, serializers +from api.organisations import ( + models, + permissions, + serializers, +) class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) + lookup_url_kwarg = "document_on_application_pk" + permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) serializer_class = serializers.DocumentOnOrganisationSerializer def get_queryset(self): @@ -23,7 +28,7 @@ def list(self, request, pk): return JsonResponse({"documents": serializer.data}) def retrieve(self, request, pk, document_on_application_pk): - instance = get_object_or_404(self.get_queryset(), pk=document_on_application_pk) + instance = self.get_object() serializer = self.serializer_class(instance) return JsonResponse(serializer.data) @@ -45,7 +50,7 @@ def create(self, request, pk): return JsonResponse({"document": serializer.data}, status=201) def delete(self, request, pk, document_on_application_pk): - instance = get_object_or_404(self.get_queryset(), pk=document_on_application_pk) + instance = self.get_object() instance.delete() organisation = models.Organisation.objects.get(pk=pk) audit_trail_service.create( @@ -60,7 +65,7 @@ def delete(self, request, pk, document_on_application_pk): return JsonResponse({}, status=204) def update(self, request, pk, document_on_application_pk): - instance = get_object_or_404(self.get_queryset(), pk=document_on_application_pk) + instance = self.get_object() organisation = models.Organisation.objects.get(pk=pk) serializer = self.serializer_class( instance=instance, data=request.data, partial=True, context={"organisation": organisation} @@ -82,6 +87,7 @@ def update(self, request, pk, document_on_application_pk): class DocumentOnOrganisationStreamView(RetrieveAPIView): authentication_classes = (SharedAuthentication,) lookup_url_kwarg = "document_on_application_pk" + permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) def get_queryset(self): return models.DocumentOnOrganisation.objects.filter( From aa7dfcb833c9386fdcf70df254462dcb61eeb9e1 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 6 Feb 2024 16:49:15 +0000 Subject: [PATCH 3/5] Exclude checking coverage for `NotImplementedError` code --- .coveragerc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index fa0434254d..fe7ee3c453 100644 --- a/.coveragerc +++ b/.coveragerc @@ -16,5 +16,9 @@ omit = ./api/staticdata/management/* ./runtime.txt ./lite_routing/management/commands/generate_rules_docs.py - branch = True + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError From 518f75505e35355e6d7c3487cd4a1ba8da933fa7 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 6 Feb 2024 18:53:27 +0000 Subject: [PATCH 4/5] Use a filter backend to filter documents by organisation --- api/organisations/filters.py | 6 ++++++ api/organisations/tests/test_documents.py | 7 +++++++ api/organisations/views/documents.py | 17 +++++++---------- 3 files changed, 20 insertions(+), 10 deletions(-) create mode 100644 api/organisations/filters.py diff --git a/api/organisations/filters.py b/api/organisations/filters.py new file mode 100644 index 0000000000..58a86fab04 --- /dev/null +++ b/api/organisations/filters.py @@ -0,0 +1,6 @@ +from rest_framework import filters + + +class OrganisationFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(organisation_id=view.kwargs["pk"]) diff --git a/api/organisations/tests/test_documents.py b/api/organisations/tests/test_documents.py index 75ff98146b..6f0bc7e475 100644 --- a/api/organisations/tests/test_documents.py +++ b/api/organisations/tests/test_documents.py @@ -91,6 +91,13 @@ def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_g document__safe=True, organisation=self.organisation, ) + other_organisation, _ = self.create_organisation_with_exporter_user() + DocumentOnOrganisationFactory.create( + document__name="other-organisation-some-document-three", + document__s3_key="thisisakey", + document__safe=True, + organisation=other_organisation, + ) url = reverse("organisations:documents", kwargs={"pk": self.organisation.pk}) diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index bd80e04138..8d1807fcc2 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -8,6 +8,7 @@ from api.core.authentication import SharedAuthentication from api.documents.libraries.s3_operations import document_download_stream from api.organisations import ( + filters, models, permissions, serializers, @@ -16,15 +17,15 @@ class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) + filter_backends = (filters.OrganisationFilter,) lookup_url_kwarg = "document_on_application_pk" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) + queryset = models.DocumentOnOrganisation.objects.all() serializer_class = serializers.DocumentOnOrganisationSerializer - def get_queryset(self): - return models.DocumentOnOrganisation.objects.filter(organisation_id=self.kwargs["pk"]) - def list(self, request, pk): - serializer = self.serializer_class(self.get_queryset(), many=True) + queryset = self.filter_queryset(self.get_queryset()) + serializer = self.serializer_class(queryset, many=True) return JsonResponse({"documents": serializer.data}) def retrieve(self, request, pk, document_on_application_pk): @@ -86,14 +87,10 @@ def update(self, request, pk, document_on_application_pk): class DocumentOnOrganisationStreamView(RetrieveAPIView): authentication_classes = (SharedAuthentication,) + filter_backends = (filters.OrganisationFilter,) lookup_url_kwarg = "document_on_application_pk" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) - - def get_queryset(self): - return models.DocumentOnOrganisation.objects.filter( - organisation_id=self.kwargs["pk"], - document__safe=True, - ) + queryset = models.DocumentOnOrganisation.objects.filter(document__safe=True) def retrieve(self, request, *args, **kwargs): document = self.get_object() From 9caf6ace77dbcc65e16f5f968c54731e50c745c1 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 6 Feb 2024 18:59:51 +0000 Subject: [PATCH 5/5] Remove unnecessary mocks in organisation document tests --- api/organisations/tests/test_documents.py | 52 ++++------------------- 1 file changed, 8 insertions(+), 44 deletions(-) diff --git a/api/organisations/tests/test_documents.py b/api/organisations/tests/test_documents.py index 6f0bc7e475..0908215396 100644 --- a/api/organisations/tests/test_documents.py +++ b/api/organisations/tests/test_documents.py @@ -68,11 +68,7 @@ def test_create_organisation_document_other_organisation(self, mock_virus_scan, self.assertEqual(response.status_code, 403) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False + def test_list_organisation_documents(self): DocumentOnOrganisationFactory.create( document__name="some-document-one", document__s3_key="thisisakey", @@ -106,12 +102,7 @@ def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_g self.assertEqual(response.status_code, 200) self.assertEqual(len(response.json()["documents"]), 3) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_list_organisation_documents_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - + def test_list_organisation_documents_other_organisation(self): other_organisation, _ = self.create_organisation_with_exporter_user() url = reverse("organisations:documents", kwargs={"pk": other_organisation.pk}) @@ -119,11 +110,7 @@ def test_list_organisation_documents_other_organisation(self, mock_virus_scan, m self.assertEqual(response.status_code, 403) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False + def test_retrieve_organisation_documents(self): document_on_application = DocumentOnOrganisationFactory.create( organisation=self.organisation, expiry_date=datetime.date(2026, 1, 1), @@ -164,11 +151,7 @@ def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operatio }, ) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_retrieve_organisation_documents_invalid_organisation(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False + def test_retrieve_organisation_documents_invalid_organisation(self): other_organisation, _ = self.create_organisation_with_exporter_user() document_on_application = DocumentOnOrganisationFactory.create( organisation=other_organisation, @@ -190,12 +173,7 @@ def test_retrieve_organisation_documents_invalid_organisation(self, mock_virus_s self.assertEqual(response.status_code, 403) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_delete_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - + def test_delete_organisation_documents(self): document_on_application = DocumentOnOrganisationFactory.create(organisation=self.organisation) url = reverse( @@ -211,12 +189,7 @@ def test_delete_organisation_documents(self, mock_virus_scan, mock_s3_operations with self.assertRaises(DocumentOnOrganisation.DoesNotExist): DocumentOnOrganisation.objects.get(pk=document_on_application.pk) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_delete_organisation_document_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - + def test_delete_organisation_document_other_organisation(self): other_organisation, _ = self.create_organisation_with_exporter_user() document_on_application = DocumentOnOrganisationFactory.create(organisation=other_organisation) @@ -232,11 +205,7 @@ def test_delete_organisation_document_other_organisation(self, mock_virus_scan, self.assertEqual(response.status_code, 403) self.assertTrue(DocumentOnOrganisation.objects.filter(pk=document_on_application.pk).exists()) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False + def test_update_organisation_documents(self): document_on_application = DocumentOnOrganisationFactory.create(organisation=self.organisation) url = reverse( @@ -267,12 +236,7 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations "567", ) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_update_organisation_documents_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - + def test_update_organisation_documents_other_organisation(self): other_organisation, _ = self.create_organisation_with_exporter_user() document_on_application = DocumentOnOrganisationFactory.create(organisation=other_organisation)