diff --git a/breadbox/breadbox/api/dataset_uploads.py b/breadbox/breadbox/api/dataset_uploads.py index 5c176c31c..f22cba5d4 100644 --- a/breadbox/breadbox/api/dataset_uploads.py +++ b/breadbox/breadbox/api/dataset_uploads.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.encoders import jsonable_encoder from breadbox.compute.dataset_uploads_tasks import run_dataset_upload +from ..schemas.custom_http_exception import UserError from ..schemas.dataset import DatasetParams, AddDatasetResponse from .dependencies import get_user @@ -64,6 +65,12 @@ def add_dataset_uploads( - `col_type`: Annotation type for the column. Annotation types may include: `continuous`, `categorical`, `text`, or `list_strings` """ + + if not dataset.is_transient and dataset.expiry_in_seconds is not None: + raise UserError( + "Dataset was not marked as 'transient' but expiry_in_seconds is set." + ) + # Converts a data type (like a Pydantic model) to something compatible with JSON, in this case a dict. Although Celery uses a JSON serializer to serialize arguments to tasks by default, pydantic models are too complex for their default serializer. Pydantic models have a built-in .dict() method but it turns out it doesn't convert enums to strings which celery can't JSON serialize, so I opted to use fastapi's jsonable_encoder() which appears to successfully json serialize enums dataset_json = jsonable_encoder(dataset) result = run_dataset_upload.delay(dataset_json, user) # pyright: ignore diff --git a/breadbox/breadbox/commands.py b/breadbox/breadbox/commands.py index b8ea68397..7af22e06c 100644 --- a/breadbox/breadbox/commands.py +++ b/breadbox/breadbox/commands.py @@ -9,6 +9,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from breadbox.crud.dataset import find_expired_datasets, delete_dataset from breadbox.db.session import SessionWithUser from breadbox.config import Settings, get_settings from breadbox.crud.access_control import PUBLIC_GROUP_ID, TRANSIENT_GROUP_ID @@ -26,6 +27,7 @@ import shutil from db_load import upload_example_datasets import hashlib +from datetime import timedelta @click.group() @@ -33,6 +35,27 @@ def cli(): pass +@cli.command() +@click.option("--dryrun", is_flag=True, default=False) +@click.option("--maxdays", default=60, type=int) +def delete_expired_datasets(maxdays, dryrun): + db = _get_db_connection() + settings = get_settings() + expired_datasets = find_expired_datasets(db, timedelta(days=maxdays)) + + print(f"Found {len(expired_datasets)} expired datasets") + + with transaction(db): + for dataset in expired_datasets: + dataset_summary = f"{dataset.id} (upload_date={dataset.upload_date}, expiry={dataset.expiry})" + if dryrun: + print(f"dryrun: Would have deleted {dataset_summary}") + else: + print(f"Deleting {dataset_summary}") + delete_dataset(db, db.user, dataset, settings.filestore_location) + print("Done") + + @cli.command() @click.argument("user_email") @click.argument("group_name") diff --git a/breadbox/breadbox/compute/analysis_tasks.py b/breadbox/breadbox/compute/analysis_tasks.py index a903624fb..9ed3b9de9 100644 --- a/breadbox/breadbox/compute/analysis_tasks.py +++ b/breadbox/breadbox/compute/analysis_tasks.py @@ -482,6 +482,7 @@ def create_cell_line_group( taiga_id=None, dataset_metadata=None, dataset_md5=None, + expiry=None, ) dataset_service.add_matrix_dataset( db, diff --git a/breadbox/breadbox/compute/dataset_tasks.py b/breadbox/breadbox/compute/dataset_tasks.py index 07022b0c0..a95102b1a 100644 --- a/breadbox/breadbox/compute/dataset_tasks.py +++ b/breadbox/breadbox/compute/dataset_tasks.py @@ -262,6 +262,7 @@ def upload_dataset( allowed_values=valid_fields.valid_allowed_values, dataset_metadata=dataset_metadata, dataset_md5=None, + expiry=None, ) added_dataset = dataset_service.add_matrix_dataset( diff --git a/breadbox/breadbox/compute/dataset_uploads_tasks.py b/breadbox/breadbox/compute/dataset_uploads_tasks.py index 1b5414ae4..97a0cc32d 100644 --- a/breadbox/breadbox/compute/dataset_uploads_tasks.py +++ b/breadbox/breadbox/compute/dataset_uploads_tasks.py @@ -1,3 +1,4 @@ +from datetime import timedelta from uuid import UUID, uuid4 from typing import Any, List, Optional, Union, Literal, Dict @@ -81,6 +82,12 @@ def dataset_upload( ) dataset_id = str(uuid4()) + expiry = None + if dataset_params.expiry_in_seconds is not None: + assert dataset_params.is_transient + expiry = dataset_crud.get_current_datetime() + timedelta( + seconds=dataset_params.expiry_in_seconds + ) unknown_ids = [] @@ -132,6 +139,7 @@ def dataset_upload( sample_type_name=dataset_params.sample_type, data_type=dataset_params.data_type, is_transient=dataset_params.is_transient, + expiry=expiry, group_id=str(dataset_params.group_id), value_type=dataset_params.value_type, priority=dataset_params.priority, @@ -186,6 +194,7 @@ def dataset_upload( index_type_name=dataset_params.index_type, data_type=dataset_params.data_type, is_transient=dataset_params.is_transient, + expiry=expiry, group_id=str(dataset_params.group_id), priority=dataset_params.priority, taiga_id=dataset_params.taiga_id, diff --git a/breadbox/breadbox/crud/dataset.py b/breadbox/breadbox/crud/dataset.py index 418c57c49..cd91c6873 100644 --- a/breadbox/breadbox/crud/dataset.py +++ b/breadbox/breadbox/crud/dataset.py @@ -1,5 +1,6 @@ import logging from collections import defaultdict +from datetime import datetime, timedelta from typing import Any, Dict, Optional, List, Type, Union, Tuple, Set from uuid import UUID, uuid4 import warnings @@ -680,6 +681,35 @@ def update_dataset( return dataset +def get_current_datetime(): + # this method only exists to allow us to mock it in tests. Since `datetime` is a built-in we're not able to + # mutate it. + return datetime.now() + + +def find_expired_datasets(db: SessionWithUser, max_age: timedelta) -> List[Dataset]: + """ + Finds transient datasets which can be deleted (because they've "expired") + Two ways a transient dataset can be expired: + 1. the `expiry` field can be explictly set, and that time is in the past + 2. the upload_date is before now - `max_age`. + """ + + now = get_current_datetime() + min_upload_date = now - max_age + + expired_datasets = ( + db.query(Dataset) + .filter( + Dataset.is_transient == True, + or_(Dataset.expiry < now, Dataset.upload_date < min_upload_date,), + ) + .all() + ) + + return expired_datasets + + def delete_dataset( db: SessionWithUser, user: str, dataset: Dataset, filestore_location: str ): diff --git a/breadbox/breadbox/models/dataset.py b/breadbox/breadbox/models/dataset.py index 8937f2edd..f414af3d2 100644 --- a/breadbox/breadbox/models/dataset.py +++ b/breadbox/breadbox/models/dataset.py @@ -93,7 +93,13 @@ class Dataset(Base, UUIDMixin, GroupMixin): data_type: Mapped[str] = mapped_column( String, ForeignKey(DataType.data_type), nullable=False ) + is_transient: Mapped[bool] = mapped_column(Boolean, nullable=False) + # only meaningful for datasets where is_transient==True. Indicates when a transient dataset + # should be deleted. + expiry: Mapped[Optional[DateTime]] = mapped_column( + DateTime(timezone=True), nullable=True + ) priority: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) taiga_id: Mapped[Optional[str]] = mapped_column(String, nullable=True) diff --git a/breadbox/breadbox/schemas/dataset.py b/breadbox/breadbox/schemas/dataset.py index 1f662b91c..59455549c 100644 --- a/breadbox/breadbox/schemas/dataset.py +++ b/breadbox/breadbox/schemas/dataset.py @@ -9,7 +9,7 @@ from breadbox.schemas.custom_http_exception import UserError from .group import Group import enum - +from datetime import datetime # NOTE: Using multivalue Literals seems to be creating errors in pydantic models and fastapi request params. # It is possible that for our version of pydantic, the schema for Literals is messed up @@ -106,6 +106,12 @@ class SharedDatasetParams(BaseModel): description="Transient datasets can be deleted - should only be set to true for non-public short-term-use datasets like custom analysis results.", ), ] = False + expiry_in_seconds: Annotated[ + Optional[int], + Field( + description="The number of seconds before this dataset is expired (only applies to transient datasets)" + ), + ] = None dataset_metadata: Annotated[ Optional[Dict[str, Any]], Body( @@ -286,6 +292,7 @@ class SharedDatasetFields(BaseModel): priority: Annotated[Optional[int], Field(default=None, gt=0,)] taiga_id: Annotated[Optional[str], Field(default=None,)] is_transient: Annotated[bool, Field(default=False,)] + expiry: Annotated[Optional[datetime], Field(default=None,)] dataset_metadata: Annotated[ Optional[Dict[str, Any]], Field() ] # NOTE: Same as Dict[str, Any] = Field(None,) diff --git a/breadbox/breadbox/service/dataset.py b/breadbox/breadbox/service/dataset.py index 456b7a81e..55207754f 100644 --- a/breadbox/breadbox/service/dataset.py +++ b/breadbox/breadbox/service/dataset.py @@ -29,7 +29,11 @@ get_transient_group, ) -from ..crud.dataset import add_tabular_dimensions, add_matrix_dataset_dimensions +from ..crud.dataset import ( + add_tabular_dimensions, + add_matrix_dataset_dimensions, + get_current_datetime, +) from ..crud.dimension_types import ( set_properties_to_index, add_metadata_dimensions, @@ -354,6 +358,7 @@ def add_tabular_dataset( index_type_name=dataset_in.index_type_name, data_type=dataset_in.data_type, is_transient=dataset_in.is_transient, + expiry=dataset_in.expiry, group_id=group.id, priority=dataset_in.priority, taiga_id=dataset_in.taiga_id, @@ -362,6 +367,7 @@ def add_tabular_dataset( short_name=short_name, version=version, description=description, + upload_date=get_current_datetime(), ) db.add(dataset) db.flush() @@ -421,6 +427,7 @@ def add_matrix_dataset( sample_type_name=dataset_in.sample_type_name, data_type=dataset_in.data_type, is_transient=dataset_in.is_transient, + expiry=dataset_in.expiry, group_id=group.id, value_type=dataset_in.value_type, priority=dataset_in.priority, @@ -431,6 +438,7 @@ def add_matrix_dataset( short_name=short_name, description=description, version=version, + upload_date=get_current_datetime(), ) db.add(dataset) db.flush() @@ -500,6 +508,7 @@ def add_dimension_type( priority=None, dataset_metadata=None, dataset_md5=None, # This may change! + expiry=None, ) check_id_mapping_is_valid(db, reference_column_mappings) diff --git a/breadbox/tests/api/test_dataset_uploads.py b/breadbox/tests/api/test_dataset_uploads.py index 09e831ba4..ebd3928cc 100644 --- a/breadbox/tests/api/test_dataset_uploads.py +++ b/breadbox/tests/api/test_dataset_uploads.py @@ -1,4 +1,5 @@ import io +from datetime import datetime from fastapi.testclient import TestClient @@ -6,14 +7,24 @@ from breadbox.schemas.dataset import AddDatasetResponse from breadbox.compute import dataset_uploads_tasks from breadbox.celery_task import utils -from breadbox.models.dataset import TabularDataset, TabularCell, TabularColumn +from breadbox.models.dataset import TabularDataset, TabularCell, TabularColumn, Dataset from sqlalchemy import and_ +from datetime import timedelta from typing import Dict -from ..utils import assert_status_ok import pytest import numpy as np from ..utils import upload_and_get_file_ids +import json +import pandas as pd +from breadbox.models.dataset import AnnotationType +from fastapi.testclient import TestClient +from breadbox.schemas.dataset import ColumnMetadata +from breadbox.crud.access_control import PUBLIC_GROUP_ID, TRANSIENT_GROUP_ID +from tests import factories +from ..utils import assert_status_not_ok, assert_status_ok +from breadbox.crud import dataset as dataset_crud +from breadbox.service import dataset as dataset_service class TestPost: @@ -1166,13 +1177,94 @@ def test_add_tabular_dataset_with_invalid_list_str_vals( assert tabular_dataset.status_code == 400 -import json -import pandas as pd -from breadbox.models.dataset import AnnotationType -from fastapi.testclient import TestClient -from breadbox.schemas.dataset import ColumnMetadata -from breadbox.crud.access_control import PUBLIC_GROUP_ID -from tests import factories +def test_dataset_with_expiry( + client: TestClient, minimal_db: SessionWithUser, mock_celery, settings, monkeypatch +): + user = settings.admin_users[0] + headers = {"X-Forwarded-User": user} + one_day_in_seconds = 60 * 60 * 24 + + file = factories.continuous_matrix_csv_file() + + factories.feature_type(minimal_db, minimal_db.user, "feature_name") + factories.sample_type(minimal_db, minimal_db.user, "sample_name") + + file_ids, expected_md5 = upload_and_get_file_ids(client, file) + + def override_time(m, mock_now): + m.setattr(dataset_crud, "get_current_datetime", lambda: mock_now) + m.setattr(dataset_service, "get_current_datetime", lambda: mock_now) + + with monkeypatch.context() as m: + override_time(m, datetime(year=2025, month=1, day=1)) + response = client.post( + "/dataset-v2/", + json={ + "format": "matrix", + "name": "a dataset", + "units": "a unit", + "feature_type": "feature_name", + "sample_type": "sample_name", + "data_type": "User upload", + "file_ids": file_ids, + "dataset_md5": expected_md5, + "is_transient": False, + "expiry_in_seconds": one_day_in_seconds, + "group_id": TRANSIENT_GROUP_ID, + "value_type": "continuous", + "allowed_values": None, + }, + headers=headers, + ) + + # verify we can't specify expiry on a non-transient dataset + assert_status_not_ok(response) + + # try again with transient = True + response = client.post( + "/dataset-v2/", + json={ + "format": "matrix", + "name": "a dataset", + "units": "a unit", + "feature_type": "feature_name", + "sample_type": "sample_name", + "data_type": "User upload", + "file_ids": file_ids, + "dataset_md5": expected_md5, + "is_transient": True, + "expiry_in_seconds": one_day_in_seconds, + "group_id": TRANSIENT_GROUP_ID, + "value_type": "continuous", + "allowed_values": None, + }, + headers=headers, + ) + + # now this should work and we should see the expiry having been set + assert_status_ok(response) + dataset_id = response.json()["result"]["datasetId"] + dataset = minimal_db.query(Dataset).filter(Dataset.id == dataset_id).one() + assert dataset.expiry == datetime(year=2025, month=1, day=2) + + # now let's try a few dates to make sure we recognize when this is expired + # first, make sure that nothing is found before the expiration date + with monkeypatch.context() as m: + override_time(m, datetime(year=2025, month=1, day=1, hour=1)) + expired = dataset_crud.find_expired_datasets(minimal_db, timedelta(days=1000)) + assert len(expired) == 0 + + # now make sure that we honor max_age when looking for expired data + with monkeypatch.context() as m: + override_time(m, datetime(year=2025, month=1, day=1, hour=1)) + expired = dataset_crud.find_expired_datasets(minimal_db, timedelta(minutes=1)) + assert len(expired) == 1 + + # okay, now make sure that we honor the expiration + with monkeypatch.context() as m: + override_time(m, datetime(year=2025, month=1, day=2, hour=1)) + expired = dataset_crud.find_expired_datasets(minimal_db, timedelta(days=1000)) + assert len(expired) == 1 def test_end_to_end_with_mismatched_metadata( diff --git a/breadbox/tests/factories.py b/breadbox/tests/factories.py index c3f263376..8242b8625 100644 --- a/breadbox/tests/factories.py +++ b/breadbox/tests/factories.py @@ -354,6 +354,7 @@ def tabular_dataset( taiga_id=taiga_id, dataset_metadata=dataset_metadata, dataset_md5=None, + expiry=None, ) assert columns_metadata is not None