Skip to content

Commit

Permalink
Merge pull request #10 from hv0905/indexer
Browse files Browse the repository at this point in the history
Refactor indexing to a standalone service and decouple of local directories
  • Loading branch information
pk5ls20 authored Dec 29, 2023
2 parents 29342cc + 4966194 commit 2b035d8
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 103 deletions.
2 changes: 1 addition & 1 deletion app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from app.Models.admin_api_model import ImageOptUpdateModel
from app.Models.api_response.base import NekoProtocol
from app.Services import db_context
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util import directories
Expand Down
3 changes: 1 addition & 2 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services import db_context
from app.Services import transformers_service
from app.Services.authentication import force_access_token_verify
from app.Services.provider import db_context, transformers_service
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

Expand Down
2 changes: 2 additions & 0 deletions app/Models/img_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ImageData(BaseModel):
height: Optional[int] = None
aspect_ratio: Optional[float] = None
starred: Optional[bool] = False
categories: Optional[list[str]] = []
local: Optional[bool] = False

@computed_field()
Expand All @@ -27,6 +28,7 @@ def ocr_text_lower(self) -> str | None:
return None
return self.ocr_text.lower()


@property
def payload(self):
result = self.model_dump(exclude={'id', 'index_date'})
Expand Down
28 changes: 0 additions & 28 deletions app/Services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +0,0 @@
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment

transformers_service = TransformersService()
db_context = VectorDbContext()
ocr_service = None

if environment.local_indexing:
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService

ocr_service = EasyOCRService()
case "easypaddleocr":
from .ocr_services import EasyPaddleOCRService

ocr_service = EasyPaddleOCRService()
case "paddleocr":
from .ocr_services import PaddleOCRService

ocr_service = PaddleOCRService()
case _:
raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.")
else:
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()
4 changes: 1 addition & 3 deletions app/Services/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@


def verify_access_token(token: str | None) -> bool:
if not config.access_protected:
return True
return token is not None and token == config.access_token
return (not config.access_protected) or (token is not None and token == config.access_token)


def permissive_access_token_verify(
Expand Down
38 changes: 38 additions & 0 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from PIL import Image

from app.Models.img_data import ImageData
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
from app.Services.vector_db_context import VectorDbContext
from app.config import config


class IndexService:
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
self._ocr_service = ocr_service
self._transformers_service = transformers_service
self._db_context = db_context

def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height

if image.mode != 'RGB':
image = image.convert('RGB') # to reduce convert in next steps
image_data.image_vector = self._transformers_service.get_image_vector(image)
if not skip_ocr and config.ocr_search.enable:
image_data.ocr_text = self._ocr_service.ocr_interface(image)
if image_data.ocr_text != "":
image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text)
else:
image_data.ocr_text = None

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
self._prepare_image(image, image_data, skip_ocr)
await self._db_context.insertItems([image_data])

async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False):
for i in range(len(image)):
self._prepare_image(image[i], image_data[i], skip_ocr)
await self._db_context.insertItems(image_data)
31 changes: 31 additions & 0 deletions app/Services/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .index_service import IndexService
from .transformers_service import TransformersService
from .vector_db_context import VectorDbContext
from ..config import config, environment

transformers_service = TransformersService()
db_context = VectorDbContext()
ocr_service = None

if environment.local_indexing:
match config.ocr_search.ocr_module:
case "easyocr":
from .ocr_services import EasyOCRService

ocr_service = EasyOCRService()
case "easypaddleocr":
from .ocr_services import EasyPaddleOCRService

ocr_service = EasyPaddleOCRService()
case "paddleocr":
from .ocr_services import PaddleOCRService

ocr_service = PaddleOCRService()
case _:
raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.")
else:
from .ocr_services import DisabledOCRService

ocr_service = DisabledOCRService()

index_service = IndexService(ocr_service, transformers_service, db_context)
22 changes: 10 additions & 12 deletions app/Services/transformers_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}",
self.device, config.clip.model, config.ocr_search.bert_model)
self.clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained(config.clip.model)
self._clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device)
self._clip_processor = CLIPProcessor.from_pretrained(config.clip.model)
logger.success("CLIP Model loaded successfully")
if config.ocr_search.enable:
self.bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device)
self.bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model)
self._bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device)
self._bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model)
logger.success("BERT Model loaded successfully")
else:
logger.info("OCR search is disabled. Skipping OCR and BERT model loading.")
Expand All @@ -34,32 +34,30 @@ def get_image_vector(self, image: Image.Image) -> ndarray:
image = image.convert("RGB")
logger.info("Processing image...")
start_time = time()
inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device)
logger.success("Image processed, now inferencing with CLIP model...")
outputs: FloatTensor = self.clip_model.get_image_features(**inputs)
outputs: FloatTensor = self._clip_model.get_image_features(**inputs)
logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
logger.info("Norm: {}", outputs.norm(dim=-1).item())
outputs /= outputs.norm(dim=-1, keepdim=True)
return outputs.numpy(force=True).reshape(-1)

@no_grad()
def get_text_vector(self, text: str) -> ndarray:
logger.info("Processing text...")
start_time = time()
inputs = self.clip_processor(text=text, return_tensors="pt").to(self.device)
inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device)
logger.success("Text processed, now inferencing with CLIP model...")
outputs: FloatTensor = self.clip_model.get_text_features(**inputs)
outputs: FloatTensor = self._clip_model.get_text_features(**inputs)
logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
logger.info("Norm: {}", outputs.norm(dim=-1).item())
outputs /= outputs.norm(dim=-1, keepdim=True)
return outputs.numpy(force=True).reshape(-1)

@no_grad()
def get_bert_vector(self, text: str) -> ndarray:
start_time = time()
logger.info("Inferencing with BERT model...")
inputs = self.bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device)
outputs = self.bert_model(**inputs)
inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device)
outputs = self._bert_model(**inputs)
vector = outputs.last_hidden_state.mean(dim=1).squeeze()
logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time)
return vector.cpu().numpy()
Expand Down
4 changes: 4 additions & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Requirements for development and testing

pytest
pylint
Empty file added scripts/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion scripts/db_migrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from loguru import logger

from app.Services import db_context, transformers_service
from app.Services.provider import db_context, transformers_service

CURRENT_VERSION = 2

Expand Down
8 changes: 3 additions & 5 deletions scripts/local_create_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from PIL import Image
from loguru import logger

from app.Services import db_context
from app.Services.provider import db_context
from app.config import config
from .local_utility import gather_valid_files


async def main():
Expand All @@ -14,17 +15,14 @@ async def main():
if not static_thumb_path.exists():
static_thumb_path.mkdir()
count = 0
for item in static_path.glob('*.*'):
for item in gather_valid_files(static_path, '*.*'):
count += 1
logger.info("[{}] Processing {}", str(count), str(item.relative_to(static_path)))
size = item.stat().st_size
if size < 1024 * 500:
logger.warning("File size too small: {}. Skip...", size)
continue
try:
if item.suffix not in ['.jpg', '.png', '.jpeg']:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)
continue
if (static_thumb_path / f'{item.stem}.webp').exists():
logger.warning("Thumbnail for {} already exists. Skip...", item.stem)
continue
Expand Down
64 changes: 13 additions & 51 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from datetime import datetime
from pathlib import Path
from shutil import copy2
Expand All @@ -9,77 +8,40 @@
from loguru import logger

from app.Models.img_data import ImageData
from app.Services import transformers_service, db_context, ocr_service
from app.Services.provider import index_service
from app.config import config
from .local_utility import gather_valid_files


def parse_args():
parser = argparse.ArgumentParser(description='Create Qdrant collection')
parser.add_argument('--copy-from', dest="local_index_target_dir", type=str, required=True,
help="Copy from this directory")
return parser.parse_args()


def copy_and_index(file_path: Path) -> ImageData | None:
async def copy_and_index(file_path: Path):
try:
img = Image.open(file_path)
except PIL.UnidentifiedImageError as e:
logger.error("Error when opening image {}: {}", file_path, e)
return None
return
image_id = uuid4()
img_ext = file_path.suffix
image_ocr_result = None
text_contain_vector = None
[width, height] = img.size
try:
image_vector = transformers_service.get_image_vector(img)
if config.ocr_search.enable:
image_ocr_result = ocr_service.ocr_interface(img) # This will modify img if you use preprocess!
if image_ocr_result != "":
text_contain_vector = transformers_service.get_bert_vector(image_ocr_result)
else:
image_ocr_result = None
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return None
imgdata = ImageData(id=image_id,
url=f'/static/{image_id}{img_ext}',
image_vector=image_vector,
text_contain_vector=text_contain_vector,
index_date=datetime.now(),
width=width,
height=height,
aspect_ratio=float(width) / height,
ocr_text=image_ocr_result,
local=True)

try:
await index_service.index_image(img, imgdata)
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return
# copy to static
copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')
return imgdata


@logger.catch()
async def main(args):
root = Path(args.local_index_target_dir)
static_path = Path(config.static_file.path)
if not static_path.exists():
static_path.mkdir()
buffer = []
static_path.mkdir(exist_ok=True)
counter = 0
for item in root.glob('**/*.*'):
for item in gather_valid_files(root):
counter += 1
logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root)))
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
imgdata = copy_and_index(item)
if imgdata is not None:
buffer.append(imgdata)
if len(buffer) >= 20:
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
buffer.clear()
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)
if len(buffer) > 0:
logger.info("Upload {} element to database", len(buffer))
await db_context.insertItems(buffer)
logger.success("Indexing completed! {} images indexed", counter)
await copy_and_index(item)
logger.success("Indexing completed! {} images indexed", counter)
11 changes: 11 additions & 0 deletions scripts/local_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pathlib import Path

from loguru import logger


def gather_valid_files(root: Path, pattern: str = '**/*.*'):
for item in root.glob(pattern):
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
yield item
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)

0 comments on commit 2b035d8

Please sign in to comment.