Skip to content

Commit

Permalink
formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mshawFD committed Jun 26, 2024
1 parent 871394c commit d316b5b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
VectorStoreQueryResult,
)
from llama_index.vector_stores.kdbai.utils import (
default_sparse_encoder_v1,
default_sparse_encoder_v1,
convert_metadata_col_v1,
default_sparse_encoder_v2,
default_sparse_encoder_v2,
convert_metadata_col_v2,
)

Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(

if hybrid_search:
if sparse_encoder is None:
if kdbai.version("kdbai_client") >= '1.2.0':
if kdbai.version("kdbai_client") >= "1.2.0":
self._sparse_encoder = default_sparse_encoder_v2
else:
self._sparse_encoder = default_sparse_encoder_v1
Expand Down Expand Up @@ -125,22 +125,22 @@ def add(
"Could not import kdbai_client package."
"Please add it to the dependencies."
)

df = pd.DataFrame()
docs = []

if kdbai.version("kdbai_client") >= '1.2.0':
schema = self._table.schema['schema']['c']
types = self._table.schema['schema']['t'].decode('utf-8')
if kdbai.version("kdbai_client") >= "1.2.0":
schema = self._table.schema["schema"]["c"]
types = self._table.schema["schema"]["t"].decode("utf-8")
else:
schema = self._table.schema()["columns"]

if self.hybrid_search:
if kdbai.version("kdbai_client") >= '1.2.0':
if kdbai.version("kdbai_client") >= "1.2.0":
schema = [item for item in schema if item != "sparseVectors"]
else:
schema = [item for item in schema if item["name"] != "sparseVectors"]

try:
for node in nodes:
doc = {
Expand All @@ -154,8 +154,11 @@ def add(

# handle extra columns
if len(schema) > len(DEFAULT_COLUMN_NAMES):
if kdbai.version("kdbai_client") >= '1.2.0':
for column_name, column_type in zip(schema[len(DEFAULT_COLUMN_NAMES):], types[len(DEFAULT_COLUMN_NAMES):]):
if kdbai.version("kdbai_client") >= "1.2.0":
for column_name, column_type in zip(
schema[len(DEFAULT_COLUMN_NAMES) :],
types[len(DEFAULT_COLUMN_NAMES) :],
):
try:
doc[column_name] = convert_metadata_col_v2(
column_name, column_type, node.metadata[column_name]
Expand Down Expand Up @@ -192,9 +195,8 @@ def add(

except Exception as e:
logger.error(f"Error preparing data for KDB.AI: {e}.")

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
try:
import kdbai_client as kdbai

Expand All @@ -205,7 +207,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
"Could not import kdbai_client package."
"Please add it to the dependencies."
)

if query.filters is None:
filter = []
else:
Expand All @@ -214,7 +216,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
if self.hybrid_search:
alpha = query.alpha if query.alpha is not None else 0.5

if kdbai.version("kdbai_client") >= '1.2.0':
if kdbai.version("kdbai_client") >= "1.2.0":
sparse_vectors = [self._sparse_encoder([query.query_str])]
else:
sparse_vectors = self._sparse_encoder([query.query_str])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ def default_sparse_encoder_v2(texts: List[str]) -> Dict[int, int]:
)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenized_texts = tokenizer(texts, padding=True, truncation=True, max_length=512)["input_ids"]
tokenized_texts = tokenizer(texts, padding=True, truncation=True, max_length=512)[
"input_ids"
]

flat_tokenized_texts = [token_id for sublist in tokenized_texts for token_id in sublist]
flat_tokenized_texts = [
token_id for sublist in tokenized_texts for token_id in sublist
]

return dict(Counter(flat_tokenized_texts))

sparse_encoding = dict(Counter(flat_tokenized_texts))
return sparse_encoding

# MATCH THE METADATA COLUMN DATA TYPE TO ITS PYTYPE
def convert_metadata_col_v2(column_name, column_type, column_value):
Expand All @@ -40,6 +44,7 @@ def convert_metadata_col_v2(column_name, column_type, column_value):
f"Failed to convert column {column_name} to qtype {column_type}: {e}"
)


def default_sparse_encoder_v1(texts: List[str]) -> List[Dict[int, int]]:
try:
from transformers import BertTokenizer
Expand All @@ -60,6 +65,7 @@ def default_sparse_encoder_v1(texts: List[str]) -> List[Dict[int, int]]:
results.append(sparse_encoding)
return results


def convert_metadata_col_v1(column, value):
try:
if column["pytype"] == "str":
Expand All @@ -75,6 +81,3 @@ def convert_metadata_col_v1(column, value):
logger.error(
f"Failed to convert column {column['name']} to type {column['pytype']}: {e}"
)



Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ version = "0.1.6"
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.0"
pykx = "^2.1.1"
kdbai-client = "^0.1.2"
kdbai-client = ">=1.1.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down

0 comments on commit d316b5b

Please sign in to comment.