Skip to content

Commit

Permalink
Use uuid5 associated with file SHA1 instead of randomly generated uui…
Browse files Browse the repository at this point in the history
…d4 (#11)

Use uuid5 associated with file SHA1 instead of randomly generated uuid4
  • Loading branch information
hv0905 authored Dec 31, 2023
2 parents 2b035d8 + 1fb7f24 commit a230014
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 23 deletions.
23 changes: 19 additions & 4 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from app.config import config


class PointDuplicateError(ValueError):
pass


class IndexService:
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
self._ocr_service = ocr_service
Expand All @@ -28,11 +32,22 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal
else:
image_data.ocr_text = None

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
# currently, here only need just a simple check
async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
image_id_list = [str(item.id) for item in image_data]
result = await self._db_context.validate_ids(image_id_list)
return len(result) != 0

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, allow_overwrite=False):
if not allow_overwrite and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!")
self._prepare_image(image, image_data, skip_ocr)
await self._db_context.insertItems([image_data])

async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], skip_ocr=False):
for i in range(len(image)):
self._prepare_image(image[i], image_data[i], skip_ocr)
async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData],
skip_ocr=False, allow_overwrite=False):
if not allow_overwrite and (await self._is_point_duplicate(image_data)):
raise PointDuplicateError("The uploaded points are contained in the database!")
for img, img_data in zip(image, image_data):
self._prepare_image(img, img_data, skip_ocr)
await self._db_context.insertItems(image_data)
54 changes: 52 additions & 2 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, point_id: str):
class VectorDbContext:
IMG_VECTOR = "image_vector"
TEXT_VECTOR = "text_contain_vector"
AVAILABLE_POINT_TYPES = models.Record | models.ScoredPoint | models.PointStruct

def __init__(self):
self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port,
Expand All @@ -30,14 +31,55 @@ def __init__(self):
self.collection_name = config.qdrant.coll

async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
"""
Retrieve an item from database by id. Will raise PointNotFoundError if the given ID doesn't exist.
:param image_id: The ID to retrieve.
:param with_vectors: Whether to retrieve vectors.
:return: The retrieved item.
"""
logger.info("Retrieving item {} from database...", image_id)
result = await self._client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True,
result = await self._client.retrieve(collection_name=self.collection_name,
ids=[image_id],
with_payload=True,
with_vectors=with_vectors)
if len(result) != 1:
logger.error("Point not exist.")
raise PointNotFoundError(image_id)
return self._get_img_data_from_point(result[0])

async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[ImageData]:
"""
Retrieve items from the database by IDs.
An exception is thrown if there are items in the IDs that do not exist in the database.
:param image_id: The list of IDs to retrieve.
:param with_vectors: Whether to retrieve vectors.
:return: The list of retrieved items.
"""
logger.info("Retrieving {} items from database...", len(image_id))
result = await self._client.retrieve(collection_name=self.collection_name,
ids=image_id,
with_payload=True,
with_vectors=with_vectors)
result_point_ids = {t.id for t in result}
missing_point_ids = set(image_id) - result_point_ids
if len(missing_point_ids) > 0:
logger.error("{} points not exist.", len(missing_point_ids))
raise PointNotFoundError(str(missing_point_ids))
return self._get_img_data_from_points(result)

async def validate_ids(self, image_id: list[str]) -> list[str]:
"""
Validate a list of IDs. Will return a list of valid IDs.
:param image_id: The list of IDs to validate.
:return: The list of valid IDs.
"""
logger.info("Validating {} items from database...", len(image_id))
result = await self._client.retrieve(collection_name=self.collection_name,
ids=image_id,
with_payload=False,
with_vectors=False)
return [t.id for t in result]

async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR,
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]:
logger.info("Querying Qdrant... top_k = {}", top_k)
Expand Down Expand Up @@ -131,6 +173,10 @@ async def scroll_points(self,

return [self._get_img_data_from_point(t) for t in resp], next_id

async def get_counts(self, exact: bool) -> int:
resp = await self._client.count(collection_name=self.collection_name, exact=exact)
return resp.count

@classmethod
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
vector = {}
Expand All @@ -152,7 +198,7 @@ def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct:
)

@classmethod
def _get_img_data_from_point(cls, point: models.Record | models.ScoredPoint | models.PointStruct) -> ImageData:
def _get_img_data_from_point(cls, point: AVAILABLE_POINT_TYPES) -> ImageData:
return (ImageData
.from_payload(point.id,
point.payload,
Expand All @@ -162,6 +208,10 @@ def _get_img_data_from_point(cls, point: models.Record | models.ScoredPoint | mo
if point.vector and cls.TEXT_VECTOR in point.vector else None
))

@classmethod
def _get_img_data_from_points(cls, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]:
return [cls._get_img_data_from_point(t) for t in points]

@classmethod
def _get_search_result_from_scored_point(cls, point: models.ScoredPoint) -> SearchResult:
return SearchResult(img=cls._get_img_data_from_point(point), score=point.score)
Expand Down
20 changes: 20 additions & 0 deletions app/util/generate_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import io
import pathlib
import hashlib
from uuid import UUID, uuid5, NAMESPACE_DNS

NAMESPACE_STR = 'github.com/hv0905/NekoImageGallery'


def generate(file_input: pathlib.Path | io.BytesIO) -> UUID:
namespace_uuid = uuid5(NAMESPACE_DNS, NAMESPACE_STR)
if isinstance(file_input, pathlib.Path):
with open(file_input, 'rb') as f:
file_content = f.read()
elif isinstance(file_input, io.BytesIO):
file_input.seek(0)
file_content = file_input.read()
else:
raise ValueError("Unsupported file type. Must be pathlib.Path or io.BytesIO.")
file_hash = hashlib.sha1(file_content).hexdigest()
return uuid5(namespace_uuid, file_hash)
59 changes: 47 additions & 12 deletions scripts/local_indexing.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,82 @@
import uuid
from datetime import datetime
from pathlib import Path
from shutil import copy2
from uuid import uuid4

import PIL
from PIL import Image
from loguru import logger

from app.Models.img_data import ImageData
from app.Services.provider import index_service
from app.Services.provider import index_service, db_context
from app.config import config
from .local_utility import gather_valid_files
from app.util import generate_uuid
from .local_utility import gather_valid_files, fetch_path_uuid_list

overall_count = 0

async def copy_and_index(file_path: Path):

async def copy_and_index(file_path: Path, uuid_str: str = None):
global overall_count
overall_count += 1
logger.info("[{}] Indexing {}", str(overall_count), str(file_path))
try:
img = Image.open(file_path)
except PIL.UnidentifiedImageError as e:
logger.error("Error when opening image {}: {}", file_path, e)
return
image_id = uuid4()
image_id = uuid.UUID(uuid_str) if uuid_str else generate_uuid.generate(file_path)
img_ext = file_path.suffix
imgdata = ImageData(id=image_id,
url=f'/static/{image_id}{img_ext}',
index_date=datetime.now(),
local=True)
try:
await index_service.index_image(img, imgdata)
# This has already been checked for duplicated, so there's no need to double-check.
await index_service.index_image(img, imgdata, allow_overwrite=True)
except Exception as e:
logger.error("Error when processing image {}: {}", file_path, e)
return
# copy to static
copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')


async def copy_and_index_batch(file_path_list: list[tuple[Path, str]]):
for file_path_uuid_tuple in file_path_list:
await copy_and_index(file_path_uuid_tuple[0], uuid_str=file_path_uuid_tuple[1])


@logger.catch()
async def main(args):
root = Path(args.local_index_target_dir)
static_path = Path(config.static_file.path)
static_path.mkdir(exist_ok=True)
counter = 0
for item in gather_valid_files(root):
counter += 1
logger.info("[{}] Indexing {}", str(counter), str(item.relative_to(root)))
await copy_and_index(item)
logger.success("Indexing completed! {} images indexed", counter)
# First, check if the database is empty
item_number = await db_context.get_counts(exact=False)
if item_number == 0:
# database is empty, do as usual
logger.warning("The database is empty, Will not check for duplicate points.")
for item in gather_valid_files(root):
await copy_and_index(item)
else:
# database is not empty, check for duplicate points
logger.warning("The database is not empty, Will check for duplicate points.")
for itm in gather_valid_files(root, max_files=5000):
local_file_path_with_uuid_list = fetch_path_uuid_list(itm)
local_file_uuid_list = [itm[1] for itm in local_file_path_with_uuid_list]
duplicate_uuid_list = await db_context.validate_ids(local_file_uuid_list)
if len(duplicate_uuid_list) > 0:
duplicate_uuid_list = set(duplicate_uuid_list)
local_file_path_with_uuid_list = [item for item in local_file_path_with_uuid_list
if item[1] not in duplicate_uuid_list]
logger.info("Found {} duplicate points, of which {} are duplicates in the database. "
"The remaining {} points will be indexed.",
len(itm) - len(local_file_path_with_uuid_list), len(duplicate_uuid_list),
len(local_file_path_with_uuid_list))
else:
logger.info("Found {} duplicate points, of which {} are duplicates in the database."
" The remaining {} points will be indexed.",
0, 0, len(local_file_path_with_uuid_list))
await copy_and_index_batch(local_file_path_with_uuid_list)

logger.success("Indexing completed! {} images indexed", overall_count)
38 changes: 33 additions & 5 deletions scripts/local_utility.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
import itertools
from pathlib import Path

from loguru import logger

from app.util import generate_uuid

def gather_valid_files(root: Path, pattern: str = '**/*.*'):
for item in root.glob(pattern):
if item.suffix in ['.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif']:
yield item

def gather_valid_files(root: Path, pattern: str = '**/*.*', max_files=None):
valid_extensions = {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'}

def file_generator():
for file in root.glob(pattern):
if file.suffix.lower() in valid_extensions:
yield file
else:
logger.warning(f"Unsupported file type: {file.suffix}. Skipping file: {file}")

def generator():
gen = file_generator()
if max_files is None:
yield from gen
else:
logger.warning("Unsupported file type: {}. Skip...", item.suffix)
while True:
batch = list(itertools.islice(gen, max_files))
if not batch:
break
yield batch

return generator()


def calculate_uuid(file_path: Path) -> str:
return str(generate_uuid.generate(file_path))


def fetch_path_uuid_list(file_path: Path | list[Path]) -> list[tuple[Path, str]]:
file_path = [file_path] if isinstance(file_path, Path) else file_path
return [(itm, calculate_uuid(itm)) for itm in file_path]

0 comments on commit a230014

Please sign in to comment.