diff --git a/.coveragerc b/.coveragerc index fa0434254d..fe7ee3c453 100644 --- a/.coveragerc +++ b/.coveragerc @@ -16,5 +16,9 @@ omit = ./api/staticdata/management/* ./runtime.txt ./lite_routing/management/commands/generate_rules_docs.py - branch = True + +[report] +exclude_lines = + pragma: no cover + raise NotImplementedError diff --git a/api/documents/tests/test_document_stream.py b/api/documents/tests/test_document_stream.py index e613770c85..d9d408a8e8 100644 --- a/api/documents/tests/test_document_stream.py +++ b/api/documents/tests/test_document_stream.py @@ -1,5 +1,3 @@ -import boto3 - from moto import mock_aws from django.http import StreamingHttpResponse @@ -7,37 +5,13 @@ from test_helpers.clients import DataTestClient -from api.conf.settings import ( - AWS_ACCESS_KEY_ID, - AWS_SECRET_ACCESS_KEY, - AWS_REGION, - AWS_STORAGE_BUCKET_NAME, -) -from api.documents.libraries.s3_operations import init_s3_client - @mock_aws class DocumentStream(DataTestClient): def setUp(self): super().setUp() - init_s3_client() - s3 = boto3.client( - "s3", - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION, - ) - s3.create_bucket( - Bucket=AWS_STORAGE_BUCKET_NAME, - CreateBucketConfiguration={ - "LocationConstraint": AWS_REGION, - }, - ) - s3.put_object( - Bucket=AWS_STORAGE_BUCKET_NAME, - Key="thisisakey", - Body=b"test", - ) + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") def test_document_stream_as_caseworker(self): # given there is a case document diff --git a/api/organisations/filters.py b/api/organisations/filters.py new file mode 100644 index 0000000000..58a86fab04 --- /dev/null +++ b/api/organisations/filters.py @@ -0,0 +1,6 @@ +from rest_framework import filters + + +class OrganisationFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(organisation_id=view.kwargs["pk"]) diff --git a/api/organisations/permissions.py b/api/organisations/permissions.py new file mode 100644 index 0000000000..8a5804c6fd --- /dev/null +++ b/api/organisations/permissions.py @@ -0,0 +1,19 @@ +from rest_framework import permissions + +from api.organisations.libraries.get_organisation import get_request_user_organisation_id + + +class IsCaseworkerOrInDocumentOrganisation(permissions.BasePermission): + def has_permission(self, request, view): + if hasattr(request.user, "govuser"): + return True + elif hasattr(request.user, "exporteruser"): + return view.kwargs["pk"] == get_request_user_organisation_id(request) + raise NotImplementedError() + + def has_object_permission(self, request, view, obj): + if hasattr(request.user, "govuser"): + return True + elif hasattr(request.user, "exporteruser"): + return obj.organisation_id == get_request_user_organisation_id(request) + raise NotImplementedError() diff --git a/api/organisations/tests/factories.py b/api/organisations/tests/factories.py index 5f9a6cf408..90dba01190 100644 --- a/api/organisations/tests/factories.py +++ b/api/organisations/tests/factories.py @@ -58,6 +58,7 @@ def site_records_located_at(self, create, extracted, **kwargs): class DocumentOnOrganisationFactory(factory.django.DjangoModelFactory): document = factory.SubFactory(DocumentFactory) + expiry_date = factory.Faker("future_date") class Meta: model = models.DocumentOnOrganisation diff --git a/api/organisations/tests/test_documents.py b/api/organisations/tests/test_documents.py index 0973986946..0908215396 100644 --- a/api/organisations/tests/test_documents.py +++ b/api/organisations/tests/test_documents.py @@ -1,10 +1,15 @@ import datetime + from unittest import mock +from moto import mock_aws + +from django.http import FileResponse from django.urls import reverse from api.organisations.enums import OrganisationDocumentType from api.organisations.models import DocumentOnOrganisation +from api.organisations.tests.factories import DocumentOnOrganisationFactory from test_helpers.clients import DataTestClient @@ -17,22 +22,20 @@ def setUp(self): "size": 123456, } - def create_document_on_organisation(self, name): + @mock.patch("api.documents.libraries.s3_operations.get_object") + @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") + def test_create_organisation_document(self, mock_virus_scan, mock_s3_operations_get_object): + mock_s3_operations_get_object.return_value = self.document_data + mock_virus_scan.return_value = False + url = reverse("organisations:documents", kwargs={"pk": self.organisation.pk}) data = { - "document": {"name": name, "s3_key": name, "size": 476}, + "document": {"name": "some-document", "s3_key": "some-document", "size": 476}, "expiry_date": "2026-01-01", "reference_code": "123", "document_type": OrganisationDocumentType.FIREARM_SECTION_FIVE, } - return self.client.post(url, data, **self.exporter_headers) - - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_create_organisation_document(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document") + response = self.client.post(url, data, **self.exporter_headers) self.assertEqual(response.status_code, 201, msg=response.content) self.assertEqual(self.organisation.document_on_organisations.count(), 1) @@ -49,12 +52,48 @@ def test_create_organisation_document(self, mock_virus_scan, mock_s3_operations_ @mock.patch("api.documents.libraries.s3_operations.get_object") @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): + def test_create_organisation_document_other_organisation(self, mock_virus_scan, mock_s3_operations_get_object): mock_s3_operations_get_object.return_value = self.document_data mock_virus_scan.return_value = False - self.assertEqual(self.create_document_on_organisation("some-document-one").status_code, 201) - self.assertEqual(self.create_document_on_organisation("some-document-two").status_code, 201) - self.assertEqual(self.create_document_on_organisation("some-document-three").status_code, 201) + other_organisation, _ = self.create_organisation_with_exporter_user() + url = reverse("organisations:documents", kwargs={"pk": other_organisation.pk}) + + data = { + "document": {"name": "some-document", "s3_key": "some-document", "size": 476}, + "expiry_date": "2026-01-01", + "reference_code": "123", + "document_type": OrganisationDocumentType.FIREARM_SECTION_FIVE, + } + response = self.client.post(url, data, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) + + def test_list_organisation_documents(self): + DocumentOnOrganisationFactory.create( + document__name="some-document-one", + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + DocumentOnOrganisationFactory.create( + document__name="some-document-two", + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + DocumentOnOrganisationFactory.create( + document__name="some-document-three", + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + other_organisation, _ = self.create_organisation_with_exporter_user() + DocumentOnOrganisationFactory.create( + document__name="other-organisation-some-document-three", + document__s3_key="thisisakey", + document__safe=True, + organisation=other_organisation, + ) url = reverse("organisations:documents", kwargs={"pk": self.organisation.pk}) @@ -63,19 +102,29 @@ def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_g self.assertEqual(response.status_code, 200) self.assertEqual(len(response.json()["documents"]), 3) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document-one") - self.assertEqual(response.status_code, 201) + def test_list_organisation_documents_other_organisation(self): + other_organisation, _ = self.create_organisation_with_exporter_user() + url = reverse("organisations:documents", kwargs={"pk": other_organisation.pk}) + + response = self.client.get(url, **self.exporter_headers) - document_on_application_pk = response.json()["document"]["id"] + self.assertEqual(response.status_code, 403) + + def test_retrieve_organisation_documents(self): + document_on_application = DocumentOnOrganisationFactory.create( + organisation=self.organisation, + expiry_date=datetime.date(2026, 1, 1), + document_type=OrganisationDocumentType.FIREARM_SECTION_FIVE, + reference_code="123", + document__name="some-document-one", + document__s3_key="some-document-one", + document__size=476, + document__safe=True, + ) url = reverse( "organisations:documents", - kwargs={"pk": self.organisation.pk, "document_on_application_pk": document_on_application_pk}, + kwargs={"pk": self.organisation.pk, "document_on_application_pk": document_on_application.pk}, ) response = self.client.get(url, **self.exporter_headers) @@ -85,7 +134,7 @@ def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operatio self.assertEqual( response.json(), { - "id": document_on_application_pk, + "id": str(document_on_application.pk), "expiry_date": "01 January 2026", "document_type": "section-five-certificate", "organisation": str(self.organisation.id), @@ -102,44 +151,68 @@ def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operatio }, ) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_delete_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document-one") - self.assertEqual(response.status_code, 201) + def test_retrieve_organisation_documents_invalid_organisation(self): + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create( + organisation=other_organisation, + expiry_date=datetime.date(2026, 1, 1), + document_type=OrganisationDocumentType.FIREARM_SECTION_FIVE, + reference_code="123", + document__name="some-document-one", + document__s3_key="some-document-one", + document__size=476, + document__safe=True, + ) + + url = reverse( + "organisations:documents", + kwargs={"pk": other_organisation.pk, "document_on_application_pk": document_on_application.pk}, + ) + + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) - document_on_application_pk = response.json()["document"]["id"] + def test_delete_organisation_documents(self): + document_on_application = DocumentOnOrganisationFactory.create(organisation=self.organisation) url = reverse( "organisations:documents", kwargs={ "pk": self.organisation.pk, - "document_on_application_pk": document_on_application_pk, + "document_on_application_pk": document_on_application.pk, }, ) response = self.client.delete(url, **self.exporter_headers) self.assertEqual(response.status_code, 204) with self.assertRaises(DocumentOnOrganisation.DoesNotExist): - DocumentOnOrganisation.objects.get(pk=document_on_application_pk) + DocumentOnOrganisation.objects.get(pk=document_on_application.pk) - @mock.patch("api.documents.libraries.s3_operations.get_object") - @mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses") - def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object): - mock_s3_operations_get_object.return_value = self.document_data - mock_virus_scan.return_value = False - response = self.create_document_on_organisation("some-document-one") - self.assertEqual(response.status_code, 201) + def test_delete_organisation_document_other_organisation(self): + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create(organisation=other_organisation) + + url = reverse( + "organisations:documents", + kwargs={ + "pk": other_organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) - document_on_application_pk = response.json()["document"]["id"] + response = self.client.delete(url, **self.exporter_headers) + self.assertEqual(response.status_code, 403) + self.assertTrue(DocumentOnOrganisation.objects.filter(pk=document_on_application.pk).exists()) + + def test_update_organisation_documents(self): + document_on_application = DocumentOnOrganisationFactory.create(organisation=self.organisation) url = reverse( "organisations:documents", kwargs={ "pk": self.organisation.pk, - "document_on_application_pk": document_on_application_pk, + "document_on_application_pk": document_on_application.pk, }, ) @@ -153,7 +226,7 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations ) self.assertEqual(response.status_code, 200) - document_on_application = DocumentOnOrganisation.objects.get(pk=document_on_application_pk) + document_on_application.refresh_from_db() self.assertEqual( document_on_application.expiry_date, datetime.date(2030, 12, 12), @@ -162,3 +235,128 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations document_on_application.reference_code, "567", ) + + def test_update_organisation_documents_other_organisation(self): + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create(organisation=other_organisation) + + url = reverse( + "organisations:documents", + kwargs={ + "pk": other_organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + + response = self.client.put( + url, + data={ + "expiry_date": "2030-12-12", + "reference_code": "567", + }, + **self.exporter_headers, + ) + self.assertEqual(response.status_code, 403) + + +@mock_aws +class DocumentOnOrganisationStreamViewTests(DataTestClient): + def setUp(self): + super().setUp() + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") + + def test_document_stream_as_caseworker(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.gov_headers) + + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_document_stream_as_exporter(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_unsafe_document_stream_as_caseworker(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=False, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.gov_headers) + + self.assertEqual(response.status_code, 404) + + def test_unsafe_document_stream_as_exporter(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=False, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 404) + + def test_document_stream_as_exporter_on_other_organisation(self): + other_organisation, _ = self.create_organisation_with_exporter_user() + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=other_organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": other_organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 403) diff --git a/api/organisations/urls.py b/api/organisations/urls.py index 9ec3fd668e..e25df2112a 100644 --- a/api/organisations/urls.py +++ b/api/organisations/urls.py @@ -39,4 +39,9 @@ documents.DocumentOnOrganisationView.as_view({"get": "retrieve", "delete": "delete", "put": "update"}), name="documents", ), + path( + "/document//stream/", + documents.DocumentOnOrganisationStreamView.as_view(), + name="document_stream", + ), ] diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 46150023fa..8d1807fcc2 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -1,27 +1,35 @@ from rest_framework import viewsets +from rest_framework.generics import RetrieveAPIView from django.http import JsonResponse -from django.shortcuts import get_object_or_404 from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication -from api.organisations import models, serializers +from api.documents.libraries.s3_operations import document_download_stream +from api.organisations import ( + filters, + models, + permissions, + serializers, +) class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) + filter_backends = (filters.OrganisationFilter,) + lookup_url_kwarg = "document_on_application_pk" + permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) + queryset = models.DocumentOnOrganisation.objects.all() serializer_class = serializers.DocumentOnOrganisationSerializer - def get_queryset(self): - return models.DocumentOnOrganisation.objects.filter(organisation_id=self.kwargs["pk"]) - def list(self, request, pk): - serializer = self.serializer_class(self.get_queryset(), many=True) + queryset = self.filter_queryset(self.get_queryset()) + serializer = self.serializer_class(queryset, many=True) return JsonResponse({"documents": serializer.data}) def retrieve(self, request, pk, document_on_application_pk): - instance = get_object_or_404(self.get_queryset(), pk=document_on_application_pk) + instance = self.get_object() serializer = self.serializer_class(instance) return JsonResponse(serializer.data) @@ -43,7 +51,7 @@ def create(self, request, pk): return JsonResponse({"document": serializer.data}, status=201) def delete(self, request, pk, document_on_application_pk): - instance = get_object_or_404(self.get_queryset(), pk=document_on_application_pk) + instance = self.get_object() instance.delete() organisation = models.Organisation.objects.get(pk=pk) audit_trail_service.create( @@ -58,7 +66,7 @@ def delete(self, request, pk, document_on_application_pk): return JsonResponse({}, status=204) def update(self, request, pk, document_on_application_pk): - instance = get_object_or_404(self.get_queryset(), pk=document_on_application_pk) + instance = self.get_object() organisation = models.Organisation.objects.get(pk=pk) serializer = self.serializer_class( instance=instance, data=request.data, partial=True, context={"organisation": organisation} @@ -75,3 +83,16 @@ def update(self, request, pk, document_on_application_pk): }, ) return JsonResponse({"document": serializer.data}, status=200) + + +class DocumentOnOrganisationStreamView(RetrieveAPIView): + authentication_classes = (SharedAuthentication,) + filter_backends = (filters.OrganisationFilter,) + lookup_url_kwarg = "document_on_application_pk" + permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) + queryset = models.DocumentOnOrganisation.objects.filter(document__safe=True) + + def retrieve(self, request, *args, **kwargs): + document = self.get_object() + document = document.document + return document_download_stream(document) diff --git a/test_helpers/clients.py b/test_helpers/clients.py index c684f2d6ea..73428b9d8f 100644 --- a/test_helpers/clients.py +++ b/test_helpers/clients.py @@ -55,6 +55,7 @@ from django.conf import settings from api.core.constants import Roles from api.conf.urls import urlpatterns +from api.documents.libraries.s3_operations import init_s3_client from api.flags.enums import SystemFlags, FlagStatuses, FlagLevels from api.flags.models import Flag, FlaggingRule from api.flags.tests.factories import FlagFactory @@ -1078,6 +1079,23 @@ def add_users(self, count=3): out.append(user) return out + def create_default_bucket(self): + s3 = init_s3_client() + s3.create_bucket( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + CreateBucketConfiguration={ + "LocationConstraint": settings.AWS_REGION, + }, + ) + + def put_object_in_default_bucket(self, key, body): + s3 = init_s3_client() + s3.put_object( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + Key=key, + Body=body, + ) + @pytest.mark.performance # we need to set debug to true otherwise we can't see the amount of queries