Skip to content

Commit

Permalink
implement the interfaces for range search, filter in pyramid
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <[email protected]>
  • Loading branch information
jinjiabao.jjb committed Jan 15, 2025
1 parent b663e67 commit 571c9a1
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 48 deletions.
44 changes: 5 additions & 39 deletions src/index/pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ Pyramid::Add(const DatasetPtr& base) {
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid) const {
Pyramid::knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const {
auto path = query->GetPaths(); // TODO(inabao): provide different search modes.

std::string current_path = path[0];
Expand Down Expand Up @@ -209,7 +209,7 @@ Pyramid::KnnSearch(const DatasetPtr& query,
auto node = candidate_indexes.front();
candidate_indexes.pop_front();
if (node->index) {
auto result = node->index->KnnSearch(query, k, parameters, invalid);
auto result = search_func(node->index);
if (result.has_value()) {
DatasetPtr r = result.value();
for (int i = 0; i < r->GetDim(); ++i) {
Expand Down Expand Up @@ -251,40 +251,6 @@ Pyramid::KnnSearch(const DatasetPtr& query,
return result;
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size) const {
return {};
}

tl::expected<BinarySet, Error>
Pyramid::Serialize() const {
BinarySet binary_set;
Expand Down
51 changes: 46 additions & 5 deletions src/index/pyramid.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include <utility>

#include "base_filter_functor.h"
#include "logger.h"
#include "pyramid_zparameters.h"
#include "safe_allocator.h"

Expand Down Expand Up @@ -70,6 +72,8 @@ struct IndexNode {
}
};

using SearchFunc = std::function<tl::expected<DatasetPtr, Error>(IndexPtr)>;

class Pyramid : public Index {
public:
Pyramid(PyramidParameters pyramid_param, const IndexCommonParam& commom_param)
Expand All @@ -90,33 +94,64 @@ class Pyramid : public Index {
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid = nullptr) const override;
BitsetPtr invalid = nullptr) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, invalid);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const override;
const std::function<bool(int64_t)>& filter) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, filter);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, limited_size);
};
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, invalid, limited_size);
};
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, filter, limited_size);
};
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

tl::expected<BinarySet, Error>
Serialize() const override;
Expand All @@ -134,6 +169,12 @@ class Pyramid : public Index {
GetMemoryUsage() const override;

private:
tl::expected<DatasetPtr, Error>
knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const;

inline std::shared_ptr<IndexNode>
try_get_node_with_init(UnorderedMap<std::string, std::shared_ptr<IndexNode>>& index_map,
const std::string& key) {
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ TestDataset::TestDataset(uint64_t dim, uint64_t count, std::string metric_str, b
this->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
this->range_radius_[i] =
0.5f * (result.first[i * count + top_k] + result.first[i * count + top_k - 1]);
0.5f * (this->ground_truth_->GetDistances()[i * top_k + top_k - 1] +
this->ground_truth_->GetDistances()[i * top_k + top_k - 2]);
}
delete[] result.first;
delete[] result.second;
Expand Down
6 changes: 4 additions & 2 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ TestIndex::TestRangeSearch(const IndexPtr& index,
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Paths(queries->GetPaths() + i)
->Owner(false);
auto res = index->RangeSearch(query, radius[i], search_param, limited_size);
REQUIRE(res.has_value() == expected_success);
Expand All @@ -291,8 +292,8 @@ TestIndex::TestRangeSearch(const IndexPtr& index,
}
auto result = res.value()->GetIds();
auto gt = gts->GetIds() + gt_topK * i;
auto val = Intersection(gt, gt_topK, result, res.value()->GetDim());
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK);
auto val = Intersection(gt, gt_topK - 1, result, res.value()->GetDim());
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK - 1);
}
REQUIRE(cur_recall > expected_recall * query_count);
}
Expand All @@ -314,6 +315,7 @@ TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index,
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Paths(queries->GetPaths() + i)
->Owner(false);
auto res = index->KnnSearch(query, topk, search_param, dataset->filter_function_);
REQUIRE(res.has_value() == expected_success);
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PyramidTestIndex : public fixtures::TestIndex {

static std::vector<int> dims;

constexpr static uint64_t base_count = 1000;
constexpr static uint64_t base_count = 3000;

constexpr static const char* search_param_tmp = R"(
{{
Expand Down Expand Up @@ -82,5 +82,8 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::PyramidTestIndex,
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type, /*with_path=*/true);
TestContinueAdd(index, dataset, true);
TestKnnSearch(index, dataset, search_param, 0.99, true);
TestFilterSearch(index, dataset, search_param, 0.99, true);
TestRangeSearch(index, dataset, search_param, 0.99, 10, true);
TestRangeSearch(index, dataset, search_param, 0.49, 5, true);
}
}

0 comments on commit 571c9a1

Please sign in to comment.