From f669b83a010239c0dbdc1bd5eba4fc0f00e8ed6d Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Sat, 30 Dec 2023 01:31:43 +0800 Subject: [PATCH] Complete IndexService --- app/Controllers/admin.py | 2 +- app/Controllers/search.py | 3 +- app/Services/__init__.py | 28 -------------- app/Services/index_service.py | 11 +++--- app/Services/provider.py | 31 +++++++++++++++ scripts/__init__.py | 0 scripts/db_migrations.py | 2 +- scripts/local_create_thumbnail.py | 2 +- scripts/local_indexing.py | 64 +++++++------------------------ scripts/local_utility.py | 11 ++++++ 10 files changed, 65 insertions(+), 89 deletions(-) create mode 100644 app/Services/provider.py create mode 100644 scripts/__init__.py create mode 100644 scripts/local_utility.py diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index cd3995b..2e9956f 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -6,8 +6,8 @@ from app.Models.admin_api_model import ImageOptUpdateModel from app.Models.api_response.base import NekoProtocol -from app.Services import db_context from app.Services.authentication import force_admin_token_verify +from app.Services.provider import db_context from app.Services.vector_db_context import PointNotFoundError from app.config import config from app.util import directories diff --git a/app/Controllers/search.py b/app/Controllers/search.py index fa7904f..00fc860 100644 --- a/app/Controllers/search.py +++ b/app/Controllers/search.py @@ -11,9 +11,8 @@ from app.Models.api_response.search_api_response import SearchApiResponse from app.Models.query_params import SearchPagingParams, FilterParams from app.Models.search_result import SearchResult -from app.Services import db_context -from app.Services import transformers_service from app.Services.authentication import force_access_token_verify +from app.Services.provider import db_context, transformers_service from app.config import config from app.util.calculate_vectors_cosine import calculate_vectors_cosine diff --git a/app/Services/__init__.py b/app/Services/__init__.py index 9bfd075..e69de29 100644 --- a/app/Services/__init__.py +++ b/app/Services/__init__.py @@ -1,28 +0,0 @@ -from .transformers_service import TransformersService -from .vector_db_context import VectorDbContext -from ..config import config, environment - -transformers_service = TransformersService() -db_context = VectorDbContext() -ocr_service = None - -if environment.local_indexing: - match config.ocr_search.ocr_module: - case "easyocr": - from .ocr_services import EasyOCRService - - ocr_service = EasyOCRService() - case "easypaddleocr": - from .ocr_services import EasyPaddleOCRService - - ocr_service = EasyPaddleOCRService() - case "paddleocr": - from .ocr_services import PaddleOCRService - - ocr_service = PaddleOCRService() - case _: - raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.") -else: - from .ocr_services import DisabledOCRService - - ocr_service = DisabledOCRService() diff --git a/app/Services/index_service.py b/app/Services/index_service.py index 3f224e4..b9c491d 100644 --- a/app/Services/index_service.py +++ b/app/Services/index_service.py @@ -1,8 +1,8 @@ from PIL import Image from app.Models.img_data import ImageData -from app.Services import TransformersService from app.Services.ocr_services import OCRService +from app.Services.transformers_service import TransformersService from app.Services.vector_db_context import VectorDbContext from app.config import config @@ -17,6 +17,7 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal image_data.width = image.width image_data.height = image.height image_data.aspect_ratio = float(image.width) / image.height + if image.mode != 'RGB': image = image.convert('RGB') # to reduce convert in next steps image_data.image_vector = self._transformers_service.get_image_vector(image) @@ -27,11 +28,11 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal else: image_data.ocr_text = None - def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): self._prepare_image(image, image_data, skip_ocr) - self._db_context.insertItems([image_data]) + await self._db_context.insertItems([image_data]) - def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False): + async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False): for i in range(len(image)): self._prepare_image(image[i], image_data[i], skip_ocr) - self._db_context.insertItems(image_data) + await self._db_context.insertItems(image_data) diff --git a/app/Services/provider.py b/app/Services/provider.py new file mode 100644 index 0000000..12604d7 --- /dev/null +++ b/app/Services/provider.py @@ -0,0 +1,31 @@ +from .index_service import IndexService +from .transformers_service import TransformersService +from .vector_db_context import VectorDbContext +from ..config import config, environment + +transformers_service = TransformersService() +db_context = VectorDbContext() +ocr_service = None + +if environment.local_indexing: + match config.ocr_search.ocr_module: + case "easyocr": + from .ocr_services import EasyOCRService + + ocr_service = EasyOCRService() + case "easypaddleocr": + from .ocr_services import EasyPaddleOCRService + + ocr_service = EasyPaddleOCRService() + case "paddleocr": + from .ocr_services import PaddleOCRService + + ocr_service = PaddleOCRService() + case _: + raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.") +else: + from .ocr_services import DisabledOCRService + + ocr_service = DisabledOCRService() + +index_service = IndexService(ocr_service, transformers_service, db_context) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/db_migrations.py b/scripts/db_migrations.py index 86805a9..919a823 100644 --- a/scripts/db_migrations.py +++ b/scripts/db_migrations.py @@ -1,6 +1,6 @@ from loguru import logger -from app.Services import db_context, transformers_service +from app.Services.provider import db_context, transformers_service CURRENT_VERSION = 2 diff --git a/scripts/local_create_thumbnail.py b/scripts/local_create_thumbnail.py index f9f2a42..23209b3 100644 --- a/scripts/local_create_thumbnail.py +++ b/scripts/local_create_thumbnail.py @@ -4,7 +4,7 @@ from PIL import Image from loguru import logger -from app.Services import db_context +from app.Services.provider import db_context from app.config import config diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index 343fd1e..74e0920 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -1,4 +1,3 @@ -import argparse from datetime import datetime from pathlib import Path from shutil import copy2 @@ -9,77 +8,40 @@ from loguru import logger from app.Models.img_data import ImageData -from app.Services import transformers_service, db_context, ocr_service +from app.Services.provider import index_service from app.config import config +from .local_utility import gather_valid_files -def parse_args(): - parser = argparse.ArgumentParser(description='Create Qdrant collection') - parser.add_argument('--copy-from', dest="local_index_target_dir", type=str, required=True, - help="Copy from this directory") - return parser.parse_args() - - -def copy_and_index(file_path: Path) -> ImageData | None: +async def copy_and_index(file_path: Path): try: img = Image.open(file_path) except PIL.UnidentifiedImageError as e: logger.error("Error when opening image {}: {}", file_path, e) - return None + return image_id = uuid4() img_ext = file_path.suffix - image_ocr_result = None - text_contain_vector = None - [width, height] = img.size - try: - image_vector = transformers_service.get_image_vector(img) - if config.ocr_search.enable: - image_ocr_result = ocr_service.ocr_interface(img) # This will modify img if you use preprocess! - if image_ocr_result != "": - text_contain_vector = transformers_service.get_bert_vector(image_ocr_result) - else: - image_ocr_result = None - except Exception as e: - logger.error("Error when processing image {}: {}", file_path, e) - return None imgdata = ImageData(id=image_id, url=f'/static/{image_id}{img_ext}', - image_vector=image_vector, - text_contain_vector=text_contain_vector, index_date=datetime.now(), - width=width, - height=height, - aspect_ratio=float(width) / height, - ocr_text=image_ocr_result, local=True) - + try: + await index_service.index_image(img, imgdata) + except Exception as e: + logger.error("Error when processing image {}: {}", file_path, e) + return # copy to static copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}') - return imgdata @logger.catch() async def main(args): root = Path(args.local_index_target_dir) static_path = Path(config.static_file.path) - if not static_path.exists(): - static_path.mkdir() - buffer = [] + static_path.mkdir(exist_ok=True) counter = 0 - for item in root.glob('**/*.*'): + for item in gather_valid_files(root): counter += 1 logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root))) - if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']: - imgdata = copy_and_index(item) - if imgdata is not None: - buffer.append(imgdata) - if len(buffer) >= 20: - logger.info("Upload {} element to database", len(buffer)) - await db_context.insertItems(buffer) - buffer.clear() - else: - logger.warning("Unsupported file type: {}. Skip...", item.suffix) - if len(buffer) > 0: - logger.info("Upload {} element to database", len(buffer)) - await db_context.insertItems(buffer) - logger.success("Indexing completed! {} images indexed", counter) + await copy_and_index(item) + logger.success("Indexing completed! {} images indexed", counter) diff --git a/scripts/local_utility.py b/scripts/local_utility.py new file mode 100644 index 0000000..0c11a87 --- /dev/null +++ b/scripts/local_utility.py @@ -0,0 +1,11 @@ +from pathlib import Path + +from loguru import logger + + +def gather_valid_files(root: Path): + for item in root.glob('**/*.*'): + if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']: + yield item + else: + logger.warning("Unsupported file type: {}. Skip...", item.suffix)