Skip to content

Commit

Permalink
Merge branch 'dev' into LTD-remove-unsupported-hmrcquery
Browse files Browse the repository at this point in the history
  • Loading branch information
saruniitr authored Feb 7, 2024
2 parents ae38a51 + 0ce06e9 commit 63ce6a7
Show file tree
Hide file tree
Showing 17 changed files with 666 additions and 169 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ipdb = "*"
watchdog = {extras = ["watchmedo"], version = "*"}
diff-pdf-visually = "~=1.7.0"
pytest-circleci-parallelized = "~=0.1.0"
moto = {extras = ["s3"], version = "*"}

[packages]
factory-boy = "~=2.12.0"
Expand Down Expand Up @@ -72,7 +73,6 @@ django-silk = "~=5.0.3"
django = "~=4.2.8"
django-queryable-properties = "~=1.9.1"


[requires]
python_version = "3.8"
python_full_version = "3.8.18"
394 changes: 278 additions & 116 deletions Pipfile.lock

Large diffs are not rendered by default.

29 changes: 22 additions & 7 deletions api/cases/tests/test_case_documents.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import uuid
from unittest import mock

from django.http import StreamingHttpResponse
from moto import mock_aws

from django.conf import settings
from django.http import FileResponse
from django.urls import reverse
from rest_framework import status

from lite_content.lite_api.strings import Documents
from test_helpers.clients import DataTestClient

from api.documents.libraries.s3_operations import init_s3_client


class CaseDocumentsTests(DataTestClient):
def setUp(self):
Expand All @@ -27,6 +31,7 @@ def test_can_view_all_documents_on_a_case(self):
self.assertEqual(len(response_data["documents"]), 2)


@mock_aws
class CaseDocumentDownloadTests(DataTestClient):
def setUp(self):
super().setUp()
Expand All @@ -35,16 +40,26 @@ def setUp(self):
self.file = self.create_case_document(self.case, self.gov_user, "Test")
self.path = "cases:document_download"

@mock.patch("api.documents.libraries.s3_operations.get_object")
def test_download_case_document_success(self, get_object_function):
get_object_function.return_value = None
s3 = init_s3_client()
s3.create_bucket(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
CreateBucketConfiguration={
"LocationConstraint": settings.AWS_REGION,
},
)
s3.put_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key=self.file.s3_key,
Body=b"test",
)

def test_download_case_document_success(self):
url = reverse(self.path, kwargs={"case_pk": self.case.id, "document_pk": self.file.id})

response = self.client.get(url, **self.exporter_headers)

get_object_function.assert_called_once()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(isinstance(response, StreamingHttpResponse))
self.assertTrue(isinstance(response, FileResponse))
self.assertEqual(response.headers["content-disposition"], 'attachment; filename="Test"')

def test_download_case_document_invalid_organisation_failure(self):
Expand Down
2 changes: 2 additions & 0 deletions api/conf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,13 @@
raise Exception("S3 Bucket not bound to environment")

aws_credentials = VCAP_SERVICES["aws-s3-bucket"][0]["credentials"]
AWS_ENDPOINT_URL = None
AWS_ACCESS_KEY_ID = aws_credentials["aws_access_key_id"]
AWS_SECRET_ACCESS_KEY = aws_credentials["aws_secret_access_key"]
AWS_REGION = aws_credentials["aws_region"]
AWS_STORAGE_BUCKET_NAME = aws_credentials["bucket_name"]
else:
AWS_ENDPOINT_URL = env("AWS_ENDPOINT_URL", default=None)
AWS_ACCESS_KEY_ID = env("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = env("AWS_SECRET_ACCESS_KEY")
AWS_REGION = env("AWS_REGION")
Expand Down
2 changes: 2 additions & 0 deletions api/conf/settings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
LOGGING = {"version": 1, "disable_existing_loggers": True}

SUPPRESS_TEST_OUTPUT = True

AWS_ENDPOINT_URL = None
65 changes: 36 additions & 29 deletions api/documents/libraries/s3_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,40 @@
import boto3
from botocore.config import Config
from botocore.exceptions import BotoCoreError, ReadTimeoutError
from django.http import StreamingHttpResponse

from api.conf.settings import (
STREAMING_CHUNK_SIZE,
S3_CONNECT_TIMEOUT,
S3_REQUEST_TIMEOUT,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_REGION,
AWS_STORAGE_BUCKET_NAME,
)

_client = boto3.client(
"s3",
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_REGION,
config=Config(connect_timeout=S3_CONNECT_TIMEOUT, read_timeout=S3_REQUEST_TIMEOUT),
)

from django.conf import settings
from django.http import FileResponse


_client = None


def init_s3_client():
# We want to instantiate this once, ideally, but there may be cases where we
# want to explicitly re-instiate the client e.g. in tests.
global _client
additional_s3_params = {}
if settings.AWS_ENDPOINT_URL:
additional_s3_params["endpoint_url"] = settings.AWS_ENDPOINT_URL
_client = boto3.client(
"s3",
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=settings.AWS_REGION,
config=Config(connect_timeout=settings.S3_CONNECT_TIMEOUT, read_timeout=settings.S3_REQUEST_TIMEOUT),
**additional_s3_params,
)
return _client


init_s3_client()


def get_object(document_id, s3_key):
logging.info(f"Retrieving file '{s3_key}' on document '{document_id}'")

try:
return _client.get_object(Bucket=AWS_STORAGE_BUCKET_NAME, Key=s3_key)
return _client.get_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key)
except ReadTimeoutError:
logging.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'")
except BotoCoreError as exc:
Expand All @@ -44,14 +52,14 @@ def generate_s3_key(document_name, file_extension):


def upload_bytes_file(raw_file, s3_key):
_client.put_object(Bucket=AWS_STORAGE_BUCKET_NAME, Key=s3_key, Body=raw_file)
_client.put_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key, Body=raw_file)


def delete_file(document_id, s3_key):
logging.info(f"Deleting file '{s3_key}' on document '{document_id}'")

try:
_client.delete_object(Bucket=AWS_STORAGE_BUCKET_NAME, Key=s3_key)
_client.delete_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key)
except ReadTimeoutError:
logging.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'")
except BotoCoreError as exc:
Expand All @@ -60,16 +68,15 @@ def delete_file(document_id, s3_key):
)


def _stream_file(result):
for chunk in iter(lambda: result["Body"].read(STREAMING_CHUNK_SIZE), b""):
yield chunk


def document_download_stream(document):
s3_response = get_object(document.id, document.s3_key)
content_type = mimetypes.MimeTypes().guess_type(document.name)[0]

response = StreamingHttpResponse(streaming_content=_stream_file(s3_response), content_type=content_type)
response["Content-Disposition"] = f'attachment; filename="{document.name}"'
response = FileResponse(
s3_response["Body"],
as_attachment=True,
filename=document.name,
)
response["Content-Type"] = content_type

return response
Empty file.
161 changes: 161 additions & 0 deletions api/documents/libraries/tests/test_s3_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from contextlib import contextmanager
from unittest.mock import Mock, patch

from moto import mock_aws

from django.conf import settings
from django.http import FileResponse
from django.test import override_settings, SimpleTestCase

from ..s3_operations import (
delete_file,
document_download_stream,
init_s3_client,
get_object,
upload_bytes_file,
)


@patch("api.documents.libraries.s3_operations.boto3")
@patch("api.documents.libraries.s3_operations.Config")
@override_settings(
AWS_ENDPOINT_URL="AWS_ENDPOINT_URL",
AWS_ACCESS_KEY_ID="AWS_ACCESS_KEY_ID",
AWS_SECRET_ACCESS_KEY="AWS_SECRET_ACCESS_KEY",
AWS_REGION="AWS_REGION",
S3_CONNECT_TIMEOUT=22,
S3_REQUEST_TIMEOUT=44,
)
class S3OperationsTests(SimpleTestCase):
@override_settings(
AWS_ENDPOINT_URL=None,
)
def test_get_client_without_aws_endpoint_url(self, mock_Config, mock_boto3):
mock_client = Mock()
mock_boto3.client.return_value = mock_client

returned_client = init_s3_client()
self.assertEqual(returned_client, mock_client)

mock_Config.assert_called_with(
connect_timeout=22,
read_timeout=44,
)
config = mock_Config(
connection_timeout=22,
read_timeout=44,
)
mock_boto3.client.assert_called_with(
"s3",
aws_access_key_id="AWS_ACCESS_KEY_ID",
aws_secret_access_key="AWS_SECRET_ACCESS_KEY",
region_name="AWS_REGION",
config=config,
)

def test_get_client_with_aws_endpoint_url(self, mock_Config, mock_boto3):
mock_client = Mock()
mock_boto3.client.return_value = mock_client

returned_client = init_s3_client()
self.assertEqual(returned_client, mock_client)

mock_Config.assert_called_with(
connect_timeout=22,
read_timeout=44,
)
config = mock_Config(
connection_timeout=22,
read_timeout=44,
)
mock_boto3.client.assert_called_with(
"s3",
aws_access_key_id="AWS_ACCESS_KEY_ID",
aws_secret_access_key="AWS_SECRET_ACCESS_KEY",
region_name="AWS_REGION",
config=config,
endpoint_url="AWS_ENDPOINT_URL",
)


@override_settings(
AWS_STORAGE_BUCKET_NAME="test-bucket",
)
class S3OperationsGetObjectTests(SimpleTestCase):
@patch("api.documents.libraries.s3_operations._client")
def test_get_object(self, mock_client):
mock_object = Mock()
mock_client.get_object.return_value = mock_object

returned_object = get_object("document-id", "s3-key")

self.assertEqual(returned_object, mock_object)
mock_client.get_object.assert_called_with(Bucket="test-bucket", Key="s3-key")


@contextmanager
def _create_bucket(s3):
s3.create_bucket(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
CreateBucketConfiguration={
"LocationConstraint": settings.AWS_REGION,
},
)
yield


@mock_aws
class S3OperationsDeleteFileTests(SimpleTestCase):
def test_delete_file(self):
s3 = init_s3_client()
with _create_bucket(s3):
s3.put_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key="s3-key",
Body=b"test",
)

delete_file("document-id", "s3-key")

objs = s3.list_objects(Bucket=settings.AWS_STORAGE_BUCKET_NAME)
keys = [o["Key"] for o in objs.get("Contents", [])]
self.assertNotIn("s3-key", keys)


@mock_aws
class S3OperationsUploadBytesFileTests(SimpleTestCase):
def test_upload_bytes_file(self):
s3 = init_s3_client()
with _create_bucket(s3):
upload_bytes_file(b"test", "s3-key")

obj = s3.get_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key="s3-key",
)
self.assertEqual(obj["Body"].read(), b"test")


@mock_aws
class S3OperationsDocumentDownloadStreamTests(SimpleTestCase):
def test_document_download_stream(self):
s3 = init_s3_client()
with _create_bucket(s3):
s3.put_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key="s3-key",
Body=b"test",
)

mock_document = Mock()
mock_document.id = "document-id"
mock_document.s3_key = "s3-key"
mock_document.name = "test.doc"

response = document_download_stream(mock_document)

self.assertIsInstance(response, FileResponse)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/msword")
self.assertEqual(response["Content-Disposition"], 'attachment; filename="test.doc"')
self.assertEqual(b"".join(response.streaming_content), b"test")
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from django.urls import reverse

from api.documents import permissions
from test_helpers.clients import DataTestClient


class CertificateDownload(DataTestClient):
class DocumentDetail(DataTestClient):
def test_document_detail_as_caseworker(self):
# given there is a case document
case = self.create_standard_application_case(self.organisation)
Expand Down
Loading

0 comments on commit 63ce6a7

Please sign in to comment.