Skip to content

Commit

Permalink
Merge pull request #2130 from ResearchHub/storage-service
Browse files Browse the repository at this point in the history
feat: Add endpoint for uploading assets
  • Loading branch information
gzurowski authored Feb 26, 2025
2 parents 93bddaf + 2fe20d1 commit 1be6a83
Show file tree
Hide file tree
Showing 15 changed files with 333 additions and 120 deletions.
48 changes: 0 additions & 48 deletions src/paper/services/storage_service.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/paper/tests/test_paper_upload_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def test_post(self):
},
)
self.mock_storage_service.create_presigned_url.assert_called_once_with(
"paper",
"test.pdf",
request.user.id,
"application/pdf",
)

def test_post_fails_with_validation_error(self):
Expand Down
51 changes: 0 additions & 51 deletions src/paper/tests/test_storage_service.py

This file was deleted.

8 changes: 5 additions & 3 deletions src/paper/views/paper_upload_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rest_framework.views import APIView

from paper.serializers.paper_upload_serializer import PaperUploadSerializer
from paper.services.storage_service import StorageService
from researchhub.services.storage_service import S3StorageService


class PaperUploadView(APIView):
Expand All @@ -15,7 +15,7 @@ class PaperUploadView(APIView):
permission_classes = [IsAuthenticated]

def dispatch(self, request, *args, **kwargs):
self.storage_service = kwargs.pop("storage_service", StorageService())
self.storage_service = kwargs.pop("storage_service", S3StorageService())
return super().dispatch(request, *args, **kwargs)

def post(self, request: Request, *args, **kwargs) -> Response:
Expand All @@ -31,7 +31,9 @@ def post(self, request: Request, *args, **kwargs) -> Response:

filename = data.get("filename")

presigned_url = self.storage_service.create_presigned_url(filename, user.id)
presigned_url = self.storage_service.create_presigned_url(
"paper", filename, user.id, "application/pdf"
)

return Response(
{
Expand Down
2 changes: 2 additions & 0 deletions src/researchhub/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .asset_upload_serializer import AssetUploadSerializer
from .serializers import DynamicModelFieldSerializer
17 changes: 17 additions & 0 deletions src/researchhub/serializers/asset_upload_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rest_framework import serializers

from researchhub.services.storage_service import (
SUPPORTED_CONTENT_TYPES,
SUPPORTED_ENTITIES,
)


class AssetUploadSerializer(serializers.Serializer):
"""
Serializer for uploading an asset into ResearchHub storage.
Used to validate request data.
"""

content_type = serializers.ChoiceField(choices=SUPPORTED_CONTENT_TYPES)
entity = serializers.ChoiceField(choices=SUPPORTED_ENTITIES)
filename = serializers.CharField(required=True)
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,8 @@ def __init__(self, *args, **kwargs):

super(DynamicModelFieldSerializer, self).__init__(*args, **kwargs)

# instance_class_name = self.instance.__class__.__name__
# is_manager = (
# instance_class_name == "RelatedManager"
# or instance_class_name == "ManyRelatedManager"
# )
# known_related_objects = getattr(self.instance, "_known_related_objects", [])
# if is_manager or len(known_related_objects) > 0:
# if _include_fields == "_all_":
# _include_fields = None
# _exclude_fields = "__all__"
# elif _include_fields == "_all_":
# _include_fields = "__all__"

if _include_fields is not None and _include_fields != "__all__":
# Drop any fields that are not specified in the
# `_include_fields` argument.
# Drop any fields that are not specified in the `_include_fields` argument.
allowed = set(_include_fields)
existing = set(self.fields)
for field_name in existing - allowed:
Expand Down
74 changes: 74 additions & 0 deletions src/researchhub/services/storage_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import uuid
from typing import NamedTuple

from django.conf import settings

from utils import aws as aws_utils


class PresignedUrl(NamedTuple):
object_key: str
object_url: str
url: str


SUPPORTED_CONTENT_TYPES = ["application/pdf", "image/png", "image/jpeg"]
SUPPORTED_ENTITIES = ["comment", "note", "paper"]


class StorageService:
def create_presigned_url(
self,
entity: str,
filename: str,
user_id: str,
content_type: str,
) -> PresignedUrl: ...


class S3StorageService(StorageService):
"""
Service for interacting with S3 storage.
"""

def create_presigned_url(
self,
entity: str,
filename: str,
user_id: str,
content_type: str,
valid_for_min: int = 2,
) -> PresignedUrl:
"""
Create a presigned URL for uploading a file to S3 that is time-limited.
"""

if entity not in SUPPORTED_ENTITIES:
raise ValueError(f"Unsupported entity: {entity}")

if content_type not in SUPPORTED_CONTENT_TYPES:
raise ValueError(f"Unsupported content type: {content_type}")

s3_filename = f"uploads/{entity}s/users/{user_id}/{uuid.uuid4()}/{filename}"

s3_client = aws_utils.create_client("s3")

url = s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": settings.AWS_STORAGE_BUCKET_NAME,
"Key": s3_filename,
"ContentType": content_type,
"Metadata": {
"created-by-id": f"{user_id}",
"file-name": filename,
},
},
ExpiresIn=60 * valid_for_min,
)

return PresignedUrl(
url=url,
object_key=s3_filename,
object_url=f"https://{settings.AWS_S3_CUSTOM_DOMAIN}/{s3_filename}",
)
Empty file.
104 changes: 104 additions & 0 deletions src/researchhub/tests/test_asset_upload_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest.mock import Mock

from rest_framework.test import APIRequestFactory, APITestCase, force_authenticate

from researchhub.views.asset_upload_view import AssetUploadView
from user.related_models.user_model import User


class AssetUploadViewTest(APITestCase):

def setUp(self):
self.factory = APIRequestFactory()
self.view = AssetUploadView.as_view()
self.mock_storage_service = Mock()
self.user = User.objects.create(username="user1")

def test_post(self):
# Arrange
request = self.factory.post(
"/asset/upload/",
{
"content_type": "application/pdf",
"entity": "paper",
"filename": "test.pdf",
},
)

force_authenticate(request, self.user)

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.data,
{
"presigned_url": self.mock_storage_service.create_presigned_url.return_value.url,
"object_key": self.mock_storage_service.create_presigned_url.return_value.object_key,
"object_url": self.mock_storage_service.create_presigned_url.return_value.object_url,
},
)
self.mock_storage_service.create_presigned_url.assert_called_once_with(
"paper",
"test.pdf",
request.user.id,
"application/pdf",
)

def test_post_fails_unauthenticated(self):
# Arrange
request = self.factory.post("/assets/upload/")

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 401)
self.mock_storage_service.create_presigned_url.assert_not_called()

def test_post_fails_with_validation_error(self):
# Arrange
request = self.factory.post(
"/asset/upload/",
{
# content type is missing!
"entity": "paper",
"filename": "test.pdf",
},
)

force_authenticate(request, self.user)

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 400)
self.assertEqual(response.data, {"content_type": ["This field is required."]})
self.mock_storage_service.create_presigned_url.assert_not_called()

def test_post_with_unsupported_entity(self):
# Arrange
request = self.factory.post(
"/asset/upload/",
{
"content_type": "application/pdf",
"entity": "UNSUPPORTED",
"filename": "test.pdf",
},
)

force_authenticate(request, self.user)

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.data,
{"entity": ['"UNSUPPORTED" is not a valid choice.']},
)
self.mock_storage_service.create_presigned_url.assert_not_called()
Loading

0 comments on commit 1be6a83

Please sign in to comment.