Skip to content

Commit

Permalink
Mock AWS in tests at session level
Browse files Browse the repository at this point in the history
  • Loading branch information
currycoder committed Feb 7, 2024
1 parent 1cb0383 commit aed91dd
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 35 deletions.
7 changes: 0 additions & 7 deletions api/cases/tests/test_case_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_can_view_all_documents_on_a_case(self):
self.assertEqual(len(response_data["documents"]), 2)


@mock_aws
class CaseDocumentDownloadTests(DataTestClient):
def setUp(self):
super().setUp()
Expand All @@ -41,12 +40,6 @@ def setUp(self):
self.path = "cases:document_download"

s3 = init_s3_client()["processed"]
s3.create_bucket(
Bucket=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
CreateBucketConfiguration={
"LocationConstraint": settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_REGION"],
},
)
s3.put_object(
Bucket=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
Key=self.file.s3_key,
Expand Down
2 changes: 2 additions & 0 deletions api/conf/settings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
SUPPRESS_TEST_OUTPUT = True

AWS_ENDPOINT_URL = None
CELERY_TASK_ALWAYS_EAGER = True
CELERY_TASK_STORE_EAGER_RESULT = True
27 changes: 19 additions & 8 deletions api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django import db

from celery import Celery
from moto import mock_aws

import re
import glob
Expand All @@ -13,6 +14,8 @@
import pytest # noqa
from django.conf import settings

from api.documents.libraries.s3_operations import init_s3_client


def camelcase_to_underscore(string):
"""SRC: https://djangosnippets.org/snippets/585/"""
Expand Down Expand Up @@ -130,11 +133,19 @@ def setup(settings):


@pytest.fixture(autouse=True)
def celery_app():
# Setup the celery worker to run in process for tests
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "api.conf.settings")
celeryapp = Celery("api")
celeryapp.autodiscover_tasks(related_name="celery_tasks")
celeryapp.conf.update(CELERY_ALWAYS_EAGER=True)
celeryapp.conf.update(CELERY_TASK_STORE_EAGER_RESULT=True)
return celeryapp
def mock_aws_calls():
with mock_aws():
clients = init_s3_client()
clients["processed"].create_bucket(
Bucket=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
CreateBucketConfiguration={
"LocationConstraint": settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_REGION"],
},
)
clients["staged"].create_bucket(
Bucket=settings.FILE_UPLOAD_STAGED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
CreateBucketConfiguration={
"LocationConstraint": settings.FILE_UPLOAD_STAGED_BUCKET["AWS_REGION"],
},
)
yield
1 change: 1 addition & 0 deletions api/documents/libraries/s3_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def move_staged_document_to_processed(document_id, s3_key):
# Grab the document from the staged S3 bucket
try:
staged_document = get_object(document_id, s3_key, "staged")

except ClientError as exc:
logger.warning(f"An error occurred when retrieving file '{s3_key}' on document '{document_id}': {exc}")
# TODO: When we move over to using two S3 buckets, we should make this raise an exception.
Expand Down
15 changes: 1 addition & 14 deletions api/documents/tests/test_document_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,10 @@
from api.documents.libraries.s3_operations import init_s3_client


@mock_aws
class DocumentStream(DataTestClient):
def setUp(self):
super().setUp()
init_s3_client()
s3 = boto3.client(
"s3",
aws_access_key_id=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_SECRET_ACCESS_KEY"],
region_name=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_REGION"],
)
s3.create_bucket(
Bucket=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
CreateBucketConfiguration={
"LocationConstraint": settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_REGION"],
},
)
s3 = init_s3_client()["processed"]
s3.put_object(
Bucket=settings.FILE_UPLOAD_PROCESSED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
Key="thisisakey",
Expand Down
12 changes: 9 additions & 3 deletions api/organisations/tests/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from unittest import mock

from django.urls import reverse
from django.conf import settings

from api.organisations.enums import OrganisationDocumentType
from api.organisations.models import DocumentOnOrganisation
from test_helpers.clients import DataTestClient
from api.documents.libraries.s3_operations import init_s3_client


class OrganisationDocumentViewTests(DataTestClient):
Expand All @@ -19,6 +21,12 @@ def setUp(self):

def create_document_on_organisation(self, name):
url = reverse("organisations:documents", kwargs={"pk": self.organisation.pk})
s3 = init_s3_client()["staged"]
s3.put_object(
Bucket=settings.FILE_UPLOAD_STAGED_BUCKET["AWS_STORAGE_BUCKET_NAME"],
Key=name,
Body=b"test",
)
data = {
"document": {"name": name, "s3_key": name, "size": 476},
"expiry_date": "2026-01-01",
Expand Down Expand Up @@ -63,10 +71,8 @@ def test_list_organisation_documents(self, mock_virus_scan, mock_s3_operations_g
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.json()["documents"]), 3)

@mock.patch("api.documents.libraries.s3_operations.get_object")
@mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses")
def test_retrieve_organisation_documents(self, mock_virus_scan, mock_s3_operations_get_object):
mock_s3_operations_get_object.return_value = self.document_data
def test_retrieve_organisation_documents(self, mock_virus_scan):
mock_virus_scan.return_value = False
response = self.create_document_on_organisation("some-document-one")
self.assertEqual(response.status_code, 201)
Expand Down
5 changes: 2 additions & 3 deletions api/test_more_documents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from unittest import mock

from django.conf import settings
from django.urls import reverse
from parameterized import parameterized
from rest_framework import status
Expand Down Expand Up @@ -82,12 +83,10 @@ def test_upload_multiple_documents_on_unsubmitted_application(self, mock_virus_s
for document in data:
self.assertTrue(document in response_data)

@mock.patch("api.documents.libraries.s3_operations.get_object")
@mock.patch("api.documents.models.Document.delete_s3")
@mock.patch("api.documents.libraries.av_operations.scan_file_for_viruses")
def test_delete_individual_draft_document(self, mock_virus_scan, mock_delete_s3, mock_s3_operations_get_object):
def test_delete_individual_draft_document(self, mock_virus_scan, mock_delete_s3):
"""Test success in deleting a document from an unsubmitted application."""
mock_s3_operations_get_object.return_value = self.data
mock_virus_scan.return_value = False
self.client.post(self.url_draft, data=self.data, **self.exporter_headers)
response = self.client.get(self.url_draft, **self.exporter_headers)
Expand Down

0 comments on commit aed91dd

Please sign in to comment.