diff --git a/src/index/pyramid.cpp b/src/index/pyramid.cpp index a009dc45..7c2b2f3f 100644 --- a/src/index/pyramid.cpp +++ b/src/index/pyramid.cpp @@ -17,6 +17,115 @@ namespace vsag { +Binary +binaryset_to_binary(const BinarySet binary_set) { + /* + * The serialized layout of the Binary data in memory will be as follows: + * | key_size_0 | key_0 (L_0 bytes) | binary_size_0 | binary_data_0 (S_0 bytes) | + * | key_size_1 | key_1 (L_1 bytes) | binary_size_1 | binary_data_1 (S_1 bytes) | + * | ... | ... | ... | ... | + * | key_size_(N-1) | key_(N-1) (L_(N-1) bytes) | binary_size_(N-1) | binary_data_(N-1) (S_(N-1) bytes) | + * Where: + * - `key_size_k`: size of the k-th key (in bytes) + * - `key_k`: the actual k-th key data (length L_k) + * - `binary_size_k`: size of the binary data associated with the k-th key (in bytes) + * - `binary_data_k`: the actual binary data contents (length S_k) + * - N: total number of keys in the BinarySet + */ + size_t total_size = 0; + auto keys = binary_set.GetKeys(); + + for (const auto& key : keys) { + total_size += sizeof(size_t) + key.size(); + total_size += sizeof(size_t); + total_size += binary_set.Get(key).size; + } + + Binary result; + result.data = std::shared_ptr(new int8_t[total_size]); + result.size = total_size; + + size_t offset = 0; + + for (const auto& key : keys) { + size_t key_size = key.size(); + memcpy(result.data.get() + offset, &key_size, sizeof(size_t)); + offset += sizeof(size_t); + memcpy(result.data.get() + offset, key.data(), key_size); + offset += key_size; + + Binary binary = binary_set.Get(key); + memcpy(result.data.get() + offset, &binary.size, sizeof(size_t)); + offset += sizeof(size_t); + memcpy(result.data.get() + offset, binary.data.get(), binary.size); + offset += binary.size; + } + + return result; +} + +BinarySet +binary_to_binaryset(const Binary binary) { + /* + * The Binary structure is serialized in the following layout: + * | key_size (sizeof(size_t)) | key (of length key_size) | binary_size (sizeof(size_t)) | binary data (of length binary_size) | + * Each key and its associated binary data are sequentially stored in the Binary object's data array, + * and this information guides the deserialization process here. + */ + BinarySet binary_set; + size_t offset = 0; + + while (offset < binary.size) { + size_t key_size; + memcpy(&key_size, binary.data.get() + offset, sizeof(size_t)); + offset += sizeof(size_t); + + std::string key(reinterpret_cast(binary.data.get() + offset), key_size); + offset += key_size; + + size_t binary_size; + memcpy(&binary_size, binary.data.get() + offset, sizeof(size_t)); + offset += sizeof(size_t); + + Binary new_binary; + new_binary.size = binary_size; + new_binary.data = std::shared_ptr(new int8_t[binary_size]); + memcpy(new_binary.data.get(), binary.data.get() + offset, binary_size); + offset += binary_size; + + binary_set.Set(key, new_binary); + } + + return binary_set; +} + +ReaderSet +reader_to_readerset(std::shared_ptr reader) { + ReaderSet reader_set; + size_t offset = 0; + + while (offset < reader->Size()) { + size_t key_size; + reader->Read(offset, sizeof(size_t), &key_size); + offset += sizeof(size_t); + std::shared_ptr key_chars = std::shared_ptr(new char[key_size]); + reader->Read(offset, key_size, key_chars.get()); + std::string key(key_chars.get(), key_size); + offset += key_size; + + size_t binary_size; + reader->Read(offset, sizeof(size_t), &binary_size); + offset += sizeof(size_t); + + auto new_reader = std::make_shared(reader, offset, binary_size); + offset += binary_size; + + reader_set.Set(key, new_reader); + } + + return reader_set; +} + template using Deque = std::deque>; @@ -85,10 +194,10 @@ Pyramid::Add(const DatasetPtr& base) { } tl::expected -Pyramid::KnnSearch(const DatasetPtr& query, - int64_t k, - const std::string& parameters, - BitsetPtr invalid) const { +Pyramid::knn_search(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const SearchFunc& search_func) const { auto path = query->GetPaths(); // TODO(inabao): provide different search modes. std::string current_path = path[0]; @@ -118,7 +227,7 @@ Pyramid::KnnSearch(const DatasetPtr& query, auto node = candidate_indexes.front(); candidate_indexes.pop_front(); if (node->index) { - auto result = node->index->KnnSearch(query, k, parameters, invalid); + auto result = search_func(node->index); if (result.has_value()) { DatasetPtr r = result.value(); for (int i = 0; i < r->GetDim(); ++i) { @@ -162,52 +271,61 @@ Pyramid::KnnSearch(const DatasetPtr& query, return result; } -tl::expected -Pyramid::KnnSearch(const DatasetPtr& query, - int64_t k, - const std::string& parameters, - const std::function& filter) const { - return {}; -} - -tl::expected -Pyramid::RangeSearch(const DatasetPtr& query, - float radius, - const std::string& parameters, - int64_t limited_size) const { - return {}; -} - -tl::expected -Pyramid::RangeSearch(const DatasetPtr& query, - float radius, - const std::string& parameters, - BitsetPtr invalid, - int64_t limited_size) const { - return {}; -} - -tl::expected -Pyramid::RangeSearch(const DatasetPtr& query, - float radius, - const std::string& parameters, - const std::function& filter, - int64_t limited_size) const { - return {}; -} - tl::expected Pyramid::Serialize() const { - return {}; + BinarySet binary_set; + for (const auto& root_index : indexes_) { + std::string path = root_index.first; + std::vector>> need_serialize_indexes; + need_serialize_indexes.emplace_back(path, root_index.second); + while (not need_serialize_indexes.empty()) { + auto [current_path, index_node] = need_serialize_indexes.back(); + need_serialize_indexes.pop_back(); + if (index_node->index) { + auto serialize_result = index_node->index->Serialize(); + if (not serialize_result.has_value()) { + return tl::unexpected(serialize_result.error()); + } + binary_set.Set(current_path, binaryset_to_binary(serialize_result.value())); + } + for (const auto& sub_index_node : index_node->children) { + need_serialize_indexes.emplace_back( + current_path + PART_OCTOTHORPE + sub_index_node.first, sub_index_node.second); + } + } + } + return binary_set; } tl::expected Pyramid::Deserialize(const BinarySet& binary_set) { + auto keys = binary_set.GetKeys(); + for (const auto& path : keys) { + const auto& binary = binary_set.Get(path); + auto path_slices = split(path, PART_OCTOTHORPE); + std::shared_ptr node = try_get_node_with_init(indexes_, path_slices[0]); + for (int j = 1; j < path_slices.size(); ++j) { + node = try_get_node_with_init(node->children, path_slices[j]); + } + node->CreateIndex(pyramid_param_.index_builder); + node->index->Deserialize(binary_to_binaryset(binary)); + } return {}; } tl::expected Pyramid::Deserialize(const ReaderSet& reader_set) { + auto keys = reader_set.GetKeys(); + for (const auto& path : keys) { + const auto& reader = reader_set.Get(path); + auto path_slices = split(path, PART_OCTOTHORPE); + std::shared_ptr node = try_get_node_with_init(indexes_, path_slices[0]); + for (int j = 1; j < path_slices.size(); ++j) { + node = try_get_node_with_init(node->children, path_slices[j]); + } + node->CreateIndex(pyramid_param_.index_builder); + node->index->Deserialize(reader_to_readerset(reader)); + } return {}; } diff --git a/src/index/pyramid.h b/src/index/pyramid.h index 353ed96d..50d23de2 100644 --- a/src/index/pyramid.h +++ b/src/index/pyramid.h @@ -17,11 +17,49 @@ #include +#include "base_filter_functor.h" +#include "logger.h" #include "pyramid_zparameters.h" #include "safe_allocator.h" namespace vsag { +class SubReader : public Reader { +public: + SubReader(std::shared_ptr parent_reader, uint64_t start_pos, uint64_t size) + : parent_reader_(std::move(parent_reader)), size_(size), start_pos_(start_pos) { + } + + void + Read(uint64_t offset, uint64_t len, void* dest) override { + if (offset + len > size_) + throw std::out_of_range("Read out of range."); + parent_reader_->Read(offset + start_pos_, len, dest); + } + + void + AsyncRead(uint64_t offset, uint64_t len, void* dest, CallBack callback) override { + throw std::runtime_error("No support for SubReader AsyncRead"); + } + + uint64_t + Size() const override { + return size_; + } + +private: + std::shared_ptr parent_reader_; + uint64_t size_; + uint64_t start_pos_; +}; + +Binary +binaryset_to_binary(const BinarySet binary_set); +BinarySet +binary_to_binaryset(const Binary binary); +ReaderSet +reader_to_readerset(std::shared_ptr reader); + struct IndexNode { std::shared_ptr index{nullptr}; UnorderedMap> children; @@ -35,6 +73,8 @@ struct IndexNode { } }; +using SearchFunc = std::function(IndexPtr)>; + class Pyramid : public Index { public: Pyramid(PyramidParameters pyramid_param, const IndexCommonParam& commom_param) @@ -55,33 +95,64 @@ class Pyramid : public Index { KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, - BitsetPtr invalid = nullptr) const override; + BitsetPtr invalid = nullptr) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->KnnSearch(query, k, parameters, invalid); + }; + SAFE_CALL(return this->knn_search(query, k, parameters, search_func);) + } tl::expected KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, - const std::function& filter) const override; + const std::function& filter) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->KnnSearch(query, k, parameters, filter); + }; + SAFE_CALL(return this->knn_search(query, k, parameters, search_func);) + } tl::expected RangeSearch(const DatasetPtr& query, float radius, const std::string& parameters, - int64_t limited_size = -1) const override; + int64_t limited_size = -1) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->RangeSearch(query, radius, parameters, limited_size); + }; + int64_t final_limit = + limited_size == -1 ? std::numeric_limits::max() : limited_size; + SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);) + } tl::expected RangeSearch(const DatasetPtr& query, float radius, const std::string& parameters, BitsetPtr invalid, - int64_t limited_size = -1) const override; + int64_t limited_size = -1) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->RangeSearch(query, radius, parameters, invalid, limited_size); + }; + int64_t final_limit = + limited_size == -1 ? std::numeric_limits::max() : limited_size; + SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);) + } tl::expected RangeSearch(const DatasetPtr& query, float radius, const std::string& parameters, const std::function& filter, - int64_t limited_size = -1) const override; + int64_t limited_size = -1) const override { + SearchFunc search_func = [&](IndexPtr index) { + return index->RangeSearch(query, radius, parameters, filter, limited_size); + }; + int64_t final_limit = + limited_size == -1 ? std::numeric_limits::max() : limited_size; + SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);) + } tl::expected Serialize() const override; @@ -99,6 +170,12 @@ class Pyramid : public Index { GetMemoryUsage() const override; private: + tl::expected + knn_search(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const SearchFunc& search_func) const; + inline std::shared_ptr try_get_node_with_init(UnorderedMap>& index_map, const std::string& key) { diff --git a/tests/fixtures/test_dataset.cpp b/tests/fixtures/test_dataset.cpp index be97b83d..2b6d454e 100644 --- a/tests/fixtures/test_dataset.cpp +++ b/tests/fixtures/test_dataset.cpp @@ -300,8 +300,11 @@ TestDataset::CreateTestDataset(uint64_t dim, dataset->range_ground_truth_ = dataset->ground_truth_; dataset->range_radius_.resize(query_count); for (uint64_t i = 0; i < query_count; ++i) { - dataset->range_radius_[i] = 0.5f * (result.first[i * count + dataset->top_k] + - result.first[i * count + dataset->top_k - 1]); + dataset->range_radius_[i] = + 0.5f * (dataset->range_ground_truth_ + ->GetDistances()[i * dataset->top_k + dataset->top_k - 1] + + dataset->range_ground_truth_ + ->GetDistances()[i * dataset->top_k + dataset->top_k - 2]); } delete[] result.first; delete[] result.second; diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 19bec5d1..6e5e5c63 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -338,6 +338,7 @@ TestIndex::TestRangeSearch(const IndexPtr& index, query->NumElements(1) ->Dim(dim) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Paths(queries->GetPaths() + i) ->Owner(false); auto res = index->RangeSearch(query, radius[i], search_param, limited_size); REQUIRE(res.has_value() == expected_success); @@ -349,8 +350,8 @@ TestIndex::TestRangeSearch(const IndexPtr& index, } auto result = res.value()->GetIds(); auto gt = gts->GetIds() + gt_topK * i; - auto val = Intersection(gt, gt_topK, result, res.value()->GetDim()); - cur_recall += static_cast(val) / static_cast(gt_topK); + auto val = Intersection(gt, gt_topK - 1, result, res.value()->GetDim()); + cur_recall += static_cast(val) / static_cast(gt_topK - 1); } if (cur_recall <= expected_recall * query_count) { WARN(fmt::format("cur_result({}) <= expected_recall * query_count({})", @@ -377,6 +378,7 @@ TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index, query->NumElements(1) ->Dim(dim) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Paths(queries->GetPaths() + i) ->Owner(false); auto res = index->KnnSearch(query, topk, search_param, dataset->filter_function_); REQUIRE(res.has_value() == expected_success); @@ -446,6 +448,7 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from, auto query = vsag::Dataset::Make(); query->NumElements(1) ->Dim(dim) + ->Paths(queries->GetPaths() + i) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) ->Owner(false); auto res_from = index_from->KnnSearch(query, topk, search_param); @@ -529,6 +532,7 @@ TestIndex::TestSerializeBinarySet(const IndexPtr& index_from, auto query = vsag::Dataset::Make(); query->NumElements(1) ->Dim(dim) + ->Paths(queries->GetPaths() + i) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) ->Owner(false); auto res_from = index_from->KnnSearch(query, topk, search_param); @@ -567,6 +571,7 @@ TestIndex::TestSerializeReaderSet(const IndexPtr& index_from, auto query = vsag::Dataset::Make(); query->NumElements(1) ->Dim(dim) + ->Paths(queries->GetPaths() + i) ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) ->Owner(false); auto res_from = index_from->KnnSearch(query, topk, search_param); diff --git a/tests/test_pyramid.cpp b/tests/test_pyramid.cpp index 896eedfd..bbf6f1f4 100644 --- a/tests/test_pyramid.cpp +++ b/tests/test_pyramid.cpp @@ -34,7 +34,7 @@ class PyramidTestIndex : public fixtures::TestIndex { static std::vector dims; - constexpr static uint64_t base_count = 1000; + constexpr static uint64_t base_count = 3000; constexpr static const char* search_param_tmp = R"( {{ @@ -82,5 +82,35 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::PyramidTestIndex, auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type, /*with_path=*/true); TestContinueAdd(index, dataset, true); TestKnnSearch(index, dataset, search_param, 0.99, true); + TestFilterSearch(index, dataset, search_param, 0.99, true); + TestRangeSearch(index, dataset, search_param, 0.99, 10, true); + TestRangeSearch(index, dataset, search_param, 0.49, 5, true); } } + +TEST_CASE_PERSISTENT_FIXTURE(fixtures::PyramidTestIndex, + "Pyramid Serialize File", + "[ft][pyramid]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + const std::string name = "pyramid"; + auto search_param = fmt::format(search_param_tmp, 200); + + for (auto& dim : dims) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GeneratePyramidBuildParametersString(metric_type, dim); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type, /*with_path=*/true); + TestBuildIndex(index, dataset, true); + SECTION("serialize/deserialize by binary") { + auto index2 = TestFactory(name, param, true); + TestSerializeBinarySet(index, index2, dataset, search_param, true); + } + SECTION("serialize/deserialize by binary") { + auto index2 = TestFactory(name, param, true); + TestSerializeReaderSet(index, index2, dataset, search_param, name, true); + } + } + vsag::Options::Instance().set_block_size_limit(origin_size); +} \ No newline at end of file