Skip to content

Commit

Permalink
Merge pull request #41 from hv0905/refactor/local-util
Browse files Browse the repository at this point in the history
Refactor local utilities and CLI interface
  • Loading branch information
hv0905 authored Jul 5, 2024
2 parents d01e75e + ae00f0f commit d870b2d
Show file tree
Hide file tree
Showing 18 changed files with 256 additions and 224 deletions.
2 changes: 1 addition & 1 deletion .idea/NekoImageGallery.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 11 additions & 9 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse, \
DuplicateValidationResponse
from app.Models.api_response.base import NekoProtocol
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util.generate_uuid import generate_uuid, generate_uuid_from_sha1
from app.util.generate_uuid import generate_uuid_from_sha1
from app.util.local_file_utility import VALID_IMAGE_EXTENSIONS

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"])

Expand Down Expand Up @@ -106,19 +108,19 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
img_type = IMAGE_MIMES[image_file.content_type.lower()]
elif image_file.filename:
extension = PurePath(image_file.filename).suffix.lower()
if extension in {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'}:
if extension in VALID_IMAGE_EXTENSIONS:
img_type = extension[1:]
if not img_type:
logger.warning("Failed to infer image format of the uploaded image. Content Type: {}, Filename: {}",
image_file.content_type, image_file.filename)
raise HTTPException(415, "Unsupported image format.")
img_bytes = await image_file.read()
img_id = generate_uuid(img_bytes)
if img_id in services.upload_service.uploading_ids or len(
await services.db_context.validate_ids([str(img_id)])) != 0: # check for duplicate points
logger.warning("Duplicate upload request for image id: {}", img_id)
raise HTTPException(409, f"The uploaded point is already contained in the database! entity id: {img_id}")

try:
img_id = await services.upload_service.assign_image_id(img_bytes)
except PointDuplicateError as ex:
raise HTTPException(409,
f"The uploaded point is already contained in the database! entity id: {ex.entity_id}") \
from ex
try:
image = Image.open(BytesIO(img_bytes))
image.verify()
Expand All @@ -136,7 +138,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
format=img_type,
index_date=datetime.now())

await services.upload_service.upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
await services.upload_service.queue_upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


Expand Down
10 changes: 10 additions & 0 deletions app/Models/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from uuid import UUID


class PointDuplicateError(ValueError):
def __init__(self, message: str, entity_id: UUID | None = None):
self.message = message
self.entity_id = entity_id
super().__init__(message)

pass
7 changes: 2 additions & 5 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image
from fastapi.concurrency import run_in_threadpool

from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Services.lifespan_service import LifespanService
from app.Services.ocr_services import OCRService
Expand All @@ -9,10 +10,6 @@
from app.config import config


class PointDuplicateError(ValueError):
pass


class IndexService(LifespanService):
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
self._ocr_service = ocr_service
Expand Down Expand Up @@ -45,7 +42,7 @@ async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False,
background=False):
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!")
raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id)

if background:
await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr)
Expand Down
17 changes: 7 additions & 10 deletions app/Services/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing import Optional, AsyncGenerator

import aiofiles
from aiopath import Path as asyncPath
from loguru import logger

from app.Services.storage.base import BaseStorage, FileMetaDataT, RemoteFilePathType, LocalFilePathType
from app.Services.storage.exception import RemoteFileNotFoundError, LocalFileNotFoundError, RemoteFilePermissionError, \
LocalFilePermissionError, LocalFileExistsError, RemoteFileExistsError
from app.config import config
from app.util.local_file_utility import glob_local_files


def transform_exception(param: str):
Expand Down Expand Up @@ -129,16 +129,13 @@ async def list_files(self,
batch_max_files: Optional[int] = None,
valid_extensions: Optional[set[str]] = None) \
-> AsyncGenerator[list[RemoteFilePathType], None]:
_path = asyncPath(self.file_path_warp(path))
local_path = self.file_path_warp(path)
files = []
if valid_extensions is None:
valid_extensions = {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'}
async for file in _path.glob(pattern):
if file.suffix.lower() in valid_extensions:
files.append(syncPath(file))
if batch_max_files is not None and len(files) == batch_max_files:
yield files
files = []
for file in glob_local_files(local_path, pattern, valid_extensions):
files.append(file)
if batch_max_files is not None and len(files) == batch_max_files:
yield files
files = []
if files:
yield files

Expand Down
3 changes: 2 additions & 1 deletion app/Services/storage/s3_compatible_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from app.Services.storage.exception import LocalFileNotFoundError, RemoteFileNotFoundError, RemoteFilePermissionError, \
RemoteFileExistsError
from app.config import config
from app.util.local_file_utility import VALID_IMAGE_EXTENSIONS


def transform_exception(func):
Expand Down Expand Up @@ -138,7 +139,7 @@ async def list_files(self,
valid_extensions: Optional[set[str]] = None) \
-> AsyncGenerator[list[RemoteFilePathType], None]:
if valid_extensions is None:
valid_extensions = {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'}
valid_extensions = VALID_IMAGE_EXTENSIONS
files = []
# In opendal, current path should be "" instead of "."
_path = "" if self._file_path_str_warp(path) == "." else self._file_path_str_warp(path)
Expand Down
23 changes: 20 additions & 3 deletions app/Services/upload_service.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import asyncio
import gc
import io
import pathlib
from io import BytesIO

from PIL import Image
from loguru import logger

from app.Models.api_models.admin_query_params import UploadImageThumbnailMode
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Services.lifespan_service import LifespanService
from app.Services.index_service import IndexService
from app.Services.lifespan_service import LifespanService
from app.Services.storage import StorageService
from app.Services.vector_db_context import VectorDbContext
from app.config import config
from app.util.generate_uuid import generate_uuid


class UploadService(LifespanService):
Expand Down Expand Up @@ -75,11 +79,24 @@ async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bo

img.close()

async def upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
self.uploading_ids.add(img_data.id)
await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())

async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes):
img_id = generate_uuid(img_file)
# check for duplicate points
if img_id in self.uploading_ids or len(await self._db_context.validate_ids([str(img_id)])) != 0:
logger.warning("Duplicate upload request for image id: {}", img_id)
raise PointDuplicateError(f"The uploaded point is already contained in the database! entity id: {img_id}",
img_id)
return img_id

async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode)

def get_queue_size(self):
return self._queue.qsize()
6 changes: 6 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__title__ = 'NekoImageGallery'
__description__ = 'An AI-powered natural language & reverse Image Search Engine powered by CLIP & qdrant.'
__version__ = '1.2.0'
__author__ = 'EdgeNeko; pk5ls20'
__author_email__ = '[email protected]'
__url__ = 'https://github.com/hv0905/NekoImageGallery'
12 changes: 12 additions & 0 deletions app/util/local_file_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path

VALID_IMAGE_EXTENSIONS = {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'}


def glob_local_files(path: Path, pattern: str = "*", valid_extensions: set[str] = None):
if valid_extensions is None:
valid_extensions = VALID_IMAGE_EXTENSIONS

for file in path.glob(pattern):
if file.suffix.lower() in valid_extensions:
yield file
4 changes: 3 additions & 1 deletion app/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.params import Depends
from fastapi.staticfiles import StaticFiles

import app
import app.Controllers.admin as admin_controller
import app.Controllers.search as search_controller
from app.Services.authentication import permissive_access_token_verify, permissive_admin_token_verify
Expand All @@ -30,9 +31,10 @@ async def lifespan(_: FastAPI):
await provider.onexit()


app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=lifespan, title=app.__title__, description=app.__description__, version=app.__version__)
init_logging()

# noinspection PyTypeChecker
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
Expand Down
2 changes: 1 addition & 1 deletion docker-compose-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ services:
# dockerfile: cpu-only.Dockerfile
image: edgeneko/neko-image-gallery:latest-cpu
ports:
- "127.0.0.1:8000:8000"
- "8000:8000"
volumes:
- "./static:/opt/NekoImageGallery/static"
environment:
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ services:
# dockerfile: Dockerfile
image: edgeneko/neko-image-gallery:latest
ports:
- "127.0.0.1:8000:8000"
- "8000:8000"
volumes:
- "./static:/opt/NekoImageGallery/static"
environment:
Expand Down
Loading

0 comments on commit d870b2d

Please sign in to comment.