diff --git a/apps/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp index 10cbc6176..43d7dbcd6 100644 --- a/apps/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -149,14 +149,13 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm, uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot, const std::string &save_path, size_t points_to_delete_from_beginning, - size_t start_deletes_after, bool concurrent, const std::string &label_file, - const std::string &universal_label) + size_t start_deletes_after, bool concurrent, diskann::IndexFilterParams &filter_params) { size_t dim, aligned_dim; size_t num_points; diskann::get_bin_metadata(data_path, num_points, dim); aligned_dim = ROUND_UP(dim, 8); - bool has_labels = label_file != ""; + bool has_labels = filter_params.label_file != ""; using TagT = uint32_t; using LabelT = uint32_t; @@ -188,9 +187,9 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa auto index = index_factory.create_instance(); /* remove set_universal_label from here and set it through filter store only*/ - if (universal_label != "") + if (filter_params.universal_label != "") { - index->set_universal_labels(universal_label); + index->set_universal_label(filter_params.universal_label); } if (points_to_skip > num_points) @@ -263,7 +262,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa points_to_delete_from_beginning, last_point_threshold); if (has_labels) { - auto parse_result = diskann::parse_raw_label_file(label_file); + auto parse_result = diskann::parse_raw_label_file(filter_params.label_file); location_to_labels = std::get<0>(parse_result); } @@ -499,21 +498,27 @@ int main(int argc, char **argv) .with_filter_list_size(Lf) .build(); + diskann::IndexFilterParams filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(index_path_prefix) + .build(); + if (data_type == std::string("int8")) build_incremental_index( data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label); + points_to_delete_from_beginning, start_deletes_after, concurrent, filter_params); else if (data_type == std::string("uint8")) build_incremental_index( data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label); + points_to_delete_from_beginning, start_deletes_after, concurrent, filter_params); else if (data_type == std::string("float")) build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, - start_deletes_after, concurrent, label_file, universal_label); + start_deletes_after, concurrent, filter_params); else std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; } diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index f8994288a..698ad5fea 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -251,7 +251,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con if (universal_label != "") { - index->set_universal_labels(universal_label); + index->set_universal_label(universal_label); } if (max_points_to_insert == 0) diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h index 7d654f8e1..09967b7c0 100644 --- a/include/abstract_filter_store.h +++ b/include/abstract_filter_store.h @@ -36,7 +36,7 @@ template class AbstractFilterStore // TODO: in future we may accept a set or vector of universal labels // DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0; - DISKANN_DLLEXPORT virtual void set_universal_labels(const std::string &universal_labels) = 0; + DISKANN_DLLEXPORT virtual void set_universal_label(const std::string &universal_labels) = 0; DISKANN_DLLEXPORT virtual std::pair get_universal_label() = 0; // takes raw label file and then genrate internal mapping file and keep the info of mapping diff --git a/include/abstract_index.h b/include/abstract_index.h index dcc5167b5..5e4f255c0 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -104,7 +104,7 @@ class AbstractIndex template int get_vector_by_tag(tag_type &tag, data_type *vec); // required for dynamic index (they dont use filter store / data store yet) - virtual void set_universal_labels(const std::string &raw_universal_labels) = 0; + virtual void set_universal_label(const std::string &raw_universal_labels) = 0; private: virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0; diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index 55a8c5f10..ff187ccff 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -30,7 +30,7 @@ template class InMemFilterStore : public AbstractFilterSto label_type get_numeric_label(const std::string &raw_label) override; // takes raw universal labels and map them internally. - void set_universal_labels(const std::string &raw_universal_labels) override; + void set_universal_label(const std::string &raw_universal_labels) override; std::pair get_universal_label() override; // ideally takes raw label file and then genrate internal mapping file and keep the info of mapping diff --git a/include/index.h b/include/index.h index 1a2c1b075..4e3dedce0 100644 --- a/include/index.h +++ b/include/index.h @@ -108,12 +108,17 @@ template clas IndexFilterParams &filter_params); // Filtered Support - DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, + DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const size_t num_points_to_load, + const IndexFilterParams &filter_params, + const std::vector &tags = std::vector()); + + // Filtered support streaming index + DISKANN_DLLEXPORT void build_filtered_index(const T *data, const size_t num_points_to_load, + const IndexFilterParams &filter_params, const std::vector &tags = std::vector()); // DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); - DISKANN_DLLEXPORT void set_universal_labels(const std::string &raw_labels); + DISKANN_DLLEXPORT void set_universal_label(const std::string &raw_labels); // Set starting point of an index before inserting any points incrementally. // The data count should be equal to _num_frozen_pts * _aligned_dim. diff --git a/include/index_build_params.h b/include/index_build_params.h index 0233fcec4..e8b40b30c 100644 --- a/include/index_build_params.h +++ b/include/index_build_params.h @@ -3,69 +3,5 @@ namespace diskann { -struct IndexFilterParams -{ - public: - std::string save_path_prefix; - std::string label_file; - std::string tags_file; - std::string universal_label; - uint32_t filter_threshold = 0; - - private: - IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file, - const std::string &universal_label, uint32_t filter_threshold) - : save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label), - filter_threshold(filter_threshold) - { - } - - friend class IndexFilterParamsBuilder; -}; -class IndexFilterParamsBuilder -{ - public: - IndexFilterParamsBuilder() = default; - - IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) - { - if (save_path_prefix.empty() || save_path_prefix == "") - throw ANNException("Error: save_path_prefix can't be empty", -1); - this->_save_path_prefix = save_path_prefix; - return *this; - } - - IndexFilterParamsBuilder &with_label_file(const std::string &label_file) - { - this->_label_file = label_file; - return *this; - } - - IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label) - { - this->_universal_label = univeral_label; - return *this; - } - - IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold) - { - this->_filter_threshold = filter_threshold; - return *this; - } - - IndexFilterParams build() - { - return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold); - } - - IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete; - IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete; - private: - std::string _save_path_prefix; - std::string _label_file; - std::string _tags_file; - std::string _universal_label; - uint32_t _filter_threshold = 0; -}; } // namespace diskann diff --git a/include/parameters.h b/include/parameters.h index 2bba9aeca..f9fd4663a 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -117,4 +117,69 @@ class IndexWriteParametersBuilder uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; }; +class IndexFilterParams +{ + public: + std::string save_path_prefix; + std::string label_file; + std::string tags_file; + std::string universal_label; + uint32_t filter_threshold = 0; + + private: + IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file, + const std::string &universal_label, uint32_t filter_threshold) + : save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label), + filter_threshold(filter_threshold) + { + } + + friend class IndexFilterParamsBuilder; +}; + +class IndexFilterParamsBuilder +{ + public: + IndexFilterParamsBuilder() = default; + + IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) + { + this->_save_path_prefix = save_path_prefix; + return *this; + } + + IndexFilterParamsBuilder &with_label_file(const std::string &label_file) + { + this->_label_file = label_file; + return *this; + } + + IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label) + { + this->_universal_label = univeral_label; + return *this; + } + + IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold) + { + this->_filter_threshold = filter_threshold; + return *this; + } + + IndexFilterParams build() + { + return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold); + } + + IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete; + IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete; + + private: + std::string _save_path_prefix; + std::string _label_file; + std::string _tags_file; + std::string _universal_label; + uint32_t _filter_threshold = 0; +}; + } // namespace diskann diff --git a/python/src/builder.cpp b/python/src/builder.cpp index 4bdc1ef6e..b4d2774b5 100644 --- a/python/src/builder.cpp +++ b/python/src/builder.cpp @@ -31,29 +31,6 @@ template void build_disk_index(diskann::Metric, const std::string &, co template void build_disk_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, double, double, uint32_t, uint32_t); -template -std::string prepare_filtered_label_map(diskann::Index &index, const std::string &index_output_path, - const std::string &filter_labels_file, const std::string &universal_label) -{ - std::string labels_file_to_use = index_output_path + "_label_numeric.txt"; - std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt"; - convert_label_to_numeric(filter_labels_file, labels_file_to_use, mem_labels_int_map_file, universal_label); - if (!universal_label.empty()) - { - index.set_universal_label(universal_label); - } - return labels_file_to_use; -} - -template std::string prepare_filtered_label_map(diskann::Index &, const std::string &, - const std::string &, const std::string &); - -template std::string prepare_filtered_label_map(diskann::Index &, - const std::string &, const std::string &, const std::string &); - -template std::string prepare_filtered_label_map(diskann::Index &, - const std::string &, const std::string &, const std::string &); - template void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path, const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity, @@ -95,9 +72,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_ } else { - auto labels_file = prepare_filtered_label_map(index, index_output_path, filter_labels_file, - universal_label); - index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags); + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(filter_labels_file) + .with_save_path_prefix(index_output_path.c_str()) + .build(); + index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params, tags); } } else @@ -108,9 +88,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_ } else { - auto labels_file = prepare_filtered_label_map(index, index_output_path, filter_labels_file, - universal_label); - index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num); + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(filter_labels_file) + .with_save_path_prefix(index_output_path.c_str()) + .build(); + index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params); } } diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 33e6212ba..eba23ed35 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -673,9 +673,14 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr if (universal_label != "") { // indicates no universal label // LabelT unv_label_as_num = 0; - _index.set_universal_labels(universal_label); + _index.set_universal_label(universal_label); } - _index.build_filtered_index(base_file.c_str(), label_file, base_num); + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(mem_index_path.c_str()) + .build(); + _index.build_filtered_index(base_file.c_str(), base_num, filter_params); } _index.save(mem_index_path.c_str()); @@ -685,8 +690,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr // file std::remove(disk_labels_to_medoids_file.c_str()); std::string mem_labels_to_medoid_file = mem_index_path + "_labels_to_medoids.txt"; + std::string mem_labels_file = mem_index_path + "_labels.txt"; + std::string mem_labels_int_map_file = mem_index_path + "_labels_map.txt"; copy_file(mem_labels_to_medoid_file, disk_labels_to_medoids_file); std::remove(mem_labels_to_medoid_file.c_str()); + std::remove(mem_labels_file.c_str()); + std::remove(mem_labels_int_map_file.c_str()); } std::remove(medoids_file.c_str()); @@ -747,9 +756,14 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr if (universal_label != "") { // indicates no universal label // LabelT unv_label_as_num = 0; - _index.set_universal_labels(universal_label); + _index.set_universal_label(universal_label); } - _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts); + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(shard_labels_file) + .with_save_path_prefix(shard_index_file.c_str()) + .build(); + _index.build_filtered_index(shard_base_file.c_str(), shard_base_pts, filter_params); } _index.save(shard_index_file.c_str()); // copy universal label file from first shard to the final destination @@ -791,11 +805,13 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr { std::string shard_index_label_file = shard_index_file + "_labels.txt"; std::string shard_index_univ_label_file = shard_index_file + "_universal_label.txt"; - std::string shard_index_label_map_file = shard_index_file + "_labels_to_medoids.txt"; + std::string shard_index_label_medoid_file = shard_index_file + "_labels_to_medoids.txt"; + std::string shard_index_label_map_file = shard_index_file + "_labels_map.txt"; std::remove(shard_labels_file.c_str()); std::remove(shard_index_label_file.c_str()); std::remove(shard_index_label_map_file.c_str()); std::remove(shard_index_univ_label_file.c_str()); + std::remove(shard_index_label_medoid_file.c_str()); } } return 0; diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp index a96cf4039..a98e5c050 100644 --- a/src/in_mem_filter_store.cpp +++ b/src/in_mem_filter_store.cpp @@ -64,7 +64,7 @@ void InMemFilterStore::add_label_to_location(const location_t point_ } template -void InMemFilterStore::set_universal_labels(const std::string &raw_universal_label) +void InMemFilterStore::set_universal_label(const std::string &raw_universal_label) { if (raw_universal_label.empty()) { diff --git a/src/index.cpp b/src/index.cpp index c47e08e20..e8985fc0d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -76,6 +76,7 @@ Index::Index(const IndexConfig &index_config, std::shared_ptr>(total_internal_points); } @@ -131,7 +132,6 @@ Index::Index(Metric m, const size_t dim, const size_t max_point .is_filtered(filtered_index) .with_num_pq_chunks(num_pq_chunks) .is_use_opq(use_opq) - .is_filtered(filtered_index) .with_data_type(diskann_type_to_name()) .build(), IndexFactory::construct_datastore( @@ -1737,11 +1737,7 @@ void Index::build(const std::string &data_file, const size_t nu } else { - if (filter_params.universal_label != "") - { - this->set_universal_labels(filter_params.universal_label); - } - this->build_filtered_index(data_file.c_str(), filter_params.label_file, points_to_load); + this->build_filtered_index(data_file.c_str(), points_to_load, filter_params); } std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; std::cout << "Indexing time: " << diff.count() << "\n"; @@ -1753,28 +1749,40 @@ void Index::parse_label_file(const std::string &label_file, siz num_points = _filter_store->load_labels(label_file); } -// template -// void Index::set_universal_label(const LabelT &label) -//{ -// //_filter_store->set_universal_label(label); -// } - template -void Index::set_universal_labels(const std::string &raw_label) +void Index::set_universal_label(const std::string &raw_label) { - _filter_store->set_universal_labels(raw_label); + _filter_store->set_universal_label(raw_label); } template -void Index::build_filtered_index(const char *filename, const std::string &raw_label_file, - const size_t num_points_to_load, const std::vector &tags) +void Index::build_filtered_index(const char *filename, const size_t num_points_to_load, + const IndexFilterParams &filter_params, const std::vector &tags) { _filtered_index = true; - size_t num_points_labels = _filter_store->load_raw_labels(raw_label_file, ""); + size_t num_points_labels = _filter_store->load_raw_labels(filter_params.label_file, ""); + if (filter_params.universal_label != "") + { + _filter_store->set_universal_label(filter_params.universal_label); + } calculate_best_medoids(num_points_to_load, 25); this->build(filename, num_points_to_load, tags); } +template +void Index::build_filtered_index(const T *data, const size_t num_points_to_load, + const IndexFilterParams &filter_params, const std::vector &tags) +{ + _filtered_index = true; + size_t num_points_labels = _filter_store->load_raw_labels(filter_params.label_file, ""); + if (filter_params.universal_label != "") + { + _filter_store->set_universal_label(filter_params.universal_label); + } + calculate_best_medoids(num_points_to_load, 25); + this->build(data, num_points_to_load, tags); +} + template std::pair Index::_search(const DataType &query, const size_t K, const uint32_t L, std::any &indices, float *distances)