diff --git a/include/vsag/engine.h b/include/vsag/engine.h index 4252a8da..ca5b5810 100644 --- a/include/vsag/engine.h +++ b/include/vsag/engine.h @@ -17,16 +17,19 @@ #include +#include "dataset.h" #include "index.h" #include "resource.h" namespace vsag { class Engine { public: - explicit Engine(Resource* resource = nullptr); + explicit Engine(); + + explicit Engine(Resource* resource); void - Shutdown(); + Shutdown(); // like ~Engine(), but will warn whether the resources are still reference outside tl::expected, Error> CreateIndex(const std::string& name, const std::string& parameters); diff --git a/include/vsag/vsag.h b/include/vsag/vsag.h index 48c04541..a9ec91ba 100644 --- a/include/vsag/vsag.h +++ b/include/vsag/vsag.h @@ -41,6 +41,7 @@ init(); #include "bitset.h" #include "constants.h" #include "dataset.h" +#include "engine.h" #include "errors.h" #include "expected.hpp" #include "factory.h" diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index e93d3d06..6c728dc9 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -53,14 +53,14 @@ next_multiple_of_power_of_two(uint64_t x, uint64_t n) { HGraph::HGraph(const JsonType& index_param, const vsag::IndexCommonParam& common_param) noexcept : index_param_(index_param), common_param_(common_param), - label_lookup_(common_param.allocator_), - label_op_mutex_(MAX_LABEL_OPERATION_LOCKS, common_param_.allocator_), - neighbors_mutex_(0, common_param_.allocator_), - route_graphs_(common_param.allocator_), - labels_(common_param.allocator_) { + label_lookup_(common_param.allocator_.get()), + label_op_mutex_(MAX_LABEL_OPERATION_LOCKS, common_param.allocator_.get()), + neighbors_mutex_(0, common_param.allocator_.get()), + route_graphs_(common_param.allocator_.get()), + labels_(common_param.allocator_.get()) { this->dim_ = common_param.dim_; this->metric_ = common_param.metric_; - this->allocator_ = common_param.allocator_; + this->allocator_ = common_param.allocator_.get(); } tl::expected diff --git a/src/data_cell/flatten_datacell.h b/src/data_cell/flatten_datacell.h index 2029f6fe..eb9b14a4 100644 --- a/src/data_cell/flatten_datacell.h +++ b/src/data_cell/flatten_datacell.h @@ -130,7 +130,7 @@ template FlattenDataCell::FlattenDataCell(const JsonType& quantization_param, const JsonType& io_param, const IndexCommonParam& common_param) - : allocator_(common_param.allocator_) { + : allocator_(common_param.allocator_.get()) { this->quantizer_ = std::make_shared(quantization_param, common_param); this->io_ = std::make_shared(io_param, common_param); this->code_size_ = quantizer_->GetCodeSize(); diff --git a/src/data_cell/flatten_datacell_test.cpp b/src/data_cell/flatten_datacell_test.cpp index 1a38c6d1..9fab0eb1 100644 --- a/src/data_cell/flatten_datacell_test.cpp +++ b/src/data_cell/flatten_datacell_test.cpp @@ -16,6 +16,7 @@ #include "flatten_datacell.h" #include +#include #include "catch2/catch_template_test_macros.hpp" #include "default_allocator.h" @@ -23,6 +24,7 @@ #include "flatten_interface_test.h" #include "io/io_headers.h" #include "quantization/quantizer_headers.h" +#include "safe_allocator.h" using namespace vsag; @@ -36,7 +38,7 @@ TestFlattenDataCell(int dim, auto counts = {100, 1000}; IndexCommonParam common; common.dim_ = dim; - common.allocator_ = allocator.get(); + common.allocator_ = std::move(allocator); common.metric_ = metric; for (auto count : counts) { auto flatten = @@ -67,7 +69,7 @@ TestFlattenDataCellFP32(int dim, } TEST_CASE("fp32", "[ut][flatten_data_cell]") { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto fp32_param = JsonType::parse("{}"); auto io_param = JsonType::parse("{}"); auto dims = {8, 64, 512}; @@ -96,7 +98,7 @@ TestFlattenDataCellSQ8(int dim, } TEST_CASE("sq8", "[ut][flatten_data_cell]") { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto sq8_param = JsonType::parse("{}"); auto io_param = JsonType::parse("{}"); auto dims = {32, 64, 512}; diff --git a/src/data_cell/graph_datacell_test.cpp b/src/data_cell/graph_datacell_test.cpp index 67b21dcc..098fb1fd 100644 --- a/src/data_cell/graph_datacell_test.cpp +++ b/src/data_cell/graph_datacell_test.cpp @@ -19,6 +19,7 @@ #include "fmt/format-inl.h" #include "graph_interface_test.h" #include "io/io_headers.h" +#include "safe_allocator.h" using namespace vsag; template @@ -37,7 +38,7 @@ TestGraphDataCell(const JsonType& graph_param, } TEST_CASE("graph basic test", "[ut][graph_datacell]") { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto dims = {32, 64}; auto max_degrees = {5, 12, 24, 32, 64, 128}; auto max_capacities = {1, 100, 10000, 10'000'000, 32'179'837}; @@ -58,7 +59,7 @@ TEST_CASE("graph basic test", "[ut][graph_datacell]") { for (auto dim : dims) { IndexCommonParam param; param.dim_ = dim; - param.allocator_ = allocator.get(); + param.allocator_ = allocator; for (auto& gp : graph_params) { TestGraphDataCell(gp, io_param, param); TestGraphDataCell(gp, io_param, param); diff --git a/src/data_cell/graph_interface_test.cpp b/src/data_cell/graph_interface_test.cpp index f101803a..fe675973 100644 --- a/src/data_cell/graph_interface_test.cpp +++ b/src/data_cell/graph_interface_test.cpp @@ -20,12 +20,13 @@ #include "catch2/catch_test_macros.hpp" #include "default_allocator.h" #include "fixtures.h" +#include "safe_allocator.h" using namespace vsag; void GraphInterfaceTest::BasicTest(uint64_t max_id, uint64_t count, GraphInterfacePtr other) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto max_degree = this->graph_->MaximumDegree(); auto old_count = this->graph_->TotalCount(); UnorderedMap>> maps(allocator.get()); diff --git a/src/data_cell/sparse_graph_datacell.cpp b/src/data_cell/sparse_graph_datacell.cpp index 735f3873..3ae27d32 100644 --- a/src/data_cell/sparse_graph_datacell.cpp +++ b/src/data_cell/sparse_graph_datacell.cpp @@ -19,7 +19,7 @@ namespace vsag { SparseGraphDataCell::SparseGraphDataCell(const JsonType& graph_param, const IndexCommonParam& common_param) - : allocator_(common_param.allocator_), neighbors_(common_param.allocator_) { + : allocator_(common_param.allocator_.get()), neighbors_(common_param.allocator_.get()) { if (graph_param.contains(GRAPH_PARAM_MAX_DEGREE)) { this->maximum_degree_ = graph_param[GRAPH_PARAM_MAX_DEGREE]; } diff --git a/src/data_cell/sparse_graph_datacell_test.cpp b/src/data_cell/sparse_graph_datacell_test.cpp index bc886396..057e28b0 100644 --- a/src/data_cell/sparse_graph_datacell_test.cpp +++ b/src/data_cell/sparse_graph_datacell_test.cpp @@ -19,6 +19,7 @@ #include "default_allocator.h" #include "fmt/format-inl.h" #include "graph_interface_test.h" +#include "safe_allocator.h" using namespace vsag; @@ -35,7 +36,7 @@ TestSparseGraphDataCell(const JsonType& graph_param, const IndexCommonParam& par } TEST_CASE("graph basic test", "[ut][sparse_graph_datacell]") { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto dims = {32, 64}; auto max_degrees = {5, 12, 24, 32, 64, 128}; auto max_capacities = {1, 100, 10000, 10'000'000, 32'179'837}; @@ -55,7 +56,7 @@ TEST_CASE("graph basic test", "[ut][sparse_graph_datacell]") { for (auto dim : dims) { IndexCommonParam param; param.dim_ = dim; - param.allocator_ = allocator.get(); + param.allocator_ = allocator; for (auto& gp : graph_params) { TestSparseGraphDataCell(gp, param); } diff --git a/src/default_allocator.h b/src/default_allocator.h index ee4aee50..bc5177ee 100644 --- a/src/default_allocator.h +++ b/src/default_allocator.h @@ -25,16 +25,6 @@ namespace vsag { class DefaultAllocator : public Allocator { -public: - static std::shared_ptr - Instance() { - static std::shared_ptr s_instance; - if (s_instance == nullptr) { - s_instance = std::make_shared(); - } - return s_instance; - } - public: DefaultAllocator() = default; ~DefaultAllocator() override { diff --git a/src/engine.cpp b/src/engine.cpp index 65ad3063..d53d5404 100644 --- a/src/engine.cpp +++ b/src/engine.cpp @@ -32,6 +32,10 @@ namespace vsag { +Engine::Engine() { + this->resource_ = std::make_shared(new Resource(), /*owned*/ true); +} + Engine::Engine(Resource* resource) { if (resource == nullptr) { this->resource_ = std::make_shared(new Resource(), /*owned*/ true); @@ -42,13 +46,16 @@ Engine::Engine(Resource* resource) { void Engine::Shutdown() { + auto refcount = this->resource_.use_count(); this->resource_.reset(); + + // TODO(LHT): add refcount warning } tl::expected, Error> Engine::CreateIndex(const std::string& origin_name, const std::string& parameters) { try { - auto* allocator = this->resource_->allocator.get(); + auto& allocator = this->resource_->allocator; std::string name = origin_name; transform(name.begin(), name.end(), name.begin(), ::tolower); JsonType parsed_params = JsonType::parse(parameters); @@ -79,9 +86,6 @@ Engine::CreateIndex(const std::string& origin_name, const std::string& parameter logger::debug("created a diskann index"); return std::make_shared(diskann_params, index_common_params); } else if (name == INDEX_HGRAPH) { - if (allocator == nullptr) { - index_common_params.allocator_ = DefaultAllocator::Instance().get(); - } logger::debug("created a hgraph index"); JsonType hgraph_params; if (parsed_params.contains(INDEX_PARAM)) { diff --git a/src/factory/factory.cpp b/src/factory/factory.cpp index a463779b..5547ad62 100644 --- a/src/factory/factory.cpp +++ b/src/factory/factory.cpp @@ -13,25 +13,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "vsag/factory.h" + #include -#include #include #include #include #include #include #include -#include #include -#include "index/diskann.h" -#include "index/diskann_zparameters.h" -#include "index/hgraph_index.h" -#include "index/hgraph_zparameters.h" -#include "index/hnsw.h" -#include "index/hnsw_zparameters.h" -#include "index/index_common_param.h" -#include "vsag/vsag.h" +#include "ThreadPool.h" +#include "vsag/engine.h" +#include "vsag/options.h" namespace vsag { @@ -39,61 +34,9 @@ tl::expected, Error> Factory::CreateIndex(const std::string& origin_name, const std::string& parameters, Allocator* allocator) { - try { - std::string name = origin_name; - transform(name.begin(), name.end(), name.begin(), ::tolower); - JsonType parsed_params = JsonType::parse(parameters); - auto index_common_params = IndexCommonParam::CheckAndCreate(parsed_params, allocator); - if (name == INDEX_HNSW) { - // read parameters from json, throw exception if not exists - CHECK_ARGUMENT(parsed_params.contains(INDEX_HNSW), - fmt::format("parameters must contains {}", INDEX_HNSW)); - auto& hnsw_param_obj = parsed_params[INDEX_HNSW]; - auto hnsw_params = HnswParameters::FromJson(hnsw_param_obj, index_common_params); - logger::debug("created a hnsw index"); - return std::make_shared(hnsw_params, index_common_params); - } else if (name == INDEX_FRESH_HNSW) { - // read parameters from json, throw exception if not exists - CHECK_ARGUMENT(parsed_params.contains(INDEX_HNSW), - fmt::format("parameters must contains {}", INDEX_HNSW)); - auto& hnsw_param_obj = parsed_params[INDEX_HNSW]; - auto hnsw_params = FreshHnswParameters::FromJson(hnsw_param_obj, index_common_params); - logger::debug("created a fresh-hnsw index"); - return std::make_shared(hnsw_params, index_common_params); - } else if (name == INDEX_DISKANN) { - // read parameters from json, throw exception if not exists - CHECK_ARGUMENT(parsed_params.contains(INDEX_DISKANN), - fmt::format("parameters must contains {}", INDEX_DISKANN)); - auto& diskann_param_obj = parsed_params[INDEX_DISKANN]; - auto diskann_params = - DiskannParameters::FromJson(diskann_param_obj, index_common_params); - logger::debug("created a diskann index"); - return std::make_shared(diskann_params, index_common_params); - } else if (name == INDEX_HGRAPH) { - if (allocator == nullptr) { - index_common_params.allocator_ = DefaultAllocator::Instance().get(); - } - logger::debug("created a hgraph index"); - JsonType hgraph_params; - if (parsed_params.contains(INDEX_PARAM)) { - hgraph_params = std::move(parsed_params[INDEX_PARAM]); - } - HGraphParameters hgraph_param(hgraph_params, index_common_params); - auto hgraph_index = - std::make_shared(hgraph_param.GetJson(), index_common_params); - hgraph_index->Init(); - return hgraph_index; - } else { - LOG_ERROR_AND_RETURNS( - ErrorType::UNSUPPORTED_INDEX, "failed to create index(unsupported): ", name); - } - } catch (const std::invalid_argument& e) { - LOG_ERROR_AND_RETURNS( - ErrorType::INVALID_ARGUMENT, "failed to create index(invalid argument): ", e.what()); - } catch (const std::exception& e) { - LOG_ERROR_AND_RETURNS( - ErrorType::UNSUPPORTED_INDEX, "failed to create index(unknown error): ", e.what()); - } + Resource resource(allocator); + Engine e(&resource); + return e.CreateIndex(origin_name, parameters); } class LocalFileReader : public Reader { diff --git a/src/index/hgraph_index.cpp b/src/index/hgraph_index.cpp index 2de9eac4..24d32737 100644 --- a/src/index/hgraph_index.cpp +++ b/src/index/hgraph_index.cpp @@ -18,5 +18,11 @@ namespace vsag { HGraphIndex::HGraphIndex(const vsag::JsonType& index_param, const vsag::IndexCommonParam& common_param) noexcept { this->hgraph_ = std::make_unique(index_param, common_param); + this->allocator_ = common_param.allocator_; +} + +HGraphIndex::~HGraphIndex() { + this->hgraph_.reset(); + this->allocator_.reset(); } } // namespace vsag diff --git a/src/index/hgraph_index.h b/src/index/hgraph_index.h index 9de9cd39..210bb1d1 100644 --- a/src/index/hgraph_index.h +++ b/src/index/hgraph_index.h @@ -25,6 +25,8 @@ class HGraphIndex : public Index { public: HGraphIndex(const JsonType& index_param, const IndexCommonParam& common_param) noexcept; + ~HGraphIndex() override; + tl::expected Init() { SAFE_CALL(return this->hgraph_->Init()); @@ -141,5 +143,7 @@ class HGraphIndex : public Index { private: std::unique_ptr hgraph_{nullptr}; + + std::shared_ptr allocator_{nullptr}; }; } // namespace vsag diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 9fcd1110..16e86ba1 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -61,11 +61,7 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para conjugate_graph_ = std::make_shared(); } - if (not index_common_param.allocator_) { - allocator_ = std::make_shared(DefaultAllocator::Instance()); - } else { - allocator_ = std::make_shared(index_common_param.allocator_); - } + allocator_ = index_common_param.allocator_; if (!use_static_) { alg_hnsw_ = @@ -276,7 +272,7 @@ HNSW::knn_search(const DatasetPtr& query, results.pop(); } - result->Dim(results.size())->NumElements(1)->Owner(true, allocator_->GetRawAllocator()); + result->Dim(results.size())->NumElements(1)->Owner(true, allocator_.get()); int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * results.size()); result->Ids(ids); @@ -387,7 +383,7 @@ HNSW::range_search(const DatasetPtr& query, if (limited_size >= 1) { target_size = std::min((size_t)limited_size, target_size); } - result->Dim(target_size)->NumElements(1)->Owner(true, allocator_->GetRawAllocator()); + result->Dim(target_size)->NumElements(1)->Owner(true, allocator_.get()); int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * target_size); result->Ids(ids); float* dists = (float*)allocator_->Allocate(sizeof(float) * target_size); @@ -714,7 +710,7 @@ HNSW::brute_force(const DatasetPtr& query, int64_t k) { fmt::format("query.dim({}) must be equal to index.dim({})", query->GetDim(), dim_)); auto result = Dataset::Make(); - result->NumElements(k)->Owner(true, allocator_->GetRawAllocator()); + result->NumElements(k)->Owner(true, allocator_.get()); int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * k); result->Ids(ids); float* dists = (float*)allocator_->Allocate(sizeof(float) * k); diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 89d3252c..acd67864 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -276,7 +276,7 @@ class HNSW : public Index { bool is_init_memory_ = false; DataTypes type_; - std::shared_ptr allocator_; + std::shared_ptr allocator_; mutable std::mutex stats_mutex_; mutable std::map result_queues_; diff --git a/src/index/hnsw_test.cpp b/src/index/hnsw_test.cpp index b0dd5338..59307577 100644 --- a/src/index/hnsw_test.cpp +++ b/src/index/hnsw_test.cpp @@ -27,8 +27,10 @@ #include "vsag/errors.h" #include "vsag/options.h" -vsag::HnswParameters -parse_hnsw_params(vsag::IndexCommonParam index_common_param) { +using namespace vsag; + +HnswParameters +parse_hnsw_params(IndexCommonParam index_common_param) { auto build_parameter_json = R"( { "max_degree": 12, @@ -36,27 +38,30 @@ parse_hnsw_params(vsag::IndexCommonParam index_common_param) { } )"; nlohmann::json parsed_params = nlohmann::json::parse(build_parameter_json); - return vsag::HnswParameters::FromJson(parsed_params, index_common_param); + return HnswParameters::FromJson(parsed_params, index_common_param); } TEST_CASE("build & add", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = allocator; - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); std::vector ids(1); int64_t incorrect_dim = 63; std::vector vectors(incorrect_dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(incorrect_dim) ->NumElements(1) ->Ids(ids.data()) @@ -66,36 +71,37 @@ TEST_CASE("build & add", "[ut][hnsw]") { SECTION("build with incorrect dim") { auto result = index->Build(dataset); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("add with incorrect dim") { auto result = index->Add(dataset); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } } TEST_CASE("build with allocator", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; - vsag::DefaultAllocator allocator; + IndexCommonParam commom_param; + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; - commom_param.allocator_ = &allocator; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = allocator; - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(dim) ->NumElements(1) ->Ids(ids.data()) @@ -106,113 +112,115 @@ TEST_CASE("build with allocator", "[ut][hnsw]") { } TEST_CASE("knn_search", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(dim) ->NumElements(1) ->Ids(ids.data()) ->Float32Vectors(vectors.data()) ->Owner(false); - auto result = index->Build(dataset); - REQUIRE(result.has_value()); + auto build_result = index->Build(dataset); + REQUIRE(build_result.has_value()); - auto query = vsag::Dataset::Make(); + auto query = Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); int64_t k = 10; - vsag::JsonType params{ + JsonType params{ {"hnsw", {{"ef_search", 100}}}, }; SECTION("invalid parameters k is 0") { auto result = index->KnnSearch(query, 0, params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid parameters k less than 0") { auto result = index->KnnSearch(query, -1, params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid parameters hnsw not found") { - vsag::JsonType invalid_params{}; + JsonType invalid_params{}; auto result = index->KnnSearch(query, k, invalid_params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid parameters ef_search not found") { - vsag::JsonType invalid_params{ + JsonType invalid_params{ {"hnsw", {}}, }; auto result = index->KnnSearch(query, k, invalid_params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("query length is not 1") { - auto query = vsag::Dataset::Make(); - query->NumElements(2)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); - auto result = index->KnnSearch(query, k, params.dump()); + auto query2 = Dataset::Make(); + query2->NumElements(2)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); + auto result = index->KnnSearch(query2, k, params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("dimension not equal") { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim - 1)->Float32Vectors(vectors.data())->Owner(false); - auto result = index->KnnSearch(query, k, params.dump()); + auto query2 = Dataset::Make(); + query2->NumElements(1)->Dim(dim - 1)->Float32Vectors(vectors.data())->Owner(false); + auto result = index->KnnSearch(query2, k, params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } } TEST_CASE("range_search", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(dim) ->NumElements(num_elements) ->Ids(ids.data()) ->Float32Vectors(vectors.data()) ->Owner(false); - auto result = index->Build(dataset); - REQUIRE(result.has_value()); + auto build_result = index->Build(dataset); + REQUIRE(build_result.has_value()); - auto query = vsag::Dataset::Make(); + auto query = Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); float radius = 9.9f; - vsag::JsonType params{ + JsonType params{ {"hnsw", {{"ef_search", 100}}}, }; @@ -241,67 +249,64 @@ TEST_CASE("range_search", "[ut][hnsw]") { int64_t range_search_limit = 0; auto result = index->RangeSearch(query, 1000, params.dump(), range_search_limit); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid parameter radius equals to 0") { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); auto result = index->RangeSearch(query, 0, params.dump()); REQUIRE(result.has_value()); } SECTION("invalid parameter radius less than 0") { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); auto result = index->RangeSearch(query, -1, params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid parameters hnsw not found") { - vsag::JsonType invalid_params{}; + JsonType invalid_params{}; auto result = index->RangeSearch(query, radius, invalid_params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid parameters ef_search not found") { - vsag::JsonType invalid_params{ + JsonType invalid_params{ {"hnsw", {}}, }; auto result = index->RangeSearch(query, radius, invalid_params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("query length is not 1") { - auto query = vsag::Dataset::Make(); - query->NumElements(2)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); - auto result = index->RangeSearch(query, radius, params.dump()); + auto query2 = Dataset::Make(); + query2->NumElements(2)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); + auto result = index->RangeSearch(query2, radius, params.dump()); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(result.error().type == ErrorType::INVALID_ARGUMENT); } } TEST_CASE("serialize empty index", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); SECTION("serialize to binaryset") { auto result = index->Serialize(); REQUIRE(result.has_value()); - REQUIRE(result.value().Contains(vsag::BLANK_INDEX)); + REQUIRE(result.value().Contains(BLANK_INDEX)); } SECTION("serialize to fstream") { @@ -309,29 +314,30 @@ TEST_CASE("serialize empty index", "[ut][hnsw]") { std::fstream out_stream(dir.path + "empty_index.bin", std::ios::out | std::ios::binary); auto result = index->Serialize(out_stream); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::INDEX_EMPTY); + REQUIRE(result.error().type == ErrorType::INDEX_EMPTY); } } TEST_CASE("deserialize on not empty index", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; hnsw_obj.use_conjugate_graph = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(dim) ->NumElements(1) ->Ids(ids.data()) @@ -346,8 +352,8 @@ TEST_CASE("deserialize on not empty index", "[ut][hnsw]") { auto voidresult = index->Deserialize(binary_set.value()); REQUIRE_FALSE(voidresult.has_value()); - REQUIRE(voidresult.error().type == vsag::ErrorType::INDEX_NOT_EMPTY); - auto another_index = std::make_shared(hnsw_obj, commom_param); + REQUIRE(voidresult.error().type == ErrorType::INDEX_NOT_EMPTY); + auto another_index = std::make_shared(hnsw_obj, commom_param); auto deserialize_result = another_index->Deserialize(binary_set.value()); REQUIRE(deserialize_result.has_value()); } @@ -362,30 +368,31 @@ TEST_CASE("deserialize on not empty index", "[ut][hnsw]") { std::fstream in_stream(dir.path + "index.bin", std::ios::in | std::ios::binary); auto voidresult = index->Deserialize(in_stream); REQUIRE_FALSE(voidresult.has_value()); - REQUIRE(voidresult.error().type == vsag::ErrorType::INDEX_NOT_EMPTY); + REQUIRE(voidresult.error().type == ErrorType::INDEX_NOT_EMPTY); in_stream.close(); } } TEST_CASE("static hnsw", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; hnsw_obj.use_static = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(dim) ->NumElements(9) ->Ids(ids.data()) @@ -394,7 +401,7 @@ TEST_CASE("static hnsw", "[ut][hnsw]") { auto result = index->Build(dataset); REQUIRE(result.has_value()); - auto one_vector = vsag::Dataset::Make(); + auto one_vector = Dataset::Make(); one_vector->Dim(dim) ->NumElements(1) ->Ids(ids.data() + 9) @@ -402,9 +409,9 @@ TEST_CASE("static hnsw", "[ut][hnsw]") { ->Owner(false); result = index->Add(one_vector); REQUIRE_FALSE(result.has_value()); - REQUIRE(result.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION); + REQUIRE(result.error().type == ErrorType::UNSUPPORTED_INDEX_OPERATION); - vsag::JsonType params{ + JsonType params{ {"hnsw", {{"ef_search", 100}}}, }; @@ -413,43 +420,44 @@ TEST_CASE("static hnsw", "[ut][hnsw]") { auto range_result = index->RangeSearch(one_vector, 1, params.dump()); REQUIRE_FALSE(range_result.has_value()); - REQUIRE(range_result.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION); + REQUIRE(range_result.error().type == ErrorType::UNSUPPORTED_INDEX_OPERATION); SECTION("incorrect dim") { - vsag::IndexCommonParam incorrect_commom_param; + IndexCommonParam incorrect_commom_param; incorrect_commom_param.dim_ = 127; - incorrect_commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - incorrect_commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; - vsag::HnswParameters incorrect_hnsw_obj = parse_hnsw_params(incorrect_commom_param); + incorrect_commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + incorrect_commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + HnswParameters incorrect_hnsw_obj = parse_hnsw_params(incorrect_commom_param); incorrect_hnsw_obj.use_static = true; incorrect_hnsw_obj.max_degree = 12; incorrect_hnsw_obj.ef_construction = 100; - REQUIRE_THROWS(std::make_shared(incorrect_hnsw_obj, incorrect_commom_param)); + REQUIRE_THROWS(std::make_shared(incorrect_hnsw_obj, incorrect_commom_param)); } auto remove_result = index->Remove(ids[0]); REQUIRE_FALSE(remove_result.has_value()); - REQUIRE(remove_result.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION); + REQUIRE(remove_result.error().type == ErrorType::UNSUPPORTED_INDEX_OPERATION); } TEST_CASE("hnsw add vector with duplicated id", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); std::vector ids{1}; std::vector vectors(dim); - auto first_time = vsag::Dataset::Make(); + auto first_time = Dataset::Make(); first_time->Dim(dim) ->NumElements(1) ->Ids(ids.data()) @@ -460,7 +468,7 @@ TEST_CASE("hnsw add vector with duplicated id", "[ut][hnsw]") { // expect failed id list emtpy REQUIRE(result.value().empty()); - auto second_time = vsag::Dataset::Make(); + auto second_time = Dataset::Make(); second_time->Dim(dim) ->NumElements(1) ->Ids(ids.data()) @@ -474,23 +482,24 @@ TEST_CASE("hnsw add vector with duplicated id", "[ut][hnsw]") { } TEST_CASE("build with reversed edges", "[ut][hnsw]") { - vsag::logger::set_level(vsag::logger::level::debug); + logger::set_level(logger::level::debug); int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 12; hnsw_obj.ef_construction = 100; hnsw_obj.use_reversed_edges = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); const int64_t num_elements = 1000; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); - auto dataset = vsag::Dataset::Make(); + auto dataset = Dataset::Make(); dataset->Dim(dim) ->NumElements(num_elements) ->Ids(ids.data()) @@ -515,7 +524,7 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { in_file.seekg(0, std::ios::end); int64_t length = in_file.tellg(); in_file.seekg(0, std::ios::beg); - auto new_index = std::make_shared(hnsw_obj, commom_param); + auto new_index = std::make_shared(hnsw_obj, commom_param); REQUIRE(new_index->Deserialize(in_file).has_value()); REQUIRE(new_index->CheckGraphIntegrity()); } @@ -527,7 +536,7 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { if (auto bs = index->Serialize(); bs.has_value()) { auto keys = bs->GetKeys(); for (auto key : keys) { - vsag::Binary b = bs->Get(key); + Binary b = bs->Get(key); std::ofstream file(dir.path + "hnsw.index." + key, std::ios::binary); file.write((const char*)b.data.get(), b.size); file.close(); @@ -537,7 +546,7 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { metafile << key << std::endl; } metafile.close(); - } else if (bs.error().type == vsag::ErrorType::NO_ENOUGH_MEMORY) { + } else if (bs.error().type == ErrorType::NO_ENOUGH_MEMORY) { std::cerr << "no enough memory to serialize index" << std::endl; } @@ -549,11 +558,11 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { } metafile.close(); - vsag::BinarySet bs; + BinarySet bs; for (auto key : keys) { std::ifstream file(dir.path + "hnsw.index." + key, std::ios::in); file.seekg(0, std::ios::end); - vsag::Binary b; + Binary b; b.size = file.tellg(); b.data.reset(new int8_t[b.size]); file.seekg(0, std::ios::beg); @@ -561,52 +570,53 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { bs.Set(key, b); } - auto new_index = std::make_shared(hnsw_obj, commom_param); + auto new_index = std::make_shared(hnsw_obj, commom_param); REQUIRE(new_index->Deserialize(bs).has_value()); REQUIRE(new_index->CheckGraphIntegrity()); } } TEST_CASE("feedback with invalid argument", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int64_t num_vectors = 1000; int64_t k = 10; int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 16; hnsw_obj.ef_construction = 200; hnsw_obj.use_conjugate_graph = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); - vsag::JsonType search_parameters{ + JsonType search_parameters{ {"hnsw", {{"ef_search", 200}}}, }; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_vectors, dim); - auto query = vsag::Dataset::Make(); + auto query = Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); SECTION("index feedback with k = 0") { REQUIRE(index->Feedback(query, 0, search_parameters.dump(), -1).error().type == - vsag::ErrorType::INVALID_ARGUMENT); + ErrorType::INVALID_ARGUMENT); REQUIRE(index->Feedback(query, 0, search_parameters.dump()).error().type == - vsag::ErrorType::INVALID_ARGUMENT); + ErrorType::INVALID_ARGUMENT); } SECTION("index feedback with invalid global optimum tag id") { auto feedback_result = index->Feedback(query, k, search_parameters.dump(), -1000); - REQUIRE(feedback_result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(feedback_result.error().type == ErrorType::INVALID_ARGUMENT); } } TEST_CASE("redundant feedback and empty enhancement", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int64_t num_base = 10; @@ -614,19 +624,20 @@ TEST_CASE("redundant feedback and empty enhancement", "[ut][hnsw]") { int64_t k = 10; int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = 128; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 16; hnsw_obj.ef_construction = 200; hnsw_obj.use_conjugate_graph = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); - auto base = vsag::Dataset::Make(); + auto base = Dataset::Make(); base->NumElements(num_base) ->Dim(dim) ->Ids(base_ids.data()) @@ -636,12 +647,12 @@ TEST_CASE("redundant feedback and empty enhancement", "[ut][hnsw]") { auto buildindex = index->Build(base); REQUIRE(buildindex.has_value()); - vsag::JsonType search_parameters{ + JsonType search_parameters{ {"hnsw", {{"ef_search", 200}, {"use_conjugate_graph", true}}}, }; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_query, dim); - auto query = vsag::Dataset::Make(); + auto query = Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); auto search_result = index->KnnSearch(query, k, search_parameters.dump()); @@ -668,7 +679,7 @@ TEST_CASE("redundant feedback and empty enhancement", "[ut][hnsw]") { } TEST_CASE("feedback and pretrain without use conjugate graph", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int64_t num_base = 10; @@ -676,17 +687,18 @@ TEST_CASE("feedback and pretrain without use conjugate graph", "[ut][hnsw]") { int64_t k = 10; int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 16; hnsw_obj.ef_construction = 200; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); - auto base = vsag::Dataset::Make(); + auto base = Dataset::Make(); base->NumElements(num_base) ->Dim(dim) ->Ids(base_ids.data()) @@ -696,25 +708,25 @@ TEST_CASE("feedback and pretrain without use conjugate graph", "[ut][hnsw]") { auto buildindex = index->Build(base); REQUIRE(buildindex.has_value()); - vsag::JsonType search_parameters{ + JsonType search_parameters{ {"hnsw", {{"ef_search", 200}}}, }; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_query, dim); - auto query = vsag::Dataset::Make(); + auto query = Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); auto feedback_result = index->Feedback(query, k, search_parameters.dump()); - REQUIRE(feedback_result.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION); + REQUIRE(feedback_result.error().type == ErrorType::UNSUPPORTED_INDEX_OPERATION); std::vector base_tag_ids; base_tag_ids.push_back(10000); auto pretrain_result = index->Pretrain(base_tag_ids, 10, search_parameters.dump()); - REQUIRE(pretrain_result.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION); + REQUIRE(pretrain_result.error().type == ErrorType::UNSUPPORTED_INDEX_OPERATION); } TEST_CASE("feedback and pretrain on empty index", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int64_t dim = 128; @@ -722,19 +734,20 @@ TEST_CASE("feedback and pretrain on empty index", "[ut][hnsw]") { int64_t num_query = 1; int64_t k = 100; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 16; hnsw_obj.ef_construction = 200; hnsw_obj.use_conjugate_graph = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); - auto base = vsag::Dataset::Make(); + auto base = Dataset::Make(); base->NumElements(num_base) ->Dim(dim) ->Ids(base_ids.data()) @@ -744,12 +757,12 @@ TEST_CASE("feedback and pretrain on empty index", "[ut][hnsw]") { auto buildindex = index->Build(base); REQUIRE(buildindex.has_value()); - vsag::JsonType search_parameters{ + JsonType search_parameters{ {"hnsw", {{"ef_search", 200}}}, }; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_query, dim); - auto query = vsag::Dataset::Make(); + auto query = Dataset::Make(); query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data())->Owner(false); auto feedback_result = index->Feedback(query, k, search_parameters.dump()); @@ -762,7 +775,7 @@ TEST_CASE("feedback and pretrain on empty index", "[ut][hnsw]") { } TEST_CASE("invalid pretrain", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int64_t num_base = 10; @@ -770,19 +783,20 @@ TEST_CASE("invalid pretrain", "[ut][hnsw]") { int64_t k = 100; int64_t dim = 128; - vsag::IndexCommonParam commom_param; + IndexCommonParam commom_param; commom_param.dim_ = dim; - commom_param.data_type_ = vsag::DataTypes::DATA_TYPE_FLOAT; - commom_param.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + commom_param.data_type_ = DataTypes::DATA_TYPE_FLOAT; + commom_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + commom_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); - vsag::HnswParameters hnsw_obj = parse_hnsw_params(commom_param); + HnswParameters hnsw_obj = parse_hnsw_params(commom_param); hnsw_obj.max_degree = 16; hnsw_obj.ef_construction = 200; hnsw_obj.use_conjugate_graph = true; - auto index = std::make_shared(hnsw_obj, commom_param); + auto index = std::make_shared(hnsw_obj, commom_param); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); - auto base = vsag::Dataset::Make(); + auto base = Dataset::Make(); base->NumElements(num_base) ->Dim(dim) ->Ids(base_ids.data()) @@ -792,7 +806,7 @@ TEST_CASE("invalid pretrain", "[ut][hnsw]") { auto buildindex = index->Build(base); REQUIRE(buildindex.has_value()); - vsag::JsonType search_parameters{ + JsonType search_parameters{ {"hnsw", {{"ef_search", 200}}}, }; @@ -800,29 +814,29 @@ TEST_CASE("invalid pretrain", "[ut][hnsw]") { std::vector base_tag_ids; base_tag_ids.push_back(10000); auto pretrain_result = index->Pretrain(base_tag_ids, 10, search_parameters.dump()); - REQUIRE(pretrain_result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(pretrain_result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid k") { std::vector base_tag_ids; base_tag_ids.push_back(0); auto pretrain_result = index->Pretrain(base_tag_ids, 0, search_parameters.dump()); - REQUIRE(pretrain_result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(pretrain_result.error().type == ErrorType::INVALID_ARGUMENT); } SECTION("invalid search parameter") { - vsag::JsonType invalid_search_parameters{ + JsonType invalid_search_parameters{ {"hnsw", {{"ef_search", -1}}}, }; std::vector base_tag_ids; base_tag_ids.push_back(0); auto pretrain_result = index->Pretrain(base_tag_ids, 10, invalid_search_parameters.dump()); - REQUIRE(pretrain_result.error().type == vsag::ErrorType::INVALID_ARGUMENT); + REQUIRE(pretrain_result.error().type == ErrorType::INVALID_ARGUMENT); } } TEST_CASE("get distance by label", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int dim = 128; @@ -835,7 +849,7 @@ TEST_CASE("get distance by label", "[ut][hnsw]") { hnswlib::L2Space space(dim); SECTION("hnsw test") { - vsag::DefaultAllocator allocator; + DefaultAllocator allocator; auto* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 100, &allocator); alg_hnsw->init_memory_space(); alg_hnsw->addPoint(base_vectors.data(), 0); @@ -846,7 +860,7 @@ TEST_CASE("get distance by label", "[ut][hnsw]") { } SECTION("static hnsw test") { - vsag::DefaultAllocator allocator; + DefaultAllocator allocator; auto* alg_hnsw_static = new hnswlib::StaticHierarchicalNSW(&space, 100, &allocator); alg_hnsw_static->init_memory_space(); alg_hnsw_static->addPoint(base_vectors.data(), 0); @@ -858,7 +872,7 @@ TEST_CASE("get distance by label", "[ut][hnsw]") { } TEST_CASE("get data by label", "[ut][hnsw]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + Options::Instance().logger()->SetLevel(Logger::Level::kDEBUG); // parameters int dim = 128; @@ -871,7 +885,7 @@ TEST_CASE("get data by label", "[ut][hnsw]") { hnswlib::L2Space space(dim); SECTION("hnsw test") { - vsag::DefaultAllocator allocator; + DefaultAllocator allocator; auto* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 100, &allocator); alg_hnsw->init_memory_space(); alg_hnsw->addPoint(base_vectors.data(), 0); @@ -882,7 +896,7 @@ TEST_CASE("get data by label", "[ut][hnsw]") { } SECTION("static hnsw test") { - vsag::DefaultAllocator allocator; + DefaultAllocator allocator; auto* alg_hnsw_static = new hnswlib::StaticHierarchicalNSW(&space, 100, &allocator); alg_hnsw_static->init_memory_space(); alg_hnsw_static->addPoint(base_vectors.data(), 0); diff --git a/src/index/index_common_param.cpp b/src/index/index_common_param.cpp index 64c81313..cf5cf98c 100644 --- a/src/index/index_common_param.cpp +++ b/src/index/index_common_param.cpp @@ -23,9 +23,9 @@ namespace vsag { IndexCommonParam -IndexCommonParam::CheckAndCreate(JsonType& params, Allocator* allocator) { +IndexCommonParam::CheckAndCreate(JsonType& params, std::shared_ptr allocator) { IndexCommonParam result; - result.allocator_ = allocator; + result.allocator_ = std::move(allocator); // Check DataType CHECK_ARGUMENT(params.contains(PARAMETER_DTYPE), fmt::format("parameters must contains {}", PARAMETER_DTYPE)); diff --git a/src/index/index_common_param.h b/src/index/index_common_param.h index d83cd201..eaeac671 100644 --- a/src/index/index_common_param.h +++ b/src/index/index_common_param.h @@ -29,9 +29,9 @@ class IndexCommonParam { MetricType metric_{MetricType::METRIC_TYPE_L2SQR}; DataTypes data_type_{DataTypes::DATA_TYPE_FLOAT}; int64_t dim_{0}; - Allocator* allocator_{nullptr}; + std::shared_ptr allocator_{nullptr}; static IndexCommonParam - CheckAndCreate(JsonType& params, Allocator* allocator); + CheckAndCreate(JsonType& params, std::shared_ptr allocator); }; } // namespace vsag diff --git a/src/io/memory_block_io.h b/src/io/memory_block_io.h index b1ca6e94..86529d73 100644 --- a/src/io/memory_block_io.h +++ b/src/io/memory_block_io.h @@ -41,7 +41,7 @@ class MemoryBlockIO : public BasicIO { } MemoryBlockIO(const JsonType& io_param, const IndexCommonParam& common_param) - : MemoryBlockIO(common_param.allocator_) { + : MemoryBlockIO(common_param.allocator_.get()) { if (io_param.contains(BLOCK_IO_BLOCK_SIZE_KEY)) { this->block_size_ = io_param[BLOCK_IO_BLOCK_SIZE_KEY]; // TODO(LHT): trans str to uint64_t diff --git a/src/io/memory_block_io_test.cpp b/src/io/memory_block_io_test.cpp index 11657f61..e424cc3d 100644 --- a/src/io/memory_block_io_test.cpp +++ b/src/io/memory_block_io_test.cpp @@ -20,12 +20,14 @@ #include "basic_io_test.h" #include "default_allocator.h" +#include "safe_allocator.h" + using namespace vsag; auto block_memory_io_block_sizes = {64, 1023, 4096, 123123, 1024 * 1024}; TEST_CASE("read&write [ut][memory_block_io]") { - auto allocator = std::make_unique(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); for (auto block_size : block_memory_io_block_sizes) { auto io = std::make_unique(allocator.get(), block_size); TestBasicReadWrite(*io); @@ -33,7 +35,7 @@ TEST_CASE("read&write [ut][memory_block_io]") { } TEST_CASE("serialize&deserialize [ut][memory_block_io]") { - auto allocator = std::make_unique(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); for (auto block_size : block_memory_io_block_sizes) { auto wio = std::make_unique(allocator.get(), block_size); auto rio = std::make_unique(allocator.get(), block_size); diff --git a/src/io/memory_io.h b/src/io/memory_io.h index 3e552ec5..4fec1ff1 100644 --- a/src/io/memory_io.h +++ b/src/io/memory_io.h @@ -36,7 +36,7 @@ class MemoryIO : public BasicIO { } MemoryIO(const JsonType& io_param, const IndexCommonParam& common_param) - : allocator_(common_param.allocator_) { + : allocator_(common_param.allocator_.get()) { start_ = reinterpret_cast(allocator_->Allocate(MIN_SIZE)); current_size_ = MIN_SIZE; } diff --git a/src/io/memory_io_test.cpp b/src/io/memory_io_test.cpp index be4af4bb..cb9b7ed2 100644 --- a/src/io/memory_io_test.cpp +++ b/src/io/memory_io_test.cpp @@ -20,16 +20,18 @@ #include "basic_io_test.h" #include "default_allocator.h" +#include "safe_allocator.h" + using namespace vsag; TEST_CASE("read&write [ut][memory_io]") { - auto allocator = std::make_unique(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto io = std::make_unique(allocator.get()); TestBasicReadWrite(*io); } TEST_CASE("serialize&deserialize [ut][memory_io]") { - auto allocator = std::make_unique(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto wio = std::make_unique(allocator.get()); auto rio = std::make_unique(allocator.get()); TestSerializeAndDeserialize(*wio, *rio); diff --git a/src/quantization/fp32_quantizer.h b/src/quantization/fp32_quantizer.h index 4aff703f..423c3bd1 100644 --- a/src/quantization/fp32_quantizer.h +++ b/src/quantization/fp32_quantizer.h @@ -80,7 +80,7 @@ class FP32Quantizer : public Quantizer> { template FP32Quantizer::FP32Quantizer(const JsonType& quantization_param, const IndexCommonParam& common_param) - : Quantizer>(common_param.dim_, common_param.allocator_) { + : Quantizer>(common_param.dim_, common_param.allocator_.get()) { this->code_size_ = common_param.dim_ * sizeof(float); } diff --git a/src/quantization/fp32_quantizer_test.cpp b/src/quantization/fp32_quantizer_test.cpp index b1cb101d..a03077b0 100644 --- a/src/quantization/fp32_quantizer_test.cpp +++ b/src/quantization/fp32_quantizer_test.cpp @@ -21,6 +21,7 @@ #include "default_allocator.h" #include "fixtures.h" #include "quantizer_test.h" +#include "safe_allocator.h" using namespace vsag; @@ -30,7 +31,7 @@ const auto counts = {10, 101}; template void TestQuantizerEncodeDecodeMetricFP32(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); FP32Quantizer quantizer(dim, allocator.get()); TestQuantizerEncodeDecode(quantizer, dim, count, error); TestQuantizerEncodeDecodeSame(quantizer, dim, count, 65536, error); @@ -50,7 +51,7 @@ TEST_CASE("encode&decode [ut][fp32_quantizer]") { template void TestComputeMetricFP32(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); FP32Quantizer quantizer(dim, allocator.get()); TestComputeCodes, metric>(quantizer, dim, count, error); TestComputeCodesSame, metric>(quantizer, dim, count, 65536); @@ -73,7 +74,7 @@ TEST_CASE("compute [ut][fp32_quantizer]") { template void TestSerializeAndDeserializeMetricFP32(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); FP32Quantizer quantizer1(dim, allocator.get()); FP32Quantizer quantizer2(0, allocator.get()); TestSerializeAndDeserialize, metric>( diff --git a/src/quantization/sq4_quantizer.h b/src/quantization/sq4_quantizer.h index 7e14f0b6..9d7a0d95 100644 --- a/src/quantization/sq4_quantizer.h +++ b/src/quantization/sq4_quantizer.h @@ -90,7 +90,7 @@ SQ4Quantizer::SQ4Quantizer(int dim, Allocator* allocator) template SQ4Quantizer::SQ4Quantizer(const JsonType& quantization_param, const IndexCommonParam& common_param) - : SQ4Quantizer(common_param.dim_, common_param.allocator_){}; + : SQ4Quantizer(common_param.dim_, common_param.allocator_.get()){}; template bool diff --git a/src/quantization/sq4_quantizer_test.cpp b/src/quantization/sq4_quantizer_test.cpp index 0b246cb8..06955ed8 100644 --- a/src/quantization/sq4_quantizer_test.cpp +++ b/src/quantization/sq4_quantizer_test.cpp @@ -21,6 +21,7 @@ #include "default_allocator.h" #include "fixtures.h" #include "quantizer_test.h" +#include "safe_allocator.h" using namespace vsag; @@ -33,7 +34,7 @@ TestQuantizerEncodeDecodeMetricSQ4(uint64_t dim, int count, float error = 1e-5, float error_same = 1e-2) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ4Quantizer quantizer(dim, allocator.get()); TestQuantizerEncodeDecode(quantizer, dim, count, error); TestQuantizerEncodeDecodeSame(quantizer, dim, count, 15, error_same); @@ -54,7 +55,7 @@ TEST_CASE("Encode and Decode", "[ut][SQ4Quantizer]") { template void TestComputeMetricSQ4(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ4Quantizer quantizer(dim, allocator.get()); TestComputeCodes, metric>(quantizer, dim, count, error); TestComputer, metric>(quantizer, dim, count, error); @@ -77,7 +78,7 @@ TEST_CASE("compute [ut][sq4_quantizer]") { template void TestSerializeAndDeserializeMetricSQ4(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ4Quantizer quantizer1(dim, allocator.get()); SQ4Quantizer quantizer2(0, allocator.get()); TestSerializeAndDeserialize, metric>( diff --git a/src/quantization/sq4_uniform_quantizer.h b/src/quantization/sq4_uniform_quantizer.h index b1920c5a..e8d301ca 100644 --- a/src/quantization/sq4_uniform_quantizer.h +++ b/src/quantization/sq4_uniform_quantizer.h @@ -137,7 +137,7 @@ SQ4UniformQuantizer::SQ4UniformQuantizer(int dim, Allocator* allocator) template SQ4UniformQuantizer::SQ4UniformQuantizer(const JsonType& quantization_param, const IndexCommonParam& common_param) - : SQ4UniformQuantizer(common_param.dim_, common_param.allocator_){}; + : SQ4UniformQuantizer(common_param.dim_, common_param.allocator_.get()){}; template bool diff --git a/src/quantization/sq4_uniform_quantizer_test.cpp b/src/quantization/sq4_uniform_quantizer_test.cpp index 3eaa1525..a912f67b 100644 --- a/src/quantization/sq4_uniform_quantizer_test.cpp +++ b/src/quantization/sq4_uniform_quantizer_test.cpp @@ -21,6 +21,7 @@ #include "default_allocator.h" #include "fixtures.h" #include "quantizer_test.h" +#include "safe_allocator.h" using namespace vsag; @@ -33,7 +34,7 @@ TestQuantizerEncodeDecodeMetricSQ4Uniform(uint64_t dim, int count, float error = 1e-5, float error_same = 1e-2) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ4UniformQuantizer quantizer(dim, allocator.get()); TestQuantizerEncodeDecode(quantizer, dim, count, error); TestQuantizerEncodeDecodeSame(quantizer, dim, count, 15, error_same); @@ -54,7 +55,7 @@ TEST_CASE("SQ4 Uniform Encode and Decode", "[ut][SQ4UniformQuantizer]") { template void TestComputeMetricSQ4Uniform(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ4UniformQuantizer quantizer(dim, allocator.get()); TestComputeCodesSame, metric>(quantizer, dim, count, error); } @@ -73,7 +74,7 @@ TEST_CASE("compute [ut][SQ4UniformQuantizer]") { template void TestSerializeAndDeserializeMetricSQ4Uniform(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ4UniformQuantizer quantizer1(dim, allocator.get()); SQ4UniformQuantizer quantizer2(0, allocator.get()); TestSerializeAndDeserialize, metric, true>( diff --git a/src/quantization/sq8_quantizer.h b/src/quantization/sq8_quantizer.h index b6195214..0d07a228 100644 --- a/src/quantization/sq8_quantizer.h +++ b/src/quantization/sq8_quantizer.h @@ -94,9 +94,9 @@ SQ8Quantizer::SQ8Quantizer(int dim, Allocator* allocator) template SQ8Quantizer::SQ8Quantizer(const JsonType& quantization_param, const IndexCommonParam& common_param) - : Quantizer>(common_param.dim_, common_param.allocator_), - diff_(common_param.allocator_), - lower_bound_(common_param.allocator_) { + : Quantizer>(common_param.dim_, common_param.allocator_.get()), + diff_(common_param.allocator_.get()), + lower_bound_(common_param.allocator_.get()) { // align 64 bytes (512 bits) to avoid illegal memory access in SIMD this->code_size_ = this->dim_; this->diff_.resize(this->dim_, 0); diff --git a/src/quantization/sq8_quantizer_test.cpp b/src/quantization/sq8_quantizer_test.cpp index 8455745e..bb44555a 100644 --- a/src/quantization/sq8_quantizer_test.cpp +++ b/src/quantization/sq8_quantizer_test.cpp @@ -21,6 +21,7 @@ #include "default_allocator.h" #include "fixtures.h" #include "quantizer_test.h" +#include "safe_allocator.h" using namespace vsag; @@ -32,7 +33,7 @@ TestQuantizerEncodeDecodeMetricSQ8(uint64_t dim, int count, float error = 1e-5, float error_same = 1e-2) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ8Quantizer quantizer(dim, allocator.get()); TestQuantizerEncodeDecode(quantizer, dim, count, error); TestQuantizerEncodeDecodeSame(quantizer, dim, count, 255, error_same); @@ -54,7 +55,7 @@ TEST_CASE("encode&decode [SQ8Quantizer]") { template void TestComputeMetricSQ8(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ8Quantizer quantizer(dim, allocator.get()); TestComputeCodes, metric>(quantizer, dim, count, error); TestComputer, metric>(quantizer, dim, count, error); @@ -77,7 +78,7 @@ TEST_CASE("compute [ut][sq8_quantizer]") { template void TestSerializeAndDeserializeMetricSQ8(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ8Quantizer quantizer1(dim, allocator.get()); SQ8Quantizer quantizer2(0, allocator.get()); TestSerializeAndDeserialize, metric>( diff --git a/src/quantization/sq8_uniform_quantizer.h b/src/quantization/sq8_uniform_quantizer.h index cc21d5bb..0b842557 100644 --- a/src/quantization/sq8_uniform_quantizer.h +++ b/src/quantization/sq8_uniform_quantizer.h @@ -123,7 +123,7 @@ SQ8UniformQuantizer::SQ8UniformQuantizer(int dim, Allocator* allocator) template SQ8UniformQuantizer::SQ8UniformQuantizer(const JsonType& quantization_param, const IndexCommonParam& common_param) - : SQ8UniformQuantizer(common_param.dim_, common_param.allocator_){}; + : SQ8UniformQuantizer(common_param.dim_, common_param.allocator_.get()){}; template bool diff --git a/src/quantization/sq8_uniform_quantizer_test.cpp b/src/quantization/sq8_uniform_quantizer_test.cpp index e0d9512e..7491db74 100644 --- a/src/quantization/sq8_uniform_quantizer_test.cpp +++ b/src/quantization/sq8_uniform_quantizer_test.cpp @@ -21,6 +21,7 @@ #include "default_allocator.h" #include "fixtures.h" #include "quantizer_test.h" +#include "safe_allocator.h" using namespace vsag; @@ -33,7 +34,7 @@ TestQuantizerEncodeDecodeMetricSQ8Uniform(uint64_t dim, int count, float error = 1e-5, float error_same = 1e-2) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ8UniformQuantizer quantizer(dim, allocator.get()); TestQuantizerEncodeDecode(quantizer, dim, count, error); TestQuantizerEncodeDecodeSame(quantizer, dim, count, 255, error_same); @@ -54,7 +55,7 @@ TEST_CASE("SQ8 Uniform Encode and Decode", "[ut][SQ8UniformQuantizer]") { template void TestComputeMetricSQ8Uniform(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ8UniformQuantizer quantizer(dim, allocator.get()); TestComputeCodesSame, metric>(quantizer, dim, count, error); } @@ -73,7 +74,7 @@ TEST_CASE("compute [ut][SQ8UniformQuantizer]") { template void TestSerializeAndDeserializeMetricSQ8Uniform(uint64_t dim, int count, float error = 1e-5) { - auto allocator = std::make_shared(); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); SQ8UniformQuantizer quantizer1(dim, allocator.get()); SQ8UniformQuantizer quantizer2(0, allocator.get()); TestSerializeAndDeserialize, metric, true>( diff --git a/src/resource.cpp b/src/resource.cpp index 247684d8..e0cbce3c 100644 --- a/src/resource.cpp +++ b/src/resource.cpp @@ -21,7 +21,7 @@ namespace vsag { Resource::Resource(Allocator* allocator) { if (allocator == nullptr) { - this->allocator = std::make_shared(new DefaultAllocator(), true); + this->allocator = SafeAllocator::FactoryDefaultAllocator(); } else { this->allocator = std::make_shared(allocator, false); } diff --git a/src/safe_allocator.h b/src/safe_allocator.h index 57fb8af8..d413be8f 100644 --- a/src/safe_allocator.h +++ b/src/safe_allocator.h @@ -17,18 +17,21 @@ #include +#include "default_allocator.h" #include "vsag/allocator.h" namespace vsag { class SafeAllocator : public Allocator { public: - explicit SafeAllocator(Allocator* raw_allocator, bool owner = false) - : raw_allocator_(raw_allocator), owner_(owner) { + static std::shared_ptr + FactoryDefaultAllocator() { + return std::make_shared(new DefaultAllocator(), true); } - explicit SafeAllocator(std::shared_ptr owned_allocator) - : owned_allocator_(owned_allocator), raw_allocator_(owned_allocator.get()) { +public: + explicit SafeAllocator(Allocator* raw_allocator, bool owned = false) + : raw_allocator_(raw_allocator), owned_(owned) { } std::string @@ -64,13 +67,16 @@ class SafeAllocator : public Allocator { } public: - ~SafeAllocator() override = default; + ~SafeAllocator() override { + if (owned_) { + delete raw_allocator_; + } + } private: Allocator* const raw_allocator_ = nullptr; - std::shared_ptr owned_allocator_ = nullptr; - bool owner_{false}; + bool owned_{false}; }; } // namespace vsag diff --git a/tests/test_engine.cpp b/tests/test_engine.cpp new file mode 100644 index 00000000..fc899c62 --- /dev/null +++ b/tests/test_engine.cpp @@ -0,0 +1,82 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "vsag/vsag.h" + +TEST_CASE("index params", "[ft][engine]") { + int dim = 16; + int max_elements = 1000; + int max_degree = 16; + int ef_construction = 100; + int ef_search = 100; + + nlohmann::json hnsw_parameters{ + {"max_degree", max_degree}, + {"ef_construction", ef_construction}, + {"ef_search", ef_search}, + }; + nlohmann::json index_parameters{ + {"dtype", "float32"}, {"metric_type", "l2"}, {"dim", dim}, {"hnsw", hnsw_parameters}}; + + vsag::Engine engine; + + std::shared_ptr hnsw; + auto index = engine.CreateIndex("hnsw", index_parameters.dump()); + REQUIRE(index.has_value()); + hnsw = index.value(); + // Generate random data + std::mt19937 rng(97); + std::uniform_real_distribution distrib_real; + auto* ids = new int64_t[max_elements]; + auto* data = new float[dim * max_elements]; + for (int i = 0; i < max_elements; i++) { + ids[i] = i; + } + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + auto dataset = vsag::Dataset::Make(); + dataset->Dim(dim)->NumElements(max_elements)->Ids(ids)->Float32Vectors(data); + hnsw->Build(dataset); + + // Query the elements for themselves and measure recall 1@1 + float correct = 0; + for (int i = 0; i < max_elements; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(data + i * dim)->Owner(false); + + nlohmann::json parameters{ + {"hnsw", {{"ef_search", ef_search}}}, + }; + int64_t k = 10; + if (auto result = hnsw->KnnSearch(query, k, parameters.dump()); result.has_value()) { + if (result.value()->GetIds()[0] == i) { + correct++; + } + } else if (result.error().type == vsag::ErrorType::INTERNAL_ERROR) { + std::cerr << "failed to search on index: internalError" << std::endl; + } + } + float recall = correct / static_cast(max_elements); + + REQUIRE(recall == 1); + + engine.Shutdown(); +}