From abebfd282638724f4d1ab22901ef01593f96e643 Mon Sep 17 00:00:00 2001 From: Haotian Zhang Date: Fri, 16 Feb 2024 10:11:10 -0500 Subject: [PATCH] cr --- .../databricks-vector-search/base.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-databricks-vector-search/llama_index/vector_stores/databricks-vector-search/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-databricks-vector-search/llama_index/vector_stores/databricks-vector-search/base.py index 9ef9d0298d3d8..1089f770b5dbb 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-databricks-vector-search/llama_index/vector_stores/databricks-vector-search/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-databricks-vector-search/llama_index/vector_stores/databricks-vector-search/base.py @@ -112,6 +112,7 @@ class DatabricksVectorSearch(BasePydanticVectorStore): Install ``databricks-vectorsearch`` package using the following in a Databricks notebook: %pip install databricks-vectorsearch dbutils.library.restartPython() + """ stores_text: bool = True @@ -150,7 +151,8 @@ def __init__( self._direct_access_index_spec = index_description.direct_access_index_spec super().__init__( - text_column=text_column, columns=columns, + text_column=text_column, + columns=columns, ) # initialize the column name for the text column in the delta table @@ -190,7 +192,11 @@ def __init__( f"columns missing from schema: {', '.join(missing_columns)}" ) - def add(self, nodes: List[BaseNode], **add_kwargs: Any,) -> List[str]: + def add( + self, + nodes: List[BaseNode], + **add_kwargs: Any, + ) -> List[str]: """Add nodes to index. Args: @@ -226,7 +232,9 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any,) -> List[str]: ids.append(node_id) # attempt the upsert - upsert_resp = self._index.upsert(entries,) + upsert_resp = self._index.upsert( + entries, + ) # return the successful IDs response_status = upsert_resp.get("status") @@ -261,11 +269,12 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: ref_doc_id (str): The doc_id of the document to delete. """ - self._index.delete(primary_keys=[ref_doc_id],) + self._index.delete( + primary_keys=[ref_doc_id], + ) def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Query index for top k most similar nodes. - """ + """Query index for top k most similar nodes.""" if self._is_databricks_managed_embeddings(): query_text = query.query_str query_vector = None @@ -295,13 +304,12 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul ) columns = [ - col["name"] - for col in search_resp.get("manifest", dict()).get("columns", []) + col["name"] for col in search_resp.get("manifest", {}).get("columns", []) ] top_k_nodes = [] top_k_ids = [] top_k_scores = [] - for result in search_resp.get("result", dict()).get("data_array", []): + for result in search_resp.get("result", {}).get("data_array", []): doc_id = result[columns.index(self._primary_key)] text_content = result[columns.index(self.text_column)] metadata = { @@ -325,7 +333,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul @property def client(self) -> Any: - """Return VectorStoreIndex""" + """Return VectorStoreIndex.""" return self._index # The remaining utilities (and snippets of the above) are taken from @@ -355,7 +363,7 @@ def _embedding_vector_column(self) -> dict: if self._is_delta_sync_index() else self._direct_access_index_spec ) - return next(iter(index_spec.get("embedding_vector_columns") or list()), dict()) + return next(iter(index_spec.get("embedding_vector_columns") or []), {}) def _embedding_source_column_name(self) -> Optional[str]: """Return the name of the embedding source column. @@ -368,8 +376,8 @@ def _embedding_source_column(self) -> dict: Empty if the index is not a Databricks-managed embedding index. """ return next( - iter(self._delta_sync_index_spec.get("embedding_source_columns") or list()), - dict(), + iter(self._delta_sync_index_spec.get("embedding_source_columns") or []), + {}, ) def _is_delta_sync_index(self) -> bool: