diff --git a/src/index/pyramid.cpp b/src/index/pyramid.cpp index 96a22349..ced3aee5 100644 --- a/src/index/pyramid.cpp +++ b/src/index/pyramid.cpp @@ -176,10 +176,10 @@ Pyramid::Add(const DatasetPtr& base) { } tl::expected -Pyramid::KnnSearch(const DatasetPtr& query, - int64_t k, - const std::string& parameters, - BitsetPtr invalid) const { +Pyramid::knn_search(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const SearchFunc& search_func) const { auto path = query->GetPaths(); // TODO(inabao): provide different search modes. std::string current_path = path[0]; @@ -209,7 +209,7 @@ Pyramid::KnnSearch(const DatasetPtr& query, auto node = candidate_indexes.front(); candidate_indexes.pop_front(); if (node->index) { - auto result = node->index->KnnSearch(query, k, parameters, invalid); + auto result = search_func(node->index); if (result.has_value()) { DatasetPtr r = result.value(); for (int i = 0; i < r->GetDim(); ++i) { @@ -251,40 +251,6 @@ Pyramid::KnnSearch(const DatasetPtr& query, return result; } -tl::expected -Pyramid::KnnSearch(const DatasetPtr& query, - int64_t k, - const std::string& parameters, - const std::function& filter) const { - return {}; -} - -tl::expected -Pyramid::RangeSearch(const DatasetPtr& query, - float radius, - const std::string& parameters, - int64_t limited_size) const { - return {}; -} - -tl::expected -Pyramid::RangeSearch(const DatasetPtr& query, - float radius, - const std::string& parameters, - BitsetPtr invalid, - int64_t limited_size) const { - return {}; -} - -tl::expected -Pyramid::RangeSearch(const DatasetPtr& query, - float radius, - const std::string& parameters, - const std::function& filter, - int64_t limited_size) const { - return {}; -} - tl::expected Pyramid::Serialize() const { BinarySet binary_set; diff --git a/src/index/pyramid.h b/src/index/pyramid.h index 8cc41fdd..c879a590 100644 --- a/src/index/pyramid.h +++ b/src/index/pyramid.h @@ -17,6 +17,8 @@ #include +#include "base_filter_functor.h" +#include "logger.h" #include "pyramid_zparameters.h" #include "safe_allocator.h" @@ -70,6 +72,8 @@ struct IndexNode { } }; +using SearchFunc = std::function(IndexPtr)>; + class Pyramid : public Index { public: Pyramid(PyramidParameters pyramid_param, const IndexCommonParam& commom_param) @@ -90,33 +94,64 @@ class Pyramid : public Index { KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, - BitsetPtr invalid = nullptr) const override; + BitsetPtr invalid = nullptr) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->KnnSearch(query, k, parameters, invalid); + }; + SAFE_CALL(return this->knn_search(query, k, parameters, search_func);) + } tl::expected KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, - const std::function& filter) const override; + const std::function& filter) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->KnnSearch(query, k, parameters, filter); + }; + SAFE_CALL(return this->knn_search(query, k, parameters, search_func);) + } tl::expected RangeSearch(const DatasetPtr& query, float radius, const std::string& parameters, - int64_t limited_size = -1) const override; + int64_t limited_size = -1) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->RangeSearch(query, radius, parameters, limited_size); + }; + int64_t final_limit = + limited_size == -1 ? std::numeric_limits::max() : limited_size; + SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);) + } tl::expected RangeSearch(const DatasetPtr& query, float radius, const std::string& parameters, BitsetPtr invalid, - int64_t limited_size = -1) const override; + int64_t limited_size = -1) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->RangeSearch(query, radius, parameters, invalid, limited_size); + }; + int64_t final_limit = + limited_size == -1 ? std::numeric_limits::max() : limited_size; + SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);) + } tl::expected RangeSearch(const DatasetPtr& query, float radius, const std::string& parameters, const std::function& filter, - int64_t limited_size = -1) const override; + int64_t limited_size = -1) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->RangeSearch(query, radius, parameters, filter, limited_size); + }; + int64_t final_limit = + limited_size == -1 ? std::numeric_limits::max() : limited_size; + SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);) + } tl::expected Serialize() const override; @@ -134,6 +169,12 @@ class Pyramid : public Index { GetMemoryUsage() const override; private: + tl::expected + knn_search(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const SearchFunc& search_func) const; + inline std::shared_ptr try_get_node_with_init(UnorderedMap>& index_map, const std::string& key) { diff --git a/tests/fixtures/test_dataset.cpp b/tests/fixtures/test_dataset.cpp index 575f9b3c..b1e35ce1 100644 --- a/tests/fixtures/test_dataset.cpp +++ b/tests/fixtures/test_dataset.cpp @@ -263,7 +263,8 @@ TestDataset::TestDataset(uint64_t dim, uint64_t count, std::string metric_str, b this->range_radius_.resize(query_count); for (uint64_t i = 0; i < query_count; ++i) { this->range_radius_[i] = - 0.5f * (result.first[i * count + top_k] + result.first[i * count + top_k - 1]); + 0.5f * (this->ground_truth_->GetDistances()[i * top_k + top_k - 1] + + this->ground_truth_->GetDistances()[i * top_k + top_k - 2]); } delete[] result.first; delete[] result.second; diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 832b8a0b..3ea6243c 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -280,6 +280,7 @@ TestIndex::TestRangeSearch(const IndexPtr& index, query->NumElements(1) ->Dim(dim) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Paths(queries->GetPaths() + i) ->Owner(false); auto res = index->RangeSearch(query, radius[i], search_param, limited_size); REQUIRE(res.has_value() == expected_success); @@ -291,8 +292,8 @@ TestIndex::TestRangeSearch(const IndexPtr& index, } auto result = res.value()->GetIds(); auto gt = gts->GetIds() + gt_topK * i; - auto val = Intersection(gt, gt_topK, result, res.value()->GetDim()); - cur_recall += static_cast(val) / static_cast(gt_topK); + auto val = Intersection(gt, gt_topK - 1, result, res.value()->GetDim()); + cur_recall += static_cast(val) / static_cast(gt_topK - 1); } REQUIRE(cur_recall > expected_recall * query_count); } @@ -314,6 +315,7 @@ TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index, query->NumElements(1) ->Dim(dim) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Paths(queries->GetPaths() + i) ->Owner(false); auto res = index->KnnSearch(query, topk, search_param, dataset->filter_function_); REQUIRE(res.has_value() == expected_success); diff --git a/tests/test_pyramid.cpp b/tests/test_pyramid.cpp index 2b19b621..a7332673 100644 --- a/tests/test_pyramid.cpp +++ b/tests/test_pyramid.cpp @@ -34,7 +34,7 @@ class PyramidTestIndex : public fixtures::TestIndex { static std::vector dims; - constexpr static uint64_t base_count = 1000; + constexpr static uint64_t base_count = 3000; constexpr static const char* search_param_tmp = R"( {{ @@ -82,5 +82,8 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::PyramidTestIndex, auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type, /*with_path=*/true); TestContinueAdd(index, dataset, true); TestKnnSearch(index, dataset, search_param, 0.99, true); + TestFilterSearch(index, dataset, search_param, 0.99, true); + TestRangeSearch(index, dataset, search_param, 0.99, 10, true); + TestRangeSearch(index, dataset, search_param, 0.49, 5, true); } }