From 86f5b979fbbf0966bbc7820c5f70c10b48b3721a Mon Sep 17 00:00:00 2001 From: Anush008 <46051506+Anush008@users.noreply.github.com> Date: Mon, 6 Nov 2023 02:12:19 +0530 Subject: [PATCH] chore: named vector, fix collection_check --- dlt/destinations/qdrant/qdrant_client.py | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/qdrant/qdrant_client.py b/dlt/destinations/qdrant/qdrant_client.py index 386cb21623..cba87e9528 100644 --- a/dlt/destinations/qdrant/qdrant_client.py +++ b/dlt/destinations/qdrant/qdrant_client.py @@ -53,6 +53,8 @@ def __init__( db_client.embedding_model_name) embeddings = list(embedding_model.embed( docs, batch_size=self.config.embedding_batch_size, parallel=self.config.embedding_parallelism)) + vector_name = db_client.get_vector_field_name() + embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] assert len(embeddings) == len(payloads) == len(ids) self._upload_data(vectors=embeddings, ids=ids, payloads=payloads) @@ -176,12 +178,14 @@ def _create_collection( full_collection_name (str): The name of the collection to be created. """ - # A straight-forward method named get_fastembed_vector_params() exists in the qdrant_client package. - # But, it generates a named vector with the model name as the vector name. But, we need an unnamed vector. - embeddings_size, distance = self.db_client._get_model_params( - model_name=self.db_client.embedding_model_name) - vectors_config = models.VectorParams( - size=embeddings_size, distance=distance) + # Generates config for a named vector according to the selected model. + # Eg: vector_config={ + # "fast-bge-small-en": { + # "size": 364, + # "distance": "Cosine" + # }, + # } + vectors_config = self.db_client.get_fastembed_vector_params() self.db_client.create_collection( collection_name=full_collection_name, vectors_config=vectors_config) @@ -239,7 +243,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: self._create_collection(full_collection_name=qualified_table_name) def is_storage_initialized(self) -> bool: - return self._collection_exists(self.sentinel_collection) + return self._collection_exists(self.sentinel_collection, qualify_table_name=False) def _create_sentinel_collection(self) -> None: """Create an empty collection to indicate that the storage is initialized.""" @@ -393,10 +397,11 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: ) self._update_schema_in_storage(self.schema) - def _collection_exists(self, table_name: str) -> bool: + def _collection_exists(self, table_name: str, qualify_table_name: bool = True) -> bool: try: - self.db_client.get_collection( - self._make_qualified_collection_name(table_name)) + table_name = self._make_qualified_collection_name( + table_name) if qualify_table_name else table_name + self.db_client.get_collection(table_name) return True except UnexpectedResponse as e: if e.status_code == 404: