Skip to content

Commit

Permalink
Complete IndexService
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Dec 29, 2023
1 parent b333b04 commit f669b83
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 89 deletions.
2 changes: 1 addition & 1 deletion app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from app.Models.admin_api_model import ImageOptUpdateModel
from app.Models.api_response.base import NekoProtocol
from app.Services import db_context
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util import directories
Expand Down
3 changes: 1 addition & 2 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services import db_context
from app.Services import transformers_service
from app.Services.authentication import force_access_token_verify
from app.Services.provider import db_context, transformers_service
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

Expand Down
28 changes: 0 additions & 28 deletions app/Services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +0,0 @@
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment

transformers_service = TransformersService()
db_context = VectorDbContext()
ocr_service = None

if environment.local_indexing:
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService

ocr_service = EasyOCRService()
case "easypaddleocr":
from .ocr_services import EasyPaddleOCRService

ocr_service = EasyPaddleOCRService()
case "paddleocr":
from .ocr_services import PaddleOCRService

ocr_service = PaddleOCRService()
case _:
raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.")
else:
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()
11 changes: 6 additions & 5 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from PIL import Image

from app.Models.img_data import ImageData
from app.Services import TransformersService
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
from app.Services.vector_db_context import VectorDbContext
from app.config import config

Expand All @@ -17,6 +17,7 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height

if image.mode != 'RGB':
image = image.convert('RGB') # to reduce convert in next steps
image_data.image_vector = self._transformers_service.get_image_vector(image)
Expand All @@ -27,11 +28,11 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal
else:
image_data.ocr_text = None

def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
self._prepare_image(image, image_data, skip_ocr)
self._db_context.insertItems([image_data])
await self._db_context.insertItems([image_data])

def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False):
async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False):
for i in range(len(image)):
self._prepare_image(image[i], image_data[i], skip_ocr)
self._db_context.insertItems(image_data)
await self._db_context.insertItems(image_data)
31 changes: 31 additions & 0 deletions app/Services/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .index_service import IndexService
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment

transformers_service = TransformersService()
db_context = VectorDbContext()
ocr_service = None

if environment.local_indexing:
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService

ocr_service = EasyOCRService()
case "easypaddleocr":
from .ocr_services import EasyPaddleOCRService

ocr_service = EasyPaddleOCRService()
case "paddleocr":
from .ocr_services import PaddleOCRService

ocr_service = PaddleOCRService()
case _:
raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.")
else:
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()

index_service = IndexService(ocr_service, transformers_service, db_context)
Empty file added scripts/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion scripts/db_migrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from loguru import logger

from app.Services import db_context, transformers_service
from app.Services.provider import db_context, transformers_service

CURRENT_VERSION = 2

Expand Down
2 changes: 1 addition & 1 deletion scripts/local_create_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from PIL import Image
from loguru import logger

from app.Services import db_context
from app.Services.provider import db_context
from app.config import config


Expand Down
64 changes: 13 additions & 51 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from datetime import datetime
from pathlib import Path
from shutil import copy2
Expand All @@ -9,77 +8,40 @@
from loguru import logger

from app.Models.img_data import ImageData
from app.Services import transformers_service, db_context, ocr_service
from app.Services.provider import index_service
from app.config import config
from .local_utility import gather_valid_files


def parse_args():
parser = argparse.ArgumentParser(description='Create Qdrant collection')
parser.add_argument('--copy-from', dest="local_index_target_dir", type=str, required=True,
help="Copy from this directory")
return parser.parse_args()


def copy_and_index(file_path: Path) -> ImageData | None:
async def copy_and_index(file_path: Path):
try:
img = Image.open(file_path)
except PIL.UnidentifiedImageError as e:
logger.error("Error when opening image {}: {}", file_path, e)
return None
return
image_id = uuid4()
img_ext = file_path.suffix
image_ocr_result = None
text_contain_vector = None
[width, height] = img.size
try:
image_vector = transformers_service.get_image_vector(img)
if config.ocr_search.enable:
image_ocr_result = ocr_service.ocr_interface(img) # This will modify img if you use preprocess!
if image_ocr_result != "":
text_contain_vector = transformers_service.get_bert_vector(image_ocr_result)
else:
image_ocr_result = None
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return None
imgdata = ImageData(id=image_id,
url=f'/static/{image_id}{img_ext}',
image_vector=image_vector,
text_contain_vector=text_contain_vector,
index_date=datetime.now(),
width=width,
height=height,
aspect_ratio=float(width) / height,
ocr_text=image_ocr_result,
local=True)

try:
await index_service.index_image(img, imgdata)
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return
# copy to static
copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')
return imgdata


@logger.catch()
async def main(args):
root = Path(args.local_index_target_dir)
static_path = Path(config.static_file.path)
if not static_path.exists():
static_path.mkdir()
buffer = []
static_path.mkdir(exist_ok=True)
counter = 0
for item in root.glob('**/*.*'):
for item in gather_valid_files(root):
counter += 1
logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root)))
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
imgdata = copy_and_index(item)
if imgdata is not None:
buffer.append(imgdata)
if len(buffer) >= 20:
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
buffer.clear()
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)
if len(buffer) > 0:
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
logger.success("Indexing completed! {} images indexed", counter)
await copy_and_index(item)
logger.success("Indexing completed! {} images indexed", counter)
11 changes: 11 additions & 0 deletions scripts/local_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pathlib import Path

from loguru import logger


def gather_valid_files(root: Path):
for item in root.glob('**/*.*'):
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
yield item
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)

0 comments on commit f669b83

Please sign in to comment.