Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kdbai version compatible #14402

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.kdbai.utils import default_sparse_encoder
from llama_index.vector_stores.kdbai.utils import (
default_sparse_encoder_v1,
convert_metadata_col_v1,
default_sparse_encoder_v2,
convert_metadata_col_v2,
)

DEFAULT_COLUMN_NAMES = ["document_id", "text", "embedding"]

Expand All @@ -27,24 +32,6 @@
logger = logging.getLogger(__name__)


# MATCH THE METADATA COLUMN DATA TYPE TO ITS PYTYPE
def convert_metadata_col(column, value):
try:
if column["pytype"] == "str":
return str(value)
elif column["pytype"] == "bytes":
return value.encode("utf-8")
elif column["pytype"] == "datetime64[ns]":
return pd.to_datetime(value)
elif column["pytype"] == "timedelta64[ns]":
return pd.to_timedelta(value)
return value.astype(column["pytype"])
except Exception as e:
logger.error(
f"Failed to convert column {column['name']} to type {column['pytype']}: {e}"
)


class KDBAIVectorStore(BasePydanticVectorStore):
"""The KDBAI Vector Store.

Expand Down Expand Up @@ -97,7 +84,10 @@ def __init__(

if hybrid_search:
if sparse_encoder is None:
self._sparse_encoder = default_sparse_encoder
if kdbai.version("kdbai_client") >= "1.2.0":
self._sparse_encoder = default_sparse_encoder_v2
else:
self._sparse_encoder = default_sparse_encoder_v1
else:
self._sparse_encoder = sparse_encoder

Expand Down Expand Up @@ -125,11 +115,31 @@ def add(
Returns:
List[str]: List of document IDs that were added.
"""
try:
import kdbai_client as kdbai

logger.info("KDBAI client version: " + kdbai.__version__)

except ImportError:
raise ValueError(
"Could not import kdbai_client package."
"Please add it to the dependencies."
)

df = pd.DataFrame()
docs = []
schema = self._table.schema()["columns"]

if kdbai.version("kdbai_client") >= "1.2.0":
schema = self._table.schema["schema"]["c"]
types = self._table.schema["schema"]["t"].decode("utf-8")
else:
schema = self._table.schema()["columns"]

if self.hybrid_search:
schema = [item for item in schema if item["name"] != "sparseVectors"]
if kdbai.version("kdbai_client") >= "1.2.0":
schema = [item for item in schema if item != "sparseVectors"]
else:
schema = [item for item in schema if item["name"] != "sparseVectors"]

try:
for node in nodes:
Expand All @@ -144,15 +154,29 @@ def add(

# handle extra columns
if len(schema) > len(DEFAULT_COLUMN_NAMES):
for column in schema[len(DEFAULT_COLUMN_NAMES) :]:
try:
doc[column["name"]] = convert_metadata_col(
column, node.metadata[column["name"]]
)
except Exception as e:
logger.error(
f"Error writing column {column['name']} as type {column['pytype']}: {e}."
)
if kdbai.version("kdbai_client") >= "1.2.0":
for column_name, column_type in zip(
schema[len(DEFAULT_COLUMN_NAMES) :],
types[len(DEFAULT_COLUMN_NAMES) :],
):
try:
doc[column_name] = convert_metadata_col_v2(
column_name, column_type, node.metadata[column_name]
)
except Exception as e:
logger.error(
f"Error writing column {column_name} as qtype {column_type}: {e}."
)
else:
for column in schema[len(DEFAULT_COLUMN_NAMES) :]:
try:
doc[column["name"]] = convert_metadata_col_v1(
column, node.metadata[column["name"]]
)
except Exception as e:
logger.error(
f"Error writing column {column['name']} as type {column['pytype']}: {e}."
)

docs.append(doc)

Expand All @@ -173,14 +197,30 @@ def add(
logger.error(f"Error preparing data for KDB.AI: {e}.")

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
try:
import kdbai_client as kdbai

logger.info("KDBAI client version: " + kdbai.__version__)

except ImportError:
raise ValueError(
"Could not import kdbai_client package."
"Please add it to the dependencies."
)

if query.filters is None:
filter = []
else:
filter = query.filters

if self.hybrid_search:
alpha = query.alpha if query.alpha is not None else 0.5
sparse_vectors = self._sparse_encoder([query.query_str])

if kdbai.version("kdbai_client") >= "1.2.0":
sparse_vectors = [self._sparse_encoder([query.query_str])]
else:
sparse_vectors = self._sparse_encoder([query.query_str])

results = self._table.hybrid_search(
dense_vectors=[query.query_embedding],
sparse_vectors=sparse_vectors,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,51 @@
from typing import List, Dict
import logging
import pandas as pd

logger = logging.getLogger(__name__)

def default_sparse_encoder(texts: List[str]) -> List[Dict[int, int]]:

def default_sparse_encoder_v2(texts: List[str]) -> Dict[int, int]:
try:
from transformers import BertTokenizer
from collections import Counter
except ImportError:
raise ImportError(
"Could not import transformers library. "
'Please install transformers with `pip install "transformers"`'
)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenized_texts = tokenizer(texts, padding=True, truncation=True, max_length=512)[
"input_ids"
]

flat_tokenized_texts = [
token_id for sublist in tokenized_texts for token_id in sublist
]

return dict(Counter(flat_tokenized_texts))


# MATCH THE METADATA COLUMN DATA TYPE TO ITS PYTYPE
def convert_metadata_col_v2(column_name, column_type, column_value):
try:
if column_type == "s":
return str(column_value)
elif column_type == "C":
return column_value.encode("utf-8")
elif column_type == "p":
return pd.to_datetime(column_value)
elif column_type == "n":
return pd.to_timedelta(column_value)
return column_value.astype(column_type)
except Exception as e:
logger.error(
f"Failed to convert column {column_name} to qtype {column_type}: {e}"
)


def default_sparse_encoder_v1(texts: List[str]) -> List[Dict[int, int]]:
try:
from transformers import BertTokenizer
from collections import Counter
Expand All @@ -20,3 +64,20 @@ def default_sparse_encoder(texts: List[str]) -> List[Dict[int, int]]:
sparse_encoding = dict(Counter(tokenized_text))
results.append(sparse_encoding)
return results


def convert_metadata_col_v1(column, value):
try:
if column["pytype"] == "str":
return str(value)
elif column["pytype"] == "bytes":
return value.encode("utf-8")
elif column["pytype"] == "datetime64[ns]":
return pd.to_datetime(value)
elif column["pytype"] == "timedelta64[ns]":
return pd.to_timedelta(value)
return value.astype(column["pytype"])
except Exception as e:
logger.error(
f"Failed to convert column {column['name']} to type {column['pytype']}: {e}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ version = "0.1.6"
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.0"
pykx = "^2.1.1"
kdbai-client = "^0.1.2"
kdbai-client = ">=1.1.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down
Loading