Skip to content

Commit

Permalink
Kdbai version compatible (#14402)
Browse files Browse the repository at this point in the history
  • Loading branch information
mshawFD authored Jun 26, 2024
1 parent 39cdf82 commit 1a0a999
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.kdbai.utils import default_sparse_encoder
from llama_index.vector_stores.kdbai.utils import (
default_sparse_encoder_v1,
convert_metadata_col_v1,
default_sparse_encoder_v2,
convert_metadata_col_v2,
)

DEFAULT_COLUMN_NAMES = ["document_id", "text", "embedding"]

Expand All @@ -27,24 +32,6 @@
logger = logging.getLogger(__name__)


# MATCH THE METADATA COLUMN DATA TYPE TO ITS PYTYPE
def convert_metadata_col(column, value):
try:
if column["pytype"] == "str":
return str(value)
elif column["pytype"] == "bytes":
return value.encode("utf-8")
elif column["pytype"] == "datetime64[ns]":
return pd.to_datetime(value)
elif column["pytype"] == "timedelta64[ns]":
return pd.to_timedelta(value)
return value.astype(column["pytype"])
except Exception as e:
logger.error(
f"Failed to convert column {column['name']} to type {column['pytype']}: {e}"
)


class KDBAIVectorStore(BasePydanticVectorStore):
"""The KDBAI Vector Store.
Expand Down Expand Up @@ -97,7 +84,10 @@ def __init__(

if hybrid_search:
if sparse_encoder is None:
self._sparse_encoder = default_sparse_encoder
if kdbai.version("kdbai_client") >= "1.2.0":
self._sparse_encoder = default_sparse_encoder_v2
else:
self._sparse_encoder = default_sparse_encoder_v1
else:
self._sparse_encoder = sparse_encoder

Expand Down Expand Up @@ -125,11 +115,31 @@ def add(
Returns:
List[str]: List of document IDs that were added.
"""
try:
import kdbai_client as kdbai

logger.info("KDBAI client version: " + kdbai.__version__)

except ImportError:
raise ValueError(
"Could not import kdbai_client package."
"Please add it to the dependencies."
)

df = pd.DataFrame()
docs = []
schema = self._table.schema()["columns"]

if kdbai.version("kdbai_client") >= "1.2.0":
schema = self._table.schema["schema"]["c"]
types = self._table.schema["schema"]["t"].decode("utf-8")
else:
schema = self._table.schema()["columns"]

if self.hybrid_search:
schema = [item for item in schema if item["name"] != "sparseVectors"]
if kdbai.version("kdbai_client") >= "1.2.0":
schema = [item for item in schema if item != "sparseVectors"]
else:
schema = [item for item in schema if item["name"] != "sparseVectors"]

try:
for node in nodes:
Expand All @@ -144,15 +154,29 @@ def add(

# handle extra columns
if len(schema) > len(DEFAULT_COLUMN_NAMES):
for column in schema[len(DEFAULT_COLUMN_NAMES) :]:
try:
doc[column["name"]] = convert_metadata_col(
column, node.metadata[column["name"]]
)
except Exception as e:
logger.error(
f"Error writing column {column['name']} as type {column['pytype']}: {e}."
)
if kdbai.version("kdbai_client") >= "1.2.0":
for column_name, column_type in zip(
schema[len(DEFAULT_COLUMN_NAMES) :],
types[len(DEFAULT_COLUMN_NAMES) :],
):
try:
doc[column_name] = convert_metadata_col_v2(
column_name, column_type, node.metadata[column_name]
)
except Exception as e:
logger.error(
f"Error writing column {column_name} as qtype {column_type}: {e}."
)
else:
for column in schema[len(DEFAULT_COLUMN_NAMES) :]:
try:
doc[column["name"]] = convert_metadata_col_v1(
column, node.metadata[column["name"]]
)
except Exception as e:
logger.error(
f"Error writing column {column['name']} as type {column['pytype']}: {e}."
)

docs.append(doc)

Expand All @@ -173,14 +197,30 @@ def add(
logger.error(f"Error preparing data for KDB.AI: {e}.")

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
try:
import kdbai_client as kdbai

logger.info("KDBAI client version: " + kdbai.__version__)

except ImportError:
raise ValueError(
"Could not import kdbai_client package."
"Please add it to the dependencies."
)

if query.filters is None:
filter = []
else:
filter = query.filters

if self.hybrid_search:
alpha = query.alpha if query.alpha is not None else 0.5
sparse_vectors = self._sparse_encoder([query.query_str])

if kdbai.version("kdbai_client") >= "1.2.0":
sparse_vectors = [self._sparse_encoder([query.query_str])]
else:
sparse_vectors = self._sparse_encoder([query.query_str])

results = self._table.hybrid_search(
dense_vectors=[query.query_embedding],
sparse_vectors=sparse_vectors,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,51 @@
from typing import List, Dict
import logging
import pandas as pd

logger = logging.getLogger(__name__)

def default_sparse_encoder(texts: List[str]) -> List[Dict[int, int]]:

def default_sparse_encoder_v2(texts: List[str]) -> Dict[int, int]:
try:
from transformers import BertTokenizer
from collections import Counter
except ImportError:
raise ImportError(
"Could not import transformers library. "
'Please install transformers with `pip install "transformers"`'
)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenized_texts = tokenizer(texts, padding=True, truncation=True, max_length=512)[
"input_ids"
]

flat_tokenized_texts = [
token_id for sublist in tokenized_texts for token_id in sublist
]

return dict(Counter(flat_tokenized_texts))


# MATCH THE METADATA COLUMN DATA TYPE TO ITS PYTYPE
def convert_metadata_col_v2(column_name, column_type, column_value):
try:
if column_type == "s":
return str(column_value)
elif column_type == "C":
return column_value.encode("utf-8")
elif column_type == "p":
return pd.to_datetime(column_value)
elif column_type == "n":
return pd.to_timedelta(column_value)
return column_value.astype(column_type)
except Exception as e:
logger.error(
f"Failed to convert column {column_name} to qtype {column_type}: {e}"
)


def default_sparse_encoder_v1(texts: List[str]) -> List[Dict[int, int]]:
try:
from transformers import BertTokenizer
from collections import Counter
Expand All @@ -20,3 +64,20 @@ def default_sparse_encoder(texts: List[str]) -> List[Dict[int, int]]:
sparse_encoding = dict(Counter(tokenized_text))
results.append(sparse_encoding)
return results


def convert_metadata_col_v1(column, value):
try:
if column["pytype"] == "str":
return str(value)
elif column["pytype"] == "bytes":
return value.encode("utf-8")
elif column["pytype"] == "datetime64[ns]":
return pd.to_datetime(value)
elif column["pytype"] == "timedelta64[ns]":
return pd.to_timedelta(value)
return value.astype(column["pytype"])
except Exception as e:
logger.error(
f"Failed to convert column {column['name']} to type {column['pytype']}: {e}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-kdbai"
readme = "README.md"
version = "0.1.6"
version = "0.1.7"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.0"
pykx = "^2.1.1"
kdbai-client = "^0.1.2"
kdbai-client = ">=1.1.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down

0 comments on commit 1a0a999

Please sign in to comment.