Skip to content

Commit

Permalink
vector search
Browse files Browse the repository at this point in the history
Signed-off-by: Praneeth Bedapudi <[email protected]>
  • Loading branch information
bedapudi6788 committed Jan 28, 2024
1 parent dff4f69 commit 40a594e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 94 deletions.
143 changes: 50 additions & 93 deletions liteindex/defined_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@

class DefinedIndex:
def __init__(
self,
name,
schema=None,
db_path=None,
ram_cache_mb=64,
compression_level=-1,
sortable_vector_keys_dims={},
self, name, schema=None, db_path=None, ram_cache_mb=64, compression_level=-1
):
if sqlite3.sqlite_version < "3.35.0":
raise ValueError(
Expand All @@ -70,7 +64,6 @@ def __init__(
self.db_path = ":memory:" if db_path is None else db_path
self.ram_cache_mb = ram_cache_mb
self.compression_level = compression_level
self.sortable_vector_keys_dims = sortable_vector_keys_dims

if self.name.startswith("__"):
raise ValueError("Index name cannot start with '__'")
Expand All @@ -79,6 +72,7 @@ def __init__(
self.__column_names = ["id", "updated_at"]

self.__vector_search_indexes = {}
self.__vector_indexes_last_updated_at = {}

self.__local_storage = threading.local()

Expand All @@ -98,93 +92,49 @@ def __init__(
self.__parse_schema()
self.__create_table_and_meta_table()

if self.sortable_vector_keys_dims:
if faiss is None:
raise ValueError("faiss must be installed. `pip install faiss-cpu`")

for k in self.sortable_vector_keys_dims:
if k not in self.schema:
raise ValueError(f"{k} must be one of {list(self.schema.keys())}")
if self.schema[k] != "normalized_embedding":
raise ValueError(f"{k} must be of type normalized_embedding")

if self.sortable_vector_keys_dims[k] is None:
raise ValueError("vector dimensions must be provided")

self.__build_vector_search_indexes()
self.__meta_schema = self.schema.copy()
self.__meta_schema["updated_at"] = "number"
self.__meta_schema["integer_id"] = "number"

def __build_vector_search_indexes(self):
if len(self.__vector_search_indexes) > 0:
return

self.__vector_search_indexes = {
k: faiss.IndexIDMap(faiss.IndexFlatIP(self.sortable_vector_keys_dims[k]))
for k in self.sortable_vector_keys_dims
}
def __update_vector_search_index(self, for_key, dim=None):
if for_key not in self.__vector_indexes_last_updated_at:
self.__vector_search_indexes[for_key] = faiss.IndexIDMap(
faiss.IndexFlatIP(dim)
)
self.__vector_indexes_last_updated_at[for_key] = 0

embeddings_batches = {}
integer_id_batches = {}
embeddings_batch = []
integer_id_batch = []
batch_len = 0

for _id, _data in self.search(
return_metadata=True, select_keys=self.sortable_vector_keys_dims.keys()
query={for_key: {"$ne": None}},
return_metadata=True,
select_keys=[for_key],
meta_query={
"integer_id": {"$gte": self.__vector_indexes_last_updated_at[for_key]}
},
).items():
for k in self.sortable_vector_keys_dims:
embedding = _data[k]
if embedding is not None:
if k not in embeddings_batches:
embeddings_batches[k] = []
integer_id_batches[k] = []

embeddings_batches[k].append(embedding)
integer_id_batches[k].append(_data["__meta"]["integer_id"])
batch_len += 1

if batch_len >= 100000:
for k in integer_id_batches:
if integer_id_batches[k]:
self.__vector_search_indexes[k].add_with_ids(
np.array(embeddings_batches[k], dtype=np.float32),
np.array(integer_id_batches[k], dtype=np.int64),
)
embeddings_batches[k] = []
integer_id_batches[k] = []
self.__vector_indexes_last_updated_at[for_key] = time.time()

embeddings_batch.append(_data[for_key])
integer_id_batch.append(_data["__meta"]["integer_id"])
batch_len += 1

if batch_len >= 10000:
self.__vector_search_indexes[k].add_with_ids(
np.array(embeddings_batch, dtype=np.float32),
np.array(integer_id_batch, dtype=np.int64),
)
batch_len = 0
embeddings_batch = []
integer_id_batch = []

if batch_len > 0:
for k in integer_id_batches:
if integer_id_batches[k]:
self.__vector_search_indexes[k].add_with_ids(
np.array(embeddings_batches[k], dtype=np.float32),
np.array(integer_id_batches[k], dtype=np.int64),
)
embeddings_batches[k] = []
integer_id_batches[k] = []

def __add_to_vector_search_index(self, _ids):
embeddings_batches = {}
integer_id_batches = {}

for _id, _data in self.get(
list(_ids), select_keys=self.sortable_vector_keys_dims, return_metadata=True
).items():
for key_name in self.sortable_vector_keys_dims:
if key_name not in embeddings_batches:
embeddings_batches[key_name] = []
integer_id_batches[key_name] = []

embedding = _data[key_name]
if embedding is not None:
embeddings_batches[key_name].append(embedding)
integer_id_batches[key_name].append(_data["__meta"]["integer_id"])

for key_name in self.sortable_vector_keys_dims:
if len(integer_id_batches[key_name]) > 0:
self.__vector_search_indexes[key_name].add_with_ids(
np.array(embeddings_batches[key_name], dtype=np.float32),
np.array(integer_id_batches[key_name], dtype=np.int64),
)
self.__vector_search_indexes[for_key].add_with_ids(
np.array(embeddings_batch, dtype=np.float32),
np.array(integer_id_batch, dtype=np.int64),
)

def __get_scores_and_integer_ids_table_name(self, query_embedding, key_name):
query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
Expand Down Expand Up @@ -378,9 +328,6 @@ def yield_transaction():

self.__connection.executemany(sql, yield_transaction())

if self.sortable_vector_keys_dims:
self.__add_to_vector_search_index(data.keys())

def get(
self,
ids,
Expand Down Expand Up @@ -473,6 +420,7 @@ def search(
return_metadata=False,
metadata_key_name="__meta",
query_vector=None,
meta_query={},
):
if not sort_by:
sort_by = "updated_at"
Expand All @@ -482,24 +430,30 @@ def search(

sorting_by_vector = False

if sort_by in self.sortable_vector_keys_dims:
if self.schema.get(sort_by) == "normalized_embedding":
if query_vector is None:
raise ValueError("query_vector must be provided")

self.__update_vector_search_index(sort_by, len(query_vector))

sorting_by_vector = True

integer_ids_to_scores_table_name = (
self.__get_scores_and_integer_ids_table_name(query_vector, sort_by)
)
sorting_by_vector = True

if not select_keys:
select_keys = self.schema

select_keys = tuple(select_keys)

if meta_query:
query.update(meta_query)

sql_query, sql_params = search_query(
table_name=self.name,
query=query,
schema=self.schema,
schema=self.__meta_schema,
sort_by=sort_by,
reversed_sort=reversed_sort,
n=n,
Expand All @@ -520,8 +474,6 @@ def search(
f"INNER JOIN {integer_ids_to_scores_table_name} ON {self.name}.integer_id = {integer_ids_to_scores_table_name}._integer_id ORDER BY {integer_ids_to_scores_table_name}.score",
)

print(sql_query, sql_params)

_results = None

if update:
Expand All @@ -541,6 +493,11 @@ def search(
else:
_results = self.__connection.execute(sql_query, sql_params).fetchall()

if sorting_by_vector:
self.__connection.execute(
f"DROP TABLE IF EXISTS {integer_ids_to_scores_table_name}"
)

results = {}

for result in _results:
Expand Down
11 changes: 10 additions & 1 deletion liteindex/defined_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,16 @@ def serialize_record(schema, record, compressor, _id=None, _updated_at=None):
_record[k] = None if v is None else json.dumps(v)

elif _type == "normalized_embedding":
v = None if v is None else v.tobytes()
if v is None:
v = None
else:
try:
if v.ndim == 1 and v.dtype == np.float32:
v = v.tobytes()
else:
raise ValueError("Invalid embedding")
except Exception:
raise ValueError("Invalid embedding")

_record[k] = (
None
Expand Down

0 comments on commit 40a594e

Please sign in to comment.