From 9db6816028bd00d75f533a183a402aa3df998439 Mon Sep 17 00:00:00 2001
From: EdgeNeko <edgeneko@aiursoft.com>
Date: Fri, 29 Dec 2023 20:38:29 +0800
Subject: [PATCH] Basic approach of indexing service

---
 app/Services/index_service.py | 30 +++++++++++++++++++++++++-----
 app/Services/ocr_services.py  |  2 +-
 2 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/app/Services/index_service.py b/app/Services/index_service.py
index 867ba4d..3f224e4 100644
--- a/app/Services/index_service.py
+++ b/app/Services/index_service.py
@@ -3,15 +3,35 @@
 from app.Models.img_data import ImageData
 from app.Services import TransformersService
 from app.Services.ocr_services import OCRService
+from app.Services.vector_db_context import VectorDbContext
+from app.config import config
 
 
 class IndexService:
-    def __init__(self, ocr_service: OCRService, transformers_service: TransformersService):
+    def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
         self._ocr_service = ocr_service
         self._transformers_service = transformers_service
+        self._db_context = db_context
 
-    def _calculate_vectors(self):
-        pass
+    def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
+        image_data.width = image.width
+        image_data.height = image.height
+        image_data.aspect_ratio = float(image.width) / image.height
+        if image.mode != 'RGB':
+            image = image.convert('RGB')  # to reduce convert in next steps
+        image_data.image_vector = self._transformers_service.get_image_vector(image)
+        if not skip_ocr and config.ocr_search.enable:
+            image_data.ocr_text = self._ocr_service.ocr_interface(image)
+            if image_data.ocr_text != "":
+                image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text)
+            else:
+                image_data.ocr_text = None
 
-    def index_image(self, image: Image.Image, image_data: ImageData):
-        pass
+    def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
+        self._prepare_image(image, image_data, skip_ocr)
+        self._db_context.insertItems([image_data])
+
+    def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False):
+        for i in range(len(image)):
+            self._prepare_image(image[i], image_data[i], skip_ocr)
+        self._db_context.insertItems(image_data)
diff --git a/app/Services/ocr_services.py b/app/Services/ocr_services.py
index 0012505..6a4fb4c 100644
--- a/app/Services/ocr_services.py
+++ b/app/Services/ocr_services.py
@@ -41,7 +41,7 @@ def _easy_paddleocr_process(self, img: Image.Image) -> str:
             return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence)
         return ""
 
-    def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
+    def ocr_interface(self, img: Image.Image, need_preprocess=False) -> str:
         start_time = time()
         logger.info("Processing text with EasyPaddleOCR...")
         res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img)