From 9db6816028bd00d75f533a183a402aa3df998439 Mon Sep 17 00:00:00 2001 From: EdgeNeko <edgeneko@aiursoft.com> Date: Fri, 29 Dec 2023 20:38:29 +0800 Subject: [PATCH] Basic approach of indexing service --- app/Services/index_service.py | 30 +++++++++++++++++++++++++----- app/Services/ocr_services.py | 2 +- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/app/Services/index_service.py b/app/Services/index_service.py index 867ba4d..3f224e4 100644 --- a/app/Services/index_service.py +++ b/app/Services/index_service.py @@ -3,15 +3,35 @@ from app.Models.img_data import ImageData from app.Services import TransformersService from app.Services.ocr_services import OCRService +from app.Services.vector_db_context import VectorDbContext +from app.config import config class IndexService: - def __init__(self, ocr_service: OCRService, transformers_service: TransformersService): + def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext): self._ocr_service = ocr_service self._transformers_service = transformers_service + self._db_context = db_context - def _calculate_vectors(self): - pass + def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + image_data.width = image.width + image_data.height = image.height + image_data.aspect_ratio = float(image.width) / image.height + if image.mode != 'RGB': + image = image.convert('RGB') # to reduce convert in next steps + image_data.image_vector = self._transformers_service.get_image_vector(image) + if not skip_ocr and config.ocr_search.enable: + image_data.ocr_text = self._ocr_service.ocr_interface(image) + if image_data.ocr_text != "": + image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text) + else: + image_data.ocr_text = None - def index_image(self, image: Image.Image, image_data: ImageData): - pass + def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + self._prepare_image(image, image_data, skip_ocr) + self._db_context.insertItems([image_data]) + + def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False): + for i in range(len(image)): + self._prepare_image(image[i], image_data[i], skip_ocr) + self._db_context.insertItems(image_data) diff --git a/app/Services/ocr_services.py b/app/Services/ocr_services.py index 0012505..6a4fb4c 100644 --- a/app/Services/ocr_services.py +++ b/app/Services/ocr_services.py @@ -41,7 +41,7 @@ def _easy_paddleocr_process(self, img: Image.Image) -> str: return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence) return "" - def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: + def ocr_interface(self, img: Image.Image, need_preprocess=False) -> str: start_time = time() logger.info("Processing text with EasyPaddleOCR...") res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img)