From ff482402051429fdf2eb5eb969c0fb99006d5bda Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Tue, 26 Dec 2023 18:52:09 +0800 Subject: [PATCH] Fix some lint error --- app/Services/ocr_services.py | 39 ++++++++++++++--------------- app/Services/vector_db_context.py | 25 +++++++++--------- scripts/local_create_thumbnail.py | 21 ++++++++-------- scripts/local_indexing.py | 13 +++++----- scripts/qdrant_create_collection.py | 3 ++- 5 files changed, 51 insertions(+), 50 deletions(-) diff --git a/app/Services/ocr_services.py b/app/Services/ocr_services.py index 1c4def1..0012505 100644 --- a/app/Services/ocr_services.py +++ b/app/Services/ocr_services.py @@ -11,7 +11,6 @@ class OCRService: def __init__(self): self._device = config.device - self._easyOCRModule, self._paddleOCRModule = None, None if self._device == "auto": self._device = "cuda" if torch.cuda.is_available() else "cpu" @@ -21,9 +20,9 @@ def _image_preprocess(img: Image.Image) -> Image.Image: img = img.convert('RGB') if img.size[0] > 1024 or img.size[1] > 1024: img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) - newImg = Image.new('RGB', (1024, 1024), (0, 0, 0)) - newImg.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2)) - return newImg + new_img = Image.new('RGB', (1024, 1024), (0, 0, 0)) + new_img.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2)) + return new_img def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: pass @@ -33,18 +32,18 @@ class EasyPaddleOCRService(OCRService): def __init__(self): super().__init__() from easypaddleocr import EasyPaddleOCR - self._paddleOCRModule = EasyPaddleOCR(use_angle_cls=True, needWarmUp=True, devices=self._device) + self._paddle_ocr_module = EasyPaddleOCR(use_angle_cls=True, needWarmUp=True, devices=self._device) logger.success("EasyPaddleOCR loaded successfully") def _easy_paddleocr_process(self, img: Image.Image) -> str: - _, _ocrResult, _ = self._paddleOCRModule.ocr(np.array(img)) - if _ocrResult: - return "".join(itm[0] for itm in _ocrResult if float(itm[1]) > config.ocr_search.ocr_min_confidence) + _, ocr_result, _ = self._paddle_ocr_module.ocr(np.array(img)) + if ocr_result: + return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence) return "" def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: start_time = time() - logger.info(f"Processing text with EasyPaddleOCR...") + logger.info("Processing text with EasyPaddleOCR...") res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img) logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time) return res @@ -55,17 +54,17 @@ def __init__(self): super().__init__() # noinspection PyPackageRequirements import easyocr # pylint: disable=import-error - self._easyOCRModule = easyocr.Reader(config.ocr_search.ocr_language, - gpu=True if self._device == "cuda" else False) + self._easy_ocr_module = easyocr.Reader(config.ocr_search.ocr_language, + gpu=self._device == "cuda") logger.success("easyOCR loaded successfully") def _easyocr_process(self, img: Image.Image) -> str: - _ocrResult = self._easyOCRModule.readtext(np.array(img)) - return " ".join(itm[1] for itm in _ocrResult if itm[2] > config.ocr_search.ocr_min_confidence) + ocr_result = self._easy_ocr_module.readtext(np.array(img)) + return " ".join(itm[1] for itm in ocr_result if itm[2] > config.ocr_search.ocr_min_confidence) def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: start_time = time() - logger.info(f"Processing text with easyOCR...") + logger.info("Processing text with easyOCR...") res = self._easyocr_process(self._image_preprocess(img) if need_preprocess else img) logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time) return res @@ -76,19 +75,19 @@ def __init__(self): super().__init__() # noinspection PyPackageRequirements import paddleocr # pylint: disable=import-error - self._paddleOCRModule = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True, - use_gpu=True if self._device == "cuda" else False) + self._paddle_ocr_module = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True, + use_gpu=self._device == "cuda") logger.success("PaddleOCR loaded successfully") def _paddleocr_process(self, img: Image.Image) -> str: - _ocrResult = self._paddleOCRModule.ocr(np.array(img), cls=True) - if _ocrResult[0]: - return "".join(itm[1][0] for itm in _ocrResult[0] if itm[1][1] > config.ocr_search.ocr_min_confidence) + ocr_result = self._paddle_ocr_module.ocr(np.array(img), cls=True) + if ocr_result[0]: + return "".join(itm[1][0] for itm in ocr_result[0] if itm[1][1] > config.ocr_search.ocr_min_confidence) return "" def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: start_time = time() - logger.info(f"Processing text with PaddleOCR...") + logger.info("Processing text with PaddleOCR...") res = self._paddleocr_process(self._image_preprocess(img) if need_preprocess else img) logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time) return res diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 658e042..248de15 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -30,13 +30,13 @@ def __init__(self): prefer_grpc=config.qdrant.prefer_grpc) self.collection_name = config.qdrant.coll - async def retrieve_by_id(self, id: str, with_vectors=False) -> ImageData: - logger.info("Retrieving item {} from database...", id) - result = await self.client.retrieve(collection_name=self.collection_name, ids=[id], with_payload=True, + async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: + logger.info("Retrieving item {} from database...", image_id) + result = await self.client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True, with_vectors=with_vectors) if len(result) != 1: logger.error("Point not exist.") - raise PointNotFoundError(id) + raise PointNotFoundError(image_id) return ImageData.from_payload(result[0].id, result[0].payload, numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None) @@ -67,8 +67,8 @@ async def querySimilar(self, _strategy = None if mode is None else (RecommendStrategy.AVERAGE_VECTOR if mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE) # since only combined_search need return vectors, We can define _combined_search_need_vectors like below - _combined_search_need_vectors = [self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] \ - if with_vectors else None + _combined_search_need_vectors = [ + self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] if with_vectors else None logger.info("Querying Qdrant... top_k = {}", top_k) result = await self.client.recommend(collection_name=self.collection_name, using=query_vector_name, @@ -99,7 +99,7 @@ def result_transform(t): async def insertItems(self, items: list[ImageData]): logger.info("Inserting {} items into Qdrant...", len(items)) - def getPoint(img_data): + def get_point(img_data): vector = { self.IMG_VECTOR: img_data.image_vector.tolist(), } @@ -111,7 +111,7 @@ def getPoint(img_data): payload=img_data.payload ) - points = [getPoint(t) for t in items] + points = [get_point(t) for t in items] response = await self.client.upsert(collection_name=self.collection_name, wait=True, @@ -188,9 +188,8 @@ def getFiltersByFilterParam(filter_param: FilterParams | None) -> models.Filter ) )) - if len(filters) > 0: - return models.Filter( - must=filters - ) - else: + if len(filters) == 0: return None + return models.Filter( + must=filters + ) diff --git a/scripts/local_create_thumbnail.py b/scripts/local_create_thumbnail.py index ec7828a..f9f2a42 100644 --- a/scripts/local_create_thumbnail.py +++ b/scripts/local_create_thumbnail.py @@ -1,10 +1,11 @@ +import uuid from pathlib import Path -from loguru import logger + from PIL import Image +from loguru import logger from app.Services import db_context from app.config import config -import uuid async def main(): @@ -15,7 +16,7 @@ async def main(): count = 0 for item in static_path.glob('*.*'): count += 1 - logger.info("[{}] Processing {}", str(count), item.relative_to(static_path).__str__()) + logger.info("[{}] Processing {}", str(count), str(item.relative_to(static_path))) size = item.stat().st_size if size < 1024 * 500: logger.warning("File size too small: {}. Skip...", size) @@ -27,14 +28,14 @@ async def main(): if (static_thumb_path / f'{item.stem}.webp').exists(): logger.warning("Thumbnail for {} already exists. Skip...", item.stem) continue - id = uuid.UUID(item.stem) + image_id = uuid.UUID(item.stem) except ValueError: logger.warning("Invalid file name: {}. Skip...", item.stem) continue try: - imgdata = await db_context.retrieve_by_id(str(id)) + imgdata = await db_context.retrieve_by_id(str(image_id)) except Exception as e: - logger.error("Error when retrieving image {}: {}", id, e) + logger.error("Error when retrieving image {}: {}", image_id, e) continue try: img = Image.open(item) @@ -44,13 +45,13 @@ async def main(): # generate thumbnail max size 256*256 img.thumbnail((256, 256)) - img.save(static_thumb_path / f'{str(id)}.webp', 'WebP') + img.save(static_thumb_path / f'{str(image_id)}.webp', 'WebP') img.close() - logger.success("Thumbnail for {} generated!", id) + logger.success("Thumbnail for {} generated!", image_id) # update payload - imgdata.thumbnail_url = f'/static/thumbnails/{str(id)}.webp' + imgdata.thumbnail_url = f'/static/thumbnails/{str(image_id)}.webp' await db_context.updatePayload(imgdata) - logger.success("Payload for {} updated!", id) + logger.success("Payload for {} updated!", image_id) logger.success("OK. Updated {} items.", count) diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index ee5e6ff..118f5d1 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -4,6 +4,7 @@ from shutil import copy2 from uuid import uuid4 +import PIL from PIL import Image from loguru import logger @@ -22,10 +23,10 @@ def parse_args(): def copy_and_index(file_path: Path) -> ImageData | None: try: img = Image.open(file_path) - except Exception as e: + except PIL.UnidentifiedImageError as e: logger.error("Error when opening image {}: {}", file_path, e) return None - id = uuid4() + image_id = uuid4() img_ext = file_path.suffix image_ocr_result = None text_contain_vector = None @@ -41,8 +42,8 @@ def copy_and_index(file_path: Path) -> ImageData | None: except Exception as e: logger.error("Error when processing image {}: {}", file_path, e) return None - imgdata = ImageData(id=id, - url=f'/static/{id}{img_ext}', + imgdata = ImageData(id=image_id, + url=f'/static/{image_id}{img_ext}', image_vector=image_vector, text_contain_vector=text_contain_vector, index_date=datetime.now(), @@ -52,7 +53,7 @@ def copy_and_index(file_path: Path) -> ImageData | None: ocr_text=image_ocr_result) # copy to static - copy2(file_path, Path(config.static_file.path) / f'{id}{img_ext}') + copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}') return imgdata @@ -66,7 +67,7 @@ async def main(args): counter = 0 for item in root.glob('**/*.*'): counter += 1 - logger.info("[{}] Indexing {}", str(counter), item.relative_to(root).__str__()) + logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root))) if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp']: imgdata = copy_and_index(item) if imgdata is not None: diff --git a/scripts/qdrant_create_collection.py b/scripts/qdrant_create_collection.py index 250da4f..592e590 100644 --- a/scripts/qdrant_create_collection.py +++ b/scripts/qdrant_create_collection.py @@ -1,6 +1,7 @@ -from qdrant_client import qdrant_client, models import argparse +from qdrant_client import qdrant_client, models + def parsing_args(): parser = argparse.ArgumentParser(description='Create Qdrant collection')