From 0165bcf16799a7c5ba1d4d1d63c6657baaea80b7 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Mon, 6 Jan 2025 16:19:59 +0800 Subject: [PATCH 01/14] Add a batch interface for getDistanceByLabel --- include/vsag/index.h | 17 +++++++++++++++ src/algorithm/hnswlib/algorithm_interface.h | 6 ++++++ src/algorithm/hnswlib/hnswalg.cpp | 22 ++++++++++++++++++++ src/algorithm/hnswlib/hnswalg.h | 6 ++++++ src/algorithm/hnswlib/hnswalg_static.h | 23 +++++++++++++++++++++ src/index/hnsw.h | 8 +++++++ 6 files changed, 82 insertions(+) diff --git a/include/vsag/index.h b/include/vsag/index.h index b9068cc1..019df8f1 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -249,6 +249,23 @@ class Index { throw std::runtime_error("Index doesn't support get distance by id"); }; + /** + * @brief Calculate the distance between the query and the vector of the given ID for batch. + * + * @param count is the count of vids + * @param vids is the unique identifier of the vector to be calculated in the index. + * @param vector is the embedding of query + * @param distances is the distances between the query and the vector of the given ID + * @return result is valid distance of input vids. + */ + virtual tl::expected + CalcBatchDistanceById(int64_t count, + int64_t *vids, + const float* vector, + float *&distances) const { + throw std::runtime_error("Index doesn't support get distance by id"); + }; + /** * @brief Checks if the specified feature is supported by the index. * diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 75860146..f5751e07 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -69,6 +69,12 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; + virtual int64_t + getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) = 0; + virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index e473fa51..7bb24f08 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,6 +171,28 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } +int64_t +HierarchicalNSW::getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) { + std::shared_lock lock_table(label_lookup_lock_); + int64_t ret_cnt = 0; + distances = (float *)allocator_->Allocate(sizeof(float) * count); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + ret_cnt++; + } + } + return ret_cnt; +} + bool HierarchicalNSW::isValidLabel(LabelType label) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index f880e746..a5c2eb9b 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -146,6 +146,12 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; + int64_t + getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) override; + bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 4d37eb38..11ee2f5d 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -81,6 +81,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { void* dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ + mutable std::shared_mutex shared_label_lookup_lock; std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -262,6 +263,28 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } + int64_t + getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) override { + std::shared_lock lock_table(shared_label_lookup_lock); + int64_t ret_cnt = 0; + distances = (float *)allocator_->Allocate(sizeof(float) * count); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vid[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + ret_cnt++; + } + } + return ret_cnt; + } + bool isValidLabel(LabelType label) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 9a06e245..2b58f9a5 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -145,6 +145,14 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; + virtual tl::expected + CalcBatchDistanceById(int64_t count, + int64_t *vids, + const float* vector, + float *&distances) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vid, vector, distances)); + }; + [[nodiscard]] bool CheckFeature(IndexFeature feature) const override; From ff71e3461f61ee65e325ece1b851e37255352f88 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Mon, 6 Jan 2025 17:14:48 +0800 Subject: [PATCH 02/14] add vids and const --- include/vsag/index.h | 2 +- src/algorithm/hnswlib/algorithm_interface.h | 4 ++-- src/algorithm/hnswlib/hnswalg.cpp | 2 +- src/algorithm/hnswlib/hnswalg.h | 2 +- src/algorithm/hnswlib/hnswalg_static.h | 4 ++-- src/index/hnsw.h | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index 019df8f1..0c6502d4 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -260,7 +260,7 @@ class Index { */ virtual tl::expected CalcBatchDistanceById(int64_t count, - int64_t *vids, + const int64_t *vids, const float* vector, float *&distances) const { throw std::runtime_error("Index doesn't support get distance by id"); diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index f5751e07..3a67c5ad 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -71,8 +71,8 @@ class AlgorithmInterface { virtual int64_t getBatchDistanceByLabel(int64_t count, - int64_t *vids, - const void* data_point, + const int64_t *vids, + const void *data_point, float *&distances) = 0; virtual const float* diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 7bb24f08..a2a09e55 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -173,7 +173,7 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { int64_t HierarchicalNSW::getBatchDistanceByLabel(int64_t count, - int64_t *vids, + const int64_t *vids, const void* data_point, float *&distances) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index a5c2eb9b..7c533ac4 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -148,7 +148,7 @@ class HierarchicalNSW : public AlgorithmInterface { int64_t getBatchDistanceByLabel(int64_t count, - int64_t *vids, + const int64_t *vids, const void* data_point, float *&distances) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 11ee2f5d..190a8998 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -265,14 +265,14 @@ class StaticHierarchicalNSW : public AlgorithmInterface { int64_t getBatchDistanceByLabel(int64_t count, - int64_t *vids, + const int64_t *vids, const void* data_point, float *&distances) override { std::shared_lock lock_table(shared_label_lookup_lock); int64_t ret_cnt = 0; distances = (float *)allocator_->Allocate(sizeof(float) * count); for (int i = 0; i < count; i++) { - auto search = label_lookup_.find(vid[i]); + auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { distances[i] = -1; } else { diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 2b58f9a5..d3bbecf5 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -147,10 +147,10 @@ class HNSW : public Index { virtual tl::expected CalcBatchDistanceById(int64_t count, - int64_t *vids, + const int64_t *vids, const float* vector, float *&distances) const override { - SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vid, vector, distances)); + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector, distances)); }; [[nodiscard]] bool From d32e60aca477458d4881195baeef03c7b92de567 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Mon, 6 Jan 2025 16:19:59 +0800 Subject: [PATCH 03/14] Add a batch interface for getDistanceByLabel --- include/vsag/index.h | 17 +++++++++++++++ src/algorithm/hnswlib/algorithm_interface.h | 6 ++++++ src/algorithm/hnswlib/hnswalg.cpp | 22 ++++++++++++++++++++ src/algorithm/hnswlib/hnswalg.h | 6 ++++++ src/algorithm/hnswlib/hnswalg_static.h | 23 +++++++++++++++++++++ src/index/hnsw.h | 8 +++++++ 6 files changed, 82 insertions(+) diff --git a/include/vsag/index.h b/include/vsag/index.h index b9068cc1..0c6502d4 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -249,6 +249,23 @@ class Index { throw std::runtime_error("Index doesn't support get distance by id"); }; + /** + * @brief Calculate the distance between the query and the vector of the given ID for batch. + * + * @param count is the count of vids + * @param vids is the unique identifier of the vector to be calculated in the index. + * @param vector is the embedding of query + * @param distances is the distances between the query and the vector of the given ID + * @return result is valid distance of input vids. + */ + virtual tl::expected + CalcBatchDistanceById(int64_t count, + const int64_t *vids, + const float* vector, + float *&distances) const { + throw std::runtime_error("Index doesn't support get distance by id"); + }; + /** * @brief Checks if the specified feature is supported by the index. * diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 75860146..3a67c5ad 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -69,6 +69,12 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; + virtual int64_t + getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void *data_point, + float *&distances) = 0; + virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index e473fa51..a2a09e55 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,6 +171,28 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } +int64_t +HierarchicalNSW::getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void* data_point, + float *&distances) { + std::shared_lock lock_table(label_lookup_lock_); + int64_t ret_cnt = 0; + distances = (float *)allocator_->Allocate(sizeof(float) * count); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + ret_cnt++; + } + } + return ret_cnt; +} + bool HierarchicalNSW::isValidLabel(LabelType label) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index f880e746..7c533ac4 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -146,6 +146,12 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; + int64_t + getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void* data_point, + float *&distances) override; + bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 4d37eb38..190a8998 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -81,6 +81,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { void* dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ + mutable std::shared_mutex shared_label_lookup_lock; std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -262,6 +263,28 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } + int64_t + getBatchDistanceByLabel(int64_t count, + const int64_t *vids, + const void* data_point, + float *&distances) override { + std::shared_lock lock_table(shared_label_lookup_lock); + int64_t ret_cnt = 0; + distances = (float *)allocator_->Allocate(sizeof(float) * count); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + ret_cnt++; + } + } + return ret_cnt; + } + bool isValidLabel(LabelType label) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 9a06e245..d3bbecf5 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -145,6 +145,14 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; + virtual tl::expected + CalcBatchDistanceById(int64_t count, + const int64_t *vids, + const float* vector, + float *&distances) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector, distances)); + }; + [[nodiscard]] bool CheckFeature(IndexFeature feature) const override; From b720efc0555dd40d49ff4580fb0825612276502b Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Fri, 10 Jan 2025 17:53:11 +0800 Subject: [PATCH 04/14] add Dataset and Modify review comments --- include/vsag/index.h | 7 +++---- src/algorithm/hnswlib/algorithm_interface.h | 8 +++++--- src/algorithm/hnswlib/hnswalg.cpp | 16 +++++++++------- src/algorithm/hnswlib/hnswalg.h | 6 +++--- src/algorithm/hnswlib/hnswalg_static.h | 19 ++++++++++--------- src/index/hnsw.h | 7 +++---- 6 files changed, 33 insertions(+), 30 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index 0c6502d4..039815dd 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -256,13 +256,12 @@ class Index { * @param vids is the unique identifier of the vector to be calculated in the index. * @param vector is the embedding of query * @param distances is the distances between the query and the vector of the given ID - * @return result is valid distance of input vids. + * @return result is valid distance of input vids. '-1' indicates an invalid distance. */ - virtual tl::expected + virtual tl::expected CalcBatchDistanceById(int64_t count, const int64_t *vids, - const float* vector, - float *&distances) const { + const float* vector) const { throw std::runtime_error("Index doesn't support get distance by id"); }; diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 3a67c5ad..dbe9ff83 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -24,6 +24,9 @@ #include "space_interface.h" #include "stream_reader.h" #include "typing.h" +#include "vsag/dataset.h" +#include "vsag/expected.hpp" +#include "vsag/errors.h" namespace hnswlib { @@ -69,11 +72,10 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; - virtual int64_t + virtual tl::expected getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void *data_point, - float *&distances) = 0; + const void *data_point) = 0; virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index a2a09e55..826ec2f3 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,14 +171,14 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } -int64_t +tl::expected HierarchicalNSW::getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void* data_point, - float *&distances) { + const void* data_point) { std::shared_lock lock_table(label_lookup_lock_); - int64_t ret_cnt = 0; - distances = (float *)allocator_->Allocate(sizeof(float) * count); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -187,10 +187,12 @@ HierarchicalNSW::getBatchDistanceByLabel(int64_t count, InnerIdType internal_id = search->second; float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); distances[i] = dist; - ret_cnt++; + valid_cnt++; } } - return ret_cnt; + result->NumElements(valid_cnt)->Owner(true, allocator_); + result->Distances(distances); + return std::move(result); } bool diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 7c533ac4..7412d270 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -38,6 +38,7 @@ #include "algorithm_interface.h" #include "block_manager.h" #include "visited_list_pool.h" +#include "vsag/dataset.h" namespace hnswlib { using InnerIdType = vsag::InnerIdType; using linklistsizeint = unsigned int; @@ -146,11 +147,10 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; - int64_t + tl::expected getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void* data_point, - float *&distances) override; + const void* data_point) override; bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 190a8998..a0172b27 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -81,7 +81,6 @@ class StaticHierarchicalNSW : public AlgorithmInterface { void* dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ - mutable std::shared_mutex shared_label_lookup_lock; std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -263,14 +262,14 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } - int64_t + tl::expected getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void* data_point, - float *&distances) override { - std::shared_lock lock_table(shared_label_lookup_lock); - int64_t ret_cnt = 0; - distances = (float *)allocator_->Allocate(sizeof(float) * count); + const void* data_point) override { + std::unique_lock lock_table(label_lookup_lock); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -279,10 +278,12 @@ class StaticHierarchicalNSW : public AlgorithmInterface { InnerIdType internal_id = search->second; float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); distances[i] = dist; - ret_cnt++; + valid_cnt++; } } - return ret_cnt; + result->NumElements(valid_cnt)->Owner(true, allocator_); + result->Distances(distances); + return std::move(result); } bool diff --git a/src/index/hnsw.h b/src/index/hnsw.h index d3bbecf5..1bb79766 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -145,12 +145,11 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; - virtual tl::expected + virtual tl::expected CalcBatchDistanceById(int64_t count, const int64_t *vids, - const float* vector, - float *&distances) const override { - SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector, distances)); + const float* vector) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector)); }; [[nodiscard]] bool From 47a2fea777e230904dc972c85bb99b7780ae1052 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Wed, 15 Jan 2025 11:54:27 +0800 Subject: [PATCH 05/14] Modify the potential memory leak risk of getBatchDistanceByLabel Signed-off-by: zourunxin.zrx --- src/algorithm/hnswlib/hnswalg.cpp | 5 +++-- src/algorithm/hnswlib/hnswalg_static.h | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 826ec2f3..0f10d49b 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -178,7 +178,9 @@ HierarchicalNSW::getBatchDistanceByLabel(int64_t count, std::shared_lock lock_table(label_lookup_lock_); int64_t valid_cnt = 0; auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -190,8 +192,7 @@ HierarchicalNSW::getBatchDistanceByLabel(int64_t count, valid_cnt++; } } - result->NumElements(valid_cnt)->Owner(true, allocator_); - result->Distances(distances); + result->NumElements(valid_cnt); return std::move(result); } diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index a0172b27..d50bad82 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -269,7 +269,9 @@ class StaticHierarchicalNSW : public AlgorithmInterface { std::unique_lock lock_table(label_lookup_lock); int64_t valid_cnt = 0; auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -281,8 +283,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { valid_cnt++; } } - result->NumElements(valid_cnt)->Owner(true, allocator_); - result->Distances(distances); + result->NumElements(valid_cnt); return std::move(result); } From 01d8d31178cc115573d9b3791c2582ed3d31491d Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Mon, 6 Jan 2025 16:19:59 +0800 Subject: [PATCH 06/14] Add a batch interface for getDistanceByLabel --- include/vsag/index.h | 17 +++++++++++++++ src/algorithm/hnswlib/algorithm_interface.h | 6 ++++++ src/algorithm/hnswlib/hnswalg.cpp | 22 ++++++++++++++++++++ src/algorithm/hnswlib/hnswalg.h | 6 ++++++ src/algorithm/hnswlib/hnswalg_static.h | 23 +++++++++++++++++++++ src/index/hnsw.h | 8 +++++++ 6 files changed, 82 insertions(+) diff --git a/include/vsag/index.h b/include/vsag/index.h index b9068cc1..019df8f1 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -249,6 +249,23 @@ class Index { throw std::runtime_error("Index doesn't support get distance by id"); }; + /** + * @brief Calculate the distance between the query and the vector of the given ID for batch. + * + * @param count is the count of vids + * @param vids is the unique identifier of the vector to be calculated in the index. + * @param vector is the embedding of query + * @param distances is the distances between the query and the vector of the given ID + * @return result is valid distance of input vids. + */ + virtual tl::expected + CalcBatchDistanceById(int64_t count, + int64_t *vids, + const float* vector, + float *&distances) const { + throw std::runtime_error("Index doesn't support get distance by id"); + }; + /** * @brief Checks if the specified feature is supported by the index. * diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 75860146..f5751e07 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -69,6 +69,12 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; + virtual int64_t + getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) = 0; + virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index e473fa51..7bb24f08 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,6 +171,28 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } +int64_t +HierarchicalNSW::getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) { + std::shared_lock lock_table(label_lookup_lock_); + int64_t ret_cnt = 0; + distances = (float *)allocator_->Allocate(sizeof(float) * count); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vids[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + ret_cnt++; + } + } + return ret_cnt; +} + bool HierarchicalNSW::isValidLabel(LabelType label) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index f880e746..a5c2eb9b 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -146,6 +146,12 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; + int64_t + getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) override; + bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 4d37eb38..11ee2f5d 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -81,6 +81,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { void* dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ + mutable std::shared_mutex shared_label_lookup_lock; std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -262,6 +263,28 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } + int64_t + getBatchDistanceByLabel(int64_t count, + int64_t *vids, + const void* data_point, + float *&distances) override { + std::shared_lock lock_table(shared_label_lookup_lock); + int64_t ret_cnt = 0; + distances = (float *)allocator_->Allocate(sizeof(float) * count); + for (int i = 0; i < count; i++) { + auto search = label_lookup_.find(vid[i]); + if (search == label_lookup_.end()) { + distances[i] = -1; + } else { + InnerIdType internal_id = search->second; + float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); + distances[i] = dist; + ret_cnt++; + } + } + return ret_cnt; + } + bool isValidLabel(LabelType label) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 9a06e245..2b58f9a5 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -145,6 +145,14 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; + virtual tl::expected + CalcBatchDistanceById(int64_t count, + int64_t *vids, + const float* vector, + float *&distances) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vid, vector, distances)); + }; + [[nodiscard]] bool CheckFeature(IndexFeature feature) const override; From c0c2cbb0f3ff2b326ce6ee1cb1d59ea4130f135f Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Mon, 6 Jan 2025 17:14:48 +0800 Subject: [PATCH 07/14] add vids and const --- include/vsag/index.h | 2 +- src/algorithm/hnswlib/algorithm_interface.h | 4 ++-- src/algorithm/hnswlib/hnswalg.cpp | 2 +- src/algorithm/hnswlib/hnswalg.h | 2 +- src/algorithm/hnswlib/hnswalg_static.h | 4 ++-- src/index/hnsw.h | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index 019df8f1..0c6502d4 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -260,7 +260,7 @@ class Index { */ virtual tl::expected CalcBatchDistanceById(int64_t count, - int64_t *vids, + const int64_t *vids, const float* vector, float *&distances) const { throw std::runtime_error("Index doesn't support get distance by id"); diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index f5751e07..3a67c5ad 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -71,8 +71,8 @@ class AlgorithmInterface { virtual int64_t getBatchDistanceByLabel(int64_t count, - int64_t *vids, - const void* data_point, + const int64_t *vids, + const void *data_point, float *&distances) = 0; virtual const float* diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 7bb24f08..a2a09e55 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -173,7 +173,7 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { int64_t HierarchicalNSW::getBatchDistanceByLabel(int64_t count, - int64_t *vids, + const int64_t *vids, const void* data_point, float *&distances) { std::shared_lock lock_table(label_lookup_lock_); diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index a5c2eb9b..7c533ac4 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -148,7 +148,7 @@ class HierarchicalNSW : public AlgorithmInterface { int64_t getBatchDistanceByLabel(int64_t count, - int64_t *vids, + const int64_t *vids, const void* data_point, float *&distances) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 11ee2f5d..190a8998 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -265,14 +265,14 @@ class StaticHierarchicalNSW : public AlgorithmInterface { int64_t getBatchDistanceByLabel(int64_t count, - int64_t *vids, + const int64_t *vids, const void* data_point, float *&distances) override { std::shared_lock lock_table(shared_label_lookup_lock); int64_t ret_cnt = 0; distances = (float *)allocator_->Allocate(sizeof(float) * count); for (int i = 0; i < count; i++) { - auto search = label_lookup_.find(vid[i]); + auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { distances[i] = -1; } else { diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 2b58f9a5..d3bbecf5 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -147,10 +147,10 @@ class HNSW : public Index { virtual tl::expected CalcBatchDistanceById(int64_t count, - int64_t *vids, + const int64_t *vids, const float* vector, float *&distances) const override { - SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vid, vector, distances)); + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector, distances)); }; [[nodiscard]] bool From 5610733270eb28fd6d7c7fa3267056edec8497ec Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Mon, 6 Jan 2025 21:45:43 +0800 Subject: [PATCH 08/14] fix lint check issue on github ci (#297) Signed-off-by: wxy407827 --- .github/workflows/lcov.yml | 7 +++---- .github/workflows/lint.yml | 6 ++++-- CMakeLists.txt | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/lcov.yml b/.github/workflows/lcov.yml index 0d7771a2..b1cbd5ff 100644 --- a/.github/workflows/lcov.yml +++ b/.github/workflows/lcov.yml @@ -14,10 +14,9 @@ jobs: image: vsaglib/vsag:ubuntu steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: '3.10' + - name: Link Python3.10 as Python + run: | + ln -s /usr/bin/python3 /usr/bin/python - name: Install run: | python -m pip install --upgrade pip diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c59e665f..fb05df3a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -10,10 +10,12 @@ jobs: clang-tidy-check: name: Lint runs-on: ubuntu-latest + container: + image: vsaglib/vsag:ubuntu steps: - name: Checkout code uses: actions/checkout@v4 - name: Install clang-tidy - run: sudo apt install clang-tidy-15 -y + run: sudo apt install clang-tidy-15 -y && sudo ln -s /usr/bin/clang-tidy-15 /usr/bin/clang-tidy - name: Run lint - run: make lint + run: make debug && make lint diff --git a/CMakeLists.txt b/CMakeLists.txt index 905dddea..6018296c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,6 +195,7 @@ if (ENABLE_TOOLS AND ENABLE_CXX11_ABI) endif () set (CMAKE_CXX_STANDARD 17) +set (CMAKE_CXX_STANDARD_REQUIRED ON) if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") if (CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 17 OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 17) From 9059ef19e4224fec84c748682637c9ebb2ab9882 Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Tue, 7 Jan 2025 00:39:20 +0800 Subject: [PATCH 09/14] replace logger inside diskann with vsag logger (#245) Signed-off-by: wxy407827 --- extern/diskann/DiskANN/include/logger_impl.h | 7 +- extern/diskann/DiskANN/src/logger.cpp | 21 +++-- extern/diskann/diskann.cmake | 4 +- src/bitset_impl_test.cpp | 8 +- src/index/diskann_test.cpp | 2 +- src/vsag.cpp | 8 ++ tests/CMakeLists.txt | 7 +- tests/fixtures/test_logger.cpp | 24 ++++++ tests/fixtures/test_logger.h | 86 ++++++++++++++++++++ tests/test_main.cpp | 31 +++++++ tests/test_multi_thread.cpp | 35 +++++++- 11 files changed, 213 insertions(+), 20 deletions(-) create mode 100644 tests/fixtures/test_logger.cpp create mode 100644 tests/fixtures/test_logger.h create mode 100644 tests/test_main.cpp diff --git a/extern/diskann/DiskANN/include/logger_impl.h b/extern/diskann/DiskANN/include/logger_impl.h index 03c65e0c..a6858ac6 100644 --- a/extern/diskann/DiskANN/include/logger_impl.h +++ b/extern/diskann/DiskANN/include/logger_impl.h @@ -23,9 +23,9 @@ class ANNStreamBuf : public std::basic_streambuf return true; // because stdout and stderr are always open. } DISKANN_DLLEXPORT void close(); - DISKANN_DLLEXPORT virtual int underflow(); - DISKANN_DLLEXPORT virtual int overflow(int c); - DISKANN_DLLEXPORT virtual int sync(); + DISKANN_DLLEXPORT virtual int underflow() override; + DISKANN_DLLEXPORT virtual int overflow(int c) override; + DISKANN_DLLEXPORT virtual int sync() override; private: FILE *_fp; @@ -33,6 +33,7 @@ class ANNStreamBuf : public std::basic_streambuf int _bufIndex; std::mutex _mutex; LogLevel _logLevel; + std::function g_logger; int flush(); void logImpl(char *str, int numchars); diff --git a/extern/diskann/DiskANN/src/logger.cpp b/extern/diskann/DiskANN/src/logger.cpp index 052f5487..e9769a44 100644 --- a/extern/diskann/DiskANN/src/logger.cpp +++ b/extern/diskann/DiskANN/src/logger.cpp @@ -7,6 +7,12 @@ #include "logger_impl.h" #include "windows_customizations.h" +namespace vsag +{ +extern std::function +vsag_get_logger(); +} // namespace vsag + namespace diskann { @@ -16,13 +22,7 @@ DISKANN_DLLEXPORT ANNStreamBuf cerrBuff(stderr); DISKANN_DLLEXPORT std::basic_ostream cout(&coutBuff); DISKANN_DLLEXPORT std::basic_ostream cerr(&cerrBuff); -std::function g_logger; -void SetCustomLogger(std::function logger) -{ - g_logger = logger; - diskann::cout << "Set Custom Logger" << std::endl; -} ANNStreamBuf::ANNStreamBuf(FILE *fp) { @@ -40,10 +40,14 @@ ANNStreamBuf::ANNStreamBuf(FILE *fp) std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char)); setp(_buf, _buf + BUFFER_SIZE - 1); + + g_logger = vsag::vsag_get_logger(); + g_logger(_logLevel, "diskann switch logger"); } ANNStreamBuf::~ANNStreamBuf() { + g_logger = nullptr; sync(); _fp = nullptr; // we'll not close because we can't. delete[] _buf; @@ -80,8 +84,13 @@ int ANNStreamBuf::flush() pbump(-num); return num; } + void ANNStreamBuf::logImpl(char *str, int num) { + // remove the newline at the end of str, 'cause logger provides + if (num > 0 and str[num - 1] == '\n') { + --num; + } str[num] = '\0'; // Safe. See the c'tor. // Invoke the OLS custom logging function. if (g_logger) diff --git a/extern/diskann/diskann.cmake b/extern/diskann/diskann.cmake index 84a6f4eb..3626a2e6 100644 --- a/extern/diskann/diskann.cmake +++ b/extern/diskann/diskann.cmake @@ -37,9 +37,9 @@ set(DISKANN_SOURCES add_library(diskann STATIC ${DISKANN_SOURCES}) # work if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") - target_compile_options(diskann PRIVATE -mavx -msse2 -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors) + target_compile_options(diskann PRIVATE -mavx -msse2 -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DENABLE_CUSTOM_LOGGER=1) else () - target_compile_options(diskann PRIVATE -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors) + target_compile_options(diskann PRIVATE -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DENABLE_CUSTOM_LOGGER=1) endif () set_property(TARGET diskann PROPERTY CXX_STANDARD 17) add_dependencies(diskann boost openblas) diff --git a/src/bitset_impl_test.cpp b/src/bitset_impl_test.cpp index df448b3d..0e3b4683 100644 --- a/src/bitset_impl_test.cpp +++ b/src/bitset_impl_test.cpp @@ -70,14 +70,14 @@ TEST_CASE("roaringbitmap example", "[ut][bitset]") { r1.setCopyOnWrite(true); uint32_t compact_size = r1.getSizeInBytes(); - std::cout << "size before run optimize " << size << " bytes, and after " << compact_size - << " bytes." << std::endl; + // std::cout << "size before run optimize " << size << " bytes, and after " << compact_size + // << " bytes." << std::endl; // create a new bitmap with varargs Roaring r2 = Roaring::bitmapOf(5, 1, 2, 3, 5, 6); - r2.printf(); - printf("\n"); + // r2.printf(); + // printf("\n"); // create a new bitmap with initializer list Roaring r2i = Roaring::bitmapOfList({1, 2, 3, 5, 6}); diff --git a/src/index/diskann_test.cpp b/src/index/diskann_test.cpp index 1a2c9b79..81dbb207 100644 --- a/src/index/diskann_test.cpp +++ b/src/index/diskann_test.cpp @@ -534,6 +534,6 @@ TEST_CASE("split building process", "[diskann][ut]") { } } float recall_full = correct / 1000; - std::cout << "Recall: " << recall_full << std::endl; + vsag::logger::debug("Recall: " + std::to_string(recall_full)); REQUIRE(recall_full == recall_partial); } diff --git a/src/vsag.cpp b/src/vsag.cpp index 03cf889b..38b401ed 100644 --- a/src/vsag.cpp +++ b/src/vsag.cpp @@ -15,6 +15,7 @@ #include "vsag/vsag.h" +#include <../extern/diskann/DiskANN/include/logger.h> #include #include @@ -23,6 +24,13 @@ #include "simd/simd.h" #include "version.h" +namespace vsag { +std::function +vsag_get_logger() { + return [](diskann::LogLevel, const char* msg) { vsag::logger::debug(msg); }; +} +} // namespace vsag + namespace vsag { std::string diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 72212c38..e6300b61 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,7 +2,9 @@ # unittests file (GLOB_RECURSE UNIT_TESTS "../src/*_test.cpp") add_executable (unittests ${UNIT_TESTS} + test_main.cpp fixtures/fixtures.cpp + fixtures/test_logger.cpp ) if (DIST_CONTAINS_SSE) target_compile_definitions (unittests PRIVATE ENABLE_SSE=1) @@ -17,7 +19,7 @@ if (DIST_CONTAINS_AVX512) target_compile_definitions (unittests PRIVATE ENABLE_AVX512=1) endif () target_include_directories (unittests PRIVATE "./fixtures") -target_link_libraries (unittests PRIVATE Catch2::Catch2WithMain vsag simd) +target_link_libraries (unittests PRIVATE Catch2::Catch2 vsag simd) add_dependencies (unittests spdlog Catch2) # function tests @@ -27,9 +29,10 @@ add_executable (functests fixtures/fixtures.cpp fixtures/test_dataset.cpp fixtures/test_dataset_pool.cpp + fixtures/test_logger.cpp ) target_include_directories (functests PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/spdlog/install/include ${HDF5_INCLUDE_DIRS}) -target_link_libraries (functests PRIVATE Catch2::Catch2WithMain vsag simd) +target_link_libraries (functests PRIVATE Catch2::Catch2 vsag simd) add_dependencies (functests spdlog Catch2) diff --git a/tests/fixtures/test_logger.cpp b/tests/fixtures/test_logger.cpp new file mode 100644 index 00000000..0e2bf8f5 --- /dev/null +++ b/tests/fixtures/test_logger.cpp @@ -0,0 +1,24 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test_logger.h" + +#include + +namespace fixtures { + +TestLogger logger; + +} // namespace fixtures diff --git a/tests/fixtures/test_logger.h b/tests/fixtures/test_logger.h new file mode 100644 index 00000000..78219441 --- /dev/null +++ b/tests/fixtures/test_logger.h @@ -0,0 +1,86 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "vsag/vsag.h" + +namespace fixtures { + +class TestLogger : public vsag::Logger { +public: + inline void + SetLevel(Level log_level) override { + std::lock_guard lock(mutex_); + level_ = log_level - vsag::Logger::Level::kTRACE; + } + + inline void + Trace(const std::string& msg) override { + std::lock_guard lock(mutex_); + if (level_ <= 0) { + UNSCOPED_INFO("[test-logger]::[trace] " + msg); + } + } + + inline void + Debug(const std::string& msg) override { + std::lock_guard lock(mutex_); + if (level_ <= 1) { + UNSCOPED_INFO("[test-logger]::[debug] " + msg); + } + } + + inline void + Info(const std::string& msg) override { + std::lock_guard lock(mutex_); + if (level_ <= 2) { + UNSCOPED_INFO("[test-logger]::[info] " + msg); + } + } + + inline void + Warn(const std::string& msg) override { + std::lock_guard lock(mutex_); + if (level_ <= 3) { + UNSCOPED_INFO("[test-logger]::[warn] " + msg); + } + } + + inline void + Error(const std::string& msg) override { + std::lock_guard lock(mutex_); + if (level_ <= 4) { + UNSCOPED_INFO("[test-logger]::[error] " + msg); + } + } + + void + Critical(const std::string& msg) override { + std::lock_guard lock(mutex_); + if (level_ <= 5) { + UNSCOPED_INFO("[test-logger]::[critical] " + msg); + } + } + +private: + int64_t level_ = 0; + std::mutex mutex_; +}; + +extern TestLogger logger; + +} // namespace fixtures diff --git a/tests/test_main.cpp b/tests/test_main.cpp new file mode 100644 index 00000000..45700c80 --- /dev/null +++ b/tests/test_main.cpp @@ -0,0 +1,31 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "./fixtures/test_logger.h" +#include "vsag/vsag.h" + +int +main(int argc, char** argv) { + // your setup ... + vsag::Options::Instance().set_logger(&fixtures::logger); + + int result = Catch::Session().run(argc, argv); + + // your clean-up... + + return result; +} diff --git a/tests/test_multi_thread.cpp b/tests/test_multi_thread.cpp index fb931e27..9cc1541e 100644 --- a/tests/test_multi_thread.cpp +++ b/tests/test_multi_thread.cpp @@ -17,8 +17,11 @@ #include #include #include +#include #include +#include "default_logger.h" +#include "fixtures/test_logger.h" #include "fixtures/thread_pool.h" #include "vsag/options.h" #include "vsag/vsag.h" @@ -34,8 +37,10 @@ query_knn(std::shared_ptr index, if (result.value()->GetDim() != 0 && result.value()->GetIds()[0] == id) { return 1.0; } else { - std::cout << result.value()->GetDim() << " " << result.value()->GetIds()[0] << " " << id - << std::endl; + std::stringstream ss; + ss << "recall failure: dim " << result.value()->GetDim() << ", id " + << result.value()->GetIds()[0] << ", expected_id " << id; + fixtures::logger.Debug(ss.str()); } } else if (result.error().type == vsag::ErrorType::INTERNAL_ERROR) { std::cerr << "failed to perform knn search on index" << std::endl; @@ -43,7 +48,27 @@ query_knn(std::shared_ptr index, return 0.0; } +// catch2 logger is NOT supported to be used in multi-threading tests, so +// we need to replace it at the start of all the test cases in this file +class LoggerReplacer { +public: + LoggerReplacer() { + origin_logger_ = vsag::Options::Instance().logger(); + vsag::Options::Instance().set_logger(&logger_); + } + + ~LoggerReplacer() { + vsag::Options::Instance().set_logger(origin_logger_); + } + +private: + vsag::Logger* origin_logger_; + vsag::DefaultLogger logger_; +}; + TEST_CASE("DiskAnn Multi-threading", "[ft][diskann]") { + LoggerReplacer _; + int dim = 65; // Dimension of the elements int max_elements = 1000; // Maximum number of elements, should be known beforehand int max_degree = 16; // Tightly connected with internal dimensionality of the data @@ -116,6 +141,8 @@ TEST_CASE("DiskAnn Multi-threading", "[ft][diskann]") { } TEST_CASE("HNSW Multi-threading", "[ft][hnsw]") { + LoggerReplacer _; + int dim = 16; // Dimension of the elements int max_elements = 1000; // Maximum number of elements, should be known beforehand int max_degree = 16; // Tightly connected with internal dimensionality of the data @@ -185,6 +212,8 @@ TEST_CASE("HNSW Multi-threading", "[ft][hnsw]") { } TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { + LoggerReplacer _; + // avoid too much slow task logs vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN); @@ -278,6 +307,8 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { } TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hnsw]") { + LoggerReplacer _; + // avoid too much slow task logs vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN); From be4fbd26459117dea4c018a82ac28d6e45c61e30 Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Tue, 7 Jan 2025 14:59:47 +0800 Subject: [PATCH 10/14] enable examples in the compilation by default (#301) Signed-off-by: wxy407827 --- CMakeLists.txt | 2 +- examples/cpp/custom_memory_allocator.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6018296c..4fcbbb5c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,7 +43,7 @@ option (ENABLE_INTEL_MKL "Enable intel-mkl (x86 platform only)" ON) option (ENABLE_CXX11_ABI "Use CXX11 ABI" ON) option (ENABLE_LIBCXX "Use libc++ instead of libstdc++" OFF) # only support in clang option (ENABLE_TOOLS "Whether compile vsag tools" ON) -option (ENABLE_EXAMPLES "Whether compile examples" OFF) +option (ENABLE_EXAMPLES "Whether compile examples" ON) option (ENABLE_TESTS "Whether compile vsag tests" ON) option (DISABLE_SSE_FORCE "Force disable sse and higher instructions" OFF) option (DISABLE_AVX_FORCE "Force disable avx and higher instructions" OFF) diff --git a/examples/cpp/custom_memory_allocator.cpp b/examples/cpp/custom_memory_allocator.cpp index e46901d3..926c9174 100644 --- a/examples/cpp/custom_memory_allocator.cpp +++ b/examples/cpp/custom_memory_allocator.cpp @@ -64,7 +64,7 @@ main() { // vsag::Options::Instance().logger()->SetLevel(vsag::Logger::kDEBUG); ExampleAllocator allocator; - vsag::Resource resource(&allocator); + vsag::Resource resource(&allocator, nullptr); vsag::Engine engine(&resource); auto paramesters = R"( From f4a330a2358ed2983a67ce2045f802669256d765 Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Wed, 8 Jan 2025 20:45:51 +0800 Subject: [PATCH 11/14] add logger in tests (#302) Signed-off-by: wxy407827 --- src/index/hnsw.cpp | 6 +- tests/fixtures/test_logger.cpp | 22 +++++- tests/fixtures/test_logger.h | 122 +++++++++++++++++++++++++++++++-- tests/test_cpuinfo.cpp | 4 +- tests/test_diskann.cpp | 3 +- tests/test_engine.cpp | 3 +- tests/test_hnsw.cpp | 49 +++++-------- tests/test_index.cpp | 7 +- tests/test_main.cpp | 2 +- tests/test_multi_thread.cpp | 35 +++------- tests/test_random_index.cpp | 26 +++---- 11 files changed, 189 insertions(+), 90 deletions(-) diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 916234b0..fad62461 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -615,7 +615,7 @@ HNSW::update_id(int64_t old_id, int64_t new_id) { std::reinterpret_pointer_cast(alg_hnsw_)->updateLabel(old_id, new_id); } catch (const std::runtime_error& e) { - spdlog::warn( + logger::warn( "update error for replace old_id {} to new_id {}: {}", old_id, new_id, e.what()); return false; } @@ -641,7 +641,7 @@ HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) std::reinterpret_pointer_cast(alg_hnsw_)->updateVector( id, new_base_vec); } catch (const std::runtime_error& e) { - spdlog::warn("update error for replace vector of id {}: {}", id, e.what()); + logger::warn("update error for replace vector of id {}: {}", id, e.what()); return false; } @@ -663,7 +663,7 @@ HNSW::remove(int64_t id) { std::reinterpret_pointer_cast(alg_hnsw_)->markDelete(id); } } catch (const std::runtime_error& e) { - spdlog::warn("mark delete error for id {}: {}", id, e.what()); + logger::warn("mark delete error for id {}: {}", id, e.what()); return false; } diff --git a/tests/fixtures/test_logger.cpp b/tests/fixtures/test_logger.cpp index 0e2bf8f5..a1976e3b 100644 --- a/tests/fixtures/test_logger.cpp +++ b/tests/fixtures/test_logger.cpp @@ -16,9 +16,25 @@ #include "test_logger.h" #include +#include -namespace fixtures { +#include "vsag/logger.h" -TestLogger logger; +namespace fixtures::logger { -} // namespace fixtures +TestLogger test_logger; + +LoggerStream trace_buff(&test_logger, vsag::Logger::kTRACE); +LoggerStream debug_buff(&test_logger, vsag::Logger::kDEBUG); +LoggerStream info_buff(&test_logger, vsag::Logger::kINFO); +LoggerStream warn_buff(&test_logger, vsag::Logger::kWARN); +LoggerStream error_buff(&test_logger, vsag::Logger::kERR); +LoggerStream critical_buff(&test_logger, vsag::Logger::kCRITICAL); + +std::basic_ostream trace(&trace_buff); +std::basic_ostream debug(&debug_buff); +std::basic_ostream info(&info_buff); +std::basic_ostream warn(&warn_buff); +std::basic_ostream error(&error_buff); + +} // namespace fixtures::logger diff --git a/tests/fixtures/test_logger.h b/tests/fixtures/test_logger.h index 78219441..d97df2c7 100644 --- a/tests/fixtures/test_logger.h +++ b/tests/fixtures/test_logger.h @@ -14,13 +14,51 @@ // limitations under the License. #include -#include +#include +#include +#include "default_logger.h" +#include "vsag/logger.h" #include "vsag/vsag.h" -namespace fixtures { +namespace fixtures::logger { class TestLogger : public vsag::Logger { +public: + inline void + Log(const std::string& msg, Level level) { + switch (level) { + case Level::kTRACE: { + Trace(msg); + break; + } + case Level::kDEBUG: { + Debug(msg); + break; + } + case Level::kINFO: { + Info(msg); + break; + } + case Level::kWARN: { + Warn(msg); + break; + } + case Level::kERR: { + Error(msg); + break; + } + case Level::kCRITICAL: { + Critical(msg); + break; + } + default: { + // will not run into here + break; + } + } + } + public: inline void SetLevel(Level log_level) override { @@ -81,6 +119,82 @@ class TestLogger : public vsag::Logger { std::mutex mutex_; }; -extern TestLogger logger; +class LoggerStream : public std::basic_streambuf { +public: + explicit LoggerStream(TestLogger* logger, + vsag::Logger::Level level, + uint64_t buffer_size = 1024) + : logger_(logger), level_(level), buffer_(buffer_size + 1) { + auto base = &buffer_.front(); + this->setp(base, base + buffer_size); + } + + virtual ~LoggerStream() { + logger_ = nullptr; + } + +public: + virtual int + overflow(int ch) override { + std::lock_guard lock(mutex_); + if (ch != EOF) { + *this->pptr() = (char)ch; + this->pbump(1); + } + this->flush(); + return ch; + } + + virtual int + sync() override { + std::lock_guard lock(mutex_); + this->flush(); + return 0; + } + +private: + void + flush() { + std::ptrdiff_t n = this->pptr() - this->pbase(); + std::string msg(this->pbase(), n); + this->pbump(-n); + if (logger_) { + logger_->Log(msg, level_); + } + } + +private: + TestLogger* logger_ = nullptr; + vsag::Logger::Level level_; + std::mutex mutex_; + std::vector buffer_; + uint64_t size_; +}; + +extern TestLogger test_logger; +extern std::basic_ostream trace; +extern std::basic_ostream debug; +extern std::basic_ostream info; +extern std::basic_ostream warn; +extern std::basic_ostream error; +extern std::basic_ostream critical; + +// catch2 logger is NOT supported to be used in multi-threading tests, so +// we need to replace it at the start of all the test cases in this file +class LoggerReplacer { +public: + LoggerReplacer() { + origin_logger_ = vsag::Options::Instance().logger(); + vsag::Options::Instance().set_logger(&logger_); + } + + ~LoggerReplacer() { + vsag::Options::Instance().set_logger(origin_logger_); + } + +private: + vsag::Logger* origin_logger_; + vsag::DefaultLogger logger_; +}; -} // namespace fixtures +} // namespace fixtures::logger diff --git a/tests/test_cpuinfo.cpp b/tests/test_cpuinfo.cpp index 3c1983d3..11aa21d2 100644 --- a/tests/test_cpuinfo.cpp +++ b/tests/test_cpuinfo.cpp @@ -18,8 +18,10 @@ #include #include +#include "fixtures/test_logger.h" + TEST_CASE("CPU info", "[ft][cpuinfo]") { cpuinfo_initialize(); - std::cout << cpuinfo_get_processors_count() << std::endl; + fixtures::logger::debug << cpuinfo_get_processors_count() << std::endl; cpuinfo_deinitialize(); } diff --git a/tests/test_diskann.cpp b/tests/test_diskann.cpp index 6f65787b..e7854d6d 100644 --- a/tests/test_diskann.cpp +++ b/tests/test_diskann.cpp @@ -21,6 +21,7 @@ #include #include "fixtures/test_dataset_pool.h" +#include "fixtures/test_logger.h" #include "test_index.h" #include "vsag/errors.h" #include "vsag/vsag.h" @@ -35,8 +36,6 @@ TestDatasetPool DiskANNTestIndex::pool{}; } // namespace fixtures TEST_CASE_METHOD(fixtures::DiskANNTestIndex, "diskann build test", "[ft][index][diskann]") { - vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); - auto test_dim_count = 3; auto dims = fixtures::get_common_used_dims(3); auto metric_type = GENERATE("l2", "ip"); diff --git a/tests/test_engine.cpp b/tests/test_engine.cpp index fc899c62..03408808 100644 --- a/tests/test_engine.cpp +++ b/tests/test_engine.cpp @@ -17,6 +17,7 @@ #include #include +#include "fixtures/test_logger.h" #include "vsag/vsag.h" TEST_CASE("index params", "[ft][engine]") { @@ -71,7 +72,7 @@ TEST_CASE("index params", "[ft][engine]") { correct++; } } else if (result.error().type == vsag::ErrorType::INTERNAL_ERROR) { - std::cerr << "failed to search on index: internalError" << std::endl; + fixtures::logger::error << "failed to search on index: internalError" << std::endl; } } float recall = correct / static_cast(max_elements); diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp index 93987293..84ee9cfb 100644 --- a/tests/test_hnsw.cpp +++ b/tests/test_hnsw.cpp @@ -13,8 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include #include #include @@ -24,6 +22,7 @@ #include #include +#include "fixtures/test_logger.h" #include "vsag/errors.h" #include "vsag/vsag.h" @@ -31,21 +30,17 @@ using namespace std; template static void -writeBinaryPOD(std::ostream& out, const T& podRef) { +write_binary_pod(std::ostream& out, const T& podRef) { out.write((char*)&podRef, sizeof(T)); } template static void -readBinaryPOD(std::istream& in, T& podRef) { +read_binary_pod(std::istream& in, T& podRef) { in.read((char*)&podRef, sizeof(T)); } -const std::string tmp_dir = "/tmp/"; - TEST_CASE("HNSW range search", "[ft][hnsw]") { - spdlog::set_level(spdlog::level::debug); - int dim = 71; int max_elements = 10000; int max_degree = 16; @@ -120,8 +115,6 @@ TEST_CASE("HNSW range search", "[ft][hnsw]") { } TEST_CASE("HNSW Filtering Test", "[ft][hnsw]") { - spdlog::set_level(spdlog::level::debug); - int dim = 17; int max_elements = 1000; int max_degree = 16; @@ -259,8 +252,6 @@ TEST_CASE("HNSW Filtering Test", "[ft][hnsw]") { } TEST_CASE("HNSW small dimension", "[ft][hnsw]") { - spdlog::set_level(spdlog::level::debug); - int dim = 3; int max_elements = 1000; int max_degree = 24; @@ -322,8 +313,6 @@ TEST_CASE("HNSW small dimension", "[ft][hnsw]") { } TEST_CASE("HNSW Random Id", "[ft][hnsw]") { - spdlog::set_level(spdlog::level::debug); - int dim = 128; int max_elements = 1000; int max_degree = 64; @@ -405,8 +394,6 @@ TEST_CASE("HNSW Random Id", "[ft][hnsw]") { } TEST_CASE("pq infer knn search time recall", "[ft][hnsw]") { - spdlog::set_level(spdlog::level::debug); - int dim = 128; int max_elements = 1000; int max_degree = 64; @@ -467,7 +454,7 @@ TEST_CASE("pq infer knn search time recall", "[ft][hnsw]") { } TEST_CASE("hnsw serialize", "[ft][hnsw]") { - spdlog::set_level(spdlog::level::debug); + const std::string tmp_dir = "/tmp/"; int dim = 128; int max_elements = 1000; @@ -523,7 +510,7 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { for (auto key : keys) { // [len][data...][len][data...]... vsag::Binary b = bs->Get(key); - writeBinaryPOD(file, b.size); + write_binary_pod(file, b.size); file.write((const char*)b.data.get(), b.size); offsets.push_back(offset); offset += sizeof(b.size) + b.size; @@ -533,13 +520,13 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { // [len][key...][offset][len][key...][offset]... const auto& key = keys[i]; int64_t len = key.length(); - writeBinaryPOD(file, len); + write_binary_pod(file, len); file.write(key.c_str(), key.length()); - writeBinaryPOD(file, offsets[i]); + write_binary_pod(file, offsets[i]); } // [num_keys][footer_offset]$ - writeBinaryPOD(file, keys.size()); - writeBinaryPOD(file, offset); + write_binary_pod(file, keys.size()); + write_binary_pod(file, offset); file.close(); } else if (bs.error().type == vsag::ErrorType::NO_ENOUGH_MEMORY) { std::cerr << "no enough memory to serialize index" << std::endl; @@ -551,8 +538,8 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { std::ifstream file(tmp_dir + "hnsw.index", std::ios::in); file.seekg(-sizeof(uint64_t) * 2, std::ios::end); uint64_t num_keys, footer_offset; - readBinaryPOD(file, num_keys); - readBinaryPOD(file, footer_offset); + read_binary_pod(file, num_keys); + read_binary_pod(file, footer_offset); // std::cout << "num_keys: " << num_keys << std::endl; // std::cout << "footer_offset: " << footer_offset << std::endl; file.seekg(footer_offset, std::ios::beg); @@ -561,7 +548,7 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { std::vector offsets; for (uint64_t i = 0; i < num_keys; ++i) { int64_t key_len; - readBinaryPOD(file, key_len); + read_binary_pod(file, key_len); // std::cout << "key_len: " << key_len << std::endl; char key_buf[key_len + 1]; memset(key_buf, 0, key_len + 1); @@ -570,7 +557,7 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { keys.push_back(key_buf); uint64_t offset; - readBinaryPOD(file, offset); + read_binary_pod(file, offset); // std::cout << "offset: " << offset << std::endl; offsets.push_back(offset); } @@ -579,7 +566,7 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { for (uint64_t i = 0; i < num_keys; ++i) { file.seekg(offsets[i], std::ios::beg); vsag::Binary b; - readBinaryPOD(file, b.size); + read_binary_pod(file, b.size); // std::cout << "len: " << b.size << std::endl; b.data.reset(new int8_t[b.size]); file.read((char*)b.data.get(), b.size); @@ -601,8 +588,8 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { std::ifstream file(tmp_dir + "hnsw.index", std::ios::in); file.seekg(-sizeof(uint64_t) * 2, std::ios::end); uint64_t num_keys, footer_offset; - readBinaryPOD(file, num_keys); - readBinaryPOD(file, footer_offset); + read_binary_pod(file, num_keys); + read_binary_pod(file, footer_offset); // std::cout << "num_keys: " << num_keys << std::endl; // std::cout << "footer_offset: " << footer_offset << std::endl; file.seekg(footer_offset, std::ios::beg); @@ -611,7 +598,7 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { std::vector offsets; for (uint64_t i = 0; i < num_keys; ++i) { int64_t key_len; - readBinaryPOD(file, key_len); + read_binary_pod(file, key_len); // std::cout << "key_len: " << key_len << std::endl; char key_buf[key_len + 1]; memset(key_buf, 0, key_len + 1); @@ -620,7 +607,7 @@ TEST_CASE("hnsw serialize", "[ft][hnsw]") { keys.push_back(key_buf); uint64_t offset; - readBinaryPOD(file, offset); + read_binary_pod(file, offset); // std::cout << "offset: " << offset << std::endl; offsets.push_back(offset); } diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 378291a7..ef79cd47 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -15,6 +15,7 @@ #include "test_index.h" +#include "fixtures/test_logger.h" #include "fixtures/thread_pool.h" #include "simd/fp32_simd.h" @@ -390,6 +391,8 @@ void TestIndex::TestConcurrentAdd(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset, bool expected_success) { + fixtures::logger::LoggerReplacer _; + auto base_count = dataset->base_->GetNumElements(); int64_t temp_count = base_count / 2; auto dim = dataset->base_->GetDim(); @@ -433,6 +436,8 @@ TestIndex::TestConcurrentKnnSearch(const TestIndex::IndexPtr& index, const std::string& search_param, float expected_recall, bool expected_success) { + fixtures::logger::LoggerReplacer _; + auto queries = dataset->query_; auto query_count = queries->GetNumElements(); auto dim = queries->GetDim(); @@ -538,4 +543,4 @@ TestIndex::TestDuplicateAdd(const TestIndex::IndexPtr& index, const TestDatasetP check_func(add_index_2.value()); } -} // namespace fixtures \ No newline at end of file +} // namespace fixtures diff --git a/tests/test_main.cpp b/tests/test_main.cpp index 45700c80..bb6d664d 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -21,7 +21,7 @@ int main(int argc, char** argv) { // your setup ... - vsag::Options::Instance().set_logger(&fixtures::logger); + vsag::Options::Instance().set_logger(&fixtures::logger::test_logger); int result = Catch::Session().run(argc, argv); diff --git a/tests/test_multi_thread.cpp b/tests/test_multi_thread.cpp index 9cc1541e..dce603b5 100644 --- a/tests/test_multi_thread.cpp +++ b/tests/test_multi_thread.cpp @@ -20,7 +20,6 @@ #include #include -#include "default_logger.h" #include "fixtures/test_logger.h" #include "fixtures/thread_pool.h" #include "vsag/options.h" @@ -37,10 +36,10 @@ query_knn(std::shared_ptr index, if (result.value()->GetDim() != 0 && result.value()->GetIds()[0] == id) { return 1.0; } else { - std::stringstream ss; - ss << "recall failure: dim " << result.value()->GetDim() << ", id " - << result.value()->GetIds()[0] << ", expected_id " << id; - fixtures::logger.Debug(ss.str()); + fixtures::logger::debug << "recall failure: dim " << result.value()->GetDim() << ", id " + << result.value()->GetIds()[0] << ", expected_id " << id + << std::endl; + ; } } else if (result.error().type == vsag::ErrorType::INTERNAL_ERROR) { std::cerr << "failed to perform knn search on index" << std::endl; @@ -48,26 +47,8 @@ query_knn(std::shared_ptr index, return 0.0; } -// catch2 logger is NOT supported to be used in multi-threading tests, so -// we need to replace it at the start of all the test cases in this file -class LoggerReplacer { -public: - LoggerReplacer() { - origin_logger_ = vsag::Options::Instance().logger(); - vsag::Options::Instance().set_logger(&logger_); - } - - ~LoggerReplacer() { - vsag::Options::Instance().set_logger(origin_logger_); - } - -private: - vsag::Logger* origin_logger_; - vsag::DefaultLogger logger_; -}; - TEST_CASE("DiskAnn Multi-threading", "[ft][diskann]") { - LoggerReplacer _; + fixtures::logger::LoggerReplacer _; int dim = 65; // Dimension of the elements int max_elements = 1000; // Maximum number of elements, should be known beforehand @@ -141,7 +122,7 @@ TEST_CASE("DiskAnn Multi-threading", "[ft][diskann]") { } TEST_CASE("HNSW Multi-threading", "[ft][hnsw]") { - LoggerReplacer _; + fixtures::logger::LoggerReplacer _; int dim = 16; // Dimension of the elements int max_elements = 1000; // Maximum number of elements, should be known beforehand @@ -212,7 +193,7 @@ TEST_CASE("HNSW Multi-threading", "[ft][hnsw]") { } TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { - LoggerReplacer _; + fixtures::logger::LoggerReplacer _; // avoid too much slow task logs vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN); @@ -307,7 +288,7 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") { } TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hnsw]") { - LoggerReplacer _; + fixtures::logger::LoggerReplacer _; // avoid too much slow task logs vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN); diff --git a/tests/test_random_index.cpp b/tests/test_random_index.cpp index 9f83cd75..2aa54c0f 100644 --- a/tests/test_random_index.cpp +++ b/tests/test_random_index.cpp @@ -16,7 +16,9 @@ #include #include +#include +#include "fixtures/test_logger.h" #include "vsag/vsag.h" using namespace std; @@ -57,22 +59,14 @@ TEST_CASE("Random Index Test", "[ft][random]") { int seed = seed_random(rng); rng.seed(seed); - spdlog::info( - "seed: {}, dim: {}, max_elements: {}, max_degree: {}, ef_construction: {}, ef_search: {}, " - "k: {}, " - "io_limit: {}, threshold: {}, pq_dims: {}, use_pq_search: {}, mold: {}", - seed, - dim, - max_elements, - max_degree, - ef_construction, - ef_search, - k, - io_limit, - threshold, - pq_dims, - use_pq_search, - mold); + fixtures::logger::info << "seed: " << seed << ", dim: " << dim + << ", max_elements: " << max_elements << ", max_degree: " << max_degree + << ", ef_construction: " << ef_construction + << ", ef_search: " << ef_search << ", k: " << k + << ",io_limit: " << io_limit << ", threshold: " << threshold + << ", pq_dims: " << pq_dims << ", use_pq_search: " << use_pq_search + << ", mold: " << mold << std::endl; + ; float pq_sample_rate = 0.5; // Initing index nlohmann::json hnsw_parameters{ From a41a874fe184053b7a692fd01972b983a81e3a3e Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:11:03 +0800 Subject: [PATCH 12/14] remove unuseful log for ut and ft (#303) Signed-off-by: jinjiabao.jjb --- CMakeLists.txt | 1 + src/index/diskann.cpp | 4 ++++ src/index/hnsw.cpp | 12 ++++++++++-- tests/test_random_allocator.cpp | 2 ++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4fcbbb5c..49fa1e27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -262,6 +262,7 @@ endif () # tests if (ENABLE_TESTS) + target_compile_definitions(vsag PRIVATE ENABLE_TESTS) add_subdirectory (tests) endif () diff --git a/src/index/diskann.cpp b/src/index/diskann.cpp index ed2c9149..eff33767 100644 --- a/src/index/diskann.cpp +++ b/src/index/diskann.cpp @@ -300,7 +300,9 @@ DiskANN::knn_search(const DatasetPtr& query, int64_t k, const std::string& parameters, const std::function& filter) const { +#ifndef ENABLE_TESTS SlowTaskTimer t("diskann knnsearch", 200); +#endif // cannot perform search on empty index if (empty_index_) { @@ -461,7 +463,9 @@ DiskANN::range_search(const DatasetPtr& query, const std::string& parameters, const std::function& filter, int64_t limited_size) const { +#ifndef ENABLE_TESTS SlowTaskTimer t("diskann rangesearch", 200); +#endif // cannot perform search on empty index if (empty_index_) { diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index fad62461..9b200c54 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -141,7 +141,9 @@ HNSW::build(const DatasetPtr& base) { tl::expected, Error> HNSW::add(const DatasetPtr& base) { +#ifndef ENABLE_TESTS SlowTaskTimer t("hnsw add", 20); +#endif if (use_static_) { LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION, "static index does not support add"); @@ -192,8 +194,9 @@ HNSW::knn_search(const DatasetPtr& query, int64_t k, const std::string& parameters, BaseFilterFunctor* filter_ptr) const { +#ifndef ENABLE_TESTS SlowTaskTimer t_total("hnsw knnsearch", 20); - +#endif try { // cannot perform search on empty index if (empty_index_) { @@ -315,8 +318,9 @@ HNSW::range_search(const DatasetPtr& query, const std::string& parameters, BaseFilterFunctor* filter_ptr, int64_t limited_size) const { +#ifndef ENABLE_TESTS SlowTaskTimer t("hnsw rangesearch", 20); - +#endif try { // cannot perform search on empty index if (empty_index_) { @@ -615,8 +619,10 @@ HNSW::update_id(int64_t old_id, int64_t new_id) { std::reinterpret_pointer_cast(alg_hnsw_)->updateLabel(old_id, new_id); } catch (const std::runtime_error& e) { +#ifndef ENABLE_TESTS logger::warn( "update error for replace old_id {} to new_id {}: {}", old_id, new_id, e.what()); +#endif return false; } @@ -641,7 +647,9 @@ HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) std::reinterpret_pointer_cast(alg_hnsw_)->updateVector( id, new_base_vec); } catch (const std::runtime_error& e) { +#ifndef ENABLE_TESTS logger::warn("update error for replace vector of id {}: {}", id, e.what()); +#endif return false; } diff --git a/tests/test_random_allocator.cpp b/tests/test_random_allocator.cpp index 81ec31c6..650d5efd 100644 --- a/tests/test_random_allocator.cpp +++ b/tests/test_random_allocator.cpp @@ -18,8 +18,10 @@ #include "fixtures/random_allocator.h" #include "vsag//factory.h" +#include "vsag/options.h" TEST_CASE("Random Alocator Test", "[ft][hnsw]") { + vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kOFF); fixtures::RandomAllocator allocator; auto paramesters = R"( From cd24b671691827da135bc09fbfc122a57e080502 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Fri, 10 Jan 2025 17:53:11 +0800 Subject: [PATCH 13/14] add Dataset and Modify review comments --- include/vsag/index.h | 7 +++---- src/algorithm/hnswlib/algorithm_interface.h | 8 +++++--- src/algorithm/hnswlib/hnswalg.cpp | 16 +++++++++------- src/algorithm/hnswlib/hnswalg.h | 6 +++--- src/algorithm/hnswlib/hnswalg_static.h | 19 ++++++++++--------- src/index/hnsw.h | 7 +++---- 6 files changed, 33 insertions(+), 30 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index 0c6502d4..039815dd 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -256,13 +256,12 @@ class Index { * @param vids is the unique identifier of the vector to be calculated in the index. * @param vector is the embedding of query * @param distances is the distances between the query and the vector of the given ID - * @return result is valid distance of input vids. + * @return result is valid distance of input vids. '-1' indicates an invalid distance. */ - virtual tl::expected + virtual tl::expected CalcBatchDistanceById(int64_t count, const int64_t *vids, - const float* vector, - float *&distances) const { + const float* vector) const { throw std::runtime_error("Index doesn't support get distance by id"); }; diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 3a67c5ad..dbe9ff83 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -24,6 +24,9 @@ #include "space_interface.h" #include "stream_reader.h" #include "typing.h" +#include "vsag/dataset.h" +#include "vsag/expected.hpp" +#include "vsag/errors.h" namespace hnswlib { @@ -69,11 +72,10 @@ class AlgorithmInterface { virtual float getDistanceByLabel(LabelType label, const void* data_point) = 0; - virtual int64_t + virtual tl::expected getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void *data_point, - float *&distances) = 0; + const void *data_point) = 0; virtual const float* getDataByLabel(LabelType label) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index a2a09e55..826ec2f3 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -171,14 +171,14 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) { return dist; } -int64_t +tl::expected HierarchicalNSW::getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void* data_point, - float *&distances) { + const void* data_point) { std::shared_lock lock_table(label_lookup_lock_); - int64_t ret_cnt = 0; - distances = (float *)allocator_->Allocate(sizeof(float) * count); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -187,10 +187,12 @@ HierarchicalNSW::getBatchDistanceByLabel(int64_t count, InnerIdType internal_id = search->second; float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); distances[i] = dist; - ret_cnt++; + valid_cnt++; } } - return ret_cnt; + result->NumElements(valid_cnt)->Owner(true, allocator_); + result->Distances(distances); + return std::move(result); } bool diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 7c533ac4..7412d270 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -38,6 +38,7 @@ #include "algorithm_interface.h" #include "block_manager.h" #include "visited_list_pool.h" +#include "vsag/dataset.h" namespace hnswlib { using InnerIdType = vsag::InnerIdType; using linklistsizeint = unsigned int; @@ -146,11 +147,10 @@ class HierarchicalNSW : public AlgorithmInterface { float getDistanceByLabel(LabelType label, const void* data_point) override; - int64_t + tl::expected getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void* data_point, - float *&distances) override; + const void* data_point) override; bool isValidLabel(LabelType label) override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 190a8998..a0172b27 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -81,7 +81,6 @@ class StaticHierarchicalNSW : public AlgorithmInterface { void* dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ - mutable std::shared_mutex shared_label_lookup_lock; std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -263,14 +262,14 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return dist; } - int64_t + tl::expected getBatchDistanceByLabel(int64_t count, const int64_t *vids, - const void* data_point, - float *&distances) override { - std::shared_lock lock_table(shared_label_lookup_lock); - int64_t ret_cnt = 0; - distances = (float *)allocator_->Allocate(sizeof(float) * count); + const void* data_point) override { + std::unique_lock lock_table(label_lookup_lock); + int64_t valid_cnt = 0; + auto result = vsag::Dataset::Make(); + auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -279,10 +278,12 @@ class StaticHierarchicalNSW : public AlgorithmInterface { InnerIdType internal_id = search->second; float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_); distances[i] = dist; - ret_cnt++; + valid_cnt++; } } - return ret_cnt; + result->NumElements(valid_cnt)->Owner(true, allocator_); + result->Distances(distances); + return std::move(result); } bool diff --git a/src/index/hnsw.h b/src/index/hnsw.h index d3bbecf5..1bb79766 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -145,12 +145,11 @@ class HNSW : public Index { SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector)); }; - virtual tl::expected + virtual tl::expected CalcBatchDistanceById(int64_t count, const int64_t *vids, - const float* vector, - float *&distances) const override { - SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector, distances)); + const float* vector) const override { + SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(count, vids, vector)); }; [[nodiscard]] bool From a519cd46dccbca59d4aae28aaa5fe1d3bebcd697 Mon Sep 17 00:00:00 2001 From: "zourunxin.zrx" Date: Wed, 15 Jan 2025 11:54:27 +0800 Subject: [PATCH 14/14] Modify the potential memory leak risk of getBatchDistanceByLabel Signed-off-by: zourunxin.zrx --- src/algorithm/hnswlib/hnswalg.cpp | 5 +++-- src/algorithm/hnswlib/hnswalg_static.h | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 826ec2f3..0f10d49b 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -178,7 +178,9 @@ HierarchicalNSW::getBatchDistanceByLabel(int64_t count, std::shared_lock lock_table(label_lookup_lock_); int64_t valid_cnt = 0; auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -190,8 +192,7 @@ HierarchicalNSW::getBatchDistanceByLabel(int64_t count, valid_cnt++; } } - result->NumElements(valid_cnt)->Owner(true, allocator_); - result->Distances(distances); + result->NumElements(valid_cnt); return std::move(result); } diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index a0172b27..d50bad82 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -269,7 +269,9 @@ class StaticHierarchicalNSW : public AlgorithmInterface { std::unique_lock lock_table(label_lookup_lock); int64_t valid_cnt = 0; auto result = vsag::Dataset::Make(); + result->Owner(true, allocator_); auto *distances = (float *)allocator_->Allocate(sizeof(float) * count); + result->Distances(distances); for (int i = 0; i < count; i++) { auto search = label_lookup_.find(vids[i]); if (search == label_lookup_.end()) { @@ -281,8 +283,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { valid_cnt++; } } - result->NumElements(valid_cnt)->Owner(true, allocator_); - result->Distances(distances); + result->NumElements(valid_cnt); return std::move(result); }