Skip to content

Commit

Permalink
add reorder for hgraph (#257)
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 authored Dec 27, 2024
1 parent 74fea2e commit 46b3887
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 24 deletions.
1 change: 1 addition & 0 deletions include/vsag/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ extern const char* const HGRAPH_GRAPH_MAX_DEGREE;
extern const char* const HGRAPH_BUILD_EF_CONSTRUCTION;
extern const char* const HGRAPH_INIT_CAPACITY;
extern const char* const HGRAPH_BUILD_THREAD_COUNT;
extern const char* const HGRAPH_PRECISE_QUANTIZATION_TYPE;

} // namespace vsag
53 changes: 47 additions & 6 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ HGraph::KnnSearch(const DatasetPtr& query,
this->basic_flatten_codes_,
search_param);

if (use_reorder_) {
this->reorder(query->GetFloat32Vectors(), this->high_precise_codes_, search_result, k);
}

while (search_result.size() > k) {
search_result.pop();
}
Expand Down Expand Up @@ -467,6 +471,11 @@ HGraph::RangeSearch(const DatasetPtr& query,
this->bottom_graph_,
this->basic_flatten_codes_,
search_param);
if (use_reorder_) {
this->reorder(
query->GetFloat32Vectors(), this->high_precise_codes_, search_result, limited_size);
}

if (limited_size > 0) {
while (search_result.size() > limited_size) {
search_result.pop();
Expand Down Expand Up @@ -792,27 +801,30 @@ HGraph::add_one_point(const float* data, int level, InnerIdType inner_id) {
};

std::lock_guard cur_lock(this->neighbors_mutex_[inner_id]);

auto flatten_codes = basic_flatten_codes_;
if (use_reorder_) {
flatten_codes = high_precise_codes_;
}
for (auto j = max_level_ - 1; j > level; --j) {
result = search_one_graph(data, route_graphs_[j], basic_flatten_codes_, param);
result = search_one_graph(data, route_graphs_[j], flatten_codes, param);
param.ep_ = result.top().second;
}

param.ef_ = this->ef_construct_;
for (auto j = level; j >= 0; --j) {
if (route_graphs_[j]->TotalCount() != 0) {
result = search_one_graph(data, route_graphs_[j], basic_flatten_codes_, param);
result = search_one_graph(data, route_graphs_[j], flatten_codes, param);
param.ep_ = this->mutually_connect_new_element(
inner_id, result, route_graphs_[j], basic_flatten_codes_, false);
inner_id, result, route_graphs_[j], flatten_codes, false);
} else {
route_graphs_[j]->InsertNeighborsById(inner_id, Vector<InnerIdType>(allocator_));
}
route_graphs_[j]->IncreaseTotalCount(1);
}
if (bottom_graph_->TotalCount() != 0) {
result = search_one_graph(data, this->bottom_graph_, basic_flatten_codes_, param);
result = search_one_graph(data, this->bottom_graph_, flatten_codes, param);
this->mutually_connect_new_element(
inner_id, result, this->bottom_graph_, basic_flatten_codes_, false);
inner_id, result, this->bottom_graph_, flatten_codes, false);
} else {
bottom_graph_->InsertNeighborsById(inner_id, Vector<InnerIdType>(allocator_));
}
Expand Down Expand Up @@ -917,4 +929,33 @@ HGraph::split_dataset_by_duplicate_label(const DatasetPtr& dataset,
return return_datasets;
}

void
HGraph::reorder(const float* query,
const FlattenInterfacePtr& flatten_interface,
MaxHeap& candidate_heap,
int64_t k) const {
uint64_t size = candidate_heap.size();
if (k <= 0) {
k = static_cast<int64_t>(size);
}
Vector<InnerIdType> ids(size, allocator_);
Vector<float> dists(size, allocator_);
uint64_t idx = 0;
while (not candidate_heap.empty()) {
ids[idx] = candidate_heap.top().second;
++idx;
candidate_heap.pop();
}
auto computer = flatten_interface->FactoryComputer(query);
flatten_interface->Query(dists.data(), computer, ids.data(), size);
for (uint64_t i = 0; i < size; ++i) {
if (candidate_heap.size() < k or dists[i] <= candidate_heap.top().first) {
candidate_heap.emplace(dists[i], ids[i]);
}
if (candidate_heap.size() > k) {
candidate_heap.pop();
}
}
}

} // namespace vsag
6 changes: 6 additions & 0 deletions src/algorithm/hgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ class HGraph {
split_dataset_by_duplicate_label(const DatasetPtr& dataset,
std::vector<LabelType>& failed_ids) const;

void
reorder(const float* query,
const FlattenInterfacePtr& flatten_interface,
MaxHeap& candidate_heap,
int64_t k) const;

private:
FlattenInterfacePtr basic_flatten_codes_{nullptr};
FlattenInterfacePtr high_precise_codes_{nullptr};
Expand Down
1 change: 1 addition & 0 deletions src/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@ const char* const HGRAPH_GRAPH_MAX_DEGREE = "max_degree";
const char* const HGRAPH_BUILD_EF_CONSTRUCTION = "ef_construction";
const char* const HGRAPH_INIT_CAPACITY = "hgraph_init_capacity";
const char* const HGRAPH_BUILD_THREAD_COUNT = "build_thread_count";
const char* const HGRAPH_PRECISE_QUANTIZATION_TYPE = "precise_quantization_type";

}; // namespace vsag
7 changes: 4 additions & 3 deletions src/index/hgraph_zparameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace vsag {
static const std::unordered_map<std::string, std::vector<std::string>> EXTERNAL_MAPPING = {
{HGRAPH_USE_REORDER, {HGRAPH_USE_REORDER_KEY}},
{HGRAPH_BASE_QUANTIZATION_TYPE, {HGRAPH_BASE_CODES_KEY, QUANTIZATION_TYPE_KEY}},
{HGRAPH_PRECISE_QUANTIZATION_TYPE, {HGRAPH_PRECISE_CODES_KEY, QUANTIZATION_TYPE_KEY}},
{HGRAPH_GRAPH_MAX_DEGREE, {HGRAPH_GRAPH_KEY, GRAPH_PARAMS_KEY, GRAPH_PARAM_MAX_DEGREE}},
{HGRAPH_BUILD_EF_CONSTRUCTION, {BUILD_PARAMS_KEY, BUILD_EF_CONSTRUCTION}},
{HGRAPH_INIT_CAPACITY, {HGRAPH_GRAPH_KEY, GRAPH_PARAMS_KEY, GRAPH_PARAM_INIT_MAX_CAPACITY}},
Expand Down Expand Up @@ -59,12 +60,12 @@ static const std::string HGRAPH_PARAMS_TEMPLATE =
"nbits": 8
}
},
"precise_codes": {
"{IO_TYPE_KEY}": "aio_ssd",
"{HGRAPH_PRECISE_CODES_KEY}": {
"{IO_TYPE_KEY}": "{IO_TYPE_VALUE_BLOCK_MEMORY_IO}",
"{IO_PARAMS_KEY}": {},
"codes_type": "flatten_codes",
"codes_param": {},
"{QUANTIZATION_TYPE_KEY}": "{QUANTIZATION_TYPE_VALUE_SQ8}",
"{QUANTIZATION_TYPE_KEY}": "{QUANTIZATION_TYPE_VALUE_FP32}",
"{QUANTIZATION_PARAMS_KEY}": {}
},
"{BUILD_PARAMS_KEY}": {
Expand Down
13 changes: 13 additions & 0 deletions tests/fixtures/fixtures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,17 @@ GetFileSize(const std::string& filename) {
return static_cast<uint64_t>(file.tellg());
}

std::vector<std::string>
SplitString(const std::string& s, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::stringstream ss(s);

while (std::getline(ss, token, delimiter)) {
tokens.emplace_back(token);
}

return tokens;
}

} // namespace fixtures
3 changes: 3 additions & 0 deletions tests/fixtures/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,7 @@ generate_one_dataset(int64_t dim, uint64_t count);

uint64_t
GetFileSize(const std::string& filename);

std::vector<std::string>
SplitString(const std::string& s, char delimiter);
} // Namespace fixtures
61 changes: 46 additions & 15 deletions tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class HgraphTestIndex : public fixtures::TestIndex {
static std::string
GenerateHGraphBuildParametersString(const std::string& metric_type,
int64_t dim,
const std::string& base_quantization_type = "sq8",
const std::string& quantization_str = "sq8",
const int thread_count = 5);
static TestDatasetPool pool;

Expand All @@ -44,6 +44,9 @@ class HgraphTestIndex : public fixtures::TestIndex {
"ef_search": {}
}}
}})";

const std::vector<std::pair<std::string, float>> test_cases = {
{"sq8_uniform,fp32", 0.98}, {"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};
};

TestDatasetPool HgraphTestIndex::pool{};
Expand All @@ -52,9 +55,27 @@ std::vector<int> HgraphTestIndex::dims = fixtures::get_common_used_dims(2, Rando
std::string
HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_type,
int64_t dim,
const std::string& base_quantization_type,
const std::string& quantization_str,
const int thread_count) {
constexpr auto parameter_temp = R"(
std::string build_parameters_str;

constexpr auto parameter_temp_reorder = R"(
{{
"dtype": "float32",
"metric_type": "{}",
"dim": {},
"index_param": {{
"use_reorder": {},
"base_quantization_type": "{}",
"max_degree": 96,
"ef_construction": 500,
"build_thread_count": {},
"precise_quantization_type": "{}"
}}
}}
)";

constexpr auto parameter_temp_origin = R"(
{{
"dtype": "float32",
"metric_type": "{}",
Expand All @@ -67,8 +88,23 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t
}}
}}
)";
std::string build_parameters_str =
fmt::format(parameter_temp, metric_type, dim, base_quantization_type, thread_count);

auto strs = fixtures::SplitString(quantization_str, ',');
std::string high_quantizer_str;
auto& base_quantizer_str = strs[0];
if (strs.size() > 1) {
high_quantizer_str = strs[1];
build_parameters_str = fmt::format(parameter_temp_reorder,
metric_type,
dim,
true, /* reorder */
base_quantizer_str,
thread_count,
high_quantizer_str);
} else {
build_parameters_str =
fmt::format(parameter_temp_origin, metric_type, dim, base_quantizer_str, thread_count);
}
return build_parameters_str;
}
} // namespace fixtures
Expand Down Expand Up @@ -189,8 +225,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex,
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};

const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
Expand Down Expand Up @@ -225,8 +260,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Build", "[ft][hg
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};

const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
Expand Down Expand Up @@ -261,8 +295,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgra
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};

const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
Expand Down Expand Up @@ -297,8 +330,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Concurrent Add",
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};

const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
Expand Down Expand Up @@ -382,8 +414,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build"
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.96}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};

const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
Expand Down

0 comments on commit 46b3887

Please sign in to comment.