Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize asset list endpoint #1904

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dandiapi/api/tests/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,4 +1539,5 @@ def test_asset_rest_glob(api_client, asset_factory, version, glob_pattern, expec
{'glob': glob_pattern},
)

assert expected_paths == [asset['path'] for asset in resp.json()['results']]
# Sort both lists before comparing since ordering is not considered
assert sorted(expected_paths) == sorted([asset['path'] for asset in resp.json()['results']])
64 changes: 45 additions & 19 deletions dandiapi/api/views/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# This should only be used for type interrogation, never instantiation
MinioStorage = type('FakeMinioStorage', (), {})

from typing import TYPE_CHECKING

from django.conf import settings
from django.db import transaction
Expand Down Expand Up @@ -57,9 +56,6 @@
AssetValidationSerializer,
)

if TYPE_CHECKING:
from django.db.models import QuerySet


class AssetFilter(filters.FilterSet):
path = filters.CharFilter(lookup_expr='istartswith')
Expand Down Expand Up @@ -232,7 +228,7 @@ class NestedAssetViewSet(NestedViewSetMixin, AssetViewSet, ReadOnlyModelViewSet)

def raise_if_unauthorized(self):
version = get_object_or_404(
Version,
Version.objects.select_related('dandiset'),
dandiset__pk=self.kwargs['versions__dandiset__pk'],
version=self.kwargs['versions__version'],
)
Expand All @@ -244,6 +240,15 @@ def raise_if_unauthorized(self):
# The user does not have ownership permission
raise PermissionDenied

# Override this here to prevent the need for raise_if_unauthorized to exist in get_queryset
def get_object(self):
self.raise_if_unauthorized()
return super().get_object()

# Override this to skip the call to raise_if_unauthorized in AssetViewSet
def get_queryset(self):
return super(AssetViewSet, self).get_queryset()

# Redefine info and download actions to update swagger manual_parameters

@swagger_auto_schema(
Expand Down Expand Up @@ -389,36 +394,57 @@ def destroy(self, request, versions__dandiset__pk, versions__version, **kwargs):

@swagger_auto_schema(query_serializer=AssetListSerializer, responses={200: AssetSerializer})
def list(self, request, *args, **kwargs):
# Manually call this to ensure user is authorized
self.raise_if_unauthorized()

serializer = AssetListSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)

# Fetch initial queryset
queryset: QuerySet[Asset] = self.filter_queryset(
self.get_queryset().select_related('blob', 'embargoed_blob', 'zarr')
# Retrieve version first and then fetch assets, to remove a join
version = Version.objects.get(
dandiset__pk=self.kwargs['versions__dandiset__pk'],
version=self.kwargs['versions__version'],
)

# Don't include metadata field if not asked for
include_metadata = serializer.validated_data['metadata']
if not include_metadata:
queryset = queryset.defer('metadata')
# Apply filtering from included filter class first
asset_queryset = self.filter_queryset(version.assets.all())

# Must do glob pattern matching before pagination
glob_pattern: str | None = serializer.validated_data.get('glob')
if glob_pattern is not None:
# Escape special characters in the glob pattern. This is a security precaution taken
# since we are using postgres' regex search. A malicious user who knows this could
# include a regex as part of the glob expression, which postgres would happily parse
# and use if it's not escaped.
glob_pattern = f'^{re.escape(glob_pattern)}$'
queryset = queryset.filter(path__iregex=glob_pattern.replace('\\*', '.*'))
asset_queryset = asset_queryset.filter(path__iregex=glob_pattern.replace('\\*', '.*'))

# Retrieve just the first N asset IDs, and use them for pagination
page_of_asset_ids = self.paginate_queryset(asset_queryset.values_list('id', flat=True))

# Not sure when the page is ever None, but this condition is checked for compatibility with
# the original implementation: https://github.com/encode/django-rest-framework/blob/f4194c4684420ac86485d9610adf760064db381f/rest_framework/mixins.py#L37-L46
# This is checked here since the query can't continue if the page is `None` anyway
if page_of_asset_ids is None:
serializer = self.get_serializer(Asset.objects.none(), many=True)
return Response(serializer.data)

# Now we can retrieve the actual fully joined rows using the limited number of assets we're
# going to return
queryset = self.filter_queryset(
Asset.objects.filter(id__in=page_of_asset_ids).select_related(
'blob', 'embargoed_blob', 'zarr'
)
)

# Paginate and return
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True, metadata=include_metadata)
return self.get_paginated_response(serializer.data)
# Must apply this to the main queryset, since it affects the data returned
include_metadata = serializer.validated_data['metadata']
if not include_metadata:
queryset = queryset.defer('metadata')

# Paginate and return
serializer = self.get_serializer(queryset, many=True, metadata=include_metadata)
return Response(serializer.data)
return self.get_paginated_response(serializer.data)

@swagger_auto_schema(
query_serializer=AssetPathsQueryParameterSerializer,
Expand Down