-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support for lancedb as vectordb (#644)
* Remove the Weaviate unit test Signed-off-by: SimFG <[email protected]> Signed-off-by: akashAD98 <[email protected]> * fix: avoid loading redis if not needed Signed-off-by: akashAD98 <[email protected]> * Update the version to `0.1.43` Signed-off-by: SimFG <[email protected]> Signed-off-by: akashAD98 <[email protected]> * add the note for the gptcache api Signed-off-by: akashAD <[email protected]> Signed-off-by: akashAD98 <[email protected]> * Fix the nil memory eviction when using the init_similar_cache method Signed-off-by: SimFG <[email protected]> Signed-off-by: akashAD <[email protected]> Signed-off-by: akashAD98 <[email protected]> * Update the version to 0.1.44 Signed-off-by: SimFG <[email protected]> Signed-off-by: akashAD <[email protected]> Signed-off-by: akashAD98 <[email protected]> * added support for lancedb as vectorstore Signed-off-by: akashAD <[email protected]> Signed-off-by: akashAD98 <[email protected]> * Fix pylint issues and improve codes structure Signed-off-by: akashAD <[email protected]> Signed-off-by: akashAD98 <[email protected]> * refactor & pylint fix code Signed-off-by: akashAD98 <[email protected]> * pylint issue fixing Signed-off-by: akashAD98 <[email protected]> --------- Signed-off-by: SimFG <[email protected]> Signed-off-by: akashAD98 <[email protected]> Signed-off-by: akashAD <[email protected]> Co-authored-by: SimFG <[email protected]> Co-authored-by: leio10 <[email protected]>
- Loading branch information
1 parent
17646e1
commit 7492681
Showing
8 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import lancedb | ||
from gptcache.manager.vector_data.base import VectorBase, VectorData | ||
from gptcache.utils import import_lancedb, import_torch | ||
|
||
import_torch() | ||
import_lancedb() | ||
|
||
|
||
class LanceDB(VectorBase): | ||
"""Vector store: LanceDB | ||
:param persist_directory: The directory to persist, defaults to '/tmp/lancedb'. | ||
:type persist_directory: str | ||
:param table_name: The name of the table in LanceDB, defaults to 'gptcache'. | ||
:type table_name: str | ||
:param top_k: The number of the vectors results to return, defaults to 1. | ||
:type top_k: int | ||
""" | ||
|
||
def __init__( | ||
self, | ||
persist_directory: Optional[str] = "/tmp/lancedb", | ||
table_name: str = "gptcache", | ||
top_k: int = 1, | ||
): | ||
self._persist_directory = persist_directory | ||
self._table_name = table_name | ||
self._top_k = top_k | ||
|
||
# Initialize LanceDB database | ||
self._db = lancedb.connect(self._persist_directory) | ||
|
||
# Initialize or open table | ||
if self._table_name not in self._db.table_names(): | ||
self._table = None # Table will be created with the first insertion | ||
else: | ||
self._table = self._db.open_table(self._table_name) | ||
|
||
def mul_add(self, datas: List[VectorData]): | ||
"""Add multiple vectors to the LanceDB table""" | ||
vectors, vector_ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) | ||
# Infer the dimension of the vectors | ||
vector_dim = len(vectors[0]) if vectors else 0 | ||
|
||
# Create table with the inferred schema if it doesn't exist | ||
if self._table is None: | ||
schema = pa.schema([ | ||
pa.field("id", pa.string()), | ||
pa.field("vector", pa.list_(pa.float32(), list_size=vector_dim)) | ||
]) | ||
self._table = self._db.create_table(self._table_name, schema=schema) | ||
|
||
# Prepare and add data to the table | ||
self._table.add(({"id": vector_id, "vector": vector} for vector_id, vector in zip(vector_ids, vectors))) | ||
|
||
def search(self, data: np.ndarray, top_k: int = -1): | ||
"""Search for the most similar vectors in the LanceDB table""" | ||
if len(self._table) == 0: | ||
return [] | ||
|
||
if top_k == -1: | ||
top_k = self._top_k | ||
|
||
results = self._table.search(data.tolist()).limit(top_k).to_list() | ||
return [(result["_distance"], int(result["id"])) for result in results] | ||
|
||
def delete(self, ids: List[int]): | ||
"""Delete vectors from the LanceDB table based on IDs""" | ||
for vector_id in ids: | ||
self._table.delete(f"id = '{vector_id}'") | ||
|
||
def rebuild(self, ids: Optional[List[int]] = None): | ||
"""Rebuild the index, if applicable""" | ||
return True | ||
|
||
def count(self): | ||
"""Return the total number of vectors in the table""" | ||
return len(self._table) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import unittest | ||
import numpy as np | ||
from gptcache.manager import VectorBase | ||
from gptcache.manager.vector_data.base import VectorData | ||
|
||
class TestLanceDB(unittest.TestCase): | ||
def test_normal(self): | ||
|
||
db = VectorBase("lancedb", persist_directory="/tmp/test_lancedb", top_k=3) | ||
|
||
# Add 100 vectors to the LanceDB | ||
db.mul_add([VectorData(id=i, data=np.random.sample(10)) for i in range(100)]) | ||
|
||
# Perform a search with a random query vector | ||
search_res = db.search(np.random.sample(10)) | ||
|
||
# Check that the search returns 3 results | ||
self.assertEqual(len(search_res), 3) | ||
|
||
# Delete vectors with specific IDs | ||
db.delete([1, 3, 5, 7]) | ||
|
||
# Check that the count of vectors in the table is now 96 | ||
self.assertEqual(db.count(), 96) |