From 6f49bac0a61d9f7ea782f9876c78d1a74ec83397 Mon Sep 17 00:00:00 2001 From: "zhongxiaoyao.zxy" Date: Tue, 21 Jan 2025 12:01:05 +0800 Subject: [PATCH] support allocator in conjugate graph Signed-off-by: zhongxiaoyao.zxy --- src/impl/conjugate_graph.cpp | 60 ++++++++++++++++++++----------- src/impl/conjugate_graph.h | 10 +++--- src/impl/conjugate_graph_test.cpp | 13 ++++--- src/index/hnsw.cpp | 6 ++-- src/index/hnsw.h | 3 ++ 5 files changed, 60 insertions(+), 32 deletions(-) diff --git a/src/impl/conjugate_graph.cpp b/src/impl/conjugate_graph.cpp index 01f9d1f1..4a9cd3cc 100644 --- a/src/impl/conjugate_graph.cpp +++ b/src/impl/conjugate_graph.cpp @@ -17,7 +17,8 @@ namespace vsag { -ConjugateGraph::ConjugateGraph() { +ConjugateGraph::ConjugateGraph(Allocator* allocator) + : allocator_(allocator), conjugate_graph_(allocator) { clear(); } @@ -26,31 +27,39 @@ ConjugateGraph::AddNeighbor(int64_t from_tag_id, int64_t to_tag_id) { if (from_tag_id == to_tag_id) { return false; } - auto& neighbor_set = conjugate_graph_[from_tag_id]; - if (neighbor_set.size() >= MAXIMUM_DEGREE) { + + std::shared_ptr> neighbor_set; + auto search_key = conjugate_graph_.find(from_tag_id); + if (search_key == conjugate_graph_.end()) { + neighbor_set = std::make_shared>(allocator_); + conjugate_graph_.emplace(from_tag_id, neighbor_set); + } else { + neighbor_set = search_key->second; + } + + if (neighbor_set->size() >= MAXIMUM_DEGREE) { return false; } - auto insert_result = neighbor_set.insert(to_tag_id); + auto insert_result = neighbor_set->insert(to_tag_id); if (!insert_result.second) { return false; } else { - if (neighbor_set.size() == 1) { + if (neighbor_set->size() == 1) { memory_usage_ += sizeof(from_tag_id); - memory_usage_ += sizeof(neighbor_set.size()); + memory_usage_ += sizeof(neighbor_set->size()); } memory_usage_ += sizeof(to_tag_id); return true; } } -const std::unordered_set& +std::shared_ptr> ConjugateGraph::get_neighbors(int64_t from_tag_id) const { - static const std::unordered_set empty_set; auto iter = conjugate_graph_.find(from_tag_id); if (iter != conjugate_graph_.end()) { return iter->second; } else { - return empty_set; + return nullptr; } } @@ -64,8 +73,8 @@ ConjugateGraph::EnhanceResult(std::priority_queue>& int64_t k = results.size(); int64_t look_at_k = std::min(LOOK_AT_K, k); std::priority_queue> old_results(results); - std::vector to_be_visited(look_at_k); - std::unordered_set visited_set; + Vector to_be_visited(look_at_k, allocator_); + UnorderedSet visited_set(allocator_); uint32_t successfully_enhanced = 0; float distance = 0; @@ -80,9 +89,13 @@ ConjugateGraph::EnhanceResult(std::priority_queue>& // add neighbors in conjugate graph to enhance result for (int j = 0; j < look_at_k; j++) { - const std::unordered_set& neighbors = get_neighbors(to_be_visited[j]); + auto neighbors = get_neighbors(to_be_visited[j]); - for (auto neighbor_tag_id : neighbors) { + if (neighbors == nullptr) { + continue; + } + + for (auto neighbor_tag_id : *neighbors) { if (not visited_set.insert(neighbor_tag_id).second) { continue; } @@ -114,6 +127,7 @@ read_var_from_stream(StreamReader& in_stream, uint32_t* offset, T* dest) { void ConjugateGraph::clear() { memory_usage_ = sizeof(memory_usage_) + FOOTER_SIZE; + conjugate_graph_.clear(); footer_.Clear(); } @@ -141,7 +155,7 @@ ConjugateGraph::Serialize(std::ostream& out_stream) const { out_stream.write((char*)&memory_usage_, sizeof(memory_usage_)); for (auto item : conjugate_graph_) { - auto neighbor_set = item.second; + auto neighbor_set = *item.second; size_t neighbor_set_size = neighbor_set.size(); out_stream.write((char*)&item.first, sizeof(item.first)); @@ -169,7 +183,8 @@ ConjugateGraph::Deserialize(const Binary& binary) { int64_t cursor = 0; ReadFuncStreamReader reader(func, cursor); - return this->Deserialize(reader); + BufferStreamReader buffer_reader(&reader, binary.size, allocator_); + return this->Deserialize(buffer_reader); } tl::expected @@ -198,10 +213,15 @@ ConjugateGraph::Deserialize(StreamReader& in_stream) { in_stream.Seek(cur_pos + offset); while (offset != memory_usage_ - FOOTER_SIZE) { read_var_from_stream(in_stream, &offset, &from_tag_id); + if (not conjugate_graph_.count(from_tag_id)) { + conjugate_graph_.emplace(from_tag_id, + std::make_shared>(allocator_)); + } + read_var_from_stream(in_stream, &offset, &neighbor_size); for (int i = 0; i < neighbor_size; i++) { read_var_from_stream(in_stream, &offset, &to_tag_id); - conjugate_graph_[from_tag_id].insert(to_tag_id); + conjugate_graph_[from_tag_id]->insert(to_tag_id); } } @@ -240,10 +260,10 @@ ConjugateGraph::UpdateId(int64_t old_tag_id, int64_t new_tag_id) { // 2. update neighbors for (auto& [key, neighbors] : conjugate_graph_) { - auto it_old_neighbor = neighbors.find(old_tag_id); - if (it_old_neighbor != neighbors.end()) { - neighbors.erase(it_old_neighbor); - neighbors.insert(new_tag_id); + auto it_old_neighbor = neighbors->find(old_tag_id); + if (it_old_neighbor != neighbors->end()) { + neighbors->erase(it_old_neighbor); + neighbors->insert(new_tag_id); updated = true; } } diff --git a/src/impl/conjugate_graph.h b/src/impl/conjugate_graph.h index 03e0e80a..32c79be3 100644 --- a/src/impl/conjugate_graph.h +++ b/src/impl/conjugate_graph.h @@ -17,8 +17,6 @@ #include #include -#include -#include #include "../footer.h" #include "../logger.h" @@ -32,7 +30,7 @@ static const int64_t MAXIMUM_DEGREE = 128; class ConjugateGraph { public: - ConjugateGraph(); + ConjugateGraph(Allocator* allocator); tl::expected AddNeighbor(int64_t from_tag_id, int64_t to_tag_id); @@ -61,7 +59,7 @@ class ConjugateGraph { GetMemoryUsage() const; private: - const std::unordered_set& + std::shared_ptr> get_neighbors(int64_t from_tag_id) const; void @@ -73,9 +71,11 @@ class ConjugateGraph { private: uint32_t memory_usage_; - std::unordered_map> conjugate_graph_; + UnorderedMap>> conjugate_graph_; SerializationFooter footer_; + + Allocator* allocator_; }; } // namespace vsag diff --git a/src/impl/conjugate_graph_test.cpp b/src/impl/conjugate_graph_test.cpp index 69aba141..4f9bf5bc 100644 --- a/src/impl/conjugate_graph_test.cpp +++ b/src/impl/conjugate_graph_test.cpp @@ -21,11 +21,13 @@ #include #include "fixtures.h" +#include "safe_allocator.h" #include "stream_reader.h" TEST_CASE("build, add and memory usage", "[ut][conjugate_graph]") { + auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator(); std::shared_ptr conjugate_graph = - std::make_shared(); + std::make_shared(allocator.get()); REQUIRE(conjugate_graph->GetMemoryUsage() == 4 + vsag::FOOTER_SIZE); REQUIRE(conjugate_graph->AddNeighbor(0, 0) == false); REQUIRE(conjugate_graph->GetMemoryUsage() == 4 + vsag::FOOTER_SIZE); @@ -44,8 +46,9 @@ TEST_CASE("build, add and memory usage", "[ut][conjugate_graph]") { } TEST_CASE("serialize and deserialize with binary", "[ut][conjugate_graph]") { + auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator(); std::shared_ptr conjugate_graph = - std::make_shared(); + std::make_shared(allocator.get()); conjugate_graph->AddNeighbor(0, 2); conjugate_graph->AddNeighbor(0, 1); @@ -145,8 +148,9 @@ TEST_CASE("serialize and deserialize with binary", "[ut][conjugate_graph]") { } TEST_CASE("serialize and deserialize with stream", "[ut][conjugate_graph]") { + auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator(); std::shared_ptr conjugate_graph = - std::make_shared(); + std::make_shared(allocator.get()); conjugate_graph->AddNeighbor(0, 2); conjugate_graph->AddNeighbor(0, 1); @@ -282,8 +286,9 @@ TEST_CASE("serialize and deserialize with stream", "[ut][conjugate_graph]") { } TEST_CASE("update id", "[ut][conjugate_graph]") { + auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator(); std::shared_ptr conjugate_graph = - std::make_shared(); + std::make_shared(allocator.get()); REQUIRE(conjugate_graph->AddNeighbor(0, 1) == true); REQUIRE(conjugate_graph->AddNeighbor(0, 2) == true); diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 35d789d2..4b1ae9b6 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -61,12 +61,12 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para throw std::runtime_error(MESSAGE_PARAMETER); } + allocator_ = index_common_param.allocator_; + if (hnsw_params.use_conjugate_graph) { - conjugate_graph_ = std::make_shared(); + conjugate_graph_ = std::make_shared(allocator_.get()); } - allocator_ = index_common_param.allocator_; - if (!use_static_) { alg_hnsw_ = std::make_shared(space_.get(), diff --git a/src/index/hnsw.h b/src/index/hnsw.h index f8e4a474..831416af 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -51,6 +51,9 @@ class HNSW : public Index { virtual ~HNSW() { alg_hnsw_ = nullptr; + if (use_conjugate_graph_) { + conjugate_graph_.reset(); + } allocator_.reset(); }