diff --git a/dandiapi/api/models/asset.py b/dandiapi/api/models/asset.py index a330c9aba..dbd78448c 100644 --- a/dandiapi/api/models/asset.py +++ b/dandiapi/api/models/asset.py @@ -162,6 +162,13 @@ def is_blob(self): def is_zarr(self): return self.zarr is not None + @property + def is_embargoed(self) -> bool: + if self.blob is not None: + return self.blob.embargoed + + return self.zarr.embargoed # type: ignore # noqa: PGH003 + @property def size(self): if self.is_blob: @@ -195,7 +202,7 @@ def is_different_from( metadata: dict, path: str, ) -> bool: - from dandiapi.zarr.models import EmbargoedZarrArchive, ZarrArchive + from dandiapi.zarr.models import ZarrArchive if isinstance(asset_blob, AssetBlob) and self.blob is not None and self.blob != asset_blob: return True @@ -207,9 +214,6 @@ def is_different_from( ): return True - if isinstance(zarr_archive, EmbargoedZarrArchive): - raise NotImplementedError - if self.path != path: return True @@ -234,9 +238,8 @@ def full_metadata(self): 'access': [ { 'schemaKey': 'AccessRequirements', - # TODO: When embargoed zarrs land, include that logic here 'status': AccessType.EmbargoedAccess.value - if self.blob and self.blob.embargoed + if self.is_embargoed else AccessType.OpenAccess.value, } ], @@ -283,7 +286,4 @@ def total_size(cls): .aggregate(size=models.Sum('size'))['size'] or 0 for cls in (AssetBlob, ZarrArchive) - # adding of Zarrs to embargoed dandisets is not supported - # so no point of adding EmbargoedZarr here since would also result in error - # TODO: add EmbagoedZarr whenever supported ) diff --git a/dandiapi/api/services/asset/__init__.py b/dandiapi/api/services/asset/__init__.py index e1ef3f4d1..834e00520 100644 --- a/dandiapi/api/services/asset/__init__.py +++ b/dandiapi/api/services/asset/__init__.py @@ -168,15 +168,16 @@ def add_asset_to_version( raise ZarrArchiveBelongsToDifferentDandisetError with transaction.atomic(): - # Creating an asset in an OPEN dandiset that points to an embargoed blob results in that - # blob being unembargoed + # Creating an asset in an OPEN dandiset that points to an + # embargoed blob results in that blob being unembargoed. + # NOTE: This only applies to asset blobs, as zarrs cannot belong to + # multiple dandisets at once. if ( asset_blob is not None and asset_blob.embargoed and version.dandiset.embargo_status == Dandiset.EmbargoStatus.OPEN ): - asset_blob.embargoed = False - asset_blob.save() + AssetBlob.objects.filter(blob_id=asset_blob.blob_id).update(embargoed=False) transaction.on_commit( lambda: remove_asset_blob_embargoed_tag_task.delay(blob_id=asset_blob.blob_id) ) diff --git a/dandiapi/api/services/embargo/__init__.py b/dandiapi/api/services/embargo/__init__.py index e2f671998..e7a8841ca 100644 --- a/dandiapi/api/services/embargo/__init__.py +++ b/dandiapi/api/services/embargo/__init__.py @@ -1,67 +1,32 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor import logging from typing import TYPE_CHECKING -from botocore.config import Config -from django.conf import settings from django.db import transaction -from more_itertools import chunked from dandiapi.api.mail import send_dandiset_unembargoed_message from dandiapi.api.models import AssetBlob, Dandiset, Version from dandiapi.api.services import audit from dandiapi.api.services.asset.exceptions import DandisetOwnerRequiredError +from dandiapi.api.services.embargo.utils import _delete_object_tags, remove_dandiset_embargo_tags from dandiapi.api.services.exceptions import DandiError from dandiapi.api.services.metadata import validate_version_metadata from dandiapi.api.storage import get_boto_client from dandiapi.api.tasks import unembargo_dandiset_task +from dandiapi.zarr.models import ZarrArchive from .exceptions import ( AssetBlobEmbargoedError, - AssetTagRemovalError, DandisetActiveUploadsError, DandisetNotEmbargoedError, ) if TYPE_CHECKING: from django.contrib.auth.models import User - from mypy_boto3_s3 import S3Client logger = logging.getLogger(__name__) -ASSET_BLOB_TAG_REMOVAL_CHUNK_SIZE = 5000 - - -def _delete_asset_blob_tags(client: S3Client, blob: str): - client.delete_object_tagging( - Bucket=settings.DANDI_DANDISETS_BUCKET_NAME, - Key=blob, - ) - - -# NOTE: In testing this took ~2 minutes for 100,000 files -def _remove_dandiset_asset_blob_embargo_tags(dandiset: Dandiset): - client = get_boto_client(config=Config(max_pool_connections=100)) - embargoed_asset_blobs = ( - AssetBlob.objects.filter(embargoed=True, assets__versions__dandiset=dandiset) - .values_list('blob', flat=True) - .iterator(chunk_size=ASSET_BLOB_TAG_REMOVAL_CHUNK_SIZE) - ) - - # Chunk the blobs so we're never storing a list of all embargoed blobs - chunks = chunked(embargoed_asset_blobs, ASSET_BLOB_TAG_REMOVAL_CHUNK_SIZE) - for chunk in chunks: - with ThreadPoolExecutor(max_workers=100) as e: - futures = [ - e.submit(_delete_asset_blob_tags, client=client, blob=blob) for blob in chunk - ] - - # Check if any failed and raise exception if so - failed = [blob for i, blob in enumerate(chunk) if futures[i].exception() is not None] - if failed: - raise AssetTagRemovalError('Some blobs failed to remove tags', blobs=failed) @transaction.atomic() @@ -80,13 +45,17 @@ def unembargo_dandiset(ds: Dandiset, user: User): # Remove tags in S3 logger.info('Removing tags...') - _remove_dandiset_asset_blob_embargo_tags(ds) + remove_dandiset_embargo_tags(ds) - # Update embargoed flag on asset blobs - updated = AssetBlob.objects.filter(embargoed=True, assets__versions__dandiset=ds).update( + # Update embargoed flag on asset blobs and zarrs + updated_blobs = AssetBlob.objects.filter(embargoed=True, assets__versions__dandiset=ds).update( embargoed=False ) - logger.info('Updated %s asset blobs', updated) + updated_zarrs = ZarrArchive.objects.filter( + embargoed=True, assets__versions__dandiset=ds + ).update(embargoed=False) + logger.info('Updated %s asset blobs', updated_blobs) + logger.info('Updated %s zarrs', updated_zarrs) # Set status to OPEN Dandiset.objects.filter(pk=ds.pk).update(embargo_status=Dandiset.EmbargoStatus.OPEN) @@ -118,7 +87,7 @@ def remove_asset_blob_embargoed_tag(asset_blob: AssetBlob) -> None: if asset_blob.embargoed: raise AssetBlobEmbargoedError - _delete_asset_blob_tags(client=get_boto_client(), blob=asset_blob.blob.name) + _delete_object_tags(client=get_boto_client(), blob=asset_blob.blob.name) def kickoff_dandiset_unembargo(*, user: User, dandiset: Dandiset): diff --git a/dandiapi/api/services/embargo/utils.py b/dandiapi/api/services/embargo/utils.py new file mode 100644 index 000000000..1b7c405ff --- /dev/null +++ b/dandiapi/api/services/embargo/utils.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +import logging +from typing import TYPE_CHECKING + +from botocore.config import Config +from django.conf import settings +from django.db.models import Q +from more_itertools import chunked + +from dandiapi.api.models.asset import Asset +from dandiapi.api.storage import get_boto_client +from dandiapi.zarr.models import zarr_s3_path + +from .exceptions import AssetTagRemovalError + +if TYPE_CHECKING: + from mypy_boto3_s3 import S3Client + + from dandiapi.api.models.dandiset import Dandiset + + +logger = logging.getLogger(__name__) +TAG_REMOVAL_CHUNK_SIZE = 5000 + + +def retry(times: int, exceptions: tuple[type[Exception]]): + """ + Retry Decorator. + + Retries the wrapped function/method `times` times if the exceptions listed + in ``exceptions`` are thrown + + :param times: The number of times to repeat the wrapped function/method + :param exceptions: Lists of exceptions that trigger a retry attempt + """ + + def decorator(func): + def newfn(*args, **kwargs): + attempt = 0 + while attempt < times: + try: + return func(*args, **kwargs) + except exceptions: + attempt += 1 + return func(*args, **kwargs) + + return newfn + + return decorator + + +@retry(times=3, exceptions=(Exception,)) +def _delete_object_tags(client: S3Client, blob: str): + client.delete_object_tagging( + Bucket=settings.DANDI_DANDISETS_BUCKET_NAME, + Key=blob, + ) + + +@retry(times=3, exceptions=(Exception,)) +def _delete_zarr_object_tags(client: S3Client, zarr: str): + paginator = client.get_paginator('list_objects_v2') + pages = paginator.paginate( + Bucket=settings.DANDI_DANDISETS_BUCKET_NAME, Prefix=zarr_s3_path(zarr_id=zarr) + ) + + with ThreadPoolExecutor(max_workers=100) as e: + for page in pages: + keys = [obj['Key'] for obj in page.get('Contents', [])] + futures = [e.submit(_delete_object_tags, client=client, blob=key) for key in keys] + + # Check if any failed and raise exception if so + failed = [key for i, key in enumerate(keys) if futures[i].exception() is not None] + if failed: + raise AssetTagRemovalError('Some zarr files failed to remove tags', blobs=failed) + + +def remove_dandiset_embargo_tags(dandiset: Dandiset): + client = get_boto_client(config=Config(max_pool_connections=100)) + embargoed_assets = ( + Asset.objects.filter(versions__dandiset=dandiset) + .filter(Q(blob__embargoed=True) | Q(zarr__embargoed=True)) + .values_list('blob__blob', 'zarr__zarr_id') + .iterator(chunk_size=TAG_REMOVAL_CHUNK_SIZE) + ) + + # Chunk the blobs so we're never storing a list of all embargoed blobs + chunks = chunked(embargoed_assets, TAG_REMOVAL_CHUNK_SIZE) + for chunk in chunks: + futures = [] + with ThreadPoolExecutor(max_workers=100) as e: + for blob, zarr in chunk: + if blob is not None: + futures.append(e.submit(_delete_object_tags, client=client, blob=blob)) + if zarr is not None: + futures.append(e.submit(_delete_zarr_object_tags, client=client, zarr=zarr)) + + # Check if any failed and raise exception if so + failed = [blob for i, blob in enumerate(chunk) if futures[i].exception() is not None] + if failed: + raise AssetTagRemovalError('Some assets failed to remove tags', blobs=failed) diff --git a/dandiapi/api/services/publish/__init__.py b/dandiapi/api/services/publish/__init__.py index 882fa0c7c..f3a2ecb5b 100644 --- a/dandiapi/api/services/publish/__init__.py +++ b/dandiapi/api/services/publish/__init__.py @@ -55,7 +55,7 @@ def _lock_dandiset_for_publishing(*, user: User, dandiset: Dandiset) -> None: # if dandiset.embargo_status != Dandiset.EmbargoStatus.OPEN: raise NotAllowedError('Operation only allowed on OPEN dandisets', 400) - if dandiset.zarr_archives.exists() or dandiset.embargoed_zarr_archives.exists(): + if dandiset.zarr_archives.exists(): raise NotAllowedError('Cannot publish dandisets which contain zarrs', 400) with transaction.atomic(): diff --git a/dandiapi/api/tests/test_asset.py b/dandiapi/api/tests/test_asset.py index a08ff6b8b..adba44319 100644 --- a/dandiapi/api/tests/test_asset.py +++ b/dandiapi/api/tests/test_asset.py @@ -173,9 +173,6 @@ def test_asset_total_size( assert Asset.total_size() == asset_blob.size + zarr_archive.size - # TODO: add testing for embargoed zarr added, whenever embargoed zarrs - # supported, ATM they are not and tested by test_zarr_rest_create_embargoed_dandiset - @pytest.mark.django_db def test_asset_full_metadata(draft_asset_factory): @@ -233,6 +230,42 @@ def test_asset_full_metadata_zarr(draft_asset_factory, zarr_archive): } +@pytest.mark.django_db +def test_asset_full_metadata_access(draft_asset_factory, asset_blob_factory, zarr_archive_factory): + raw_metadata = { + 'foo': 'bar', + 'schemaVersion': settings.DANDI_SCHEMA_VERSION, + } + embargoed_zarr_asset: Asset = draft_asset_factory( + metadata=raw_metadata, blob=None, zarr=zarr_archive_factory(embargoed=True) + ) + open_zarr_asset: Asset = draft_asset_factory( + metadata=raw_metadata, blob=None, zarr=zarr_archive_factory(embargoed=False) + ) + + embargoed_blob_asset: Asset = draft_asset_factory( + metadata=raw_metadata, blob=asset_blob_factory(embargoed=True), zarr=None + ) + open_blob_asset: Asset = draft_asset_factory( + metadata=raw_metadata, blob=asset_blob_factory(embargoed=False), zarr=None + ) + + # Test that access is correctly inferred from embargo status + assert embargoed_zarr_asset.full_metadata['access'] == [ + {'schemaKey': 'AccessRequirements', 'status': AccessType.EmbargoedAccess.value} + ] + assert embargoed_blob_asset.full_metadata['access'] == [ + {'schemaKey': 'AccessRequirements', 'status': AccessType.EmbargoedAccess.value} + ] + + assert open_zarr_asset.full_metadata['access'] == [ + {'schemaKey': 'AccessRequirements', 'status': AccessType.OpenAccess.value} + ] + assert open_blob_asset.full_metadata['access'] == [ + {'schemaKey': 'AccessRequirements', 'status': AccessType.OpenAccess.value} + ] + + # API Tests @@ -1048,6 +1081,40 @@ def test_asset_create_existing_path(api_client, user, draft_version, asset_blob, assert resp.status_code == 409 +# Must use transaction=True as the tested function uses a transaction on_commit hook +@pytest.mark.django_db(transaction=True) +def test_asset_create_on_open_dandiset_embargoed_asset_blob( + api_client, user, draft_version, embargoed_asset_blob, mocker +): + mocked = mocker.patch('dandiapi.api.services.asset.remove_asset_blob_embargoed_tag_task.delay') + + assert embargoed_asset_blob.embargoed + + assign_perm('owner', user, draft_version.dandiset) + api_client.force_authenticate(user) + + path = 'test/create/asset.txt' + metadata = { + 'encodingFormat': 'application/x-nwb', + 'path': path, + } + + resp = api_client.post( + f'/api/dandisets/{draft_version.dandiset.identifier}' + f'/versions/{draft_version.version}/assets/', + {'metadata': metadata, 'blob_id': embargoed_asset_blob.blob_id}, + format='json', + ) + assert resp.status_code == 200 + + # Check that asset blob is no longer embargoed + embargoed_asset_blob.refresh_from_db() + assert not embargoed_asset_blob.embargoed + + # Check that tag removal function called + mocked.assert_called_once() + + @pytest.mark.django_db def test_asset_rest_rename(api_client, user, draft_version, asset_blob): assign_perm('owner', user, draft_version.dandiset) diff --git a/dandiapi/api/tests/test_unembargo.py b/dandiapi/api/tests/test_unembargo.py index ebd86b336..8fa6149f0 100644 --- a/dandiapi/api/tests/test_unembargo.py +++ b/dandiapi/api/tests/test_unembargo.py @@ -10,7 +10,6 @@ from dandiapi.api.models.version import Version from dandiapi.api.services.embargo import ( AssetBlobEmbargoedError, - _remove_dandiset_asset_blob_embargo_tags, remove_asset_blob_embargoed_tag, unembargo_dandiset, ) @@ -18,22 +17,20 @@ AssetTagRemovalError, DandisetActiveUploadsError, ) +from dandiapi.api.services.embargo.utils import ( + _delete_zarr_object_tags, + remove_dandiset_embargo_tags, +) from dandiapi.api.services.exceptions import DandiError +from dandiapi.api.storage import get_boto_client from dandiapi.api.tasks import unembargo_dandiset_task +from dandiapi.zarr.models import ZarrArchive, ZarrArchiveStatus, zarr_s3_path +from dandiapi.zarr.tasks import ingest_zarr_archive if TYPE_CHECKING: from dandiapi.api.models.asset import AssetBlob -@pytest.mark.django_db -def test_remove_asset_blob_embargoed_tag_fails_on_embargod(embargoed_asset_blob, asset_blob): - with pytest.raises(AssetBlobEmbargoedError): - remove_asset_blob_embargoed_tag(embargoed_asset_blob) - - # Test that error not raised on non-embargoed asset blob - remove_asset_blob_embargoed_tag(asset_blob) - - @pytest.mark.django_db def test_kickoff_dandiset_unembargo_dandiset_not_embargoed( api_client, user, dandiset_factory, draft_version_factory @@ -125,16 +122,16 @@ def test_unembargo_dandiset_uploads_exist(draft_version_factory, upload_factory, @pytest.mark.django_db -def test_remove_dandiset_asset_blob_embargo_tags_chunks( +def test_remove_dandiset_embargo_tags_chunks( draft_version_factory, asset_factory, embargoed_asset_blob_factory, mocker, ): delete_asset_blob_tags_mock = mocker.patch( - 'dandiapi.api.services.embargo._delete_asset_blob_tags' + 'dandiapi.api.services.embargo.utils._delete_object_tags' ) - chunk_size = mocker.patch('dandiapi.api.services.embargo.ASSET_BLOB_TAG_REMOVAL_CHUNK_SIZE', 2) + chunk_size = mocker.patch('dandiapi.api.services.embargo.utils.TAG_REMOVAL_CHUNK_SIZE', 2) draft_version: Version = draft_version_factory( dandiset__embargo_status=Dandiset.EmbargoStatus.UNEMBARGOING @@ -144,32 +141,92 @@ def test_remove_dandiset_asset_blob_embargo_tags_chunks( asset = asset_factory(blob=embargoed_asset_blob_factory()) draft_version.assets.add(asset) - _remove_dandiset_asset_blob_embargo_tags(dandiset=ds) + remove_dandiset_embargo_tags(dandiset=ds) - # Assert that _delete_asset_blob_tags was called chunk_size +1 times, to ensure that it works + # Assert that _delete_object_tags was called chunk_size +1 times, to ensure that it works # correctly across chunks assert len(delete_asset_blob_tags_mock.mock_calls) == chunk_size + 1 @pytest.mark.django_db -def test_delete_asset_blob_tags_fails( +def test_remove_dandiset_embargo_tags_fails_remove_tags( draft_version_factory, asset_factory, embargoed_asset_blob_factory, mocker, ): - mocker.patch('dandiapi.api.services.embargo._delete_asset_blob_tags', side_effect=ValueError) + # Patch function to raise error when called + mocker.patch('dandiapi.api.services.embargo.utils._delete_object_tags', side_effect=ValueError) + + # Create dandiset/version and add assets draft_version: Version = draft_version_factory( dandiset__embargo_status=Dandiset.EmbargoStatus.UNEMBARGOING ) ds: Dandiset = draft_version.dandiset - asset = asset_factory(blob=embargoed_asset_blob_factory()) - draft_version.assets.add(asset) + for _ in range(2): + asset = asset_factory(blob=embargoed_asset_blob_factory()) + draft_version.assets.add(asset) - # Check that if an exception within `_delete_asset_blob_tags` is raised, it's propagated upwards - # as an AssetTagRemovalError + # Remove tags with pytest.raises(AssetTagRemovalError): - _remove_dandiset_asset_blob_embargo_tags(dandiset=ds) + remove_dandiset_embargo_tags(dandiset=ds) + + +@pytest.mark.django_db +def test_remove_asset_blob_embargoed_tag_fails_on_embargod(embargoed_asset_blob, asset_blob): + with pytest.raises(AssetBlobEmbargoedError): + remove_asset_blob_embargoed_tag(embargoed_asset_blob) + + # Test that error not raised on non-embargoed asset blob + remove_asset_blob_embargoed_tag(asset_blob) + + +@pytest.mark.django_db +def test_remove_asset_blob_embargoed_tag(asset_blob, mocker): + mocked_func = mocker.patch('dandiapi.api.services.embargo._delete_object_tags') + remove_asset_blob_embargoed_tag(asset_blob) + mocked_func.assert_called_once() + + +@pytest.mark.django_db +def test_delete_zarr_object_tags_fails_remove_tags(zarr_archive, zarr_file_factory, mocker): + mocked = mocker.patch( + 'dandiapi.api.services.embargo.utils._delete_object_tags', side_effect=ValueError + ) + files = [zarr_file_factory(zarr_archive) for _ in range(2)] + + with pytest.raises(AssetTagRemovalError): + _delete_zarr_object_tags(client=get_boto_client(), zarr=zarr_archive.zarr_id) + + # Check that each file was called 4 times total. Once initially, and 3 retries + assert mocked.call_count == 4 * len(files) + for file in files: + calls = [ + c + for c in mocked.mock_calls + if c.kwargs['blob'] + == zarr_s3_path(zarr_id=zarr_archive.zarr_id, zarr_path=str(file.path)) + ] + assert len(calls) == 4 + + +@pytest.mark.django_db +def test_delete_zarr_object_tags(zarr_archive, zarr_file_factory, mocker): + mocked_delete_object_tags = mocker.patch( + 'dandiapi.api.services.embargo.utils._delete_object_tags' + ) + + # Create files + files = [zarr_file_factory(zarr_archive) for _ in range(10)] + + # This should call the mocked function for each file + _delete_zarr_object_tags(client=get_boto_client(), zarr=zarr_archive.zarr_id) + + assert mocked_delete_object_tags.call_count == len(files) + + called_blobs = sorted([call.kwargs['blob'] for call in mocked_delete_object_tags.mock_calls]) + file_bucket_paths = sorted([zarr_archive.s3_path(str(file.path)) for file in files]) + assert called_blobs == file_bucket_paths @pytest.mark.django_db @@ -177,6 +234,8 @@ def test_unembargo_dandiset( draft_version_factory, asset_factory, embargoed_asset_blob_factory, + embargoed_zarr_archive_factory, + zarr_file_factory, mocker, mailoutbox, user_factory, @@ -190,20 +249,29 @@ def test_unembargo_dandiset( assign_perm('owner', user, ds) embargoed_blob: AssetBlob = embargoed_asset_blob_factory() - asset = asset_factory(blob=embargoed_blob) - draft_version.assets.add(asset) - assert embargoed_blob.embargoed + draft_version.assets.add(asset_factory(blob=embargoed_blob)) + + zarr_archive: ZarrArchive = embargoed_zarr_archive_factory( + dandiset=ds, status=ZarrArchiveStatus.UPLOADED + ) + for _ in range(5): + zarr_file_factory(zarr_archive) + ingest_zarr_archive(zarr_id=zarr_archive.zarr_id) + zarr_archive.refresh_from_db() + draft_version.assets.add(asset_factory(zarr=zarr_archive, blob=None)) + + assert all(asset.is_embargoed for asset in draft_version.assets.all()) # Patch this function to check if it's been called, since we can't test the tagging directly - patched = mocker.patch('dandiapi.api.services.embargo._delete_asset_blob_tags') + patched = mocker.patch('dandiapi.api.services.embargo.utils._delete_object_tags') unembargo_dandiset(ds, owners[0]) - patched.assert_called_once() - embargoed_blob.refresh_from_db() + assert patched.call_count == 1 + zarr_archive.file_count + assert not any(asset.is_embargoed for asset in draft_version.assets.all()) + ds.refresh_from_db() draft_version.refresh_from_db() - assert not embargoed_blob.embargoed assert ds.embargo_status == Dandiset.EmbargoStatus.OPEN assert ( draft_version.metadata['access'][0]['status'] diff --git a/dandiapi/api/views/asset.py b/dandiapi/api/views/asset.py index ef335c16c..062622c17 100644 --- a/dandiapi/api/views/asset.py +++ b/dandiapi/api/views/asset.py @@ -87,10 +87,8 @@ def raise_if_unauthorized(self): if asset_id is None: return - asset = get_object_or_404(Asset.objects.select_related('blob'), asset_id=asset_id) - - # TODO: When EmbargoedZarrArchive is implemented, check that as well - if not (asset.blob and asset.blob.embargoed): + asset = get_object_or_404(Asset.objects.select_related('blob', 'zarr'), asset_id=asset_id) + if not asset.is_embargoed: return # Clients must be authenticated to access it diff --git a/dandiapi/conftest.py b/dandiapi/conftest.py index ce8826796..332b8640e 100644 --- a/dandiapi/conftest.py +++ b/dandiapi/conftest.py @@ -23,7 +23,7 @@ UploadFactory, UserFactory, ) -from dandiapi.zarr.tests.factories import ZarrArchiveFactory +from dandiapi.zarr.tests.factories import EmbargoedZarrArchiveFactory, ZarrArchiveFactory from dandiapi.zarr.tests.utils import upload_zarr_file if TYPE_CHECKING: @@ -47,6 +47,7 @@ # zarr app register(ZarrArchiveFactory) +register(EmbargoedZarrArchiveFactory, _name='embargoed_zarr_archive') # Register zarr file/directory factories diff --git a/dandiapi/zarr/admin.py b/dandiapi/zarr/admin.py index b0fba4eda..13bcd2916 100644 --- a/dandiapi/zarr/admin.py +++ b/dandiapi/zarr/admin.py @@ -3,17 +3,21 @@ from django.contrib import admin, messages from django.utils.translation import ngettext -from dandiapi.zarr.models import EmbargoedZarrArchive, ZarrArchive +from dandiapi.zarr.models import ZarrArchive from dandiapi.zarr.tasks import ingest_zarr_archive @admin.register(ZarrArchive) class ZarrArchiveAdmin(admin.ModelAdmin): search_fields = ['zarr_id', 'name'] - list_display = ['id', 'zarr_id', 'name', 'dandiset'] + list_display = ['id', 'zarr_id', 'name', 'dandiset', 'public'] list_display_links = ['id', 'zarr_id', 'name'] actions = ('ingest_zarr_archive',) + @admin.display(boolean=True, description='Public Access', ordering='embargoed') + def public(self, obj: ZarrArchive): + return not obj.embargoed + @admin.action(description='Ingest selected zarr archives') def ingest_zarr_archive(self, request, queryset): for zarr in queryset: @@ -30,10 +34,3 @@ def ingest_zarr_archive(self, request, queryset): % queryset.count(), messages.SUCCESS, ) - - -@admin.register(EmbargoedZarrArchive) -class EmbargoedZarrArchiveAdmin(admin.ModelAdmin): - search_fields = ['zarr_id', 'name'] - list_display = ['id', 'zarr_id', 'name', 'dandiset'] - list_display_links = ['id', 'zarr_id', 'name'] diff --git a/dandiapi/zarr/migrations/0004_zarrarchive_embargoed_delete_embargoedzarrarchive.py b/dandiapi/zarr/migrations/0004_zarrarchive_embargoed_delete_embargoedzarrarchive.py new file mode 100644 index 000000000..722464124 --- /dev/null +++ b/dandiapi/zarr/migrations/0004_zarrarchive_embargoed_delete_embargoedzarrarchive.py @@ -0,0 +1,21 @@ +# Generated by Django 4.1.13 on 2024-08-21 16:06 +from __future__ import annotations + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ('zarr', '0003_alter_embargoedzarrarchive_options_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='zarrarchive', + name='embargoed', + field=models.BooleanField(default=False), + ), + migrations.DeleteModel( + name='EmbargoedZarrArchive', + ), + ] diff --git a/dandiapi/zarr/models.py b/dandiapi/zarr/models.py index 9ffd6a55b..512b3266e 100644 --- a/dandiapi/zarr/models.py +++ b/dandiapi/zarr/models.py @@ -15,6 +15,14 @@ logger = logging.getLogger(name=__name__) +# TODO: Move this somewhere better? +def zarr_s3_path(zarr_id: str, zarr_path: str = ''): + return ( + f'{settings.DANDI_DANDISETS_BUCKET_PREFIX}{settings.DANDI_ZARR_PREFIX_NAME}/' + f'{zarr_id}/{zarr_path}' + ) + + # The status of the zarr ingestion (checksums, size, file count) class ZarrArchiveStatus(models.TextChoices): PENDING = 'Pending' @@ -23,14 +31,14 @@ class ZarrArchiveStatus(models.TextChoices): COMPLETE = 'Complete' -class BaseZarrArchive(TimeStampedModel): +class ZarrArchive(TimeStampedModel): UUID_REGEX = r'[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}' INGEST_ERROR_MSG = 'Zarr archive is currently ingesting or has already ingested' + storage = get_storage() class Meta: ordering = ['created'] get_latest_by = 'modified' - abstract = True constraints = [ models.UniqueConstraint( name='%(app_label)s-%(class)s-unique-name', @@ -50,7 +58,9 @@ class Meta: ), ] + dandiset = models.ForeignKey(Dandiset, related_name='zarr_archives', on_delete=models.CASCADE) zarr_id = models.UUIDField(unique=True, default=uuid4, db_index=True) + embargoed = models.BooleanField(default=False) name = models.CharField(max_length=512) file_count = models.BigIntegerField(default=0) size = models.BigIntegerField(default=0) @@ -72,6 +82,10 @@ def s3_url(self): parsed = urlparse(signed_url) return urlunparse((parsed[0], parsed[1], parsed[2], '', '', '')) + def s3_path(self, zarr_path: str) -> str: + """Generate a full S3 object path from a path in this zarr_archive.""" + return zarr_s3_path(str(self.zarr_id), zarr_path) + def generate_upload_urls(self, path_md5s: list[dict]): return [ self.storage.generate_presigned_put_object_url(self.s3_path(o['path']), o['base64md5']) @@ -94,29 +108,3 @@ def delete_files(self, paths: list[str]): # Files deleted, mark pending self.mark_pending() self.save() - - -class ZarrArchive(BaseZarrArchive): - storage = get_storage() - dandiset = models.ForeignKey(Dandiset, related_name='zarr_archives', on_delete=models.CASCADE) - - def s3_path(self, zarr_path: str) -> str: - """Generate a full S3 object path from a path in this zarr_archive.""" - return ( - f'{settings.DANDI_DANDISETS_BUCKET_PREFIX}{settings.DANDI_ZARR_PREFIX_NAME}/' - f'{self.zarr_id}/{zarr_path}' - ) - - -class EmbargoedZarrArchive(BaseZarrArchive): - storage = get_storage() - dandiset = models.ForeignKey( - Dandiset, related_name='embargoed_zarr_archives', on_delete=models.CASCADE - ) - - def s3_path(self, zarr_path: str) -> str: - """Generate a full S3 object path from a path in this zarr_archive.""" - return ( - f'{settings.DANDI_ZARR_PREFIX_NAME}/' - f'{self.dandiset.identifier}/{self.zarr_id}/{zarr_path}' - ) diff --git a/dandiapi/zarr/tests/factories.py b/dandiapi/zarr/tests/factories.py index d536e3152..b9e472225 100644 --- a/dandiapi/zarr/tests/factories.py +++ b/dandiapi/zarr/tests/factories.py @@ -13,3 +13,7 @@ class Meta: zarr_id = factory.Faker('uuid4') name = factory.Faker('catch_phrase') dandiset = factory.SubFactory(DandisetFactory) + + +class EmbargoedZarrArchiveFactory(ZarrArchiveFactory): + embargoed = True diff --git a/dandiapi/zarr/tests/test_zarr.py b/dandiapi/zarr/tests/test_zarr.py index 5a0a886ff..b38813c8a 100644 --- a/dandiapi/zarr/tests/test_zarr.py +++ b/dandiapi/zarr/tests/test_zarr.py @@ -99,8 +99,7 @@ def test_zarr_rest_create_embargoed_dandiset( }, format='json', ) - assert resp.status_code == 400 - assert resp.json() == ['Cannot add zarr to embargoed dandiset'] + assert resp.status_code == 200 @pytest.mark.django_db @@ -128,6 +127,37 @@ def test_zarr_rest_get(authenticated_api_client, storage, zarr_archive_factory, } +@pytest.mark.django_db +def test_zarr_rest_get_embargoed(authenticated_api_client, user, embargoed_zarr_archive): + assert user not in embargoed_zarr_archive.dandiset.owners + + resp = authenticated_api_client.get(f'/api/zarr/{embargoed_zarr_archive.zarr_id}/') + assert resp.status_code == 404 + + embargoed_zarr_archive.dandiset.set_owners([user]) + resp = authenticated_api_client.get(f'/api/zarr/{embargoed_zarr_archive.zarr_id}/') + assert resp.status_code == 200 + + +@pytest.mark.django_db +def test_zarr_rest_list_embargoed(authenticated_api_client, user, dandiset, zarr_archive_factory): + # Create some embargoed and some open zarrs + open_zarrs = [zarr_archive_factory() for _ in range(3)] + embargoed_zarrs = [zarr_archive_factory(embargoed=True, dandiset=dandiset) for _ in range(3)] + + # Assert only open zarrs are returned + zarrs = authenticated_api_client.get('/api/zarr/').json()['results'] + assert sorted(z['zarr_id'] for z in zarrs) == sorted(z.zarr_id for z in open_zarrs) + + # Assert that all zarrs returned when user has access to embargoed zarrs + dandiset.set_owners([user]) + zarrs = authenticated_api_client.get('/api/zarr/').json()['results'] + assert len(zarrs) == len(open_zarrs + embargoed_zarrs) + assert sorted(z['zarr_id'] for z in zarrs) == sorted( + z.zarr_id for z in (open_zarrs + embargoed_zarrs) + ) + + @pytest.mark.django_db def test_zarr_rest_list_filter(authenticated_api_client, dandiset_factory, zarr_archive_factory): # Create dandisets and zarrs diff --git a/dandiapi/zarr/views/__init__.py b/dandiapi/zarr/views/__init__.py index 39510008d..dd8a47c78 100644 --- a/dandiapi/zarr/views/__init__.py +++ b/dandiapi/zarr/views/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from django.db import IntegrityError, transaction +from django.db.models import Q from django.http import HttpResponseRedirect from drf_yasg.utils import no_body, swagger_auto_schema from rest_framework import serializers, status @@ -14,7 +15,7 @@ from rest_framework.utils.urls import replace_query_param from rest_framework.viewsets import ReadOnlyModelViewSet -from dandiapi.api.models.dandiset import Dandiset +from dandiapi.api.models.dandiset import Dandiset, DandisetUserObjectPermission from dandiapi.api.services import audit from dandiapi.api.storage import get_boto_client from dandiapi.api.views.pagination import DandiPagination @@ -97,6 +98,22 @@ class ZarrViewSet(ReadOnlyModelViewSet): lookup_field = 'zarr_id' lookup_value_regex = ZarrArchive.UUID_REGEX + def get_queryset(self) -> QuerySet: + qs = super().get_queryset() + + # Start with just public zarrs, in case user is not authenticated + queryset_filter = Q(embargoed=False) + + # Filter zarrs to either open access or owned + if self.request.user.is_authenticated: + user_owned_dandiset_ids = DandisetUserObjectPermission.objects.filter( + user=self.request.user, permission__codename='owner' + ).values_list('content_object_id', flat=True) + queryset_filter |= Q(dandiset_id__in=user_owned_dandiset_ids) + + # Apply filter + return qs.filter(queryset_filter) + @swagger_auto_schema( query_serializer=ZarrListQuerySerializer, responses={200: ZarrListSerializer(many=True)}, @@ -105,10 +122,10 @@ class ZarrViewSet(ReadOnlyModelViewSet): def list(self, request, *args, **kwargs): query_serializer = ZarrListQuerySerializer(data=self.request.query_params) query_serializer.is_valid(raise_exception=True) - - # Add filters from query parameters data = query_serializer.validated_data queryset: QuerySet[ZarrArchive] = self.get_queryset() + + # Add filters from query parameters if 'dandiset' in data: queryset = queryset.filter(dandiset=int(data['dandiset'].lstrip('0'))) if 'name' in data: @@ -136,8 +153,6 @@ def create(self, request): ) if not self.request.user.has_perm('owner', dandiset): raise PermissionDenied - if dandiset.embargo_status != Dandiset.EmbargoStatus.OPEN: - raise ValidationError('Cannot add zarr to embargoed dandiset') zarr_archive: ZarrArchive = ZarrArchive(name=name, dandiset=dandiset) with transaction.atomic(): # Use nested transaction block to prevent zarr creation race condition