Skip to content

Commit

Permalink
Merge branch 'master' into indexer
Browse files Browse the repository at this point in the history
# Conflicts:
#	app/Models/img_data.py
  • Loading branch information
hv0905 committed Dec 29, 2023
2 parents e033942 + 6c7a433 commit cebf1d3
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 69 deletions.
13 changes: 11 additions & 2 deletions app/Models/img_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from uuid import UUID

from numpy import ndarray
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, computed_field


class ImageData(BaseModel):
Expand All @@ -19,10 +19,19 @@ class ImageData(BaseModel):
aspect_ratio: Optional[float] = None
starred: Optional[bool] = False
categories: Optional[list[str]] = []
local: Optional[bool] = False

@computed_field()
@property
def ocr_text_lower(self) -> str | None:
if self.ocr_text is None:
return None
return self.ocr_text.lower()


@property
def payload(self):
result = self.model_dump(exclude={'image_vector', 'text_contain_vector', 'id', 'index_date'})
result = self.model_dump(exclude={'id', 'index_date'})
# Qdrant database cannot accept datetime object, so we have to convert it to string
result['index_date'] = self.index_date.isoformat()
return result
Expand Down
155 changes: 91 additions & 64 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from loguru import logger
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import PointStruct
from qdrant_client.models import RecommendStrategy

from app.Models.api_model import SearchModelEnum, SearchBasisEnum
Expand All @@ -25,32 +24,31 @@ class VectorDbContext:
TEXT_VECTOR = "text_contain_vector"

def __init__(self):
self.client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port,
grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key,
prefer_grpc=config.qdrant.prefer_grpc)
self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port,
grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key,
prefer_grpc=config.qdrant.prefer_grpc)
self.collection_name = config.qdrant.coll

async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
logger.info("Retrieving item {} from database...", image_id)
result = await self.client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True,
with_vectors=with_vectors)
result = await self._client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True,
with_vectors=with_vectors)
if len(result) != 1:
logger.error("Point not exist.")
raise PointNotFoundError(image_id)
return ImageData.from_payload(result[0].id, result[0].payload,
numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None)
return self._get_img_data_from_point(result[0])

async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR,
top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]:
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.search(collection_name=self.collection_name,
query_vector=(query_vector_name, query_vector),
query_filter=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
with_payload=True)
result = await self._client.search(collection_name=self.collection_name,
query_vector=(query_vector_name, query_vector),
query_filter=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
with_payload=True)
logger.success("Query completed!")
return [SearchResult(img=ImageData.from_payload(t.id, t.payload), score=t.score) for t in result]
return [self._get_search_result_from_scored_point(t) for t in result]

async def querySimilar(self,
query_vector_name: str = IMG_VECTOR,
Expand All @@ -70,61 +68,37 @@ async def querySimilar(self,
_combined_search_need_vectors = [
self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] if with_vectors else None
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self.client.recommend(collection_name=self.collection_name,
using=query_vector_name,
positive=_positive_vectors,
negative=_negative_vectors,
strategy=_strategy,
with_vectors=_combined_search_need_vectors,
query_filter=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
with_payload=True)
result = await self._client.recommend(collection_name=self.collection_name,
using=query_vector_name,
positive=_positive_vectors,
negative=_negative_vectors,
strategy=_strategy,
with_vectors=_combined_search_need_vectors,
query_filter=self.getFiltersByFilterParam(filter_param),
limit=top_k,
offset=skip,
with_payload=True)
logger.success("Query completed!")

def result_transform(t):
return SearchResult(
img=ImageData.from_payload(
t.id,
t.payload,
numpy.array(t.vector['image_vector']) if t.vector and 'image_vector' in t.vector else None,
numpy.array(
t.vector['text_contain_vector']) if t.vector and 'text_contain_vector' in t.vector else None
),
score=t.score
)

return [result_transform(t) for t in result]
return [self._get_search_result_from_scored_point(t) for t in result]

async def insertItems(self, items: list[ImageData]):
logger.info("Inserting {} items into Qdrant...", len(items))

def get_point(img_data):
vector = {
self.IMG_VECTOR: img_data.image_vector.tolist(),
}
if img_data.text_contain_vector is not None:
vector[self.TEXT_VECTOR] = img_data.text_contain_vector.tolist()
return PointStruct(
id=str(img_data.id),
vector=vector,
payload=img_data.payload
)

points = [get_point(t) for t in items]

response = await self.client.upsert(collection_name=self.collection_name,
wait=True,
points=points)
points = [self._get_point_from_img_data(t) for t in items]

response = await self._client.upsert(collection_name=self.collection_name,
wait=True,
points=points)
logger.success("Insert completed! Status: {}", response.status)

async def deleteItems(self, ids: list[str]):
logger.info("Deleting {} items from Qdrant...", len(ids))
response = await self.client.delete(collection_name=self.collection_name,
points_selector=models.PointIdsList(
points=ids
),
)
response = await self._client.delete(collection_name=self.collection_name,
points_selector=models.PointIdsList(
points=ids
),
)
logger.success("Delete completed! Status: {}", response.status)

async def updatePayload(self, new_data: ImageData):
Expand All @@ -133,12 +107,65 @@ async def updatePayload(self, new_data: ImageData):
Warning: This method will not update the vector of the item.
:param new_data: The new data to update.
"""
response = await self.client.set_payload(collection_name=self.collection_name,
payload=new_data.payload,
points=[str(new_data.id)],
wait=True)
response = await self._client.set_payload(collection_name=self.collection_name,
payload=new_data.payload,
points=[str(new_data.id)],
wait=True)
logger.success("Update completed! Status: {}", response.status)

async def updateVectors(self, new_points: list[ImageData]):
resp = await self._client.update_vectors(collection_name=self.collection_name,
points=[self._get_vector_from_img_data(t) for t in new_points],
)
logger.success("Update vectors completed! Status: {}", resp.status)

async def scroll_points(self,
from_id: str | None = None,
count=50,
with_vectors=False) -> tuple[list[ImageData], str]:
resp, next_id = await self._client.scroll(collection_name=self.collection_name,
limit=count,
offset=from_id,
with_vectors=with_vectors
)

return [self._get_img_data_from_point(t) for t in resp], next_id

@classmethod
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
vector = {}
if img_data.image_vector is not None:
vector[cls.IMG_VECTOR] = img_data.image_vector.tolist()
if img_data.text_contain_vector is not None:
vector[cls.TEXT_VECTOR] = img_data.text_contain_vector.tolist()
return models.PointVectors(
id=str(img_data.id),
vector=vector
)

@classmethod
def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct:
return models.PointStruct(
id=str(img_data.id),
payload=img_data.payload,
vector=cls._get_vector_from_img_data(img_data).vector
)

@classmethod
def _get_img_data_from_point(cls, point: models.Record | models.ScoredPoint | models.PointStruct) -> ImageData:
return (ImageData
.from_payload(point.id,
point.payload,
image_vector=numpy.array(point.vector[cls.IMG_VECTOR], dtype=numpy.float32)
if point.vector and cls.IMG_VECTOR in point.vector else None,
text_contain_vector=numpy.array(point.vector[cls.TEXT_VECTOR], dtype=numpy.float32)
if point.vector and cls.TEXT_VECTOR in point.vector else None
))

@classmethod
def _get_search_result_from_scored_point(cls, point: models.ScoredPoint) -> SearchResult:
return SearchResult(img=cls._get_img_data_from_point(point), score=point.score)

@classmethod
def getVectorByBasis(cls, basis: SearchBasisEnum) -> str:
match basis:
Expand Down
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import collections

import uvicorn
Expand All @@ -13,6 +14,9 @@ def parse_args():
help="Initialize qdrant database using connection settings in "
"config.py. When this flag is set, will not"
"start the server.")
parser.add_argument('--migrate-db', dest="migrate_from_version", type=int,
help="Migrate qdrant database using connection settings in config from version specified."
"When this flag is set, will not start the server.")
parser.add_argument('--local-index', dest="local_index_target_dir", type=str,
help="Index all the images in this directory and copy them to "
"static folder set in config.py. When this flag is set, "
Expand All @@ -39,17 +43,19 @@ def parse_args():
qdrant_create_collection.create_coll(
collections.namedtuple('Options', ['host', 'port', 'name'])(config.qdrant.host, config.qdrant.port,
config.qdrant.coll))
elif args.migrate_from_version is not None:
from scripts import db_migrations

asyncio.run(db_migrations.migrate(args.migrate_from_version))
elif args.local_index_target_dir is not None:
from app.config import environment

environment.local_indexing = True
from scripts import local_indexing
import asyncio

asyncio.run(local_indexing.main(args))
elif args.local_create_thumbnail:
from scripts import local_create_thumbnail
import asyncio

asyncio.run(local_create_thumbnail.main())
else:
Expand Down
40 changes: 40 additions & 0 deletions scripts/db_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from loguru import logger

from app.Services import db_context, transformers_service

CURRENT_VERSION = 2


async def migrate_v1_v2():
logger.info("Migrating from v1 to v2...")
next_id = None
count = 0
while True:
points, next_id = await db_context.scroll_points(next_id, count=100)
for point in points:
count += 1
logger.info("[{}] Migrating point {}", count, point.id)
if point.url.startswith('/'):
# V1 database assuming all image with '/' as begins is a local image,
# v2 migrate to a more strict approach
point.local = True
await db_context.updatePayload(point) # This will also store ocr_text_lower field, if present
if point.ocr_text is not None:
point.text_contain_vector = transformers_service.get_bert_vector(point.ocr_text_lower)

logger.info("Updating vectors...")
# Update vectors for this group of points
await db_context.updateVectors([t for t in points if t.text_contain_vector is not None])
if next_id is None:
break


async def migrate(from_version: int):
match from_version:
case 1:
await migrate_v1_v2()
case 2:
logger.info("Already up to date.")
pass
case _:
raise Exception(f"Unknown version {from_version}")
3 changes: 2 additions & 1 deletion scripts/local_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def copy_and_index(file_path: Path) -> ImageData | None:
width=width,
height=height,
aspect_ratio=float(width) / height,
ocr_text=image_ocr_result)
ocr_text=image_ocr_result,
local=True)

# copy to static
copy2(file_path, Path(config.static_file.path) / f'{image_id}{img_ext}')
Expand Down

0 comments on commit cebf1d3

Please sign in to comment.