Skip to content

Commit

Permalink
lancedb integration update (#11490)
Browse files Browse the repository at this point in the history
* updated example, pkg breakage due to refactor errors fixed

* typos

* nb clean

* output restored

* formatting

* typo

* integration update

* error fix

* pytest fixes
  • Loading branch information
raghavdixit99 authored Feb 29, 2024
1 parent 238f82c commit 5a4d246
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional

import numpy as np
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import (
BaseNode,
MetadataMode,
Expand All @@ -13,7 +14,7 @@
)
from llama_index.core.vector_stores.types import (
MetadataFilters,
VectorStore,
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
Expand Down Expand Up @@ -54,7 +55,7 @@ def _to_llama_similarities(results: DataFrame) -> List[float]:
return normalized_similarities.tolist()


class LanceDBVectorStore(VectorStore):
class LanceDBVectorStore(BasePydanticVectorStore):
"""
The LanceDB Vector Store.
Expand Down Expand Up @@ -84,10 +85,18 @@ class LanceDBVectorStore(VectorStore):

stores_text = True
flat_metadata: bool = True
_connection: Any = PrivateAttr()
uri: Optional[str]
table_name: Optional[str]
vector_column_name: Optional[str]
nprobes: Optional[int]
refine_factor: Optional[int]
text_key: Optional[str]
doc_id_key: Optional[str]

def __init__(
self,
uri: str,
uri: Optional[str],
table_name: str = "vectors",
vector_column_name: str = "vector",
nprobes: int = 20,
Expand All @@ -97,19 +106,48 @@ def __init__(
**kwargs: Any,
) -> None:
"""Init params."""
self.connection = lancedb.connect(uri)
self.uri = uri
self.table_name = table_name
self.vector_column_name = vector_column_name
self.nprobes = nprobes
self.text_key = text_key
self.doc_id_key = doc_id_key
self.refine_factor = refine_factor
self._connection = lancedb.connect(uri)
super().__init__(
uri=uri,
table_name=table_name,
vector_column_name=vector_column_name,
nprobes=nprobes,
refine_factor=refine_factor,
text_key=text_key,
doc_id_key=doc_id_key,
**kwargs,
)

@property
def client(self) -> None:
"""Get client."""
return
return self._connection

@classmethod
def from_params(
cls,
uri: Optional[str],
table_name: str = "vectors",
vector_column_name: str = "vector",
nprobes: int = 20,
refine_factor: Optional[int] = None,
text_key: str = DEFAULT_TEXT_KEY,
doc_id_key: str = DEFAULT_DOC_ID_KEY,
**kwargs: Any,
) -> "LanceDBVectorStore":
"""Create instance from params."""
_connection_ = cls._connection
return cls(
_connection=_connection_,
uri=uri,
table_name=table_name,
vector_column_name=vector_column_name,
nprobes=nprobes,
refine_factor=refine_factor,
text_key=text_key,
doc_id_key=doc_id_key,
**kwargs,
)

def add(
self,
Expand All @@ -132,11 +170,11 @@ def add(
data.append(append_data)
ids.append(node.node_id)

if self.table_name in self.connection.table_names():
tbl = self.connection.open_table(self.table_name)
if self.table_name in self._connection.table_names():
tbl = self._connection.open_table(self.table_name)
tbl.add(data)
else:
self.connection.create_table(self.table_name, data)
self._connection.create_table(self.table_name, data)
return ids

def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
Expand All @@ -147,7 +185,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
ref_doc_id (str): The doc_id of the document to delete.
"""
table = self.connection.open_table(self.table_name)
table = self._connection.open_table(self.table_name)
table.delete('document_id = "' + ref_doc_id + '"')

def query(
Expand All @@ -167,7 +205,7 @@ def query(
else:
where = kwargs.pop("where", None)

table = self.connection.open_table(self.table_name)
table = self._connection.open_table(self.table_name)
lance_query = (
table.search(
query=query.query_embedding,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.lancedb import LanceDBVectorStore


def test_class():
names_of_base_classes = [b.__name__ for b in LanceDBVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes

0 comments on commit 5a4d246

Please sign in to comment.