Skip to content

Commit

Permalink
Add endpoint to stream document on organisation
Browse files Browse the repository at this point in the history
  • Loading branch information
kevincarrogan committed Feb 6, 2024
1 parent 0ce06e9 commit 5f0bb3c
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 29 deletions.
30 changes: 2 additions & 28 deletions api/documents/tests/test_document_stream.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,17 @@
import boto3

from moto import mock_aws

from django.http import StreamingHttpResponse
from django.urls import reverse

from test_helpers.clients import DataTestClient

from api.conf.settings import (
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_REGION,
AWS_STORAGE_BUCKET_NAME,
)
from api.documents.libraries.s3_operations import init_s3_client


@mock_aws
class DocumentStream(DataTestClient):
def setUp(self):
super().setUp()
init_s3_client()
s3 = boto3.client(
"s3",
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_REGION,
)
s3.create_bucket(
Bucket=AWS_STORAGE_BUCKET_NAME,
CreateBucketConfiguration={
"LocationConstraint": AWS_REGION,
},
)
s3.put_object(
Bucket=AWS_STORAGE_BUCKET_NAME,
Key="thisisakey",
Body=b"test",
)
self.create_default_bucket()
self.put_object_in_default_bucket("thisisakey", b"test")

def test_document_stream_as_caseworker(self):
# given there is a case document
Expand Down
1 change: 1 addition & 0 deletions api/organisations/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def site_records_located_at(self, create, extracted, **kwargs):

class DocumentOnOrganisationFactory(factory.django.DjangoModelFactory):
document = factory.SubFactory(DocumentFactory)
expiry_date = factory.Faker("future_date")

class Meta:
model = models.DocumentOnOrganisation
89 changes: 89 additions & 0 deletions api/organisations/tests/test_documents.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import datetime

from unittest import mock

from moto import mock_aws

from django.http import FileResponse
from django.urls import reverse

from api.organisations.enums import OrganisationDocumentType
from api.organisations.models import DocumentOnOrganisation
from api.organisations.tests.factories import DocumentOnOrganisationFactory
from test_helpers.clients import DataTestClient


Expand Down Expand Up @@ -162,3 +167,87 @@ def test_update_organisation_documents(self, mock_virus_scan, mock_s3_operations
document_on_application.reference_code,
"567",
)


@mock_aws
class DocumentOnOrganisationStreamViewTests(DataTestClient):
def setUp(self):
super().setUp()
self.create_default_bucket()
self.put_object_in_default_bucket("thisisakey", b"test")

def test_document_stream_as_caseworker(self):
document_on_application = DocumentOnOrganisationFactory.create(
document__s3_key="thisisakey",
document__safe=True,
organisation=self.organisation,
)

url = reverse(
"organisations:document_stream",
kwargs={
"pk": self.organisation.pk,
"document_on_application_pk": document_on_application.pk,
},
)
response = self.client.get(url, **self.gov_headers)

self.assertEqual(response.status_code, 200)
self.assertIsInstance(response, FileResponse)
self.assertEqual(b"".join(response.streaming_content), b"test")

def test_document_stream_as_exporter(self):
document_on_application = DocumentOnOrganisationFactory.create(
document__s3_key="thisisakey",
document__safe=True,
organisation=self.organisation,
)

url = reverse(
"organisations:document_stream",
kwargs={
"pk": self.organisation.pk,
"document_on_application_pk": document_on_application.pk,
},
)
response = self.client.get(url, **self.exporter_headers)

self.assertEqual(response.status_code, 200)
self.assertIsInstance(response, FileResponse)
self.assertEqual(b"".join(response.streaming_content), b"test")

def test_unsafe_document_stream_as_caseworker(self):
document_on_application = DocumentOnOrganisationFactory.create(
document__s3_key="thisisakey",
document__safe=False,
organisation=self.organisation,
)

url = reverse(
"organisations:document_stream",
kwargs={
"pk": self.organisation.pk,
"document_on_application_pk": document_on_application.pk,
},
)
response = self.client.get(url, **self.gov_headers)

self.assertEqual(response.status_code, 404)

def test_unsafe_document_stream_as_exporter(self):
document_on_application = DocumentOnOrganisationFactory.create(
document__s3_key="thisisakey",
document__safe=False,
organisation=self.organisation,
)

url = reverse(
"organisations:document_stream",
kwargs={
"pk": self.organisation.pk,
"document_on_application_pk": document_on_application.pk,
},
)
response = self.client.get(url, **self.exporter_headers)

self.assertEqual(response.status_code, 404)
5 changes: 5 additions & 0 deletions api/organisations/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,9 @@
documents.DocumentOnOrganisationView.as_view({"get": "retrieve", "delete": "delete", "put": "update"}),
name="documents",
),
path(
"<uuid:pk>/document/<uuid:document_on_application_pk>/stream/",
documents.DocumentOnOrganisationStreamView.as_view(),
name="document_stream",
),
]
21 changes: 20 additions & 1 deletion api/organisations/views/documents.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from rest_framework import viewsets
from rest_framework.generics import RetrieveAPIView

from django.http import JsonResponse
from django.http import Http404, JsonResponse
from django.shortcuts import get_object_or_404

from api.audit_trail import service as audit_trail_service
from api.audit_trail.enums import AuditType
from api.core.authentication import SharedAuthentication
from api.documents.models import Document
from api.documents.libraries.s3_operations import document_download_stream
from api.organisations import models, serializers


Expand Down Expand Up @@ -75,3 +78,19 @@ def update(self, request, pk, document_on_application_pk):
},
)
return JsonResponse({"document": serializer.data}, status=200)


class DocumentOnOrganisationStreamView(RetrieveAPIView):
authentication_classes = (SharedAuthentication,)
lookup_url_kwarg = "document_on_application_pk"

def get_queryset(self):
return models.DocumentOnOrganisation.objects.filter(
organisation_id=self.kwargs["pk"],
document__safe=True,
)

def retrieve(self, request, *args, **kwargs):
document = self.get_object()
document = document.document
return document_download_stream(document)
18 changes: 18 additions & 0 deletions test_helpers/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from django.conf import settings
from api.core.constants import Roles
from api.conf.urls import urlpatterns
from api.documents.libraries.s3_operations import init_s3_client
from api.flags.enums import SystemFlags, FlagStatuses, FlagLevels
from api.flags.models import Flag, FlaggingRule
from api.flags.tests.factories import FlagFactory
Expand Down Expand Up @@ -1078,6 +1079,23 @@ def add_users(self, count=3):
out.append(user)
return out

def create_default_bucket(self):
s3 = init_s3_client()
s3.create_bucket(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
CreateBucketConfiguration={
"LocationConstraint": settings.AWS_REGION,
},
)

def put_object_in_default_bucket(self, key, body):
s3 = init_s3_client()
s3.put_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key=key,
Body=body,
)


@pytest.mark.performance
# we need to set debug to true otherwise we can't see the amount of queries
Expand Down

0 comments on commit 5f0bb3c

Please sign in to comment.