diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..995dcd0a0 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -17,7 +17,7 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, label_type; - uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold, filter_bf_threshold; float B, M; bool append_reorder_data = false; bool use_opq = false; @@ -74,8 +74,9 @@ int main(int argc, char **argv) optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), program_options_utils::FILTERED_LBUILD); optional_configs.add_options()("filter_threshold,F", po::value(&filter_threshold)->default_value(0), - "Threshold to break up the existing nodes to generate new graph " - "internally where each node has a maximum F labels."); + program_options_utils::FILTER_THRESHOLD_DESCRIPTION); + optional_configs.add_options()("filter_bruteforce_threshold", po::value(&filter_bf_threshold)->default_value(0), + program_options_utils::FILTER_BRUTEFORCE_THRESHOLD_DESCRIPTION); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); @@ -139,6 +140,10 @@ int main(int argc, char **argv) std::string(std::to_string(append_reorder_data)) + " " + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); + if (filter_bf_threshold == 0) { + filter_bf_threshold = std::numeric_limits::max(); + } + try { if (label_file != "" && label_type == "ushort") @@ -146,15 +151,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("uint8")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("float")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, filter_bf_threshold); else { diskann::cerr << "Error. Unsupported data type" << std::endl; @@ -166,15 +171,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("uint8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else if (data_type == std::string("float")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, filter_bf_threshold); else { diskann::cerr << "Error. Unsupported data type" << std::endl; diff --git a/include/filter_brute_force_index.h b/include/filter_brute_force_index.h new file mode 100644 index 000000000..09db7bf50 --- /dev/null +++ b/include/filter_brute_force_index.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. +#pragma once +#include "common_includes.h" +#include "windows_customizations.h" +#include "filter_utils.h" + +namespace diskann { + + template + class FilterBruteForceIndex { + public : + DISKANN_DLLEXPORT FilterBruteForceIndex(const std::string& disk_index_file); + DISKANN_DLLEXPORT bool brute_force_index_available() const; + DISKANN_DLLEXPORT bool brute_forceable_filter(const std::string& filter) const; + DISKANN_DLLEXPORT int load(); + + private : + diskann::inverted_index_t _bf_filter_index; + bool _is_loaded; + std::string _disk_index_file; + }; +} \ No newline at end of file diff --git a/include/filter_utils.h b/include/filter_utils.h index 3ef1ed1c2..4fcfd533f 100644 --- a/include/filter_utils.h +++ b/include/filter_utils.h @@ -57,6 +57,7 @@ namespace diskann { //CONSTANTS DISKANN_DLLEXPORT extern const char* NO_LABEL_FOR_POINT; DISKANN_DLLEXPORT extern const char FILTERS_LABEL_DELIMITER; + typedef std::map> inverted_index_t; template diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 2be60595b..965a7dc5b 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -43,6 +43,11 @@ const char *FILTERS_FILE_DESCRIPTION = const char *LABEL_TYPE_DESCRIPTION = "Storage type of Labels {uint/uint32, ushort/uint16}, default value is uint which will consume memory 4 bytes per " "filter. 'uint' is an alias for 'uint32' and 'ushort' is an alias for 'uint16'."; +const char* FILTER_THRESHOLD_DESCRIPTION = "Threshold to break up the existing nodes to generate new graph " +"internally where each node has a maximum F labels."; +const char* FILTER_BRUTEFORCE_THRESHOLD_DESCRIPTION = "Use brute force for searching with a filter if it occurs" +" fewer than this many times in the dataset."; + const char *GROUND_TRUTH_FILE_DESCRIPTION = "ground truth file for the queryset"; // what's the format, what's the requirements? does it need to include an // entry for every item or just a small subset? I have so many questions about diff --git a/src/distance.cpp b/src/distance.cpp index fc4e43a75..a85f5b622 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. // TODO // CHECK COSINE ON LINUX diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 1de73b76d..7fc6522ae 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -5,7 +5,7 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../par ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp - ../in_mem_filter_store.cpp) + ../in_mem_filter_store.cpp ../filter_brute_force_index.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") diff --git a/src/filter_brute_force_index.cpp b/src/filter_brute_force_index.cpp new file mode 100644 index 000000000..0ba888433 --- /dev/null +++ b/src/filter_brute_force_index.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "filter_brute_force_index.h" + +namespace diskann { + + template + FilterBruteForceIndex::FilterBruteForceIndex(const std::string& disk_index_file) { + _disk_index_file = disk_index_file; + _filter_bf_data_file = _disk_index_file + "_brute_force.txt"; + } + template + bool FilterBruteForceIndex::brute_force_index_available() const {} + + template + bool FilterBruteForceIndex::brute_forceable_filter(const std::string& filter) const {} + + template + int FilterBruteForceIndex::load() { + if (false == file_exists(_filter_bf_data_file)) { + diskann::cerr << "Index does not have brute force support." << std::endl; + return 1; + } + std::ifstream bf_in(_filter_bf_data_file); + if (!bf_in.is_open()) { + std::stringstream ss; + ss << "Could not open " << _filter_bf_data_file << " for reading. " << std::endl; + diskann::cerr << ss.str() << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + std::string line; + std::vector label_and_points; + label_and_points.reserve(2); + std::unordered_set points; + + size_t linenum = 0; + while (getline(bf_in, line)) { + split_string(line, '\t', label_and_points); + if (label_and_points.size() == 2) { + + std::istringstream iss(label_and_points[1]); + std::string pt_str; + while (getline(iss, pt_str, ',')) { + points.insert(strtoul(pt_str)); + } + assert(points.size() > 0); + _bf_filter_index.insert(label_and_points[0], points); + points.clear(); + } else { + std::stringstream ss; + ss << "Error reading brute force data at line: " << line_num + << " found " << label_and_points.size() << " tab separated entries instead of 2" + << std::endl; + diskann::cerr << ss.str(); + throw diskann::ANNException(ss.str(), -1); + } + line_num++; + } + } +} \ No newline at end of file diff --git a/src/filter_utils.cpp b/src/filter_utils.cpp index 62532f7d9..89963c44d 100644 --- a/src/filter_utils.cpp +++ b/src/filter_utils.cpp @@ -325,10 +325,7 @@ parse_formatted_label_file(std::string label_file) { //TODO: This is a test implementation of adding brute force logic while //building a filtered index. Must be cleaned up later. - -typedef std::map> inverted_index_t; - -void get_inv_index(const std::string& label_file, inverted_index_t& inv_index) { +void get_inv_index(const std::string& label_file, const location_t filter_bf_threshold, inverted_index_t& inv_index) { std::ifstream label_in(label_file); if (!label_in.is_open()) { std::stringstream ss; @@ -352,6 +349,16 @@ void get_inv_index(const std::string& label_file, inverted_index_t& inv_index) { line_labels.clear(); line_num++; } + + diskann::cout << "Built inverted index for filters. Label count: " << inv_index.size(); + auto num_bf_labels = 0; + for (auto& label_and_points : inv_index) { + if (label_and_points.second.size() < filter_bf_threshold) { + num_bf_labels++; + } + } + diskann::cout << " number of sparse labels: " << num_bf_labels << std::endl; + } void get_labels_of_point(const inverted_index_t& inv_index, location_t point, std::vector& labels, location_t sparse_threshold) { @@ -378,12 +385,14 @@ void write_new_label_file(const inverted_index_t& inv_index, location_t nrows, c std::vector labels_of_point; labels_of_point.reserve(200); //just assuming, won't affect anything. + location_t num_graph_points = 0; + for (location_t i = 0; i < (location_t)nrows; i++) { get_labels_of_point(inv_index, i, labels_of_point, sparse_threshold); if (labels_of_point.size() == 0) { label_out << NO_LABEL_FOR_POINT << std::endl; - } - else { + } else { + num_graph_points++; for (int i = 0; i < labels_of_point.size() - 1; i++) { label_out << labels_of_point[i] << ","; } @@ -391,6 +400,7 @@ void write_new_label_file(const inverted_index_t& inv_index, location_t nrows, c } labels_of_point.clear(); } + diskann::cout << "New label file: " << new_label_file << ", num graph points: " << num_graph_points << std::endl; label_out.close(); } @@ -420,6 +430,7 @@ void write_brute_force_data(const inverted_index_t& inv_index, const std::string } } } + diskann::cout << "Brute force file: " << bf_data_file << std::endl; bf_out.close(); } @@ -430,12 +441,14 @@ void separate_brute_forceable_points( const std::string& new_lbl_file, const std::string& bf_data_file) { + diskann::cout << "Excluding brute forceable points from the dataset for building the diskann graph" << std::endl; + std::ifstream data_in(base_file, std::ios::binary); uint64_t nrows, ncols; get_bin_metadata_impl(data_in, nrows, ncols); inverted_index_t inv_index; - get_inv_index(label_file, inv_index); + get_inv_index(label_file, filter_bf_threshold, inv_index); write_new_label_file(inv_index, (location_t)nrows, new_lbl_file, filter_bf_threshold); write_brute_force_data(inv_index, bf_data_file, filter_bf_threshold); diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp index d5f299063..9b04f61e1 100644 --- a/src/in_mem_filter_store.cpp +++ b/src/in_mem_filter_store.cpp @@ -66,6 +66,8 @@ bool InMemFilterStore::load(const std::string &disk_index_file) { std::string dummy_map_file = disk_index_file + "_dummy_map.txt"; std::string labels_map_file = disk_index_file + "_labels_map.txt"; std::string univ_label_file = disk_index_file + "_universal_label.txt"; + std::string brute_force_data_file = disk_index_path + "_brute_force.txt"; + std::string bf_excluded_label_file = disk_index_path + "_non_brute_force_labels.txt"; size_t num_pts_in_label_file = 0; @@ -114,6 +116,9 @@ void InMemFilterStore::load_label_file( std::string line; uint32_t line_cnt = 0; + //TODO: This code is very inefficient because it reads the label file twice - + //once for computing stats and then for loading the labels. Must merge the + //two reads. uint32_t num_pts_in_label_file; uint32_t num_total_labels; get_label_file_metadata(label_file_content, num_pts_in_label_file,