diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 293a43d..25070e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,10 +47,13 @@ repos: - id: mypy language_version: '3.10' additional_dependencies: + - aiosqlite==0.17.0 - click==8.1.3 + - databases==0.6.0 - fastapi==0.75.2 - pydantic==1.9.0 - pytest==7.1.2 + - sqlalchemy2-stubs==0.0.2a24 - types-requests==2.27.27 - types-ujson==4.2.1 args: [--config-file=pyproject.toml] @@ -73,6 +76,7 @@ repos: pass_filenames: false additional_dependencies: - click==8.1.3 + - databases==0.6.0 - email-validator==1.2.1 - fastapi==0.75.2 - python-multipart==0.0.5 diff --git a/marketplace_standard_app_api/main.py b/marketplace_standard_app_api/main.py index cfd3ad8..73b641d 100644 --- a/marketplace_standard_app_api/main.py +++ b/marketplace_standard_app_api/main.py @@ -1,6 +1,12 @@ +import os +import sqlite3 +import uuid +from pathlib import Path from typing import Any, Callable, Dict, Optional, Union +import databases import requests +import sqlalchemy from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.responses import HTMLResponse, Response @@ -22,9 +28,49 @@ TransformationUpdateModel, TransformationUpdateResponse, ) +from .reference import object_storage +from .reference.common import metadata from .security import AuthTokenBearer from .version import __version__ +# Standard approach to enabling foreign key support for sqlite3, however since +# we use the async databases library, we need to use a custom Connection object +# as implemented in the get_database() function below. +# See also: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#foreign-key-support +# @event.listens_for(Engine, "connect") +# def set_sqlite_pragma(dbapi_connection, connection_record): +# print("SET SQLITE PRAGMA") +# cursor = dbapi_connection.cursor() +# cursor.execute("PRAGMA foreign_keys=ON") +# cursor.close() + + +DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///./app.db") +DATA_DIR = Path.cwd() / "data" + +database = None +engine = None + + +def get_database() -> databases.Database: + "Get the database connection." + global database, engine + if database is None: + + # Work-around for sqlite3 due to a limitation in encode/databases + # and aioqslite: https://github.com/encode/databases/issues/169 + class Connection(sqlite3.Connection): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.execute("PRAGMA foreign_keys=ON") + + database = databases.Database(DATABASE_URL, factory=Connection) + + engine = sqlalchemy.create_engine( + DATABASE_URL, connect_args={"check_same_thread": False} + ) + return database + async def catch_authentication_request_errors_middleware( request: Request, call_next: Callable @@ -38,6 +84,9 @@ async def catch_authentication_request_errors_middleware( raise +auth_token_bearer = AuthTokenBearer() + + class MarketPlaceAPI(FastAPI): def openapi(self) -> Dict[str, Any]: openapi_schema = super().openapi() @@ -56,11 +105,23 @@ def openapi(self) -> Dict[str, Any]: "email": "dirk.helm@iwm.fraunhofer.de", }, license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"}, - dependencies=[Depends(AuthTokenBearer())], + dependencies=[Depends(auth_token_bearer)], ) api.middleware("http")(catch_authentication_request_errors_middleware) +@api.on_event("startup") +async def startup(): + database = get_database() + metadata.create_all(engine) + await database.connect() + + +@api.on_event("shutdown") +async def shutdown(): + await get_database().disconnect() + + @api.get( "/", operation_id="frontend", @@ -131,7 +192,8 @@ async def list_collections( limit: int = 100, offset: int = 0 ) -> Union[CollectionListResponse, Response]: """List all collections.""" - raise HTTPException(status_code=501, detail="Not implemented.") + collections = await object_storage.list_collections(get_database(), limit, offset) + return collections or Response(status_code=204) @api.get( @@ -152,7 +214,21 @@ async def list_datasets( collection_name: CollectionName, limit: int = 100, offset: int = 0 ) -> Union[DatasetListResponse, Response]: """List all datasets.""" - raise HTTPException(status_code=501, detail="Not implemented.") + try: + datasets, headers = await object_storage.list_datasets( + get_database(), collection_name, limit, offset + ) + if datasets: + return Response( + content="{{ {} }}".format( + ",".join([dataset.json() for dataset in datasets]) + ), + headers=headers, + ) + else: + return Response(status_code=204, headers=headers) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Collection not found.") CREATE_COLLECTION_DESCRIPTION = """ @@ -210,7 +286,14 @@ async def create_collection( request: Request, collection_name: CollectionName = None ) -> Response: """Create a new or replace an existing collection.""" - raise HTTPException(status_code=501, detail="Not implemented.") + # TODO: Support updates. + if collection_name is None: + collection_name = CollectionName(str(uuid.uuid4())) + + await object_storage.create_collection( + get_database(), collection_name, request.headers + ) + return Response(status_code=201, content=collection_name) @api.head( @@ -231,7 +314,13 @@ async def create_collection( ) async def get_collection_metadata(collection_name: CollectionName) -> Response: """Get the metadata for a collection.""" - raise HTTPException(status_code=501, detail="Not implemented.") + try: + headers = await object_storage.get_collection_metadata_headers( + get_database(), collection_name + ) + return Response(status_code=204, headers=headers) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Collection not found.") @api.delete( @@ -254,7 +343,11 @@ async def get_collection_metadata(collection_name: CollectionName) -> Response: ) async def delete_collection(collection_name: CollectionName) -> Response: """Delete an empty collection.""" - raise HTTPException(status_code=501, detail="Not implemented.") + try: + await object_storage.delete_collection(get_database(), collection_name) + return Response(status_code=204, content="Collection has been deleted.") + except object_storage.ConflictError as error: + raise HTTPException(status_code=409, detail=str(error)) CREATE_DATASET_DESCRIPTION = """ @@ -309,7 +402,18 @@ async def create_dataset( dataset_name: Optional[DatasetName] = None, ) -> Union[DatasetCreateResponse, Response]: """Create a new or replace an existing dataset.""" - raise HTTPException(status_code=501, detail="Not implemented.") + if dataset_name is None: + dataset_name = DatasetName(str(uuid.uuid4())) + + await object_storage.create_dataset( + get_database(), + DATA_DIR, + collection_name, + dataset_name, + file, + dict(request.headers), + ) + return Response(status_code=201, content=dataset_name) @api.post( @@ -390,8 +494,13 @@ async def get_dataset_metadata( storage API: https://docs.openstack.org/api-ref/object-store/index.html#show-object-metadata """ - # return Response(content=None, headers={"X-Object-Meta-my-key": "some-value"}) - raise HTTPException(status_code=501, detail="Not implemented.") + try: + headers = await object_storage.get_dataset_metadata_headers( + get_database(), collection_name, dataset_name + ) + return Response(status_code=200, headers=headers) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Not found.") @api.get( @@ -433,8 +542,13 @@ async def get_dataset( storage API: https://docs.openstack.org/api-ref/object-store/index.html#get-object-content-and-metadata """ - # return Response(content=data, headers={"X-Object-Meta-my-key": "some-value"}) - raise HTTPException(status_code=501, detail="Not implemented.") + try: + content, headers = await object_storage.get_dataset( + get_database(), DATA_DIR, collection_name, dataset_name + ) + return Response(content=content, headers=headers) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Not found.") @api.delete( @@ -459,7 +573,10 @@ async def delete_dataset( storage API: https://docs.openstack.org/api-ref/object-store/index.html#delete-object """ - raise HTTPException(status_code=501, detail="Not implemented.") + await object_storage.delete_dataset( + get_database(), DATA_DIR, collection_name, dataset_name + ) + return Response(status_code=204) @api.post( diff --git a/marketplace_standard_app_api/reference/__init__.py b/marketplace_standard_app_api/reference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/marketplace_standard_app_api/reference/common.py b/marketplace_standard_app_api/reference/common.py new file mode 100644 index 0000000..483d0ca --- /dev/null +++ b/marketplace_standard_app_api/reference/common.py @@ -0,0 +1,3 @@ +import sqlalchemy + +metadata = sqlalchemy.MetaData() diff --git a/marketplace_standard_app_api/reference/object_storage.py b/marketplace_standard_app_api/reference/object_storage.py new file mode 100644 index 0000000..46186a8 --- /dev/null +++ b/marketplace_standard_app_api/reference/object_storage.py @@ -0,0 +1,258 @@ +import sqlite3 +from datetime import datetime +from pathlib import Path +from typing import List, Mapping, Tuple, Union + +import sqlalchemy +from databases import Database +from databases.interfaces import Record +from fastapi import UploadFile +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, and_, select + +from marketplace_standard_app_api.models.object_storage import ( + CollectionListItemModel, + CollectionListResponse, + CollectionName, + DatasetListResponse, + DatasetModel, + DatasetName, +) + +from .common import metadata + +collections = sqlalchemy.Table( + "collections", + metadata, + Column("name", String, primary_key=True), +) + +collections_metadata = sqlalchemy.Table( + "collections_metadata", + metadata, + Column("id", Integer, primary_key=True), + Column( + "collection_name", + ForeignKey("collections.name", ondelete="CASCADE"), + nullable=False, + ), + Column("key", String), + Column("value", String), +) + +datasets = sqlalchemy.Table( + "datasets", + metadata, + Column("name", String, primary_key=True), + Column( + "collection_name", + ForeignKey("collections.name", ondelete="RESTRICT"), + nullable=False, + ), + Column("content_type", String), + Column("last_modified", DateTime), +) + +datasets_metadata = sqlalchemy.Table( + "datasets_metadata", + metadata, + Column("id", Integer, primary_key=True), + Column( + "collection_name", + ForeignKey("collections.name", ondelete="CASCADE"), + nullable=False, + ), + Column( + "dataset_name", ForeignKey("datasets.name", ondelete="CASCADE"), nullable=False + ), + Column("key", String), + Column("value", String), +) + + +class ConflictError(RuntimeError): + pass + + +def metadata_from_headers(headers: Mapping): + metadata_magic_keyword = "X-Object-Meta-" + for key, value in headers.items(): + if key.lower().startswith(metadata_magic_keyword.lower()): + yield key[len(metadata_magic_keyword) :], value + + +def metadata_to_headers(metadata: List[Record]): + metadata_magic_keyword = "X-Object-Meta-" + for key, value in metadata: + yield metadata_magic_keyword + key, value + + +async def create_collection( + database: Database, collection_name: CollectionName, headers: Mapping +): + metadata = dict(metadata_from_headers(headers)) + async with database.transaction(): + insert_stmt = collections.insert().values(name=collection_name) + await database.execute(insert_stmt) + for key, value in metadata.items(): + insert_stmt = collections_metadata.insert().values( + collection_name=collection_name, key=key, value=value + ) + await database.execute(insert_stmt) + + +async def delete_collection(database: Database, collection_name: CollectionName): + delete_stmt = collections.delete().where(collections.c.name == collection_name) + try: + await database.execute(delete_stmt) + except sqlite3.IntegrityError: + raise ConflictError("Collection is not empty.") + + +async def list_collections( + database: Database, limit: int, offset: int +) -> CollectionListResponse: + query = select(collections).offset(offset).limit(limit) + result = await database.fetch_all(query) + if len(result): + return [ + CollectionListItemModel(name=row[0], count=0, bytes=0) for row in result + ] + else: + return [] + + +async def get_collection_metadata_headers( + database: Database, collection_name: CollectionName +): + query = select(collections).where(collections.c.name == collection_name) + entry = await database.fetch_one(query) + if entry: + query = select(collections_metadata.c.key, collections_metadata.c.value).where( + collections_metadata.c.collection_name == collection_name + ) + rows = await database.fetch_all(query) + headers = dict(metadata_to_headers(rows)) + return headers + raise FileNotFoundError() + + +async def list_datasets( + database: Database, collection_name: CollectionName, limit: int, offset: int +) -> Tuple[DatasetListResponse, dict]: + headers = await get_collection_metadata_headers(database, collection_name) + query = ( + select(datasets) + .where(datasets.c.collection_name == collection_name) + .offset(offset) + .limit(limit) + ) + rows = await database.fetch_all(query) + if len(rows): + return [ + DatasetModel( + id=row._mapping["name"], + content_type=row._mapping["content_type"], + last_modified=str(row._mapping["last_modified"]), + ) + for row in rows + ], headers + else: + return [], headers + + +async def get_dataset_metadata_headers( + database: Database, collection_name: CollectionName, dataset_name: DatasetName +) -> dict: + query = select(datasets).where( + and_( + datasets.c.collection_name == collection_name, + datasets.c.name == dataset_name, + ) + ) + entry = await database.fetch_one(query) + if entry: + query = select(datasets_metadata.c.key, datasets_metadata.c.value).where( + and_( + datasets_metadata.c.collection_name == collection_name, + datasets_metadata.c.dataset_name == dataset_name, + ) + ) + rows = await database.fetch_all(query) + headers = dict(metadata_to_headers(rows)) + headers["Content-Type"] = entry._mapping["content_type"] + headers["Last-Modified"] = str(entry._mapping["last_modified"]) + return headers + raise FileNotFoundError() + + +async def create_dataset( + database: Database, + data_dir: Path, + collection_name: CollectionName, + dataset_name: DatasetName, + file: UploadFile, + headers: dict, +): + metadata = dict(metadata_from_headers(headers)) + + async with database.transaction(): + # Create dataset entry in database + insert_stmt = datasets.insert().values( + name=dataset_name, + collection_name=collection_name, + content_type=file.content_type, + last_modified=datetime.utcnow(), + ) + await database.execute(insert_stmt) + + # Create dataset metadata in database + for key, value in metadata.items(): + insert_stmt = datasets_metadata.insert().values( + collection_name=collection_name, + dataset_name=dataset_name, + key=key, + value=value, + ) + await database.execute(insert_stmt) + + # Move file into data directory + dst = data_dir / collection_name / dataset_name + dst.parent.mkdir(parents=True, exist_ok=True) + contents = await file.read() # TODO: optimize this + if isinstance(contents, str): + dst.write_text(contents) + else: + dst.write_bytes(contents) + + +async def get_dataset( + database: Database, + data_dir: Path, + collection_name: CollectionName, + dataset_name: DatasetName, +) -> Tuple[Union[str, bytes], dict]: + # TODO: Support bytes + src = data_dir / collection_name / dataset_name + headers = await get_dataset_metadata_headers( + database, collection_name, dataset_name + ) + content = src.read_text() # TODO: should be non-blocking + return content, headers + + +async def delete_dataset( + database: Database, + data_dir: Path, + collection_name: CollectionName, + dataset_name: DatasetName, +): + async with database.transaction(): + delete_stmt = datasets.delete().where( + sqlalchemy.and_( + datasets.c.collection_name == collection_name, + datasets.c.name == dataset_name, + ) + ) + path = data_dir / collection_name / dataset_name + path.unlink() + await database.execute(delete_stmt) diff --git a/pyproject.toml b/pyproject.toml index 94e74a5..7cfbacf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ classifiers = ["License :: OSI Approved :: MIT License"] dynamic = ["version", "description"] requires-python = ">=3.8,<4" dependencies = [ + "SQLAlchemy==1.4.37", + "databases[aiosqlite]==0.6.0", "email-validator==1.2.1", "fastapi==0.75.2", "pydantic==1.9.0", @@ -40,7 +42,7 @@ name = "marketplace_standard_app_api" [tool.mypy] -plugins = ["pydantic.mypy"] +plugins = ["pydantic.mypy", "sqlalchemy.ext.mypy.plugin"] [tool.bumpver] current_version = "v0.1.0" diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..6d60665 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,122 @@ +import json + +import pytest +from fastapi import Request +from fastapi.testclient import TestClient + +import marketplace_standard_app_api.main +from marketplace_standard_app_api import api +from marketplace_standard_app_api.main import auth_token_bearer + + +async def _fake_auth_token_bearer(request: Request): + return None + + +@pytest.fixture +def client(tmp_path, monkeypatch): + db_path = tmp_path / "test.db" + monkeypatch.setattr( + marketplace_standard_app_api.main, + "DATABASE_URL", + f"sqlite:///{db_path.resolve()}", + ) + monkeypatch.setattr( + marketplace_standard_app_api.main, + "DATA_DIR", + tmp_path / "data", + ) + + api.dependency_overrides[auth_token_bearer] = _fake_auth_token_bearer + client = TestClient(api) + with client: + yield client + api.dependency_overrides = {} + + +def test_frontend(client): + response = client.get("/") + assert response.status_code == 501 + + +@pytest.fixture +def collection(client): + # Create collection + response = client.put("/data/", headers={"X-Object-Meta-foo": "bar"}) + assert response.ok + assert response.status_code == 201 + collection_name = response.content.decode("utf-8") + + # Yield collection name + yield collection_name + + # Delete all datasets. + response = client.get(f"/data/{collection_name}") + if response.status_code == 200: + for dataset in response.json(): + client.delete(f"/data/{collection_name}/{dataset['name']}") + # Delete collection + assert client.delete(f"/data/{collection_name}").status_code == 204 + # Check that collection is deleted + assert client.head(f"/data/{collection_name}").status_code == 404 + + +def test_insert_collection_with_metadata(client): + response = client.get("/data") + assert response.ok + assert response.status_code == 204 + + response = client.put("/data/", headers={"X-Object-Meta-foo": "bar"}) + assert response.ok + assert response.status_code == 201 + collection_id = response.content.decode("utf-8") + + response = client.head(f"/data/{collection_id}") + assert response.ok + assert response.status_code == 204 + assert "X-Object-Meta-foo".lower() in ( + key.lower() for key in response.headers.keys() + ) + + response = client.get("/data") + assert response.ok + assert response.status_code == 200 + collections = response.json() + assert len(collections) == 1 + + +def test_create_and_delete_json_dataset(client, collection, tmp_path): + p = tmp_path / "test.json" + p.write_text(json.dumps({"foo": "bar"})) + + assert client.get(f"/data/{collection}").status_code == 204 + assert client.head(f"/data/{collection}/test.json").status_code == 404 + assert client.get(f"/data/{collection}/test.json").status_code == 404 + + assert ( + client.put( + f"data/{collection}/test.json", + files={"file": ("test.json", p.open("rb"), "application/json")}, + ).status_code + == 201 + ) + + assert client.head(f"/data/{collection}").status_code == 204 + assert client.get(f"/data/{collection}").status_code == 200 + + response = client.head(f"/data/{collection}/test.json") + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert "last-modified" in response.headers.keys() + + response = client.get(f"/data/{collection}/test.json") + assert response.status_code == 200 + assert response.json() == {"foo": "bar"} + assert response.headers["Content-Type"] == "application/json" + assert "last-modified" in response.headers.keys() + + assert client.delete(f"/data/{collection}").status_code == 409 + assert client.delete(f"/data/{collection}/test.json").status_code == 204 + assert client.head(f"/data/{collection}/test.json").status_code == 404 + assert client.get(f"/data/{collection}/test.json").status_code == 404 + assert client.get(f"/data/{collection}").status_code == 204