Skip to content

Commit

Permalink
fix bugs during mugration
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Dec 29, 2023
1 parent e45bd8a commit 6c7a433
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
17 changes: 12 additions & 5 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ async def updatePayload(self, new_data: ImageData):

async def updateVectors(self, new_points: list[ImageData]):
resp = await self._client.update_vectors(collection_name=self.collection_name,
points=[models.PointVectors(**self._get_point_from_img_data(t)
.model_dump()) for t in new_points],
points=[self._get_vector_from_img_data(t) for t in new_points],
)
logger.success("Update vectors completed! Status: {}", resp.status)

async def scroll_points(self,
from_id: str | None = None,
Expand All @@ -132,16 +132,23 @@ async def scroll_points(self,
return [self._get_img_data_from_point(t) for t in resp], next_id

@classmethod
def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct:
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
vector = {}
if img_data.image_vector is not None:
vector[cls.IMG_VECTOR] = img_data.image_vector.tolist()
if img_data.text_contain_vector is not None:
vector[cls.TEXT_VECTOR] = img_data.text_contain_vector.tolist()
return models.PointVectors(
id=str(img_data.id),
vector=vector
)

@classmethod
def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct:
return models.PointStruct(
id=str(img_data.id),
vector=vector,
payload=img_data.payload
payload=img_data.payload,
vector=cls._get_vector_from_img_data(img_data).vector
)

@classmethod
Expand Down
13 changes: 8 additions & 5 deletions scripts/db_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ async def migrate_v1_v2():
for point in points:
count += 1
logger.info("[{}] Migrating point {}", count, point.id)
if point.ocr_text is not None:
point.text_contain_vector = transformers_service.get_bert_vector(point.ocr_text_lower)
if point.url.startswith('/'):
point.local = True # V1 database assuming all image with '/' as begins is a local image,
# V1 database assuming all image with '/' as begins is a local image,
# v2 migrate to a more strict approach
await db_context.updatePayload(point) # store the new ocr_text_lower, vector won't update currently
point.local = True
await db_context.updatePayload(point) # This will also store ocr_text_lower field, if present
if point.ocr_text is not None:
point.text_contain_vector = transformers_service.get_bert_vector(point.ocr_text_lower)

logger.info("Updating vectors...")
await db_context.updateVectors(points) # Update vectors for this group of points
# Update vectors for this group of points
await db_context.updateVectors([t for t in points if t.text_contain_vector is not None])
if next_id is None:
break

Expand Down

0 comments on commit 6c7a433

Please sign in to comment.