diff --git a/src/paper/services/storage_service.py b/src/paper/services/storage_service.py deleted file mode 100644 index 857a94189e..0000000000 --- a/src/paper/services/storage_service.py +++ /dev/null @@ -1,48 +0,0 @@ -import uuid -from typing import NamedTuple - -from django.conf import settings - -from utils import aws as aws_utils - - -class PresignedUrl(NamedTuple): - object_key: str - url: str - - -class StorageService: - """ - Service for interacting with S3 storage. - """ - - def create_presigned_url( - self, - filename: str, - user_id: str, - content_type: str = "application/pdf", - valid_for: int = 2, - ) -> PresignedUrl: - """ - Create a presigned URL for uploading a file to S3 that is time-limited. - """ - - s3_filename = f"uploads/papers/users/{user_id}/{uuid.uuid4()}/{filename}" - - s3_client = aws_utils.create_client("s3") - - url = s3_client.generate_presigned_url( - "put_object", - Params={ - "Bucket": settings.AWS_STORAGE_BUCKET_NAME, - "Key": s3_filename, - "ContentType": content_type, - "Metadata": { - "created-by-id": f"{user_id}", - "file-name": filename, - }, - }, - ExpiresIn=60 * valid_for, - ) - - return PresignedUrl(url=url, object_key=s3_filename) diff --git a/src/paper/tests/test_paper_upload_view.py b/src/paper/tests/test_paper_upload_view.py index d301688030..1ac549cadc 100644 --- a/src/paper/tests/test_paper_upload_view.py +++ b/src/paper/tests/test_paper_upload_view.py @@ -38,8 +38,10 @@ def test_post(self): }, ) self.mock_storage_service.create_presigned_url.assert_called_once_with( + "paper", "test.pdf", request.user.id, + "application/pdf", ) def test_post_fails_with_validation_error(self): diff --git a/src/paper/tests/test_storage_service.py b/src/paper/tests/test_storage_service.py deleted file mode 100644 index e288d2746a..0000000000 --- a/src/paper/tests/test_storage_service.py +++ /dev/null @@ -1,51 +0,0 @@ -import uuid -from unittest import TestCase -from unittest.mock import Mock, patch - -from django.conf import settings - -from paper.services import storage_service -from paper.services.storage_service import StorageService - - -class StorageServiceTest(TestCase): - - @patch("paper.services.storage_service.aws_utils.create_client") - @patch("paper.services.storage_service.uuid.uuid4") - def test_create_presigned_url(self, mock_uuid, mock_create_client): - # Arrange - uuid1 = uuid.uuid4() - mock_uuid.return_value = uuid1 - - mock_s3_client = Mock() - mock_create_client.return_value = mock_s3_client - - mock_s3_client.generate_presigned_url.return_value = "https://presignedUrl1" - - service = StorageService() - - # Act - url = service.create_presigned_url("file1.pdf", "userId1", valid_for=2) - - # Assert - mock_s3_client.generate_presigned_url.assert_called_once_with( - "put_object", - Params={ - "Bucket": settings.AWS_STORAGE_BUCKET_NAME, - "Key": f"uploads/papers/users/userId1/{uuid1}/file1.pdf", - "ContentType": "application/pdf", - "Metadata": { - "created-by-id": "userId1", - "file-name": "file1.pdf", - }, - }, - ExpiresIn=60 * 2, - ) - - self.assertEqual( - url, - storage_service.PresignedUrl( - url="https://presignedUrl1", - object_key=f"uploads/papers/users/userId1/{uuid1}/file1.pdf", - ), - ) diff --git a/src/paper/views/paper_upload_views.py b/src/paper/views/paper_upload_views.py index 59b6fe0630..e98227f022 100644 --- a/src/paper/views/paper_upload_views.py +++ b/src/paper/views/paper_upload_views.py @@ -4,7 +4,7 @@ from rest_framework.views import APIView from paper.serializers.paper_upload_serializer import PaperUploadSerializer -from paper.services.storage_service import StorageService +from researchhub.services.storage_service import S3StorageService class PaperUploadView(APIView): @@ -15,7 +15,7 @@ class PaperUploadView(APIView): permission_classes = [IsAuthenticated] def dispatch(self, request, *args, **kwargs): - self.storage_service = kwargs.pop("storage_service", StorageService()) + self.storage_service = kwargs.pop("storage_service", S3StorageService()) return super().dispatch(request, *args, **kwargs) def post(self, request: Request, *args, **kwargs) -> Response: @@ -31,7 +31,9 @@ def post(self, request: Request, *args, **kwargs) -> Response: filename = data.get("filename") - presigned_url = self.storage_service.create_presigned_url(filename, user.id) + presigned_url = self.storage_service.create_presigned_url( + "paper", filename, user.id, "application/pdf" + ) return Response( { diff --git a/src/researchhub/serializers/__init__.py b/src/researchhub/serializers/__init__.py new file mode 100644 index 0000000000..4a820434ff --- /dev/null +++ b/src/researchhub/serializers/__init__.py @@ -0,0 +1,2 @@ +from .asset_upload_serializer import AssetUploadSerializer +from .serializers import DynamicModelFieldSerializer diff --git a/src/researchhub/serializers/asset_upload_serializer.py b/src/researchhub/serializers/asset_upload_serializer.py new file mode 100644 index 0000000000..6620a0cd20 --- /dev/null +++ b/src/researchhub/serializers/asset_upload_serializer.py @@ -0,0 +1,17 @@ +from rest_framework import serializers + +from researchhub.services.storage_service import ( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_ENTITIES, +) + + +class AssetUploadSerializer(serializers.Serializer): + """ + Serializer for uploading an asset into ResearchHub storage. + Used to validate request data. + """ + + content_type = serializers.ChoiceField(choices=SUPPORTED_CONTENT_TYPES) + entity = serializers.ChoiceField(choices=SUPPORTED_ENTITIES) + filename = serializers.CharField(required=True) diff --git a/src/researchhub/serializers.py b/src/researchhub/serializers/serializers.py similarity index 71% rename from src/researchhub/serializers.py rename to src/researchhub/serializers/serializers.py index a9760dad8c..98c459cf98 100644 --- a/src/researchhub/serializers.py +++ b/src/researchhub/serializers/serializers.py @@ -19,22 +19,8 @@ def __init__(self, *args, **kwargs): super(DynamicModelFieldSerializer, self).__init__(*args, **kwargs) - # instance_class_name = self.instance.__class__.__name__ - # is_manager = ( - # instance_class_name == "RelatedManager" - # or instance_class_name == "ManyRelatedManager" - # ) - # known_related_objects = getattr(self.instance, "_known_related_objects", []) - # if is_manager or len(known_related_objects) > 0: - # if _include_fields == "_all_": - # _include_fields = None - # _exclude_fields = "__all__" - # elif _include_fields == "_all_": - # _include_fields = "__all__" - if _include_fields is not None and _include_fields != "__all__": - # Drop any fields that are not specified in the - # `_include_fields` argument. + # Drop any fields that are not specified in the `_include_fields` argument. allowed = set(_include_fields) existing = set(self.fields) for field_name in existing - allowed: diff --git a/src/researchhub/services/storage_service.py b/src/researchhub/services/storage_service.py new file mode 100644 index 0000000000..8d7adbef86 --- /dev/null +++ b/src/researchhub/services/storage_service.py @@ -0,0 +1,74 @@ +import uuid +from typing import NamedTuple + +from django.conf import settings + +from utils import aws as aws_utils + + +class PresignedUrl(NamedTuple): + object_key: str + object_url: str + url: str + + +SUPPORTED_CONTENT_TYPES = ["application/pdf", "image/png", "image/jpeg"] +SUPPORTED_ENTITIES = ["comment", "note", "paper"] + + +class StorageService: + def create_presigned_url( + self, + entity: str, + filename: str, + user_id: str, + content_type: str, + ) -> PresignedUrl: ... + + +class S3StorageService(StorageService): + """ + Service for interacting with S3 storage. + """ + + def create_presigned_url( + self, + entity: str, + filename: str, + user_id: str, + content_type: str, + valid_for_min: int = 2, + ) -> PresignedUrl: + """ + Create a presigned URL for uploading a file to S3 that is time-limited. + """ + + if entity not in SUPPORTED_ENTITIES: + raise ValueError(f"Unsupported entity: {entity}") + + if content_type not in SUPPORTED_CONTENT_TYPES: + raise ValueError(f"Unsupported content type: {content_type}") + + s3_filename = f"uploads/{entity}s/users/{user_id}/{uuid.uuid4()}/{filename}" + + s3_client = aws_utils.create_client("s3") + + url = s3_client.generate_presigned_url( + "put_object", + Params={ + "Bucket": settings.AWS_STORAGE_BUCKET_NAME, + "Key": s3_filename, + "ContentType": content_type, + "Metadata": { + "created-by-id": f"{user_id}", + "file-name": filename, + }, + }, + ExpiresIn=60 * valid_for_min, + ) + + return PresignedUrl( + url=url, + object_key=s3_filename, + object_url=f"https://{settings.AWS_S3_CUSTOM_DOMAIN}/{s3_filename}", + ) diff --git a/src/researchhub/tests/__init__.py b/src/researchhub/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/researchhub/tests/test_asset_upload_view.py b/src/researchhub/tests/test_asset_upload_view.py new file mode 100644 index 0000000000..b06ecebd88 --- /dev/null +++ b/src/researchhub/tests/test_asset_upload_view.py @@ -0,0 +1,104 @@ +from unittest.mock import Mock + +from rest_framework.test import APIRequestFactory, APITestCase, force_authenticate + +from researchhub.views.asset_upload_view import AssetUploadView +from user.related_models.user_model import User + + +class AssetUploadViewTest(APITestCase): + + def setUp(self): + self.factory = APIRequestFactory() + self.view = AssetUploadView.as_view() + self.mock_storage_service = Mock() + self.user = User.objects.create(username="user1") + + def test_post(self): + # Arrange + request = self.factory.post( + "/asset/upload/", + { + "content_type": "application/pdf", + "entity": "paper", + "filename": "test.pdf", + }, + ) + + force_authenticate(request, self.user) + + # Act + response = self.view(request, storage_service=self.mock_storage_service) + + # Assert + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.data, + { + "presigned_url": self.mock_storage_service.create_presigned_url.return_value.url, + "object_key": self.mock_storage_service.create_presigned_url.return_value.object_key, + "object_url": self.mock_storage_service.create_presigned_url.return_value.object_url, + }, + ) + self.mock_storage_service.create_presigned_url.assert_called_once_with( + "paper", + "test.pdf", + request.user.id, + "application/pdf", + ) + + def test_post_fails_unauthenticated(self): + # Arrange + request = self.factory.post("/assets/upload/") + + # Act + response = self.view(request, storage_service=self.mock_storage_service) + + # Assert + self.assertEqual(response.status_code, 401) + self.mock_storage_service.create_presigned_url.assert_not_called() + + def test_post_fails_with_validation_error(self): + # Arrange + request = self.factory.post( + "/asset/upload/", + { + # content type is missing! + "entity": "paper", + "filename": "test.pdf", + }, + ) + + force_authenticate(request, self.user) + + # Act + response = self.view(request, storage_service=self.mock_storage_service) + + # Assert + self.assertEqual(response.status_code, 400) + self.assertEqual(response.data, {"content_type": ["This field is required."]}) + self.mock_storage_service.create_presigned_url.assert_not_called() + + def test_post_with_unsupported_entity(self): + # Arrange + request = self.factory.post( + "/asset/upload/", + { + "content_type": "application/pdf", + "entity": "UNSUPPORTED", + "filename": "test.pdf", + }, + ) + + force_authenticate(request, self.user) + + # Act + response = self.view(request, storage_service=self.mock_storage_service) + + # Assert + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.data, + {"entity": ['"UNSUPPORTED" is not a valid choice.']}, + ) + self.mock_storage_service.create_presigned_url.assert_not_called() diff --git a/src/researchhub/tests/test_storage_service.py b/src/researchhub/tests/test_storage_service.py new file mode 100644 index 0000000000..deb44fbb6c --- /dev/null +++ b/src/researchhub/tests/test_storage_service.py @@ -0,0 +1,71 @@ +import uuid +from unittest.mock import Mock, patch + +from django.conf import settings +from django.test import TestCase, override_settings + +from researchhub.services import storage_service +from researchhub.services.storage_service import S3StorageService + + +@override_settings(AWS_S3_CUSTOM_DOMAIN="storage.test.researchhub.com") +class StorageServiceTest(TestCase): + + @patch("researchhub.services.storage_service.aws_utils.create_client") + @patch("researchhub.services.storage_service.uuid.uuid4") + def test_create_presigned_url(self, mock_uuid, mock_create_client): + # Arrange + uuid1 = uuid.uuid4() + mock_uuid.return_value = uuid1 + + mock_s3_client = Mock() + mock_create_client.return_value = mock_s3_client + + mock_s3_client.generate_presigned_url.return_value = "https://presignedUrl1" + + service = S3StorageService() + + # Act + url = service.create_presigned_url( + "paper", "file1.pdf", "userId1", "application/pdf", valid_for_min=3 + ) + + # Assert + mock_s3_client.generate_presigned_url.assert_called_once_with( + "put_object", + Params={ + "Bucket": settings.AWS_STORAGE_BUCKET_NAME, + "Key": f"uploads/papers/users/userId1/{uuid1}/file1.pdf", + "ContentType": "application/pdf", + "Metadata": { + "created-by-id": "userId1", + "file-name": "file1.pdf", + }, + }, + ExpiresIn=60 * 3, + ) + + self.assertEqual( + url, + storage_service.PresignedUrl( + url="https://presignedUrl1", + object_key=f"uploads/papers/users/userId1/{uuid1}/file1.pdf", + object_url=f"https://{settings.AWS_S3_CUSTOM_DOMAIN}/uploads/papers/users/userId1/{uuid1}/file1.pdf", + ), + ) + + def test_create_presigned_url_unsupported_entity(self): + with self.assertRaises(ValueError): + S3StorageService().create_presigned_url( + "UNSUPPORTED", + "file1.pdf", + "userId1", + "application/pdf", + valid_for_min=3, + ) + + def test_create_presigned_url_unsupported_content_type(self): + with self.assertRaises(ValueError): + S3StorageService().create_presigned_url( + "paper", "file1.pdf", "userId1", "UNSUPPORTED", valid_for_min=3 + ) diff --git a/src/researchhub/urls.py b/src/researchhub/urls.py index d1dc18fbe6..888205e87c 100644 --- a/src/researchhub/urls.py +++ b/src/researchhub/urls.py @@ -37,6 +37,7 @@ from feed.views import FeedViewSet from paper.views import paper_upload_views from purchase.views import stripe_webhook_view +from researchhub.views import asset_upload_view from researchhub_comment.views.rh_comment_view import RhCommentViewSet from review.views.peer_review_view import PeerReviewViewSet from review.views.review_view import ReviewViewSet @@ -261,6 +262,11 @@ ), path("email_notifications/", mailing_list.views.email_notifications), path("", researchhub.views.index, name="index"), + path( + "api/asset/upload/", + asset_upload_view.AssetUploadView.as_view(), + name="asset_upload", + ), path( "paper/upload/", paper_upload_views.PaperUploadView.as_view(), diff --git a/src/researchhub/views/__init__.py b/src/researchhub/views/__init__.py new file mode 100644 index 0000000000..d9e12aa10f --- /dev/null +++ b/src/researchhub/views/__init__.py @@ -0,0 +1,2 @@ +from .asset_upload_view import AssetUploadView +from .views import * diff --git a/src/researchhub/views/asset_upload_view.py b/src/researchhub/views/asset_upload_view.py new file mode 100644 index 0000000000..8395d3a0a8 --- /dev/null +++ b/src/researchhub/views/asset_upload_view.py @@ -0,0 +1,46 @@ +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView + +from researchhub.serializers.asset_upload_serializer import AssetUploadSerializer +from researchhub.services.storage_service import S3StorageService + + +class AssetUploadView(APIView): + """ + View for uploading assets into ResearchHub storage. + """ + + permission_classes = [IsAuthenticated] + + def dispatch(self, request, *args, **kwargs): + self.storage_service = kwargs.pop("storage_service", S3StorageService()) + return super().dispatch(request, *args, **kwargs) + + def post(self, request: Request, *args, **kwargs) -> Response: + """ + Creates a presigned URL for uploading an asset and returns it. + """ + user = request.user + data = request.data + + # Validate request data + serializer = AssetUploadSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + content_type = data.get("content_type") + entity = data.get("entity") + filename = data.get("filename") + + presigned_url = self.storage_service.create_presigned_url( + entity, filename, user.id, content_type + ) + + return Response( + { + "presigned_url": presigned_url.url, + "object_key": presigned_url.object_key, + "object_url": presigned_url.object_url, + } + ) diff --git a/src/researchhub/views.py b/src/researchhub/views/views.py similarity index 88% rename from src/researchhub/views.py rename to src/researchhub/views/views.py index 0a8837d72e..be7a9c7736 100644 --- a/src/researchhub/views.py +++ b/src/researchhub/views/views.py @@ -6,17 +6,17 @@ from researchhub.settings import BASE_DIR -def index(request): +def index(): return HttpResponse("Authenticate with a token in the Authorization header.") -def permissions(request): +def permissions(): path = os.path.join(BASE_DIR, "static", "researchhub", "user_permissions.json") with open(path, "r") as file: data = file.read() return HttpResponse(content=data, content_type="application/json") -def robots_txt(request): +def robots_txt(): content = render_to_string("robots.txt") return HttpResponse(content, content_type="text/plain")