Skip to content

Commit

Permalink
support allocator in conjugate graph
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou committed Jan 21, 2025
1 parent 892cbe0 commit 6f49bac
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 32 deletions.
60 changes: 40 additions & 20 deletions src/impl/conjugate_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

namespace vsag {

ConjugateGraph::ConjugateGraph() {
ConjugateGraph::ConjugateGraph(Allocator* allocator)
: allocator_(allocator), conjugate_graph_(allocator) {
clear();
}

Expand All @@ -26,31 +27,39 @@ ConjugateGraph::AddNeighbor(int64_t from_tag_id, int64_t to_tag_id) {
if (from_tag_id == to_tag_id) {
return false;
}
auto& neighbor_set = conjugate_graph_[from_tag_id];
if (neighbor_set.size() >= MAXIMUM_DEGREE) {

std::shared_ptr<UnorderedSet<int64_t>> neighbor_set;
auto search_key = conjugate_graph_.find(from_tag_id);
if (search_key == conjugate_graph_.end()) {
neighbor_set = std::make_shared<UnorderedSet<int64_t>>(allocator_);
conjugate_graph_.emplace(from_tag_id, neighbor_set);
} else {
neighbor_set = search_key->second;
}

if (neighbor_set->size() >= MAXIMUM_DEGREE) {
return false;
}
auto insert_result = neighbor_set.insert(to_tag_id);
auto insert_result = neighbor_set->insert(to_tag_id);
if (!insert_result.second) {
return false;
} else {
if (neighbor_set.size() == 1) {
if (neighbor_set->size() == 1) {
memory_usage_ += sizeof(from_tag_id);
memory_usage_ += sizeof(neighbor_set.size());
memory_usage_ += sizeof(neighbor_set->size());
}
memory_usage_ += sizeof(to_tag_id);
return true;
}
}

const std::unordered_set<int64_t>&
std::shared_ptr<UnorderedSet<int64_t>>
ConjugateGraph::get_neighbors(int64_t from_tag_id) const {
static const std::unordered_set<int64_t> empty_set;
auto iter = conjugate_graph_.find(from_tag_id);
if (iter != conjugate_graph_.end()) {
return iter->second;
} else {
return empty_set;
return nullptr;
}
}

Expand All @@ -64,8 +73,8 @@ ConjugateGraph::EnhanceResult(std::priority_queue<std::pair<float, LabelType>>&
int64_t k = results.size();
int64_t look_at_k = std::min(LOOK_AT_K, k);
std::priority_queue<std::pair<float, LabelType>> old_results(results);
std::vector<int64_t> to_be_visited(look_at_k);
std::unordered_set<int64_t> visited_set;
Vector<int64_t> to_be_visited(look_at_k, allocator_);
UnorderedSet<int64_t> visited_set(allocator_);
uint32_t successfully_enhanced = 0;
float distance = 0;

Expand All @@ -80,9 +89,13 @@ ConjugateGraph::EnhanceResult(std::priority_queue<std::pair<float, LabelType>>&

// add neighbors in conjugate graph to enhance result
for (int j = 0; j < look_at_k; j++) {
const std::unordered_set<int64_t>& neighbors = get_neighbors(to_be_visited[j]);
auto neighbors = get_neighbors(to_be_visited[j]);

for (auto neighbor_tag_id : neighbors) {
if (neighbors == nullptr) {
continue;
}

for (auto neighbor_tag_id : *neighbors) {
if (not visited_set.insert(neighbor_tag_id).second) {
continue;
}
Expand Down Expand Up @@ -114,6 +127,7 @@ read_var_from_stream(StreamReader& in_stream, uint32_t* offset, T* dest) {
void
ConjugateGraph::clear() {
memory_usage_ = sizeof(memory_usage_) + FOOTER_SIZE;
conjugate_graph_.clear();
footer_.Clear();
}

Expand Down Expand Up @@ -141,7 +155,7 @@ ConjugateGraph::Serialize(std::ostream& out_stream) const {
out_stream.write((char*)&memory_usage_, sizeof(memory_usage_));

for (auto item : conjugate_graph_) {
auto neighbor_set = item.second;
auto neighbor_set = *item.second;
size_t neighbor_set_size = neighbor_set.size();

out_stream.write((char*)&item.first, sizeof(item.first));
Expand Down Expand Up @@ -169,7 +183,8 @@ ConjugateGraph::Deserialize(const Binary& binary) {

int64_t cursor = 0;
ReadFuncStreamReader reader(func, cursor);
return this->Deserialize(reader);
BufferStreamReader buffer_reader(&reader, binary.size, allocator_);
return this->Deserialize(buffer_reader);
}

tl::expected<void, Error>
Expand Down Expand Up @@ -198,10 +213,15 @@ ConjugateGraph::Deserialize(StreamReader& in_stream) {
in_stream.Seek(cur_pos + offset);
while (offset != memory_usage_ - FOOTER_SIZE) {
read_var_from_stream(in_stream, &offset, &from_tag_id);
if (not conjugate_graph_.count(from_tag_id)) {
conjugate_graph_.emplace(from_tag_id,
std::make_shared<UnorderedSet<int64_t>>(allocator_));
}

read_var_from_stream(in_stream, &offset, &neighbor_size);
for (int i = 0; i < neighbor_size; i++) {
read_var_from_stream(in_stream, &offset, &to_tag_id);
conjugate_graph_[from_tag_id].insert(to_tag_id);
conjugate_graph_[from_tag_id]->insert(to_tag_id);
}
}

Expand Down Expand Up @@ -240,10 +260,10 @@ ConjugateGraph::UpdateId(int64_t old_tag_id, int64_t new_tag_id) {

// 2. update neighbors
for (auto& [key, neighbors] : conjugate_graph_) {
auto it_old_neighbor = neighbors.find(old_tag_id);
if (it_old_neighbor != neighbors.end()) {
neighbors.erase(it_old_neighbor);
neighbors.insert(new_tag_id);
auto it_old_neighbor = neighbors->find(old_tag_id);
if (it_old_neighbor != neighbors->end()) {
neighbors->erase(it_old_neighbor);
neighbors->insert(new_tag_id);
updated = true;
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/impl/conjugate_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

#include <nlohmann/json.hpp>
#include <queue>
#include <unordered_map>
#include <unordered_set>

#include "../footer.h"
#include "../logger.h"
Expand All @@ -32,7 +30,7 @@ static const int64_t MAXIMUM_DEGREE = 128;

class ConjugateGraph {
public:
ConjugateGraph();
ConjugateGraph(Allocator* allocator);

tl::expected<bool, Error>
AddNeighbor(int64_t from_tag_id, int64_t to_tag_id);
Expand Down Expand Up @@ -61,7 +59,7 @@ class ConjugateGraph {
GetMemoryUsage() const;

private:
const std::unordered_set<int64_t>&
std::shared_ptr<UnorderedSet<int64_t>>
get_neighbors(int64_t from_tag_id) const;

void
Expand All @@ -73,9 +71,11 @@ class ConjugateGraph {
private:
uint32_t memory_usage_;

std::unordered_map<int64_t, std::unordered_set<int64_t>> conjugate_graph_;
UnorderedMap<int64_t, std::shared_ptr<UnorderedSet<int64_t>>> conjugate_graph_;

SerializationFooter footer_;

Allocator* allocator_;
};

} // namespace vsag
13 changes: 9 additions & 4 deletions src/impl/conjugate_graph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
#include <nlohmann/json.hpp>

#include "fixtures.h"
#include "safe_allocator.h"
#include "stream_reader.h"

TEST_CASE("build, add and memory usage", "[ut][conjugate_graph]") {
auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator();
std::shared_ptr<vsag::ConjugateGraph> conjugate_graph =
std::make_shared<vsag::ConjugateGraph>();
std::make_shared<vsag::ConjugateGraph>(allocator.get());
REQUIRE(conjugate_graph->GetMemoryUsage() == 4 + vsag::FOOTER_SIZE);
REQUIRE(conjugate_graph->AddNeighbor(0, 0) == false);
REQUIRE(conjugate_graph->GetMemoryUsage() == 4 + vsag::FOOTER_SIZE);
Expand All @@ -44,8 +46,9 @@ TEST_CASE("build, add and memory usage", "[ut][conjugate_graph]") {
}

TEST_CASE("serialize and deserialize with binary", "[ut][conjugate_graph]") {
auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator();
std::shared_ptr<vsag::ConjugateGraph> conjugate_graph =
std::make_shared<vsag::ConjugateGraph>();
std::make_shared<vsag::ConjugateGraph>(allocator.get());

conjugate_graph->AddNeighbor(0, 2);
conjugate_graph->AddNeighbor(0, 1);
Expand Down Expand Up @@ -145,8 +148,9 @@ TEST_CASE("serialize and deserialize with binary", "[ut][conjugate_graph]") {
}

TEST_CASE("serialize and deserialize with stream", "[ut][conjugate_graph]") {
auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator();
std::shared_ptr<vsag::ConjugateGraph> conjugate_graph =
std::make_shared<vsag::ConjugateGraph>();
std::make_shared<vsag::ConjugateGraph>(allocator.get());

conjugate_graph->AddNeighbor(0, 2);
conjugate_graph->AddNeighbor(0, 1);
Expand Down Expand Up @@ -282,8 +286,9 @@ TEST_CASE("serialize and deserialize with stream", "[ut][conjugate_graph]") {
}

TEST_CASE("update id", "[ut][conjugate_graph]") {
auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator();
std::shared_ptr<vsag::ConjugateGraph> conjugate_graph =
std::make_shared<vsag::ConjugateGraph>();
std::make_shared<vsag::ConjugateGraph>(allocator.get());

REQUIRE(conjugate_graph->AddNeighbor(0, 1) == true);
REQUIRE(conjugate_graph->AddNeighbor(0, 2) == true);
Expand Down
6 changes: 3 additions & 3 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
throw std::runtime_error(MESSAGE_PARAMETER);
}

allocator_ = index_common_param.allocator_;

if (hnsw_params.use_conjugate_graph) {
conjugate_graph_ = std::make_shared<ConjugateGraph>();
conjugate_graph_ = std::make_shared<ConjugateGraph>(allocator_.get());
}

allocator_ = index_common_param.allocator_;

if (!use_static_) {
alg_hnsw_ =
std::make_shared<hnswlib::HierarchicalNSW>(space_.get(),
Expand Down
3 changes: 3 additions & 0 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class HNSW : public Index {

virtual ~HNSW() {
alg_hnsw_ = nullptr;
if (use_conjugate_graph_) {
conjugate_graph_.reset();
}
allocator_.reset();
}

Expand Down

0 comments on commit 6f49bac

Please sign in to comment.