Skip to content

Commit

Permalink
Fix some lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Dec 26, 2023
1 parent ef2ece1 commit ff48240
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 50 deletions.
39 changes: 19 additions & 20 deletions app/Services/ocr_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
class OCRService:
def __init__(self):
self._device = config.device
self._easyOCRModule, self._paddleOCRModule = None, None
if self._device == "auto":
self._device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -21,9 +20,9 @@ def _image_preprocess(img: Image.Image) -> Image.Image:
img = img.convert('RGB')
if img.size[0] > 1024 or img.size[1] > 1024:
img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
newImg = Image.new('RGB', (1024, 1024), (0, 0, 0))
newImg.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2))
return newImg
new_img = Image.new('RGB', (1024, 1024), (0, 0, 0))
new_img.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2))
return new_img

def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
pass
Expand All @@ -33,18 +32,18 @@ class EasyPaddleOCRService(OCRService):
def __init__(self):
super().__init__()
from easypaddleocr import EasyPaddleOCR
self._paddleOCRModule = EasyPaddleOCR(use_angle_cls=True, needWarmUp=True, devices=self._device)
self._paddle_ocr_module = EasyPaddleOCR(use_angle_cls=True, needWarmUp=True, devices=self._device)
logger.success("EasyPaddleOCR loaded successfully")

def _easy_paddleocr_process(self, img: Image.Image) -> str:
_, _ocrResult, _ = self._paddleOCRModule.ocr(np.array(img))
if _ocrResult:
return "".join(itm[0] for itm in _ocrResult if float(itm[1]) > config.ocr_search.ocr_min_confidence)
_, ocr_result, _ = self._paddle_ocr_module.ocr(np.array(img))
if ocr_result:
return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence)
return ""

def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
start_time = time()
logger.info(f"Processing text with EasyPaddleOCR...")
logger.info("Processing text with EasyPaddleOCR...")
res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img)
logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
return res
Expand All @@ -55,17 +54,17 @@ def __init__(self):
super().__init__()
# noinspection PyPackageRequirements
import easyocr # pylint: disable=import-error
self._easyOCRModule = easyocr.Reader(config.ocr_search.ocr_language,
gpu=True if self._device == "cuda" else False)
self._easy_ocr_module = easyocr.Reader(config.ocr_search.ocr_language,
gpu=self._device == "cuda")
logger.success("easyOCR loaded successfully")

def _easyocr_process(self, img: Image.Image) -> str:
_ocrResult = self._easyOCRModule.readtext(np.array(img))
return " ".join(itm[1] for itm in _ocrResult if itm[2] > config.ocr_search.ocr_min_confidence)
ocr_result = self._easy_ocr_module.readtext(np.array(img))
return " ".join(itm[1] for itm in ocr_result if itm[2] > config.ocr_search.ocr_min_confidence)

def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
start_time = time()
logger.info(f"Processing text with easyOCR...")
logger.info("Processing text with easyOCR...")
res = self._easyocr_process(self._image_preprocess(img) if need_preprocess else img)
logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
return res
Expand All @@ -76,19 +75,19 @@ def __init__(self):
super().__init__()
# noinspection PyPackageRequirements
import paddleocr # pylint: disable=import-error
self._paddleOCRModule = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True,
use_gpu=True if self._device == "cuda" else False)
self._paddle_ocr_module = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True,
use_gpu=self._device == "cuda")
logger.success("PaddleOCR loaded successfully")

def _paddleocr_process(self, img: Image.Image) -> str:
_ocrResult = self._paddleOCRModule.ocr(np.array(img), cls=True)
if _ocrResult[0]:
return "".join(itm[1][0] for itm in _ocrResult[0] if itm[1][1] > config.ocr_search.ocr_min_confidence)
ocr_result = self._paddle_ocr_module.ocr(np.array(img), cls=True)
if ocr_result[0]:
return "".join(itm[1][0] for itm in ocr_result[0] if itm[1][1] > config.ocr_search.ocr_min_confidence)
return ""

def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
start_time = time()
logger.info(f"Processing text with PaddleOCR...")
logger.info("Processing text with PaddleOCR...")
res = self._paddleocr_process(self._image_preprocess(img) if need_preprocess else img)
logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
return res
Expand Down
25 changes: 12 additions & 13 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(self):
prefer_grpc=config.qdrant.prefer_grpc)
self.collection_name = config.qdrant.coll

async def retrieve_by_id(self, id: str, with_vectors=False) -> ImageData:
logger.info("Retrieving item {} from database...", id)
result = await self.client.retrieve(collection_name=self.collection_name, ids=[id], with_payload=True,
async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
logger.info("Retrieving item {} from database...", image_id)
result = await self.client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True,
with_vectors=with_vectors)
if len(result) != 1:
logger.error("Point not exist.")
raise PointNotFoundError(id)
raise PointNotFoundError(image_id)
return ImageData.from_payload(result[0].id, result[0].payload,
numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None)

Expand Down Expand Up @@ -67,8 +67,8 @@ async def querySimilar(self,
_strategy = None if mode is None else (RecommendStrategy.AVERAGE_VECTOR if
mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE)
# since only combined_search need return vectors, We can define _combined_search_need_vectors like below
_combined_search_need_vectors = [self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] \
if with_vectors else None
_combined_search_need_vectors = [
self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] if with_vectors else None
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.recommend(collection_name=self.collection_name,
using=query_vector_name,
Expand Down Expand Up @@ -99,7 +99,7 @@ def result_transform(t):
async def insertItems(self, items: list[ImageData]):
logger.info("Inserting {} items into Qdrant...", len(items))

def getPoint(img_data):
def get_point(img_data):
vector = {
self.IMG_VECTOR: img_data.image_vector.tolist(),
}
Expand All @@ -111,7 +111,7 @@ def getPoint(img_data):
payload=img_data.payload
)

points = [getPoint(t) for t in items]
points = [get_point(t) for t in items]

response = await self.client.upsert(collection_name=self.collection_name,
wait=True,
Expand Down Expand Up @@ -188,9 +188,8 @@ def getFiltersByFilterParam(filter_param: FilterParams | None) -> models.Filter
)
))

if len(filters) > 0:
return models.Filter(
must=filters
)
else:
if len(filters) == 0:
return None
return models.Filter(
must=filters
)
21 changes: 11 additions & 10 deletions scripts/local_create_thumbnail.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import uuid
from pathlib import Path
from loguru import logger

from PIL import Image
from loguru import logger

from app.Services import db_context
from app.config import config
import uuid


async def main():
Expand All @@ -15,7 +16,7 @@ async def main():
count = 0
for item in static_path.glob('*.*'):
count += 1
logger.info("[{}] Processing {}", str(count), item.relative_to(static_path).__str__())
logger.info("[{}] Processing {}", str(count), str(item.relative_to(static_path)))
size = item.stat().st_size
if size < 1024 * 500:
logger.warning("File size too small: {}. Skip...", size)
Expand All @@ -27,14 +28,14 @@ async def main():
if (static_thumb_path / f'{item.stem}.webp').exists():
logger.warning("Thumbnail for {} already exists. Skip...", item.stem)
continue
id = uuid.UUID(item.stem)
image_id = uuid.UUID(item.stem)
except ValueError:
logger.warning("Invalid file name: {}. Skip...", item.stem)
continue
try:
imgdata = await db_context.retrieve_by_id(str(id))
imgdata = await db_context.retrieve_by_id(str(image_id))
except Exception as e:
logger.error("Error when retrieving image {}: {}", id, e)
logger.error("Error when retrieving image {}: {}", image_id, e)
continue
try:
img = Image.open(item)
Expand All @@ -44,13 +45,13 @@ async def main():

# generate thumbnail max size 256*256
img.thumbnail((256, 256))
img.save(static_thumb_path / f'{str(id)}.webp', 'WebP')
img.save(static_thumb_path / f'{str(image_id)}.webp', 'WebP')
img.close()
logger.success("Thumbnail for {} generated!", id)
logger.success("Thumbnail for {} generated!", image_id)

# update payload
imgdata.thumbnail_url = f'/static/thumbnails/{str(id)}.webp'
imgdata.thumbnail_url = f'/static/thumbnails/{str(image_id)}.webp'
await db_context.updatePayload(imgdata)
logger.success("Payload for {} updated!", id)
logger.success("Payload for {} updated!", image_id)

logger.success("OK. Updated {} items.", count)
13 changes: 7 additions & 6 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from shutil import copy2
from uuid import uuid4

import PIL
from PIL import Image
from loguru import logger

Expand All @@ -22,10 +23,10 @@ def parse_args():
def copy_and_index(file_path: Path) -> ImageData | None:
try:
img = Image.open(file_path)
except Exception as e:
except PIL.UnidentifiedImageError as e:
logger.error("Error when opening image {}: {}", file_path, e)
return None
id = uuid4()
image_id = uuid4()
img_ext = file_path.suffix
image_ocr_result = None
text_contain_vector = None
Expand All @@ -41,8 +42,8 @@ def copy_and_index(file_path: Path) -> ImageData | None:
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return None
imgdata = ImageData(id=id,
url=f'/static/{id}{img_ext}',
imgdata = ImageData(id=image_id,
url=f'/static/{image_id}{img_ext}',
image_vector=image_vector,
text_contain_vector=text_contain_vector,
index_date=datetime.now(),
Expand All @@ -52,7 +53,7 @@ def copy_and_index(file_path: Path) -> ImageData | None:
ocr_text=image_ocr_result)

# copy to static
copy2(file_path, Path(config.static_file.path) / f'{id}{img_ext}')
copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')
return imgdata


Expand All @@ -66,7 +67,7 @@ async def main(args):
counter = 0
for item in root.glob('**/*.*'):
counter += 1
logger.info("[{}] Indexing {}", str(counter), item.relative_to(root).__str__())
logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root)))
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp']:
imgdata = copy_and_index(item)
if imgdata is not None:
Expand Down
3 changes: 2 additions & 1 deletion scripts/qdrant_create_collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from qdrant_client import qdrant_client, models
import argparse

from qdrant_client import qdrant_client, models


def parsing_args():
parser = argparse.ArgumentParser(description='Create Qdrant collection')
Expand Down

0 comments on commit ff48240

Please sign in to comment.