From d7610263cd401f691af48a6e8e90eeab900e428b Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Wed, 1 May 2024 01:10:59 +0800 Subject: [PATCH] feat: Add task queue for image uploading --- app/Controllers/admin.py | 35 +++---------------- app/Services/index_service.py | 11 ++++-- app/Services/provider.py | 6 ++++ app/Services/upload_service.py | 63 ++++++++++++++++++++++++++++++++++ app/config.py | 4 +-- 5 files changed, 84 insertions(+), 35 deletions(-) create mode 100644 app/Services/upload_service.py diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index ecc9ffa..7e56c58 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -5,7 +5,7 @@ from uuid import UUID from PIL import Image, UnidentifiedImageError -from fastapi import APIRouter, Depends, HTTPException, params, BackgroundTasks, UploadFile, File +from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File from loguru import logger from app.Models.api_models.admin_api_model import ImageOptUpdateModel @@ -14,7 +14,7 @@ from app.Models.api_response.base import NekoProtocol from app.Models.img_data import ImageData from app.Services.authentication import force_admin_token_verify -from app.Services.provider import db_context, storage_service, index_service +from app.Services.provider import db_context, storage_service, upload_service from app.Services.vector_db_context import PointNotFoundError from app.config import config from app.util.generate_uuid import generate_uuid @@ -75,32 +75,6 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id return NekoProtocol(message="Image updated.") -async def upload_task(img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool): - logger.info('Start indexing image {}. Local: {}', img_data.id, img_data.local) - logger.info('Content length is {} bytes.', len(img_bytes)) - file_name = f"{img_data.id}.{img_data.format}" - thumb_path = f"thumbnails/{img_data.id}.webp" - if img_data.local: - img_data.url = await storage_service.active_storage.url(file_name) - if len(img_bytes) > 1024 * 500: - img_thumb = img.copy() - img_data.thumbnail_url = await storage_service.active_storage.url(f"thumbnails/{img_data.id}.webp") - - await index_service.index_image(img, img_data, skip_ocr=skip_ocr) # The img might be modified after calling this - logger.success("Image {} indexed.", img_data.id) - - if img_data.local: - logger.info("Start uploading image {} to local storage.", img_data.id) - await storage_service.active_storage.upload(img_bytes, file_name) - logger.success("Image {} uploaded to local storage.", img_data.id) - if len(img_bytes) > 1024 * 500: - img_thumb.thumbnail((256, 256)) - img_byte_arr = BytesIO() - img_thumb.save(img_byte_arr, 'WebP') - await storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path) - logger.success("Thumbnail for {} generated and uploaded!", img_data.id) - - IMAGE_MIMES = { "image/jpeg": "jpeg", "image/png": "png", @@ -112,8 +86,7 @@ async def upload_task(img: Image.Image, img_data: ImageData, img_bytes: bytes, s @admin_router.post("/upload", description="Upload image to server. The image will be indexed and stored in the database. If local is set to true, the image will be uploaded to local storage.") async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")], - model: Annotated[UploadImageModel, Depends()], - background_tasks: BackgroundTasks): + model: Annotated[UploadImageModel, Depends()]): # generate an ID for the image img_type = None if image_file.content_type.lower() in IMAGE_MIMES: @@ -143,7 +116,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i format=img_type, index_date=datetime.now()) - background_tasks.add_task(upload_task, image, image_data, img_bytes, model.skip_ocr) + await upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr) return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id) diff --git a/app/Services/index_service.py b/app/Services/index_service.py index e8a8e8b..be9cb89 100644 --- a/app/Services/index_service.py +++ b/app/Services/index_service.py @@ -1,4 +1,5 @@ from PIL import Image +from fastapi.concurrency import run_in_threadpool from app.Models.img_data import ImageData from app.Services.ocr_services import OCRService @@ -38,10 +39,16 @@ async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool: result = await self._db_context.validate_ids(image_id_list) return len(result) != 0 - async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False): + async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False, + background=False): if not skip_duplicate_check and (await self._is_point_duplicate([image_data])): raise PointDuplicateError("The uploaded points are contained in the database!") - self._prepare_image(image, image_data, skip_ocr) + + if background: + await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr) + else: + self._prepare_image(image, image_data, skip_ocr) + await self._db_context.insertItems([image_data]) async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], diff --git a/app/Services/provider.py b/app/Services/provider.py index 0936385..5589c41 100644 --- a/app/Services/provider.py +++ b/app/Services/provider.py @@ -3,6 +3,7 @@ from .index_service import IndexService from .storage import StorageService from .transformers_service import TransformersService +from .upload_service import UploadService from .vector_db_context import VectorDbContext from ..config import config, environment @@ -35,3 +36,8 @@ index_service = IndexService(ocr_service, transformers_service, db_context) storage_service = StorageService() logger.info(f"Storage service '{type(storage_service.active_storage).__name__}' initialized.") + +upload_service = None + +if config.admin_api_enable: + upload_service = UploadService(storage_service, db_context, index_service) diff --git a/app/Services/upload_service.py b/app/Services/upload_service.py new file mode 100644 index 0000000..5aaf077 --- /dev/null +++ b/app/Services/upload_service.py @@ -0,0 +1,63 @@ +import asyncio +from io import BytesIO + +from PIL import Image +from loguru import logger + +from app.Models.img_data import ImageData +from app.Services.index_service import IndexService +from app.Services.storage import StorageService +from app.Services.vector_db_context import VectorDbContext + + +class UploadService: + def __init__(self, storage_service: StorageService, db_context: VectorDbContext, index_service: IndexService): + self._storage_service = storage_service + self._db_context = db_context + self._index_service = index_service + + self._queue = asyncio.Queue(200) + self._upload_worker_task = asyncio.create_task(self._upload_worker()) + + async def _upload_worker(self): + while True: + img, img_data, img_bytes, skip_ocr = await self._queue.get() + try: + await self._upload_task(img, img_data, img_bytes, skip_ocr) + logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize()) + except Exception as ex: + logger.error("Error occurred while uploading image {}", img_data.id) + logger.exception(ex) + finally: + self._queue.task_done() + + async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool): + logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes)) + file_name = f"{img_data.id}.{img_data.format}" + thumb_path = f"thumbnails/{img_data.id}.webp" + img_thumb = None + if img_data.local: + img_data.url = await self._storage_service.active_storage.url(file_name) + if len(img_bytes) > 1024 * 500: + img_thumb = img.copy() + img_data.thumbnail_url = await self._storage_service.active_storage.url( + f"thumbnails/{img_data.id}.webp") + + await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, + background=True) # The img might be modified after calling this + logger.success("Image {} indexed.", img_data.id) + + if img_data.local: + logger.info("Start uploading image {} to local storage.", img_data.id) + await self._storage_service.active_storage.upload(img_bytes, file_name) + logger.success("Image {} uploaded to local storage.", img_data.id) + if len(img_bytes) > 1024 * 500: + img_thumb.thumbnail((256, 256), resample=Image.LANCZOS) + img_byte_arr = BytesIO() + img_thumb.save(img_byte_arr, 'WebP') + await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path) + logger.success("Thumbnail for {} generated and uploaded!", img_data.id) + + async def upload_image(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool): + await self._queue.put((img, img_data, img_bytes, skip_ocr)) + logger.info("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize()) diff --git a/app/config.py b/app/config.py index ee57e3d..40999da 100644 --- a/app/config.py +++ b/app/config.py @@ -9,8 +9,8 @@ class QdrantSettings(BaseModel): host: str = 'localhost' port: int = 6333 grpc_port: int = 6334 - coll: str = 'NekoImageGallery' - prefer_grpc: bool = False + coll: str = 'NekoImg' + prefer_grpc: bool = True api_key: str | None = None