Skip to content

Commit

Permalink
feat: Add task queue for image uploading
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Apr 30, 2024
1 parent b69a590 commit d761026
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 35 deletions.
35 changes: 4 additions & 31 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import UUID

from PIL import Image, UnidentifiedImageError
from fastapi import APIRouter, Depends, HTTPException, params, BackgroundTasks, UploadFile, File
from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File
from loguru import logger

from app.Models.api_models.admin_api_model import ImageOptUpdateModel
Expand All @@ -14,7 +14,7 @@
from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context, storage_service, index_service
from app.Services.provider import db_context, storage_service, upload_service
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util.generate_uuid import generate_uuid
Expand Down Expand Up @@ -75,32 +75,6 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
return NekoProtocol(message="Image updated.")


async def upload_task(img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool):
logger.info('Start indexing image {}. Local: {}', img_data.id, img_data.local)
logger.info('Content length is {} bytes.', len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
if img_data.local:
img_data.url = await storage_service.active_storage.url(file_name)
if len(img_bytes) > 1024 * 500:
img_thumb = img.copy()
img_data.thumbnail_url = await storage_service.active_storage.url(f"thumbnails/{img_data.id}.webp")

await index_service.index_image(img, img_data, skip_ocr=skip_ocr) # The img might be modified after calling this
logger.success("Image {} indexed.", img_data.id)

if img_data.local:
logger.info("Start uploading image {} to local storage.", img_data.id)
await storage_service.active_storage.upload(img_bytes, file_name)
logger.success("Image {} uploaded to local storage.", img_data.id)
if len(img_bytes) > 1024 * 500:
img_thumb.thumbnail((256, 256))
img_byte_arr = BytesIO()
img_thumb.save(img_byte_arr, 'WebP')
await storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)


IMAGE_MIMES = {
"image/jpeg": "jpeg",
"image/png": "png",
Expand All @@ -112,8 +86,7 @@ async def upload_task(img: Image.Image, img_data: ImageData, img_bytes: bytes, s
@admin_router.post("/upload",
description="Upload image to server. The image will be indexed and stored in the database. If local is set to true, the image will be uploaded to local storage.")
async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")],
model: Annotated[UploadImageModel, Depends()],
background_tasks: BackgroundTasks):
model: Annotated[UploadImageModel, Depends()]):
# generate an ID for the image
img_type = None
if image_file.content_type.lower() in IMAGE_MIMES:
Expand Down Expand Up @@ -143,7 +116,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
format=img_type,
index_date=datetime.now())

background_tasks.add_task(upload_task, image, image_data, img_bytes, model.skip_ocr)
await upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


Expand Down
11 changes: 9 additions & 2 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from PIL import Image
from fastapi.concurrency import run_in_threadpool

from app.Models.img_data import ImageData
from app.Services.ocr_services import OCRService
Expand Down Expand Up @@ -38,10 +39,16 @@ async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
result = await self._db_context.validate_ids(image_id_list)
return len(result) != 0

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False):
async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False,
background=False):
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!")
self._prepare_image(image, image_data, skip_ocr)

if background:
await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr)
else:
self._prepare_image(image, image_data, skip_ocr)

await self._db_context.insertItems([image_data])

async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData],
Expand Down
6 changes: 6 additions & 0 deletions app/Services/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .index_service import IndexService
from .storage import StorageService
from .transformers_service import TransformersService
from .upload_service import UploadService
from .vector_db_context import VectorDbContext
from ..config import config, environment

Expand Down Expand Up @@ -35,3 +36,8 @@
index_service = IndexService(ocr_service, transformers_service, db_context)
storage_service = StorageService()
logger.info(f"Storage service '{type(storage_service.active_storage).__name__}' initialized.")

upload_service = None

if config.admin_api_enable:
upload_service = UploadService(storage_service, db_context, index_service)
63 changes: 63 additions & 0 deletions app/Services/upload_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import asyncio
from io import BytesIO

from PIL import Image
from loguru import logger

from app.Models.img_data import ImageData
from app.Services.index_service import IndexService
from app.Services.storage import StorageService
from app.Services.vector_db_context import VectorDbContext


class UploadService:
def __init__(self, storage_service: StorageService, db_context: VectorDbContext, index_service: IndexService):
self._storage_service = storage_service
self._db_context = db_context
self._index_service = index_service

self._queue = asyncio.Queue(200)
self._upload_worker_task = asyncio.create_task(self._upload_worker())

async def _upload_worker(self):
while True:
img, img_data, img_bytes, skip_ocr = await self._queue.get()
try:
await self._upload_task(img, img_data, img_bytes, skip_ocr)
logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize())
except Exception as ex:
logger.error("Error occurred while uploading image {}", img_data.id)
logger.exception(ex)
finally:
self._queue.task_done()

async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool):
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
img_thumb = None
if img_data.local:
img_data.url = await self._storage_service.active_storage.url(file_name)
if len(img_bytes) > 1024 * 500:
img_thumb = img.copy()
img_data.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{img_data.id}.webp")

await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr,
background=True) # The img might be modified after calling this
logger.success("Image {} indexed.", img_data.id)

if img_data.local:
logger.info("Start uploading image {} to local storage.", img_data.id)
await self._storage_service.active_storage.upload(img_bytes, file_name)
logger.success("Image {} uploaded to local storage.", img_data.id)
if len(img_bytes) > 1024 * 500:
img_thumb.thumbnail((256, 256), resample=Image.LANCZOS)
img_byte_arr = BytesIO()
img_thumb.save(img_byte_arr, 'WebP')
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)

async def upload_image(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool):
await self._queue.put((img, img_data, img_bytes, skip_ocr))
logger.info("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())
4 changes: 2 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class QdrantSettings(BaseModel):
host: str = 'localhost'
port: int = 6333
grpc_port: int = 6334
coll: str = 'NekoImageGallery'
prefer_grpc: bool = False
coll: str = 'NekoImg'
prefer_grpc: bool = True
api_key: str | None = None


Expand Down

0 comments on commit d761026

Please sign in to comment.