Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
hatianzhang committed Feb 16, 2024
1 parent 18d700c commit abebfd2
Showing 1 changed file with 21 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class DatabricksVectorSearch(BasePydanticVectorStore):
Install ``databricks-vectorsearch`` package using the following in a Databricks notebook:
%pip install databricks-vectorsearch
dbutils.library.restartPython()
"""

stores_text: bool = True
Expand Down Expand Up @@ -150,7 +151,8 @@ def __init__(
self._direct_access_index_spec = index_description.direct_access_index_spec

super().__init__(
text_column=text_column, columns=columns,
text_column=text_column,
columns=columns,
)

# initialize the column name for the text column in the delta table
Expand Down Expand Up @@ -190,7 +192,11 @@ def __init__(
f"columns missing from schema: {', '.join(missing_columns)}"
)

def add(self, nodes: List[BaseNode], **add_kwargs: Any,) -> List[str]:
def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index.
Args:
Expand Down Expand Up @@ -226,7 +232,9 @@ def add(self, nodes: List[BaseNode], **add_kwargs: Any,) -> List[str]:
ids.append(node_id)

# attempt the upsert
upsert_resp = self._index.upsert(entries,)
upsert_resp = self._index.upsert(
entries,
)

# return the successful IDs
response_status = upsert_resp.get("status")
Expand Down Expand Up @@ -261,11 +269,12 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._index.delete(primary_keys=[ref_doc_id],)
self._index.delete(
primary_keys=[ref_doc_id],
)

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
"""
"""Query index for top k most similar nodes."""
if self._is_databricks_managed_embeddings():
query_text = query.query_str
query_vector = None
Expand Down Expand Up @@ -295,13 +304,12 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
)

columns = [
col["name"]
for col in search_resp.get("manifest", dict()).get("columns", [])
col["name"] for col in search_resp.get("manifest", {}).get("columns", [])
]
top_k_nodes = []
top_k_ids = []
top_k_scores = []
for result in search_resp.get("result", dict()).get("data_array", []):
for result in search_resp.get("result", {}).get("data_array", []):
doc_id = result[columns.index(self._primary_key)]
text_content = result[columns.index(self.text_column)]
metadata = {
Expand All @@ -325,7 +333,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul

@property
def client(self) -> Any:
"""Return VectorStoreIndex"""
"""Return VectorStoreIndex."""
return self._index

# The remaining utilities (and snippets of the above) are taken from
Expand Down Expand Up @@ -355,7 +363,7 @@ def _embedding_vector_column(self) -> dict:
if self._is_delta_sync_index()
else self._direct_access_index_spec
)
return next(iter(index_spec.get("embedding_vector_columns") or list()), dict())
return next(iter(index_spec.get("embedding_vector_columns") or []), {})

def _embedding_source_column_name(self) -> Optional[str]:
"""Return the name of the embedding source column.
Expand All @@ -368,8 +376,8 @@ def _embedding_source_column(self) -> dict:
Empty if the index is not a Databricks-managed embedding index.
"""
return next(
iter(self._delta_sync_index_spec.get("embedding_source_columns") or list()),
dict(),
iter(self._delta_sync_index_spec.get("embedding_source_columns") or []),
{},
)

def _is_delta_sync_index(self) -> bool:
Expand Down

0 comments on commit abebfd2

Please sign in to comment.