Skip to content

Commit

Permalink
Merge pull request #14 from hv0905/feat/storage
Browse files Browse the repository at this point in the history
Basic implementation of storage service
  • Loading branch information
pk5ls20 authored Apr 26, 2024
2 parents a0f9616 + 59216ec commit f96a420
Show file tree
Hide file tree
Showing 18 changed files with 666 additions and 123 deletions.
30 changes: 13 additions & 17 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import PurePath
from typing import Annotated
from uuid import UUID

Expand All @@ -8,10 +9,9 @@
from app.Models.api_response.admin_api_response import ServerInfoResponse
from app.Models.api_response.base import NekoProtocol
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context
from app.Services.provider import db_context, storage_service
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util import directories

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"])

Expand All @@ -23,28 +23,24 @@ async def delete_image(
image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")]) -> NekoProtocol:
try:
point = await db_context.retrieve_by_id(str(image_id))
except PointNotFoundError:
raise HTTPException(404, "Cannot find the image with the given ID.")

except PointNotFoundError as ex:
raise HTTPException(404, "Cannot find the image with the given ID.") from ex
await db_context.deleteItems([str(point.id)])
logger.success("Image {} deleted from database.", point.id)

if point.url.startswith('/') and config.static_file.enable: # local image
image_files = list(directories.static_dir.glob(f"{point.id}.*"))
if point.local and config.storage.method.enabled: # local image
image_files = [itm[0] async for itm in storage_service.active_storage.list_files("", f"{point.id}.*")]
assert len(image_files) <= 1

if not image_files:
logger.warning("Image {} is a local image but not found in static folder.", point.id)
else:
directories.deleted_dir.mkdir(parents=True, exist_ok=True)

image_files[0].rename(directories.deleted_dir / image_files[0].name)
logger.success("Local image {} removed.", image_files[0].name)

await storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}")
logger.success("Image {} removed.", image_files[0].name)
if point.thumbnail_url is not None:
thumbnail_file = directories.thumbnails_dir / f"{point.id}.webp"
if thumbnail_file.is_file():
thumbnail_file.unlink()
thumbnail_file = PurePath(f"thumbnails/{point.id}.webp")
if await storage_service.active_storage.is_exist(thumbnail_file):
await storage_service.active_storage.delete(thumbnail_file)
logger.success("Thumbnail {} removed.", thumbnail_file.name)
else:
logger.warning("Thumbnail {} not found.", thumbnail_file.name)
Expand All @@ -59,8 +55,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
raise HTTPException(422, "Nothing to update.")
try:
point = await db_context.retrieve_by_id(str(image_id))
except PointNotFoundError:
raise HTTPException(404, "Cannot find the image with the given ID.")
except PointNotFoundError as ex:
raise HTTPException(404, "Cannot find the image with the given ID.") from ex

if model.starred is not None:
point.starred = model.starred
Expand Down
11 changes: 10 additions & 1 deletion app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import db_context, transformers_service
from app.Services.provider import db_context, transformers_service, storage_service
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

searchRouter = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
tags=["Search"])


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
for item in resp.result:
if item.img.local and config.storage.method.enabled:
item.img.url = await storage_service.active_storage.get_image_url(item.img)
if item.img.thumbnail_url is not None:
item.img.thumbnail_url = await storage_service.active_storage.get_url(item.img.thumbnail_url)
return resp


class SearchBasisParams:
def __init__(self,
basis: Annotated[SearchBasisEnum, Query(
Expand Down
1 change: 1 addition & 0 deletions app/Models/img_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ImageData(BaseModel):
starred: Optional[bool] = False
categories: Optional[list[str]] = []
local: Optional[bool] = False
format: Optional[str] = None # required for s3 local storage

@computed_field()
@property
Expand Down
5 changes: 5 additions & 0 deletions app/Services/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from loguru import logger
from .index_service import IndexService
from .storage import StorageService
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment
Expand Down Expand Up @@ -27,5 +29,8 @@
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()
logger.info(f"OCR service '{type(ocr_service).__name__}' initialized.")

index_service = IndexService(ocr_service, transformers_service, db_context)
storage_service = StorageService()
logger.info(f"Storage service '{type(storage_service.active_storage).__name__}' initialized.")
20 changes: 20 additions & 0 deletions app/Services/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from app.Services.storage.local_storage import LocalStorage
from app.Services.storage.s3_compatible_storage import S3Storage
from app.config import config, StorageMode


class StorageService:
def __init__(self):
self.local_storage = LocalStorage()
self.active_storage = None
match config.storage.method:
case StorageMode.LOCAL:
self.active_storage = self.local_storage
case StorageMode.S3:
self.active_storage = S3Storage()
case StorageMode.DISABLED:
return
case _:
raise NotImplementedError(f"Storage method {config.storage.method} not implemented. "
f"Available methods: local, s3")
self.active_storage.pre_check()
153 changes: 153 additions & 0 deletions app/Services/storage/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import abc
import os
from typing import TypeVar, Generic, TypeAlias, Optional, AsyncGenerator

from app.Models.img_data import ImageData

FileMetaDataT = TypeVar('FileMetaDataT')

PathLikeType: TypeAlias = str | os.PathLike
LocalFilePathType: TypeAlias = PathLikeType | bytes
RemoteFilePathType: TypeAlias = PathLikeType
LocalFileMetaDataType: TypeAlias = FileMetaDataT
RemoteFileMetaDataType: TypeAlias = FileMetaDataT


class BaseStorage(abc.ABC, Generic[FileMetaDataT]):
def __init__(self):
self.static_dir: os.PathLike
self.thumbnails_dir: os.PathLike
self.deleted_dir: os.PathLike
self.file_metadata: FileMetaDataT

@abc.abstractmethod
def pre_check(self):
raise NotImplementedError

@abc.abstractmethod
async def is_exist(self,
remote_file: RemoteFilePathType) -> bool:
"""
Check if a remote_file exists.
:param remote_file: The file path relative to static_dir
:return: True if the file exists, False otherwise
"""
raise NotImplementedError

@abc.abstractmethod
async def size(self,
remote_file: RemoteFilePathType) -> int:
"""
Get the size of a file in static_dir
:param remote_file: The file path relative to static_dir
:return: file's size
"""
raise NotImplementedError

@abc.abstractmethod
async def url(self,
remote_file: RemoteFilePathType) -> str:
"""
Get the original URL of a file in static_dir.
This url will be placed in the payload field of the qdrant.
:param remote_file: The file path relative to static_dir
:return: file's "original URL"
"""
raise NotImplementedError

@abc.abstractmethod
async def presign_url(self,
remote_file: RemoteFilePathType,
expire_second: int = 3600) -> str:
"""
Get the presign URL of a file in static_dir.
:param remote_file: The file path relative to static_dir
:param expire_second: Valid time for presign url
:return: file's "presign URL"
"""
raise NotImplementedError

@abc.abstractmethod
async def fetch(self,
remote_file: RemoteFilePathType) -> bytes:
"""
Fetch a file from static_dir
:param remote_file: The file path relative to static_dir
:return: file's content
"""
raise NotImplementedError

@abc.abstractmethod
async def upload(self,
local_file: "LocalFilePathType",
remote_file: RemoteFilePathType) -> None:
"""
Move a local picture file to the static_dir.
:param local_file: The absolute path to the local file or bytes.
:param remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def copy(self,
old_remote_file: RemoteFilePathType,
new_remote_file: RemoteFilePathType) -> None:
"""
Copy a file in static_dir.
:param old_remote_file: The file path relative to static_dir
:param new_remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def move(self,
old_remote_file: RemoteFilePathType,
new_remote_file: RemoteFilePathType) -> None:
"""
Move a file in static_dir.
:param old_remote_file: The file path relative to static_dir
:param new_remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def delete(self,
remote_file: RemoteFilePathType) -> None:
"""
Move a file in static_dir.
:param remote_file: The file path relative to static_dir
"""
raise NotImplementedError

@abc.abstractmethod
async def list_files(self,
path: RemoteFilePathType,
pattern: Optional[str] = "*",
batch_max_files: Optional[int] = None,
valid_extensions: Optional[set[str]] = None) \
-> AsyncGenerator[list[RemoteFilePathType], None]:
"""
Asynchronously generates a list of files from a given base directory path that match a specified pattern and set
of file extensions.
:param path: The relative base directory path from which relative to static_dir to start listing files.
:param pattern: A glob pattern to filter files based on their names. Defaults to '*' which selects all files.
:param batch_max_files: The maximum number of files to return. If None, all matching files are returned.
:param valid_extensions: An extra set of file extensions to include (e.g., {".jpg", ".png"}).
If None, files are not filtered by extension.
:return: An asynchronous generator yielding lists of RemoteFilePathType objects representing the matching files.
Usage example:
async for batch in list_files(base_path=".", pattern="*", max_files=100, valid_extensions={".jpg", ".png"}):
print(f"Batch: {batch}")
"""
raise NotImplementedError

@abc.abstractmethod
async def update_metadata(self,
local_file_metadata: LocalFileMetaDataType,
remote_file_metadata: RemoteFileMetaDataType) -> None:
raise NotImplementedError

async def get_image_url(self, img: ImageData) -> str:
return img.url
30 changes: 30 additions & 0 deletions app/Services/storage/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class StorageExtension(Exception):
pass


class LocalFileNotFoundError(StorageExtension):
pass


class LocalFileExistsError(StorageExtension):
pass


class LocalFilePermissionError(StorageExtension):
pass


class RemoteFileNotFoundError(StorageExtension):
pass


class RemoteFileExistsError(StorageExtension):
pass


class RemoteFilePermissionError(StorageExtension):
pass


class RemoteConnectError(StorageExtension):
pass
Loading

0 comments on commit f96a420

Please sign in to comment.