diff --git a/Makefile b/Makefile index 05d262bb..e1040b1b 100644 --- a/Makefile +++ b/Makefile @@ -23,23 +23,8 @@ create_conda_env: remove_conda_env: @bash ./scripts/manage_conda_env.sh remove -docs_build: - cd docs && poetry run make html - -docs_clean: - cd docs && poetry run make clean - -docs_linkcheck: - poetry run linkchecker docs/_build/html/index.html - -PYTHON_FILES=. -lint: PYTHON_FILES=. -lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$') - -lint lint_diff: - poetry run mypy $(PYTHON_FILES) - poetry run black $(PYTHON_FILES) --check - poetry run ruff . - pylint_check: - pylint --rcfile=pylint.conf --output-format=colorized gptcache \ No newline at end of file + pylint --rcfile=pylint.conf --output-format=colorized gptcache + +pytest: + pytest tests/ \ No newline at end of file diff --git a/examples/benchmark/benchmark_sqlite_faiss_onnx.py b/examples/benchmark/benchmark_sqlite_faiss_onnx.py index b33e814e..76f107f5 100644 --- a/examples/benchmark/benchmark_sqlite_faiss_onnx.py +++ b/examples/benchmark/benchmark_sqlite_faiss_onnx.py @@ -65,12 +65,8 @@ def range(self): if not has_data: print('insert data') - id_origin = {} - for pair in mock_data: - question = pair['origin'] - answer = pair['id'] - id_origin[answer] = question - cache.data_manager.save(question, answer, cache.embedding_func(question)) + questions, answers = zip(*((pair['origin'], pair['id']) for pair in mock_data)) + cache.import_data(questions=questions, answers=answers) print('end insert data') all_time = 0.0 diff --git a/gptcache/__init__.py b/gptcache/__init__.py index b85e637e..98be31fc 100644 --- a/gptcache/__init__.py +++ b/gptcache/__init__.py @@ -1,6 +1,8 @@ import atexit import os import time +from typing import List, Any, Optional + import openai from gptcache.embedding.string import to_embeddings as string_embedding @@ -55,7 +57,9 @@ def __init__( similarity_threshold=0.8, ): if similarity_threshold < 0 or similarity_threshold > 1: - raise CacheError("Invalid the similarity threshold param, reasonable range: 0-1") + raise CacheError( + "Invalid the similarity threshold param, reasonable range: 0-1" + ) self.log_time_func = log_time_func self.similarity_threshold = similarity_threshold @@ -134,7 +138,7 @@ def __init__(self): self.cache_enable_func = None self.pre_embedding_func = None self.embedding_func = None - self.data_manager = None + self.data_manager: Optional[DataManager] = None self.post_process_messages_func = None self.config = Config() self.report = Report() @@ -179,6 +183,19 @@ def close(): except Exception as e: # pylint: disable=W0703 print(e) + def import_data(self, questions: List[Any], answers: List[Any]) -> None: + """ Import data to GPTCache + + :param questions: preprocessed question Data + :param answers: list of answers to questions + :return: None + """ + self.data_manager.import_data( + questions=questions, + answers=answers, + embedding_datas=[self.embedding_func(question) for question in questions], + ) + @staticmethod def set_openai_key(): openai.api_key = os.getenv("OPENAI_API_KEY") diff --git a/gptcache/manager/data_manager.py b/gptcache/manager/data_manager.py index c3768718..e6beee95 100644 --- a/gptcache/manager/data_manager.py +++ b/gptcache/manager/data_manager.py @@ -1,11 +1,13 @@ from abc import abstractmethod, ABCMeta import pickle +from typing import List, Any + import cachetools import numpy as np -from gptcache.utils.error import CacheError -from gptcache.manager.scalar_data.base import CacheStorage -from gptcache.manager.vector_data.base import VectorBase, ClearStrategy +from gptcache.utils.error import CacheError, NotFoundStrategyError, ParamError +from gptcache.manager.scalar_data.base import CacheStorage, CacheData +from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData from gptcache.manager.eviction import EvictionManager @@ -16,6 +18,12 @@ class DataManager(metaclass=ABCMeta): def save(self, question, answer, embedding_data, **kwargs): pass + @abstractmethod + def import_data( + self, questions: List[Any], answers: List[Any], embedding_datas: List[Any] + ): + pass + # should return the tuple, (question, answer) @abstractmethod def get_scalar_data(self, res_data, **kwargs): @@ -49,12 +57,20 @@ def init(self): return except PermissionError: raise CacheError( # pylint: disable=W0707 - f"You don't have permission to access this file <${self.data_path}>." + f"You don't have permission to access this file <{self.data_path}>." ) def save(self, question, answer, embedding_data, **kwargs): self.data[embedding_data] = (question, answer) + def import_data( + self, questions: List[Any], answers: List[Any], embedding_datas: List[Any] + ): + if len(questions) != len(answers) or len(questions) != len(embedding_datas): + raise ParamError("Make sure that all parameters have the same length") + for i, embedding_data in enumerate(embedding_datas): + self.data[embedding_data] = (questions[i], answers[i]) + def get_scalar_data(self, res_data, **kwargs): return res_data @@ -139,13 +155,39 @@ def save(self, question, answer, embedding_data, **kwargs): if self.cur_size >= self.max_size: self._clear() - embedding_data = normalize(embedding_data) - if self.v.clear_strategy() == ClearStrategy.DELETE: - key = self.s.insert(question, answer) - elif self.v.clear_strategy() == ClearStrategy.REBUILD: - key = self.s.insert(question, answer, embedding_data.astype("float32")) - self.v.add(key, embedding_data) - self.cur_size += 1 + + self.import_data([question], [answer], [embedding_data]) + + def import_data( + self, questions: List[Any], answers: List[Any], embedding_datas: List[Any] + ): + if len(questions) != len(answers) or len(questions) != len(embedding_datas): + raise ParamError("Make sure that all parameters have the same length") + cache_datas = [] + embedding_datas = [ + normalize(embedding_data) for embedding_data in embedding_datas + ] + for i, embedding_data in enumerate(embedding_datas): + if self.v.clear_strategy() == ClearStrategy.DELETE: + cache_datas.append(CacheData(question=questions[i], answer=answers[i])) + elif self.v.clear_strategy() == ClearStrategy.REBUILD: + cache_datas.append( + CacheData( + question=questions[i], + answer=answers[i], + embedding_data=embedding_data.astype("float32"), + ) + ) + else: + raise NotFoundStrategyError(self.v.clear_strategy()) + ids = self.s.batch_insert(cache_datas) + self.v.mul_add( + [ + VectorData(id=ids[i], data=embedding_data) + for i, embedding_data in enumerate(embedding_datas) + ] + ) + self.cur_size += len(questions) def get_scalar_data(self, res_data, **kwargs): return self.s.get_data_by_id(res_data[1]) diff --git a/gptcache/manager/scalar_data/base.py b/gptcache/manager/scalar_data/base.py index a2c36cf7..7b462cb8 100644 --- a/gptcache/manager/scalar_data/base.py +++ b/gptcache/manager/scalar_data/base.py @@ -1,8 +1,17 @@ from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Optional, Any, List import numpy as np +@dataclass +class CacheData: + question: Any + answer: Any + embedding_data: Optional[np.ndarray] = None + + class CacheStorage(metaclass=ABCMeta): """ BaseStorage for scalar data. @@ -13,7 +22,7 @@ def create(self): pass @abstractmethod - def insert(self, data, reply, embedding_data: np.ndarray = None): + def batch_insert(self, datas: List[CacheData]): pass @abstractmethod diff --git a/gptcache/manager/scalar_data/sqlalchemy.py b/gptcache/manager/scalar_data/sqlalchemy.py index 2ebdfdb0..275f313b 100644 --- a/gptcache/manager/scalar_data/sqlalchemy.py +++ b/gptcache/manager/scalar_data/sqlalchemy.py @@ -1,8 +1,9 @@ -import numpy as np +from typing import List + from datetime import datetime from gptcache.utils import import_sqlalchemy -from gptcache.manager.scalar_data.base import CacheStorage +from gptcache.manager.scalar_data.base import CacheStorage, CacheData import_sqlalchemy() @@ -29,8 +30,8 @@ class CacheTable(Base): __table_args__ = {"extend_existing": True} id = Column(Integer, primary_key=True, autoincrement=True) - data = Column(String(1000), nullable=False) - reply = Column(String(1000), nullable=False) + question = Column(String(1000), nullable=False) + answer = Column(String(1000), nullable=False) create_on = Column(DateTime, default=datetime.now) last_access = Column(DateTime, default=datetime.now) embedding_data = Column(LargeBinary, nullable=True) @@ -48,8 +49,8 @@ class CacheTableSequence(Base): id = Column( Integer, Sequence("id_seq", start=1), primary_key=True, autoincrement=True ) - data = Column(String(1000), nullable=False) - reply = Column(String(1000), nullable=False) + question = Column(String(1000), nullable=False) + answer = Column(String(1000), nullable=False) create_on = Column(DateTime, default=datetime.now) last_access = Column(DateTime, default=datetime.now) embedding_data = Column(LargeBinary, nullable=True) @@ -68,10 +69,10 @@ class SQLDataBase(CacheStorage): """ def __init__( - self, - db_type: str = "sqlite", - url: str = "sqlite:///./sqlite.db", - table_name: str = "gptcache", + self, + db_type: str = "sqlite", + url: str = "sqlite:///./sqlite.db", + table_name: str = "gptcache", ): self._url = url self._model = get_model(table_name, db_type) @@ -83,19 +84,25 @@ def __init__( def create(self): self._model.__table__.create(bind=self._engine, checkfirst=True) - def insert(self, data, reply, embedding_data: np.ndarray = None): - if embedding_data is None: - model_obj = self._model(data=data, reply=reply) - else: - embedding_data = embedding_data.tobytes() - model_obj = self._model(data=data, reply=reply, embedding_data=embedding_data) - self._session.add(model_obj) + def batch_insert(self, datas: List[CacheData]): + model_objs = [] + for data in datas: + model_obj = self._model( + question=data.question, + answer=data.answer, + embedding_data=data.embedding_data.tobytes() + if data.embedding_data is not None + else None, + ) + model_objs.append(model_obj) + + self._session.add_all(model_objs) self._session.commit() - return model_obj.id + return [model_obj.id for model_obj in model_objs] def get_data_by_id(self, key): res = ( - self._session.query(self._model.data, self._model.reply) + self._session.query(self._model.question, self._model.answer) .filter(self._model.id == key) .filter(self._model.state == 0) .first() diff --git a/gptcache/manager/vector_data/base.py b/gptcache/manager/vector_data/base.py index ded6ca3d..a61d266c 100644 --- a/gptcache/manager/vector_data/base.py +++ b/gptcache/manager/vector_data/base.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum +from typing import List + +import numpy as np class ClearStrategy(Enum): @@ -7,22 +11,28 @@ class ClearStrategy(Enum): DELETE = 1 +@dataclass +class VectorData: + id: int + data: np.ndarray + + class VectorBase(ABC): """VectorBase: base vector store interface""" @abstractmethod - def add(self, key: str, data: "ndarray"): + def mul_add(self, datas: List[VectorData]): pass @abstractmethod - def search(self, data: "ndarray"): + def search(self, data: np.ndarray): pass @abstractmethod def clear_strategy(self): pass - def rebuild(self) -> bool: + def rebuild(self, all_data, keys) -> bool: raise NotImplementedError def delete(self, ids) -> bool: diff --git a/gptcache/manager/vector_data/chroma.py b/gptcache/manager/vector_data/chroma.py index e719643c..f581ca0d 100644 --- a/gptcache/manager/vector_data/chroma.py +++ b/gptcache/manager/vector_data/chroma.py @@ -1,4 +1,6 @@ -from gptcache.manager.vector_data.base import VectorBase, ClearStrategy +from typing import List + +from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData from gptcache.utils import import_chromadb import_chromadb() @@ -30,8 +32,9 @@ def __init__( self._persist_directory = persist_directory self._collection = self._client.get_or_create_collection(name=collection_name) - def add(self, key, data): - self._collection.add(embeddings=[data], ids=[key]) + def mul_add(self, datas: List[VectorData]): + data_array, id_array = map(list, zip(*((data.data, str(data.id)) for data in datas))) + self._collection.add(embeddings=data_array, ids=id_array) def search(self, data): if self._collection.count() == 0: diff --git a/gptcache/manager/vector_data/faiss.py b/gptcache/manager/vector_data/faiss.py index 7fd82f8b..d55b300a 100644 --- a/gptcache/manager/vector_data/faiss.py +++ b/gptcache/manager/vector_data/faiss.py @@ -1,7 +1,9 @@ import os +from typing import List + import numpy as np -from gptcache.manager.vector_data.base import VectorBase, ClearStrategy +from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData from gptcache.utils import import_faiss import_faiss() @@ -23,17 +25,13 @@ def __init__(self, index_file_path, dimension, top_k, skip_file=False): if os.path.isfile(index_file_path) and not skip_file: self.index = faiss.read_index(index_file_path) - def add(self, key: int, data: "ndarray"): - np_data = np.array(data).astype("float32").reshape(1, -1) - ids = np.array([key]) + def mul_add(self, datas: List[VectorData]): + data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) + np_data = np.array(data_array).astype("float32") + ids = np.array(id_array) self.index.add_with_ids(np_data, ids) - def _mult_add(self, datas, keys): - np_data = np.array(datas).astype("float32") - ids = np.array(keys).astype(np.int64) - self.index.add_with_ids(np_data, ids) - - def search(self, data: "ndarray"): + def search(self, data: np.ndarray): if self.index.ntotal == 0: return None np_data = np.array(data).astype("float32").reshape(1, -1) @@ -45,11 +43,11 @@ def clear_strategy(self): return ClearStrategy.REBUILD def rebuild(self, all_data, keys): - f = Faiss( - self.index_file_path, self.dimension, top_k=self.top_k, skip_file=True - ) - f._mult_add(all_data, keys) # pylint: disable=protected-access - return f + self.index = faiss.index_factory(self.dimension, "IDMap,Flat", faiss.METRIC_L2) + datas = [] + for i, key in enumerate(keys): + datas.append(VectorData(id=key, data=all_data[i])) + self.mul_add(datas) def close(self): faiss.write_index(self.index, self.index_file_path) diff --git a/gptcache/manager/vector_data/hnswlib_store.py b/gptcache/manager/vector_data/hnswlib_store.py index 7b18ae59..e6172197 100644 --- a/gptcache/manager/vector_data/hnswlib_store.py +++ b/gptcache/manager/vector_data/hnswlib_store.py @@ -1,7 +1,9 @@ import os +from typing import List + import numpy as np -from gptcache.manager.vector_data.base import VectorBase, ClearStrategy +from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData from gptcache.utils import import_hnswlib import_hnswlib() @@ -12,7 +14,7 @@ class Hnswlib(VectorBase): """vector store: hnswlib""" - def __init__(self, index_file_path: int, top_k: str, dimension: int, max_elements: int): + def __init__(self, index_file_path: str, top_k: int, dimension: int, max_elements: int): self._index_file_path = index_file_path self._dimension = dimension self._max_elements = max_elements @@ -24,15 +26,17 @@ def __init__(self, index_file_path: int, top_k: str, dimension: int, max_element self._index.init_index(max_elements=max_elements, ef_construction=100, M=16) self._index.set_ef(self._top_k * 2) - def add(self, key: int, data: "ndarray"): + def add(self, key: int, data: np.ndarray): np_data = np.array(data).astype("float32").reshape(1, -1) - self._index.add_items(np_data, np.asarray([key])) + self._index.add_items(np_data, np.array([key])) - def _mult_add(self, data, keys): - np_data = np.array(data).astype("float32") - self._index.add_items(np_data, np.asarray(keys)) + def mul_add(self, datas: List[VectorData]): + data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) + np_data = np.array(data_array).astype("float32") + ids = np.array(id_array) + self._index.add_items(np_data, ids) - def search(self, data: "ndarray"): + def search(self, data: np.ndarray): np_data = np.array(data).astype("float32").reshape(1, -1) ids, dist = self._index.knn_query(data=np_data, k=self._top_k) return list(zip(dist[0], ids[0])) @@ -45,7 +49,10 @@ def rebuild(self, all_data, keys): new_index.init_index(max_elements=self._max_elements, ef_construction=100, M=16) new_index.set_ef(self._top_k * 2) self._index = new_index - self._mult_add(all_data, keys) + datas = [] + for i, key in enumerate(keys): + datas.append(VectorData(id=key, data=all_data[i])) + self.mul_add(datas) def close(self): self._index.save_index(self._index_file_path) diff --git a/gptcache/manager/vector_data/milvus.py b/gptcache/manager/vector_data/milvus.py index 1584c22a..ece9a191 100644 --- a/gptcache/manager/vector_data/milvus.py +++ b/gptcache/manager/vector_data/milvus.py @@ -1,8 +1,9 @@ +from typing import List from uuid import uuid4 import numpy as np from gptcache.utils import import_pymilvus -from gptcache.manager.vector_data.base import VectorBase, ClearStrategy +from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData import_pymilvus() @@ -117,8 +118,10 @@ def _create_collection(self, collection_name): self.col.load() - def add(self, key: str, data: np.ndarray): - entities = [[key], data.reshape(1, self.dimension)] + def mul_add(self, datas: List[VectorData]): + data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) + np_data = np.array(data_array).astype("float32") + entities = [id_array, np_data] self.col.insert(entities) def search(self, data: np.ndarray): diff --git a/gptcache/utils/error.py b/gptcache/utils/error.py index eddf276d..603338b7 100644 --- a/gptcache/utils/error.py +++ b/gptcache/utils/error.py @@ -5,7 +5,7 @@ class CacheError(Exception): class NotInitError(CacheError): """Raise when the cache has been used before it's inited""" def __init__(self): - super().__init__("the cache should be inited before using") + super().__init__("The cache should be inited before using") class NotFoundStoreError(CacheError): @@ -22,3 +22,8 @@ class PipInstallError(CacheError): """Raise when failed to install package.""" def __init__(self, package): super().__init__(f"Ran into error installing {package}.") + + +class NotFoundStrategyError(CacheError): + def __init__(self, strategy): + super().__init__(f"Unsupported vector store strategy, {strategy}.") diff --git a/tests/integration_tests/examples/map/test_example_map.py b/tests/integration_tests/examples/map/test_example_map.py index a479ab57..1105db09 100644 --- a/tests/integration_tests/examples/map/test_example_map.py +++ b/tests/integration_tests/examples/map/test_example_map.py @@ -23,24 +23,17 @@ def test_map(): ] if not os.path.isfile(bak_data_file): - for i in range(10): - question = f"foo{i}" - answer = f"receiver the foo {i}" - cache.data_manager.save(question, answer, cache.embedding_func(question)) + cache.import_data( + [f"foo{i}" for i in range(10)], [f"receiver the foo {i}" for i in range(10)] + ) if not os.path.isfile(data_file): - for i in range(10, 20): - question = f"foo{i}" - answer = f"receiver the foo {i}" - bak_cache.data_manager.save( - question, answer, bak_cache.embedding_func(question) - ) + bak_cache.import_data( + [f"foo{i}" for i in range(10, 20)], + [f"receiver the foo {i}" for i in range(10, 20)], + ) answer = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=mock_messages, ) print(answer) - - -# if __name__ == "__main__": -# run() diff --git a/tests/integration_tests/examples/sqlite_faiss_mock/test_example_sqlite_faiss.py b/tests/integration_tests/examples/sqlite_faiss_mock/test_example_sqlite_faiss.py index f549b559..10244c66 100644 --- a/tests/integration_tests/examples/sqlite_faiss_mock/test_example_sqlite_faiss.py +++ b/tests/integration_tests/examples/sqlite_faiss_mock/test_example_sqlite_faiss.py @@ -34,10 +34,9 @@ def test_sqlite_faiss(): {"role": "user", "content": "foo"}, ] if not has_data: - for i in range(10): - question = f"foo{i}" - answer = f"receiver the foo {i}" - cache.data_manager.save(question, answer, cache.embedding_func(question)) + cache.import_data( + [f"foo{i}" for i in range(10)], [f"receiver the foo {i}" for i in range(10)] + ) answer = openai.ChatCompletion.create( model="gpt-3.5-turbo", diff --git a/tests/integration_tests/examples/sqlite_faiss_onnx/test_example_sqlite_faiss_onnx.py b/tests/integration_tests/examples/sqlite_faiss_onnx/test_example_sqlite_faiss_onnx.py index 4e506c1e..14e37c7a 100644 --- a/tests/integration_tests/examples/sqlite_faiss_onnx/test_example_sqlite_faiss_onnx.py +++ b/tests/integration_tests/examples/sqlite_faiss_onnx/test_example_sqlite_faiss_onnx.py @@ -25,20 +25,17 @@ def log_time_func(func_name, delta_time): embedding_func=onnx.to_embeddings, data_manager=data_manager, similarity_evaluation=SearchDistanceEvaluation(), - config=Config( - log_time_func=log_time_func, - similarity_threshold=0.9 - ), + config=Config(log_time_func=log_time_func, similarity_threshold=0.9), ) if not has_data: question = "what do you think about chatgpt" answer = "chatgpt is a good application" - cache.data_manager.save(question, answer, cache.embedding_func(question)) + cache.import_data([question], [answer]) mock_messages = [ - {'role': 'system', 'content': 'You are a helpful assistant.'}, - {'role': 'user', 'content': 'what do you think chatgpt'} + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "what do you think chatgpt"}, ] start_time = time.time() diff --git a/tests/unit_tests/manager/test_chromadb.py b/tests/unit_tests/manager/test_chromadb.py index 3f2f9cfe..eca74ba8 100644 --- a/tests/unit_tests/manager/test_chromadb.py +++ b/tests/unit_tests/manager/test_chromadb.py @@ -1,13 +1,14 @@ import unittest import numpy as np + +from gptcache.manager.vector_data.base import VectorData from gptcache.manager.vector_data.chroma import Chromadb class TestChromadb(unittest.TestCase): def test_normal(self): db = Chromadb(**{"client_settings": {}, "top_k": 3}) - for i in range(100): - db.add(str(i), np.random.sample(10)) + db.mul_add([VectorData(id=i, data=np.random.sample(10)) for i in range(100)]) self.assertEqual(len(db.search(np.random.sample(10))), 3) db.delete(["1", "3", "5", "7"]) self.assertEqual(db._collection.count(), 96) diff --git a/tests/unit_tests/manager/test_hnswlib_store.py b/tests/unit_tests/manager/test_hnswlib_store.py index 2ecfb97d..7dfa5aa4 100644 --- a/tests/unit_tests/manager/test_hnswlib_store.py +++ b/tests/unit_tests/manager/test_hnswlib_store.py @@ -4,7 +4,7 @@ from pathlib import Path from gptcache.manager.vector_data.hnswlib_store import Hnswlib -from gptcache.manager.vector_data.base import ClearStrategy +from gptcache.manager.vector_data.base import ClearStrategy, VectorData from gptcache.manager.vector_data import VectorBase @@ -17,7 +17,7 @@ def test_normal(self): top_k = 10 index = Hnswlib(index_path, top_k, dim, size + 10) data = np.random.randn(size, dim).astype(np.float32) - index._mult_add(data, list(range(size))) + index.mul_add([VectorData(id=i, data=data[i]) for i in range(size)]) self.assertEqual(len(index.search(data[0])), top_k) index.add(size, data[0]) ret = index.search(data[0]) @@ -32,7 +32,7 @@ def test_with_rebuild(self): top_k = 10 index = Hnswlib(index_path, top_k, dim, size + 10) data = np.random.randn(size, dim).astype(np.float32) - index._mult_add(data, list(range(1, data.shape[0] + 1))) + index.mul_add([VectorData(id=i, data=data[i - 1]) for i in range(1, data.shape[0] + 1)]) self.assertEqual(index.clear_strategy(), ClearStrategy.REBUILD) index.rebuild(data[1:], list(range(size - 1))) @@ -46,7 +46,7 @@ def test_reload(self): top_k = 10 index = Hnswlib(index_path, top_k, dim, size + 10) data = np.random.randn(size, dim).astype(np.float32) - index._mult_add(data, list(range(size))) + index.mul_add([VectorData(id=i, data=data[i]) for i in range(size)]) index.close() new_index = Hnswlib(index_path, top_k, dim, size + 10) @@ -62,6 +62,5 @@ def test_create_from_vector_base(self): index = VectorBase('hnswlib', top_k=3, dimension=512, max_elements=5000, index_path=index_path) data = np.random.randn(100, 512).astype(np.float32) - for i in range(100): - index.add(i, data[i]) + index.mul_add([VectorData(id=i, data=data[i]) for i in range(100)]) self.assertEqual(index.search(data[0])[0][1], 0)