diff --git a/api/documents/tests/test_document_stream.py b/api/documents/tests/test_document_stream.py index e613770c85..d9d408a8e8 100644 --- a/api/documents/tests/test_document_stream.py +++ b/api/documents/tests/test_document_stream.py @@ -1,5 +1,3 @@ -import boto3 - from moto import mock_aws from django.http import StreamingHttpResponse @@ -7,37 +5,13 @@ from test_helpers.clients import DataTestClient -from api.conf.settings import ( - AWS_ACCESS_KEY_ID, - AWS_SECRET_ACCESS_KEY, - AWS_REGION, - AWS_STORAGE_BUCKET_NAME, -) -from api.documents.libraries.s3_operations import init_s3_client - @mock_aws class DocumentStream(DataTestClient): def setUp(self): super().setUp() - init_s3_client() - s3 = boto3.client( - "s3", - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION, - ) - s3.create_bucket( - Bucket=AWS_STORAGE_BUCKET_NAME, - CreateBucketConfiguration={ - "LocationConstraint": AWS_REGION, - }, - ) - s3.put_object( - Bucket=AWS_STORAGE_BUCKET_NAME, - Key="thisisakey", - Body=b"test", - ) + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") def test_document_stream_as_caseworker(self): # given there is a case document diff --git a/api/organisations/tests/factories.py b/api/organisations/tests/factories.py index 5f9a6cf408..90dba01190 100644 --- a/api/organisations/tests/factories.py +++ b/api/organisations/tests/factories.py @@ -58,6 +58,7 @@ def site_records_located_at(self, create, extracted, **kwargs): class DocumentOnOrganisationFactory(factory.django.DjangoModelFactory): document = factory.SubFactory(DocumentFactory) + expiry_date = factory.Faker("future_date") class Meta: model = models.DocumentOnOrganisation diff --git a/api/organisations/tests/test_documents.py b/api/organisations/tests/test_documents.py index 0973986946..46b5765581 100644 --- a/api/organisations/tests/test_documents.py +++ b/api/organisations/tests/test_documents.py @@ -1,10 +1,15 @@ import datetime + from unittest import mock +from moto import mock_aws + +from django.http import FileResponse from django.urls import reverse from api.organisations.enums import OrganisationDocumentType from api.organisations.models import DocumentOnOrganisation +from api.organisations.tests.factories import DocumentOnOrganisationFactory from test_helpers.clients import DataTestClient @@ -162,3 +167,87 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations document_on_application.reference_code, "567", ) + + +@mock_aws +class DocumentOnOrganisationStreamViewTests(DataTestClient): + def setUp(self): + super().setUp() + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") + + def test_document_stream_as_caseworker(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.gov_headers) + + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_document_stream_as_exporter(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=True, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_unsafe_document_stream_as_caseworker(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=False, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.gov_headers) + + self.assertEqual(response.status_code, 404) + + def test_unsafe_document_stream_as_exporter(self): + document_on_application = DocumentOnOrganisationFactory.create( + document__s3_key="thisisakey", + document__safe=False, + organisation=self.organisation, + ) + + url = reverse( + "organisations:document_stream", + kwargs={ + "pk": self.organisation.pk, + "document_on_application_pk": document_on_application.pk, + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, 404) diff --git a/api/organisations/urls.py b/api/organisations/urls.py index 9ec3fd668e..e25df2112a 100644 --- a/api/organisations/urls.py +++ b/api/organisations/urls.py @@ -39,4 +39,9 @@ documents.DocumentOnOrganisationView.as_view({"get": "retrieve", "delete": "delete", "put": "update"}), name="documents", ), + path( + "/document//stream/", + documents.DocumentOnOrganisationStreamView.as_view(), + name="document_stream", + ), ] diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 46150023fa..0015aec71d 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -1,11 +1,14 @@ from rest_framework import viewsets +from rest_framework.generics import RetrieveAPIView -from django.http import JsonResponse +from django.http import Http404, JsonResponse from django.shortcuts import get_object_or_404 from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication +from api.documents.models import Document +from api.documents.libraries.s3_operations import document_download_stream from api.organisations import models, serializers @@ -75,3 +78,19 @@ def update(self, request, pk, document_on_application_pk): }, ) return JsonResponse({"document": serializer.data}, status=200) + + +class DocumentOnOrganisationStreamView(RetrieveAPIView): + authentication_classes = (SharedAuthentication,) + lookup_url_kwarg = "document_on_application_pk" + + def get_queryset(self): + return models.DocumentOnOrganisation.objects.filter( + organisation_id=self.kwargs["pk"], + document__safe=True, + ) + + def retrieve(self, request, *args, **kwargs): + document = self.get_object() + document = document.document + return document_download_stream(document) diff --git a/test_helpers/clients.py b/test_helpers/clients.py index c684f2d6ea..73428b9d8f 100644 --- a/test_helpers/clients.py +++ b/test_helpers/clients.py @@ -55,6 +55,7 @@ from django.conf import settings from api.core.constants import Roles from api.conf.urls import urlpatterns +from api.documents.libraries.s3_operations import init_s3_client from api.flags.enums import SystemFlags, FlagStatuses, FlagLevels from api.flags.models import Flag, FlaggingRule from api.flags.tests.factories import FlagFactory @@ -1078,6 +1079,23 @@ def add_users(self, count=3): out.append(user) return out + def create_default_bucket(self): + s3 = init_s3_client() + s3.create_bucket( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + CreateBucketConfiguration={ + "LocationConstraint": settings.AWS_REGION, + }, + ) + + def put_object_in_default_bucket(self, key, body): + s3 = init_s3_client() + s3.put_object( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + Key=key, + Body=body, + ) + @pytest.mark.performance # we need to set debug to true otherwise we can't see the amount of queries