diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index fc063d7..b4cb497 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -115,9 +115,9 @@ async def updatePayload(self, new_data: ImageData): async def updateVectors(self, new_points: list[ImageData]): resp = await self._client.update_vectors(collection_name=self.collection_name, - points=[models.PointVectors(**self._get_point_from_img_data(t) - .model_dump()) for t in new_points], + points=[self._get_vector_from_img_data(t) for t in new_points], ) + logger.success("Update vectors completed! Status: {}", resp.status) async def scroll_points(self, from_id: str | None = None, @@ -132,16 +132,23 @@ async def scroll_points(self, return [self._get_img_data_from_point(t) for t in resp], next_id @classmethod - def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: + def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: vector = {} if img_data.image_vector is not None: vector[cls.IMG_VECTOR] = img_data.image_vector.tolist() if img_data.text_contain_vector is not None: vector[cls.TEXT_VECTOR] = img_data.text_contain_vector.tolist() + return models.PointVectors( + id=str(img_data.id), + vector=vector + ) + + @classmethod + def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: return models.PointStruct( id=str(img_data.id), - vector=vector, - payload=img_data.payload + payload=img_data.payload, + vector=cls._get_vector_from_img_data(img_data).vector ) @classmethod diff --git a/scripts/db_migrations.py b/scripts/db_migrations.py index b88fb4b..86805a9 100644 --- a/scripts/db_migrations.py +++ b/scripts/db_migrations.py @@ -14,14 +14,17 @@ async def migrate_v1_v2(): for point in points: count += 1 logger.info("[{}] Migrating point {}", count, point.id) - if point.ocr_text is not None: - point.text_contain_vector = transformers_service.get_bert_vector(point.ocr_text_lower) if point.url.startswith('/'): - point.local = True # V1 database assuming all image with '/' as begins is a local image, + # V1 database assuming all image with '/' as begins is a local image, # v2 migrate to a more strict approach - await db_context.updatePayload(point) # store the new ocr_text_lower, vector won't update currently + point.local = True + await db_context.updatePayload(point) # This will also store ocr_text_lower field, if present + if point.ocr_text is not None: + point.text_contain_vector = transformers_service.get_bert_vector(point.ocr_text_lower) + logger.info("Updating vectors...") - await db_context.updateVectors(points) # Update vectors for this group of points + # Update vectors for this group of points + await db_context.updateVectors([t for t in points if t.text_contain_vector is not None]) if next_id is None: break