From 46b3887220d389f8b9daca568abbd25ca6580619 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Fri, 27 Dec 2024 14:18:39 +0800 Subject: [PATCH] add reorder for hgraph (#257) Signed-off-by: LHT129 --- include/vsag/constants.h | 1 + src/algorithm/hgraph.cpp | 53 +++++++++++++++++++++++---- src/algorithm/hgraph.h | 6 ++++ src/constants.cpp | 1 + src/index/hgraph_zparameters.cpp | 7 ++-- tests/fixtures/fixtures.cpp | 13 +++++++ tests/fixtures/fixtures.h | 3 ++ tests/test_hgraph.cpp | 61 ++++++++++++++++++++++++-------- 8 files changed, 121 insertions(+), 24 deletions(-) diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 8ba3bdb2..cbc68825 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -103,5 +103,6 @@ extern const char* const HGRAPH_GRAPH_MAX_DEGREE; extern const char* const HGRAPH_BUILD_EF_CONSTRUCTION; extern const char* const HGRAPH_INIT_CAPACITY; extern const char* const HGRAPH_BUILD_THREAD_COUNT; +extern const char* const HGRAPH_PRECISE_QUANTIZATION_TYPE; } // namespace vsag diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 344f3a96..86e2d6b4 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -189,6 +189,10 @@ HGraph::KnnSearch(const DatasetPtr& query, this->basic_flatten_codes_, search_param); + if (use_reorder_) { + this->reorder(query->GetFloat32Vectors(), this->high_precise_codes_, search_result, k); + } + while (search_result.size() > k) { search_result.pop(); } @@ -467,6 +471,11 @@ HGraph::RangeSearch(const DatasetPtr& query, this->bottom_graph_, this->basic_flatten_codes_, search_param); + if (use_reorder_) { + this->reorder( + query->GetFloat32Vectors(), this->high_precise_codes_, search_result, limited_size); + } + if (limited_size > 0) { while (search_result.size() > limited_size) { search_result.pop(); @@ -792,27 +801,30 @@ HGraph::add_one_point(const float* data, int level, InnerIdType inner_id) { }; std::lock_guard cur_lock(this->neighbors_mutex_[inner_id]); - + auto flatten_codes = basic_flatten_codes_; + if (use_reorder_) { + flatten_codes = high_precise_codes_; + } for (auto j = max_level_ - 1; j > level; --j) { - result = search_one_graph(data, route_graphs_[j], basic_flatten_codes_, param); + result = search_one_graph(data, route_graphs_[j], flatten_codes, param); param.ep_ = result.top().second; } param.ef_ = this->ef_construct_; for (auto j = level; j >= 0; --j) { if (route_graphs_[j]->TotalCount() != 0) { - result = search_one_graph(data, route_graphs_[j], basic_flatten_codes_, param); + result = search_one_graph(data, route_graphs_[j], flatten_codes, param); param.ep_ = this->mutually_connect_new_element( - inner_id, result, route_graphs_[j], basic_flatten_codes_, false); + inner_id, result, route_graphs_[j], flatten_codes, false); } else { route_graphs_[j]->InsertNeighborsById(inner_id, Vector(allocator_)); } route_graphs_[j]->IncreaseTotalCount(1); } if (bottom_graph_->TotalCount() != 0) { - result = search_one_graph(data, this->bottom_graph_, basic_flatten_codes_, param); + result = search_one_graph(data, this->bottom_graph_, flatten_codes, param); this->mutually_connect_new_element( - inner_id, result, this->bottom_graph_, basic_flatten_codes_, false); + inner_id, result, this->bottom_graph_, flatten_codes, false); } else { bottom_graph_->InsertNeighborsById(inner_id, Vector(allocator_)); } @@ -917,4 +929,33 @@ HGraph::split_dataset_by_duplicate_label(const DatasetPtr& dataset, return return_datasets; } +void +HGraph::reorder(const float* query, + const FlattenInterfacePtr& flatten_interface, + MaxHeap& candidate_heap, + int64_t k) const { + uint64_t size = candidate_heap.size(); + if (k <= 0) { + k = static_cast(size); + } + Vector ids(size, allocator_); + Vector dists(size, allocator_); + uint64_t idx = 0; + while (not candidate_heap.empty()) { + ids[idx] = candidate_heap.top().second; + ++idx; + candidate_heap.pop(); + } + auto computer = flatten_interface->FactoryComputer(query); + flatten_interface->Query(dists.data(), computer, ids.data(), size); + for (uint64_t i = 0; i < size; ++i) { + if (candidate_heap.size() < k or dists[i] <= candidate_heap.top().first) { + candidate_heap.emplace(dists[i], ids[i]); + } + if (candidate_heap.size() > k) { + candidate_heap.pop(); + } + } +} + } // namespace vsag diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 756f2c62..cd0b7eba 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -188,6 +188,12 @@ class HGraph { split_dataset_by_duplicate_label(const DatasetPtr& dataset, std::vector& failed_ids) const; + void + reorder(const float* query, + const FlattenInterfacePtr& flatten_interface, + MaxHeap& candidate_heap, + int64_t k) const; + private: FlattenInterfacePtr basic_flatten_codes_{nullptr}; FlattenInterfacePtr high_precise_codes_{nullptr}; diff --git a/src/constants.cpp b/src/constants.cpp index 4c4190ca..d791cc52 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -104,5 +104,6 @@ const char* const HGRAPH_GRAPH_MAX_DEGREE = "max_degree"; const char* const HGRAPH_BUILD_EF_CONSTRUCTION = "ef_construction"; const char* const HGRAPH_INIT_CAPACITY = "hgraph_init_capacity"; const char* const HGRAPH_BUILD_THREAD_COUNT = "build_thread_count"; +const char* const HGRAPH_PRECISE_QUANTIZATION_TYPE = "precise_quantization_type"; }; // namespace vsag diff --git a/src/index/hgraph_zparameters.cpp b/src/index/hgraph_zparameters.cpp index 4cb78d91..6f823e17 100644 --- a/src/index/hgraph_zparameters.cpp +++ b/src/index/hgraph_zparameters.cpp @@ -26,6 +26,7 @@ namespace vsag { static const std::unordered_map> EXTERNAL_MAPPING = { {HGRAPH_USE_REORDER, {HGRAPH_USE_REORDER_KEY}}, {HGRAPH_BASE_QUANTIZATION_TYPE, {HGRAPH_BASE_CODES_KEY, QUANTIZATION_TYPE_KEY}}, + {HGRAPH_PRECISE_QUANTIZATION_TYPE, {HGRAPH_PRECISE_CODES_KEY, QUANTIZATION_TYPE_KEY}}, {HGRAPH_GRAPH_MAX_DEGREE, {HGRAPH_GRAPH_KEY, GRAPH_PARAMS_KEY, GRAPH_PARAM_MAX_DEGREE}}, {HGRAPH_BUILD_EF_CONSTRUCTION, {BUILD_PARAMS_KEY, BUILD_EF_CONSTRUCTION}}, {HGRAPH_INIT_CAPACITY, {HGRAPH_GRAPH_KEY, GRAPH_PARAMS_KEY, GRAPH_PARAM_INIT_MAX_CAPACITY}}, @@ -59,12 +60,12 @@ static const std::string HGRAPH_PARAMS_TEMPLATE = "nbits": 8 } }, - "precise_codes": { - "{IO_TYPE_KEY}": "aio_ssd", + "{HGRAPH_PRECISE_CODES_KEY}": { + "{IO_TYPE_KEY}": "{IO_TYPE_VALUE_BLOCK_MEMORY_IO}", "{IO_PARAMS_KEY}": {}, "codes_type": "flatten_codes", "codes_param": {}, - "{QUANTIZATION_TYPE_KEY}": "{QUANTIZATION_TYPE_VALUE_SQ8}", + "{QUANTIZATION_TYPE_KEY}": "{QUANTIZATION_TYPE_VALUE_FP32}", "{QUANTIZATION_PARAMS_KEY}": {} }, "{BUILD_PARAMS_KEY}": { diff --git a/tests/fixtures/fixtures.cpp b/tests/fixtures/fixtures.cpp index 7bba1403..c7e71f9e 100644 --- a/tests/fixtures/fixtures.cpp +++ b/tests/fixtures/fixtures.cpp @@ -289,4 +289,17 @@ GetFileSize(const std::string& filename) { return static_cast(file.tellg()); } +std::vector +SplitString(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::stringstream ss(s); + + while (std::getline(ss, token, delimiter)) { + tokens.emplace_back(token); + } + + return tokens; +} + } // namespace fixtures diff --git a/tests/fixtures/fixtures.h b/tests/fixtures/fixtures.h index ef334006..dc1e388c 100644 --- a/tests/fixtures/fixtures.h +++ b/tests/fixtures/fixtures.h @@ -225,4 +225,7 @@ generate_one_dataset(int64_t dim, uint64_t count); uint64_t GetFileSize(const std::string& filename); + +std::vector +SplitString(const std::string& s, char delimiter); } // Namespace fixtures diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index a174e330..e6967fc3 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -30,7 +30,7 @@ class HgraphTestIndex : public fixtures::TestIndex { static std::string GenerateHGraphBuildParametersString(const std::string& metric_type, int64_t dim, - const std::string& base_quantization_type = "sq8", + const std::string& quantization_str = "sq8", const int thread_count = 5); static TestDatasetPool pool; @@ -44,6 +44,9 @@ class HgraphTestIndex : public fixtures::TestIndex { "ef_search": {} }} }})"; + + const std::vector> test_cases = { + {"sq8_uniform,fp32", 0.98}, {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; }; TestDatasetPool HgraphTestIndex::pool{}; @@ -52,9 +55,27 @@ std::vector HgraphTestIndex::dims = fixtures::get_common_used_dims(2, Rando std::string HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_type, int64_t dim, - const std::string& base_quantization_type, + const std::string& quantization_str, const int thread_count) { - constexpr auto parameter_temp = R"( + std::string build_parameters_str; + + constexpr auto parameter_temp_reorder = R"( + {{ + "dtype": "float32", + "metric_type": "{}", + "dim": {}, + "index_param": {{ + "use_reorder": {}, + "base_quantization_type": "{}", + "max_degree": 96, + "ef_construction": 500, + "build_thread_count": {}, + "precise_quantization_type": "{}" + }} + }} + )"; + + constexpr auto parameter_temp_origin = R"( {{ "dtype": "float32", "metric_type": "{}", @@ -67,8 +88,23 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t }} }} )"; - std::string build_parameters_str = - fmt::format(parameter_temp, metric_type, dim, base_quantization_type, thread_count); + + auto strs = fixtures::SplitString(quantization_str, ','); + std::string high_quantizer_str; + auto& base_quantizer_str = strs[0]; + if (strs.size() > 1) { + high_quantizer_str = strs[1]; + build_parameters_str = fmt::format(parameter_temp_reorder, + metric_type, + dim, + true, /* reorder */ + base_quantizer_str, + thread_count, + high_quantizer_str); + } else { + build_parameters_str = + fmt::format(parameter_temp_origin, metric_type, dim, base_quantizer_str, thread_count); + } return build_parameters_str; } } // namespace fixtures @@ -189,8 +225,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::vector> test_cases = { - {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + const std::string name = "hgraph"; auto search_param = fmt::format(search_param_tmp, 200); for (auto& dim : dims) { @@ -225,8 +260,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Build", "[ft][hg auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::vector> test_cases = { - {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + const std::string name = "hgraph"; auto search_param = fmt::format(search_param_tmp, 200); for (auto& dim : dims) { @@ -261,8 +295,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgra auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::vector> test_cases = { - {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + const std::string name = "hgraph"; auto search_param = fmt::format(search_param_tmp, 200); for (auto& dim : dims) { @@ -297,8 +330,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Concurrent Add", auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::vector> test_cases = { - {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + const std::string name = "hgraph"; auto search_param = fmt::format(search_param_tmp, 200); for (auto& dim : dims) { @@ -382,8 +414,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build" auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::vector> test_cases = { - {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + const std::string name = "hgraph"; auto search_param = fmt::format(search_param_tmp, 200); for (auto& dim : dims) {