Skip to content

Commit

Permalink
Merge pull request #1904 from dandi/improve-asset-list-query
Browse files Browse the repository at this point in the history
Optimize asset list endpoint
  • Loading branch information
jjnesbitt authored Mar 25, 2024
2 parents ca7c89b + 886e297 commit 3a34cc0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
3 changes: 2 additions & 1 deletion dandiapi/api/tests/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,4 +1539,5 @@ def test_asset_rest_glob(api_client, asset_factory, version, glob_pattern, expec
{'glob': glob_pattern},
)

assert expected_paths == [asset['path'] for asset in resp.json()['results']]
# Sort both lists before comparing since ordering is not considered
assert sorted(expected_paths) == sorted([asset['path'] for asset in resp.json()['results']])
64 changes: 45 additions & 19 deletions dandiapi/api/views/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# This should only be used for type interrogation, never instantiation
MinioStorage = type('FakeMinioStorage', (), {})

from typing import TYPE_CHECKING

from django.conf import settings
from django.db import transaction
Expand Down Expand Up @@ -57,9 +56,6 @@
AssetValidationSerializer,
)

if TYPE_CHECKING:
from django.db.models import QuerySet


class AssetFilter(filters.FilterSet):
path = filters.CharFilter(lookup_expr='istartswith')
Expand Down Expand Up @@ -232,7 +228,7 @@ class NestedAssetViewSet(NestedViewSetMixin, AssetViewSet, ReadOnlyModelViewSet)

def raise_if_unauthorized(self):
version = get_object_or_404(
Version,
Version.objects.select_related('dandiset'),
dandiset__pk=self.kwargs['versions__dandiset__pk'],
version=self.kwargs['versions__version'],
)
Expand All @@ -244,6 +240,15 @@ def raise_if_unauthorized(self):
# The user does not have ownership permission
raise PermissionDenied

# Override this here to prevent the need for raise_if_unauthorized to exist in get_queryset
def get_object(self):
self.raise_if_unauthorized()
return super().get_object()

# Override this to skip the call to raise_if_unauthorized in AssetViewSet
def get_queryset(self):
return super(AssetViewSet, self).get_queryset()

# Redefine info and download actions to update swagger manual_parameters

@swagger_auto_schema(
Expand Down Expand Up @@ -389,36 +394,57 @@ def destroy(self, request, versions__dandiset__pk, versions__version, **kwargs):

@swagger_auto_schema(query_serializer=AssetListSerializer, responses={200: AssetSerializer})
def list(self, request, *args, **kwargs):
# Manually call this to ensure user is authorized
self.raise_if_unauthorized()

serializer = AssetListSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)

# Fetch initial queryset
queryset: QuerySet[Asset] = self.filter_queryset(
self.get_queryset().select_related('blob', 'embargoed_blob', 'zarr')
# Retrieve version first and then fetch assets, to remove a join
version = Version.objects.get(
dandiset__pk=self.kwargs['versions__dandiset__pk'],
version=self.kwargs['versions__version'],
)

# Don't include metadata field if not asked for
include_metadata = serializer.validated_data['metadata']
if not include_metadata:
queryset = queryset.defer('metadata')
# Apply filtering from included filter class first
asset_queryset = self.filter_queryset(version.assets.all())

# Must do glob pattern matching before pagination
glob_pattern: str | None = serializer.validated_data.get('glob')
if glob_pattern is not None:
# Escape special characters in the glob pattern. This is a security precaution taken
# since we are using postgres' regex search. A malicious user who knows this could
# include a regex as part of the glob expression, which postgres would happily parse
# and use if it's not escaped.
glob_pattern = f'^{re.escape(glob_pattern)}$'
queryset = queryset.filter(path__iregex=glob_pattern.replace('\\*', '.*'))
asset_queryset = asset_queryset.filter(path__iregex=glob_pattern.replace('\\*', '.*'))

# Retrieve just the first N asset IDs, and use them for pagination
page_of_asset_ids = self.paginate_queryset(asset_queryset.values_list('id', flat=True))

# Not sure when the page is ever None, but this condition is checked for compatibility with
# the original implementation: https://github.com/encode/django-rest-framework/blob/f4194c4684420ac86485d9610adf760064db381f/rest_framework/mixins.py#L37-L46
# This is checked here since the query can't continue if the page is `None` anyway
if page_of_asset_ids is None:
serializer = self.get_serializer(Asset.objects.none(), many=True)
return Response(serializer.data)

# Now we can retrieve the actual fully joined rows using the limited number of assets we're
# going to return
queryset = self.filter_queryset(
Asset.objects.filter(id__in=page_of_asset_ids).select_related(
'blob', 'embargoed_blob', 'zarr'
)
)

# Paginate and return
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True, metadata=include_metadata)
return self.get_paginated_response(serializer.data)
# Must apply this to the main queryset, since it affects the data returned
include_metadata = serializer.validated_data['metadata']
if not include_metadata:
queryset = queryset.defer('metadata')

# Paginate and return
serializer = self.get_serializer(queryset, many=True, metadata=include_metadata)
return Response(serializer.data)
return self.get_paginated_response(serializer.data)

@swagger_auto_schema(
query_serializer=AssetPathsQueryParameterSerializer,
Expand Down

0 comments on commit 3a34cc0

Please sign in to comment.