diff --git a/dandiapi/api/tests/test_asset.py b/dandiapi/api/tests/test_asset.py index 14588dec9..f1f381ecc 100644 --- a/dandiapi/api/tests/test_asset.py +++ b/dandiapi/api/tests/test_asset.py @@ -1539,4 +1539,5 @@ def test_asset_rest_glob(api_client, asset_factory, version, glob_pattern, expec {'glob': glob_pattern}, ) - assert expected_paths == [asset['path'] for asset in resp.json()['results']] + # Sort both lists before comparing since ordering is not considered + assert sorted(expected_paths) == sorted([asset['path'] for asset in resp.json()['results']]) diff --git a/dandiapi/api/views/asset.py b/dandiapi/api/views/asset.py index 34d35e8f1..cef64a44c 100644 --- a/dandiapi/api/views/asset.py +++ b/dandiapi/api/views/asset.py @@ -22,7 +22,6 @@ # This should only be used for type interrogation, never instantiation MinioStorage = type('FakeMinioStorage', (), {}) -from typing import TYPE_CHECKING from django.conf import settings from django.db import transaction @@ -57,9 +56,6 @@ AssetValidationSerializer, ) -if TYPE_CHECKING: - from django.db.models import QuerySet - class AssetFilter(filters.FilterSet): path = filters.CharFilter(lookup_expr='istartswith') @@ -232,7 +228,7 @@ class NestedAssetViewSet(NestedViewSetMixin, AssetViewSet, ReadOnlyModelViewSet) def raise_if_unauthorized(self): version = get_object_or_404( - Version, + Version.objects.select_related('dandiset'), dandiset__pk=self.kwargs['versions__dandiset__pk'], version=self.kwargs['versions__version'], ) @@ -244,6 +240,15 @@ def raise_if_unauthorized(self): # The user does not have ownership permission raise PermissionDenied + # Override this here to prevent the need for raise_if_unauthorized to exist in get_queryset + def get_object(self): + self.raise_if_unauthorized() + return super().get_object() + + # Override this to skip the call to raise_if_unauthorized in AssetViewSet + def get_queryset(self): + return super(AssetViewSet, self).get_queryset() + # Redefine info and download actions to update swagger manual_parameters @swagger_auto_schema( @@ -389,19 +394,22 @@ def destroy(self, request, versions__dandiset__pk, versions__version, **kwargs): @swagger_auto_schema(query_serializer=AssetListSerializer, responses={200: AssetSerializer}) def list(self, request, *args, **kwargs): + # Manually call this to ensure user is authorized + self.raise_if_unauthorized() + serializer = AssetListSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) - # Fetch initial queryset - queryset: QuerySet[Asset] = self.filter_queryset( - self.get_queryset().select_related('blob', 'embargoed_blob', 'zarr') + # Retrieve version first and then fetch assets, to remove a join + version = Version.objects.get( + dandiset__pk=self.kwargs['versions__dandiset__pk'], + version=self.kwargs['versions__version'], ) - # Don't include metadata field if not asked for - include_metadata = serializer.validated_data['metadata'] - if not include_metadata: - queryset = queryset.defer('metadata') + # Apply filtering from included filter class first + asset_queryset = self.filter_queryset(version.assets.all()) + # Must do glob pattern matching before pagination glob_pattern: str | None = serializer.validated_data.get('glob') if glob_pattern is not None: # Escape special characters in the glob pattern. This is a security precaution taken @@ -409,16 +417,34 @@ def list(self, request, *args, **kwargs): # include a regex as part of the glob expression, which postgres would happily parse # and use if it's not escaped. glob_pattern = f'^{re.escape(glob_pattern)}$' - queryset = queryset.filter(path__iregex=glob_pattern.replace('\\*', '.*')) + asset_queryset = asset_queryset.filter(path__iregex=glob_pattern.replace('\\*', '.*')) + + # Retrieve just the first N asset IDs, and use them for pagination + page_of_asset_ids = self.paginate_queryset(asset_queryset.values_list('id', flat=True)) + + # Not sure when the page is ever None, but this condition is checked for compatibility with + # the original implementation: https://github.com/encode/django-rest-framework/blob/f4194c4684420ac86485d9610adf760064db381f/rest_framework/mixins.py#L37-L46 + # This is checked here since the query can't continue if the page is `None` anyway + if page_of_asset_ids is None: + serializer = self.get_serializer(Asset.objects.none(), many=True) + return Response(serializer.data) + + # Now we can retrieve the actual fully joined rows using the limited number of assets we're + # going to return + queryset = self.filter_queryset( + Asset.objects.filter(id__in=page_of_asset_ids).select_related( + 'blob', 'embargoed_blob', 'zarr' + ) + ) - # Paginate and return - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True, metadata=include_metadata) - return self.get_paginated_response(serializer.data) + # Must apply this to the main queryset, since it affects the data returned + include_metadata = serializer.validated_data['metadata'] + if not include_metadata: + queryset = queryset.defer('metadata') + # Paginate and return serializer = self.get_serializer(queryset, many=True, metadata=include_metadata) - return Response(serializer.data) + return self.get_paginated_response(serializer.data) @swagger_auto_schema( query_serializer=AssetPathsQueryParameterSerializer,