From 09dda841fd016f293cb24b14c1f11f5d48eee39b Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Sat, 19 Oct 2024 16:45:43 +0800 Subject: [PATCH] Add filter streamging interface --- src/index.cpp | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 0b01afa20..4daf2839c 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1806,9 +1806,46 @@ void Index::build(const std::string &data_file, const size_t nu size_t points_to_load = num_points_to_load == 0 ? _max_points : num_points_to_load; auto s = std::chrono::high_resolution_clock::now(); + + std::vector tags; + + if (_enable_tags) + { + if (filter_params.tags_file.empty()) + { + throw ANNException("Tag filename isn't set, while _enable_tags is set", -1, __FUNCSIG__, __FILE__, __LINE__); + } + else + { + if (file_exists(filter_params.tags_file)) + { + diskann::cout << "Loading tags from " << filter_params.tags_file << " for vamana index build" << std::endl; + TagT* tag_data = nullptr; + size_t npts, ndim; + diskann::load_bin(filter_params.tags_file, tag_data, npts, ndim); + if (npts < num_points_to_load) + { + std::stringstream sstream; + sstream << "Loaded " << npts << " tags, insufficient to populate tags for " << num_points_to_load + << " points to load"; + throw diskann::ANNException(sstream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + tags.resize(num_points_to_load); + memcpy(tags.data(), tag_data, sizeof(TagT) * num_points_to_load); + + delete[] tag_data; + } + else + { + throw diskann::ANNException(std::string("Tag file") + filter_params.tags_file + " does not exist", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + } + } + if (filter_params.label_file == "") { - this->build(data_file.c_str(), points_to_load); + this->build(data_file.c_str(), points_to_load, tags); } else { @@ -1823,7 +1860,7 @@ void Index::build(const std::string &data_file, const size_t nu // LabelT unv_label_as_num = 0; this->set_universal_label(unv_label_as_num); } - this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load); + this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load, tags); } std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; std::cout << "Indexing time: " << diff.count() << "\n";