Skip to content

Commit

Permalink
Build with support for brute force seems to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Nov 22, 2024
1 parent b2485d8 commit 48314e7
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 17 deletions.
23 changes: 14 additions & 9 deletions apps/build_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
label_type;
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold, filter_bf_threshold;
float B, M;
bool append_reorder_data = false;
bool use_opq = false;
Expand Down Expand Up @@ -74,8 +74,9 @@ int main(int argc, char **argv)
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
"Threshold to break up the existing nodes to generate new graph "
"internally where each node has a maximum F labels.");
program_options_utils::FILTER_THRESHOLD_DESCRIPTION);
optional_configs.add_options()("filter_bruteforce_threshold", po::value<uint32_t>(&filter_bf_threshold)->default_value(0),
program_options_utils::FILTER_BRUTEFORCE_THRESHOLD_DESCRIPTION);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);

Expand Down Expand Up @@ -139,22 +140,26 @@ int main(int argc, char **argv)
std::string(std::to_string(append_reorder_data)) + " " +
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));

if (filter_bf_threshold == 0) {
filter_bf_threshold = std::numeric_limits<uint32_t>::max();
}

try
{
if (label_file != "" && label_type == "ushort")
{
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, filter_bf_threshold);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
use_filters, label_file, universal_label, filter_threshold, Lf, filter_bf_threshold);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
use_filters, label_file, universal_label, filter_threshold, Lf, filter_bf_threshold);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
Expand All @@ -166,15 +171,15 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, filter_bf_threshold);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, filter_bf_threshold);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, filter_bf_threshold);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
Expand Down
23 changes: 23 additions & 0 deletions include/filter_brute_force_index.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include "common_includes.h"
#include "windows_customizations.h"
#include "filter_utils.h"

namespace diskann {

template<typename T>
class FilterBruteForceIndex {
public :
DISKANN_DLLEXPORT FilterBruteForceIndex(const std::string& disk_index_file);
DISKANN_DLLEXPORT bool brute_force_index_available() const;
DISKANN_DLLEXPORT bool brute_forceable_filter(const std::string& filter) const;
DISKANN_DLLEXPORT int load();

private :
diskann::inverted_index_t _bf_filter_index;
bool _is_loaded;
std::string _disk_index_file;
};
}
1 change: 1 addition & 0 deletions include/filter_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace diskann {
//CONSTANTS
DISKANN_DLLEXPORT extern const char* NO_LABEL_FOR_POINT;
DISKANN_DLLEXPORT extern const char FILTERS_LABEL_DELIMITER;
typedef std::map<std::string, std::unordered_set<location_t>> inverted_index_t;


template <typename T>
Expand Down
5 changes: 5 additions & 0 deletions include/program_options_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ const char *FILTERS_FILE_DESCRIPTION =
const char *LABEL_TYPE_DESCRIPTION =
"Storage type of Labels {uint/uint32, ushort/uint16}, default value is uint which will consume memory 4 bytes per "
"filter. 'uint' is an alias for 'uint32' and 'ushort' is an alias for 'uint16'.";
const char* FILTER_THRESHOLD_DESCRIPTION = "Threshold to break up the existing nodes to generate new graph "
"internally where each node has a maximum F labels.";
const char* FILTER_BRUTEFORCE_THRESHOLD_DESCRIPTION = "Use brute force for searching with a filter if it occurs"
" fewer than this many times in the dataset.";

const char *GROUND_TRUTH_FILE_DESCRIPTION =
"ground truth file for the queryset"; // what's the format, what's the requirements? does it need to include an
// entry for every item or just a small subset? I have so many questions about
Expand Down
2 changes: 2 additions & 0 deletions src/distance.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
// TODO
// CHECK COSINE ON LINUX

Expand Down
2 changes: 1 addition & 1 deletion src/dll/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../par
../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp
../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp
../in_mem_filter_store.cpp)
../in_mem_filter_store.cpp ../filter_brute_force_index.cpp)

set(TARGET_DIR "$<$<CONFIG:Debug>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$<CONFIG:Release>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>")

Expand Down
62 changes: 62 additions & 0 deletions src/filter_brute_force_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "filter_brute_force_index.h"

namespace diskann {

template<typename T>
FilterBruteForceIndex<T>::FilterBruteForceIndex(const std::string& disk_index_file) {
_disk_index_file = disk_index_file;
_filter_bf_data_file = _disk_index_file + "_brute_force.txt";
}
template<typename T>
bool FilterBruteForceIndex<T>::brute_force_index_available() const {}

template<typename T>
bool FilterBruteForceIndex<T>::brute_forceable_filter(const std::string& filter) const {}

template<typename T>
int FilterBruteForceIndex<T>::load() {
if (false == file_exists(_filter_bf_data_file)) {
diskann::cerr << "Index does not have brute force support." << std::endl;
return 1;
}
std::ifstream bf_in(_filter_bf_data_file);
if (!bf_in.is_open()) {
std::stringstream ss;
ss << "Could not open " << _filter_bf_data_file << " for reading. " << std::endl;
diskann::cerr << ss.str() << std::endl;
throw diskann::ANNException(ss.str(), -1);
}

std::string line;
std::vector<std::string> label_and_points;
label_and_points.reserve(2);
std::unordered_set<location_t> points;

size_t linenum = 0;
while (getline(bf_in, line)) {
split_string(line, '\t', label_and_points);
if (label_and_points.size() == 2) {

std::istringstream iss(label_and_points[1]);
std::string pt_str;
while (getline(iss, pt_str, ',')) {
points.insert(strtoul(pt_str));
}
assert(points.size() > 0);
_bf_filter_index.insert(label_and_points[0], points);
points.clear();
} else {
std::stringstream ss;
ss << "Error reading brute force data at line: " << line_num
<< " found " << label_and_points.size() << " tab separated entries instead of 2"
<< std::endl;
diskann::cerr << ss.str();
throw diskann::ANNException(ss.str(), -1);
}
line_num++;
}
}
}
27 changes: 20 additions & 7 deletions src/filter_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,7 @@ parse_formatted_label_file(std::string label_file) {

//TODO: This is a test implementation of adding brute force logic while
//building a filtered index. Must be cleaned up later.

typedef std::map<std::string, std::unordered_set<location_t>> inverted_index_t;

void get_inv_index(const std::string& label_file, inverted_index_t& inv_index) {
void get_inv_index(const std::string& label_file, const location_t filter_bf_threshold, inverted_index_t& inv_index) {
std::ifstream label_in(label_file);
if (!label_in.is_open()) {
std::stringstream ss;
Expand All @@ -352,6 +349,16 @@ void get_inv_index(const std::string& label_file, inverted_index_t& inv_index) {
line_labels.clear();
line_num++;
}

diskann::cout << "Built inverted index for filters. Label count: " << inv_index.size();
auto num_bf_labels = 0;
for (auto& label_and_points : inv_index) {
if (label_and_points.second.size() < filter_bf_threshold) {
num_bf_labels++;
}
}
diskann::cout << " number of sparse labels: " << num_bf_labels << std::endl;

}

void get_labels_of_point(const inverted_index_t& inv_index, location_t point, std::vector<std::string>& labels, location_t sparse_threshold) {
Expand All @@ -378,19 +385,22 @@ void write_new_label_file(const inverted_index_t& inv_index, location_t nrows, c

std::vector<std::string> labels_of_point;
labels_of_point.reserve(200); //just assuming, won't affect anything.
location_t num_graph_points = 0;

for (location_t i = 0; i < (location_t)nrows; i++) {
get_labels_of_point(inv_index, i, labels_of_point, sparse_threshold);
if (labels_of_point.size() == 0) {
label_out << NO_LABEL_FOR_POINT << std::endl;
}
else {
} else {
num_graph_points++;
for (int i = 0; i < labels_of_point.size() - 1; i++) {
label_out << labels_of_point[i] << ",";
}
label_out << labels_of_point[labels_of_point.size() - 1] << std::endl;
}
labels_of_point.clear();
}
diskann::cout << "New label file: " << new_label_file << ", num graph points: " << num_graph_points << std::endl;
label_out.close();
}

Expand Down Expand Up @@ -420,6 +430,7 @@ void write_brute_force_data(const inverted_index_t& inv_index, const std::string
}
}
}
diskann::cout << "Brute force file: " << bf_data_file << std::endl;
bf_out.close();
}

Expand All @@ -430,12 +441,14 @@ void separate_brute_forceable_points(
const std::string& new_lbl_file,
const std::string& bf_data_file) {

diskann::cout << "Excluding brute forceable points from the dataset for building the diskann graph" << std::endl;

std::ifstream data_in(base_file, std::ios::binary);
uint64_t nrows, ncols;
get_bin_metadata_impl(data_in, nrows, ncols);

inverted_index_t inv_index;
get_inv_index(label_file, inv_index);
get_inv_index(label_file, filter_bf_threshold, inv_index);

write_new_label_file(inv_index, (location_t)nrows, new_lbl_file, filter_bf_threshold);
write_brute_force_data(inv_index, bf_data_file, filter_bf_threshold);
Expand Down
5 changes: 5 additions & 0 deletions src/in_mem_filter_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ bool InMemFilterStore<LabelT>::load(const std::string &disk_index_file) {
std::string dummy_map_file = disk_index_file + "_dummy_map.txt";
std::string labels_map_file = disk_index_file + "_labels_map.txt";
std::string univ_label_file = disk_index_file + "_universal_label.txt";
std::string brute_force_data_file = disk_index_path + "_brute_force.txt";
std::string bf_excluded_label_file = disk_index_path + "_non_brute_force_labels.txt";

size_t num_pts_in_label_file = 0;

Expand Down Expand Up @@ -114,6 +116,9 @@ void InMemFilterStore<LabelT>::load_label_file(
std::string line;
uint32_t line_cnt = 0;

//TODO: This code is very inefficient because it reads the label file twice -
//once for computing stats and then for loading the labels. Must merge the
//two reads.
uint32_t num_pts_in_label_file;
uint32_t num_total_labels;
get_label_file_metadata(label_file_content, num_pts_in_label_file,
Expand Down

0 comments on commit 48314e7

Please sign in to comment.