diff --git a/app/Services/index_service.py b/app/Services/index_service.py index 45a60b1..0b673a0 100644 --- a/app/Services/index_service.py +++ b/app/Services/index_service.py @@ -2,6 +2,7 @@ from fastapi.concurrency import run_in_threadpool from app.Models.img_data import ImageData +from app.Services.lifespan_service import LifespanService from app.Services.ocr_services import OCRService from app.Services.transformers_service import TransformersService from app.Services.vector_db_context import VectorDbContext @@ -12,7 +13,7 @@ class PointDuplicateError(ValueError): pass -class IndexService: +class IndexService(LifespanService): def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext): self._ocr_service = ocr_service self._transformers_service = transformers_service diff --git a/app/Services/lifespan_service.py b/app/Services/lifespan_service.py new file mode 100644 index 0000000..a6437f0 --- /dev/null +++ b/app/Services/lifespan_service.py @@ -0,0 +1,6 @@ +class LifespanService: + async def on_load(self): + pass + + async def on_exit(self): + pass diff --git a/app/Services/ocr_services.py b/app/Services/ocr_services.py index 3c0534a..ac6907b 100644 --- a/app/Services/ocr_services.py +++ b/app/Services/ocr_services.py @@ -5,10 +5,11 @@ from PIL import Image from loguru import logger +from app.Services.lifespan_service import LifespanService from app.config import config -class OCRService: +class OCRService(LifespanService): def __init__(self): self._device = config.device if self._device == "auto": diff --git a/app/Services/provider.py b/app/Services/provider.py index 986ff40..6870aa9 100644 --- a/app/Services/provider.py +++ b/app/Services/provider.py @@ -1,6 +1,8 @@ +import asyncio from loguru import logger from .index_service import IndexService +from .lifespan_service import LifespanService from .storage import StorageService from .transformers_service import TransformersService from .upload_service import UploadService @@ -40,10 +42,15 @@ def __init__(self): self.storage_service = StorageService() logger.info(f"Storage service '{type(self.storage_service.active_storage).__name__}' initialized.") - self.upload_service = None - - if config.admin_api_enable: - self.upload_service = UploadService(self.storage_service, self.db_context, self.index_service) + self.upload_service = UploadService(self.storage_service, self.db_context, self.index_service) + logger.info(f"Upload service '{type(self.upload_service).__name__}' initialized") async def onload(self): - await self.db_context.onload() + tasks = [service.on_load() for service_name in dir(self) + if isinstance((service := getattr(self, service_name)), LifespanService)] + await asyncio.gather(*tasks) + + async def onexit(self): + tasks = [service.on_exit() for service_name in dir(self) + if isinstance((service := getattr(self, service_name)), LifespanService)] + await asyncio.gather(*tasks) diff --git a/app/Services/storage/__init__.py b/app/Services/storage/__init__.py index 5ba2034..1626794 100644 --- a/app/Services/storage/__init__.py +++ b/app/Services/storage/__init__.py @@ -1,20 +1,27 @@ +from app.Services.lifespan_service import LifespanService +from app.Services.storage.base import BaseStorage +from app.Services.storage.disabled_storage import DisabledStorage from app.Services.storage.local_storage import LocalStorage from app.Services.storage.s3_compatible_storage import S3Storage from app.config import config, StorageMode -class StorageService: +class StorageService(LifespanService): def __init__(self): - self.local_storage = LocalStorage() self.active_storage = None match config.storage.method: case StorageMode.LOCAL: - self.active_storage = self.local_storage + self.active_storage = LocalStorage() case StorageMode.S3: self.active_storage = S3Storage() case StorageMode.DISABLED: - return + self.active_storage = DisabledStorage() case _: raise NotImplementedError(f"Storage method {config.storage.method} not implemented. " f"Available methods: local, s3") - self.active_storage.pre_check() + + async def on_load(self): + await self.active_storage.on_load() + + async def on_exit(self): + await self.active_storage.on_exit() diff --git a/app/Services/storage/base.py b/app/Services/storage/base.py index ec8f82f..15e81bb 100644 --- a/app/Services/storage/base.py +++ b/app/Services/storage/base.py @@ -2,6 +2,8 @@ import os from typing import TypeVar, Generic, TypeAlias, Optional, AsyncGenerator +from app.Services.lifespan_service import LifespanService + FileMetaDataT = TypeVar('FileMetaDataT') PathLikeType: TypeAlias = str | os.PathLike @@ -11,17 +13,13 @@ RemoteFileMetaDataType: TypeAlias = FileMetaDataT -class BaseStorage(abc.ABC, Generic[FileMetaDataT]): +class BaseStorage(LifespanService, abc.ABC, Generic[FileMetaDataT]): def __init__(self): self.static_dir: os.PathLike self.thumbnails_dir: os.PathLike self.deleted_dir: os.PathLike self.file_metadata: FileMetaDataT - @abc.abstractmethod - def pre_check(self): - raise NotImplementedError - @abc.abstractmethod async def is_exist(self, remote_file: RemoteFilePathType) -> bool: diff --git a/app/Services/storage/disabled_storage.py b/app/Services/storage/disabled_storage.py new file mode 100644 index 0000000..902b231 --- /dev/null +++ b/app/Services/storage/disabled_storage.py @@ -0,0 +1,43 @@ +from typing import Optional, AsyncGenerator + +from app.Services.storage import BaseStorage +from app.Services.storage.base import RemoteFilePathType, LocalFileMetaDataType, RemoteFileMetaDataType, \ + LocalFilePathType + + +class DisabledStorage(BaseStorage): # pragma: no cover + async def size(self, remote_file: RemoteFilePathType) -> int: + raise NotImplementedError + + async def url(self, remote_file: RemoteFilePathType) -> str: + raise NotImplementedError + + async def presign_url(self, remote_file: RemoteFilePathType, expire_second: int = 3600) -> str: + raise NotImplementedError + + async def fetch(self, remote_file: RemoteFilePathType) -> bytes: + raise NotImplementedError + + async def upload(self, local_file: "LocalFilePathType", remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def copy(self, old_remote_file: RemoteFilePathType, new_remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def move(self, old_remote_file: RemoteFilePathType, new_remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def delete(self, remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def update_metadata(self, local_file_metadata: LocalFileMetaDataType, + remote_file_metadata: RemoteFileMetaDataType) -> None: + raise NotImplementedError + + async def list_files(self, path: RemoteFilePathType, pattern: Optional[str] = "*", + batch_max_files: Optional[int] = None, valid_extensions: Optional[set[str]] = None) -> \ + AsyncGenerator[list[RemoteFilePathType], None]: + raise NotImplementedError + + async def is_exist(self, remote_file: RemoteFilePathType) -> bool: + raise NotImplementedError diff --git a/app/Services/storage/local_storage.py b/app/Services/storage/local_storage.py index f79ce7c..88c26cf 100644 --- a/app/Services/storage/local_storage.py +++ b/app/Services/storage/local_storage.py @@ -47,7 +47,7 @@ def __init__(self): def file_path_wrap(self, path: RemoteFilePathType) -> syncPath: return self.static_dir / syncPath(path) - def pre_check(self): + async def on_load(self): if not self.static_dir.is_dir(): self.static_dir.mkdir(parents=True) logger.warning(f"static_dir {self.static_dir} not found, created.") diff --git a/app/Services/storage/s3_compatible_storage.py b/app/Services/storage/s3_compatible_storage.py index a9f861e..9199c4d 100644 --- a/app/Services/storage/s3_compatible_storage.py +++ b/app/Services/storage/s3_compatible_storage.py @@ -64,9 +64,6 @@ def __init__(self): def _file_path_str_wrap(p: RemoteFilePathType): return str(PurePosixPath(p)) - def pre_check(self): - pass - async def is_exist(self, remote_file: "RemoteFilePathType") -> bool: try: diff --git a/app/Services/transformers_service.py b/app/Services/transformers_service.py index 7c72ab1..2048c25 100644 --- a/app/Services/transformers_service.py +++ b/app/Services/transformers_service.py @@ -8,10 +8,11 @@ from torch import FloatTensor, no_grad from transformers import CLIPProcessor, CLIPModel, BertTokenizer, BertModel +from app.Services.lifespan_service import LifespanService from app.config import config -class TransformersService: +class TransformersService(LifespanService): def __init__(self): self.device = config.device if self.device == "auto": diff --git a/app/Services/upload_service.py b/app/Services/upload_service.py index 6a83bfe..6520a0c 100644 --- a/app/Services/upload_service.py +++ b/app/Services/upload_service.py @@ -7,13 +7,14 @@ from app.Models.api_models.admin_query_params import UploadImageThumbnailMode from app.Models.img_data import ImageData +from app.Services.lifespan_service import LifespanService from app.Services.index_service import IndexService from app.Services.storage import StorageService from app.Services.vector_db_context import VectorDbContext from app.config import config -class UploadService: +class UploadService(LifespanService): def __init__(self, storage_service: StorageService, db_context: VectorDbContext, index_service: IndexService): self._storage_service = storage_service self._db_context = db_context diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index e1a0f27..06fa7e6 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -12,6 +12,7 @@ from app.Models.img_data import ImageData from app.Models.query_params import FilterParams from app.Models.search_result import SearchResult +from app.Services.lifespan_service import LifespanService from app.config import config, QdrantMode from app.util.retry_deco_async import wrap_object, retry_async @@ -22,7 +23,7 @@ def __init__(self, point_id: str): super().__init__(f"Point {point_id} not found.") -class VectorDbContext: +class VectorDbContext(LifespanService): IMG_VECTOR = "image_vector" TEXT_VECTOR = "text_contain_vector" AVAILABLE_POINT_TYPES = models.Record | models.ScoredPoint | models.PointStruct @@ -44,7 +45,7 @@ def __init__(self): raise ValueError("Invalid Qdrant mode.") self.collection_name = config.qdrant.coll - async def onload(self): + async def on_load(self): if not await self.check_collection(): logger.warning("Collection not found. Initializing...") await self.initialize_collection() diff --git a/app/webapp.py b/app/webapp.py index d7c948c..a1ded75 100644 --- a/app/webapp.py +++ b/app/webapp.py @@ -27,6 +27,8 @@ async def lifespan(_: FastAPI): admin_controller.services = provider yield + await provider.onexit() + app = FastAPI(lifespan=lifespan) init_logging() diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index 9c38422..2632abb 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -8,6 +8,7 @@ from app.Models.img_data import ImageData from app.Services.provider import ServiceProvider +from app.config import config, StorageMode from app.util.generate_uuid import generate_uuid from .local_utility import fetch_path_uuid_list @@ -50,6 +51,7 @@ async def copy_and_index_batch(file_path_list: list[tuple[Path, str]]): @logger.catch() async def main(args): global services + config.storage.method = StorageMode.LOCAL # ensure to use LocalStorage services = ServiceProvider() await services.onload() root = Path(args.local_index_target_dir) @@ -58,12 +60,12 @@ async def main(args): if item_number == 0: # database is empty, do as usual logger.warning("The database is empty, Will not check for duplicate points.") - async for item in services.storage_service.local_storage.list_files(root, batch_max_files=1): + async for item in services.storage_service.active_storage.list_files(root, batch_max_files=1): await copy_and_index(item[0]) else: # database is not empty, check for duplicate points logger.warning("The database is not empty, Will check for duplicate points.") - async for itm in services.storage_service.local_storage.list_files(root, batch_max_files=5000): + async for itm in services.storage_service.active_storage.list_files(root, batch_max_files=5000): local_file_path_with_uuid_list = fetch_path_uuid_list(itm) local_file_uuid_list = [itm[1] for itm in local_file_path_with_uuid_list] duplicate_uuid_list = await services.db_context.validate_ids(local_file_uuid_list)