Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add endpoint for uploading assets #2130

Merged
merged 11 commits into from
Feb 26, 2025
48 changes: 0 additions & 48 deletions src/paper/services/storage_service.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/paper/tests/test_paper_upload_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def test_post(self):
},
)
self.mock_storage_service.create_presigned_url.assert_called_once_with(
"paper",
"test.pdf",
request.user.id,
"application/pdf",
)

def test_post_fails_with_validation_error(self):
Expand Down
51 changes: 0 additions & 51 deletions src/paper/tests/test_storage_service.py

This file was deleted.

8 changes: 5 additions & 3 deletions src/paper/views/paper_upload_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rest_framework.views import APIView

from paper.serializers.paper_upload_serializer import PaperUploadSerializer
from paper.services.storage_service import StorageService
from researchhub.services.storage_service import S3StorageService


class PaperUploadView(APIView):
Expand All @@ -15,7 +15,7 @@ class PaperUploadView(APIView):
permission_classes = [IsAuthenticated]

def dispatch(self, request, *args, **kwargs):
self.storage_service = kwargs.pop("storage_service", StorageService())
self.storage_service = kwargs.pop("storage_service", S3StorageService())
return super().dispatch(request, *args, **kwargs)

def post(self, request: Request, *args, **kwargs) -> Response:
Expand All @@ -31,7 +31,9 @@ def post(self, request: Request, *args, **kwargs) -> Response:

filename = data.get("filename")

presigned_url = self.storage_service.create_presigned_url(filename, user.id)
presigned_url = self.storage_service.create_presigned_url(
"paper", filename, user.id, "application/pdf"
)

return Response(
{
Expand Down
2 changes: 2 additions & 0 deletions src/researchhub/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .asset_upload_serializer import AssetUploadSerializer
from .serializers import DynamicModelFieldSerializer
17 changes: 17 additions & 0 deletions src/researchhub/serializers/asset_upload_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rest_framework import serializers

from researchhub.services.storage_service import (
SUPPORTED_CONTENT_TYPES,
SUPPORTED_ENTITIES,
)


class AssetUploadSerializer(serializers.Serializer):
"""
Serializer for uploading an asset into ResearchHub storage.
Used to validate request data.
"""

content_type = serializers.ChoiceField(choices=SUPPORTED_CONTENT_TYPES)
entity = serializers.ChoiceField(choices=SUPPORTED_ENTITIES)
filename = serializers.CharField(required=True)
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,8 @@ def __init__(self, *args, **kwargs):

super(DynamicModelFieldSerializer, self).__init__(*args, **kwargs)

# instance_class_name = self.instance.__class__.__name__
# is_manager = (
# instance_class_name == "RelatedManager"
# or instance_class_name == "ManyRelatedManager"
# )
# known_related_objects = getattr(self.instance, "_known_related_objects", [])
# if is_manager or len(known_related_objects) > 0:
# if _include_fields == "_all_":
# _include_fields = None
# _exclude_fields = "__all__"
# elif _include_fields == "_all_":
# _include_fields = "__all__"

if _include_fields is not None and _include_fields != "__all__":
# Drop any fields that are not specified in the
# `_include_fields` argument.
# Drop any fields that are not specified in the `_include_fields` argument.
allowed = set(_include_fields)
existing = set(self.fields)
for field_name in existing - allowed:
Expand Down
74 changes: 74 additions & 0 deletions src/researchhub/services/storage_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import uuid
from typing import NamedTuple

from django.conf import settings

from utils import aws as aws_utils


class PresignedUrl(NamedTuple):
object_key: str
object_url: str
url: str


SUPPORTED_CONTENT_TYPES = ["application/pdf", "image/png", "image/jpeg"]
SUPPORTED_ENTITIES = ["comment", "note", "paper"]


class StorageService:
def create_presigned_url(
self,
entity: str,
filename: str,
user_id: str,
content_type: str,
) -> PresignedUrl: ...


class S3StorageService(StorageService):
"""
Service for interacting with S3 storage.
"""

def create_presigned_url(
self,
entity: str,
filename: str,
user_id: str,
content_type: str,
valid_for_min: int = 2,
) -> PresignedUrl:
"""
Create a presigned URL for uploading a file to S3 that is time-limited.
"""

if entity not in SUPPORTED_ENTITIES:
raise ValueError(f"Unsupported entity: {entity}")

if content_type not in SUPPORTED_CONTENT_TYPES:
raise ValueError(f"Unsupported content type: {content_type}")

s3_filename = f"uploads/{entity}s/users/{user_id}/{uuid.uuid4()}/{filename}"

s3_client = aws_utils.create_client("s3")

url = s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": settings.AWS_STORAGE_BUCKET_NAME,
"Key": s3_filename,
"ContentType": content_type,
"Metadata": {
"created-by-id": f"{user_id}",
"file-name": filename,
},
},
ExpiresIn=60 * valid_for_min,
)

return PresignedUrl(
url=url,
object_key=s3_filename,
object_url=f"https://{settings.AWS_S3_CUSTOM_DOMAIN}/{s3_filename}",
)
Empty file.
104 changes: 104 additions & 0 deletions src/researchhub/tests/test_asset_upload_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest.mock import Mock

from rest_framework.test import APIRequestFactory, APITestCase, force_authenticate

from researchhub.views.asset_upload_view import AssetUploadView
from user.related_models.user_model import User


class AssetUploadViewTest(APITestCase):

def setUp(self):
self.factory = APIRequestFactory()
self.view = AssetUploadView.as_view()
self.mock_storage_service = Mock()
self.user = User.objects.create(username="user1")

def test_post(self):
# Arrange
request = self.factory.post(
"/asset/upload/",
{
"content_type": "application/pdf",
"entity": "paper",
"filename": "test.pdf",
},
)

force_authenticate(request, self.user)

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.data,
{
"presigned_url": self.mock_storage_service.create_presigned_url.return_value.url,
"object_key": self.mock_storage_service.create_presigned_url.return_value.object_key,
"object_url": self.mock_storage_service.create_presigned_url.return_value.object_url,
},
)
self.mock_storage_service.create_presigned_url.assert_called_once_with(
"paper",
"test.pdf",
request.user.id,
"application/pdf",
)

def test_post_fails_unauthenticated(self):
# Arrange
request = self.factory.post("/assets/upload/")

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 401)
self.mock_storage_service.create_presigned_url.assert_not_called()

def test_post_fails_with_validation_error(self):
# Arrange
request = self.factory.post(
"/asset/upload/",
{
# content type is missing!
"entity": "paper",
"filename": "test.pdf",
},
)

force_authenticate(request, self.user)

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 400)
self.assertEqual(response.data, {"content_type": ["This field is required."]})
self.mock_storage_service.create_presigned_url.assert_not_called()

def test_post_with_unsupported_entity(self):
# Arrange
request = self.factory.post(
"/asset/upload/",
{
"content_type": "application/pdf",
"entity": "UNSUPPORTED",
"filename": "test.pdf",
},
)

force_authenticate(request, self.user)

# Act
response = self.view(request, storage_service=self.mock_storage_service)

# Assert
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.data,
{"entity": ['"UNSUPPORTED" is not a valid choice.']},
)
self.mock_storage_service.create_presigned_url.assert_not_called()
Loading
Loading